In [93]:
%matplotlib inline
import json

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset

from typing import Any, Tuple, List

from cv2 import Mat
from numpy import dtype, floating, integer, ndarray

plt.rcParams["figure.figsize"] = (16, 10)  # (w, h)

In [94]:
with open("../data/iwildcam2020_train_annotations.json") as f:
	data = json.load(f)


annotations = pd.DataFrame.from_dict(data["annotations"])
images_metadata = pd.DataFrame.from_dict(data["images"])
categories = pd.DataFrame.from_dict(data["categories"])

In [95]:
images_metadata.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 217959 entries, 0 to 217958
Data columns (total 9 columns):
 #   Column          Non-Null Count   Dtype 
---  ------          --------------   ----- 
 0   seq_num_frames  217959 non-null  int64 
 1   location        217959 non-null  int64 
 2   datetime        217959 non-null  object
 3   id              217959 non-null  object
 4   frame_num       217959 non-null  int64 
 5   seq_id          217959 non-null  object
 6   width           217959 non-null  int64 
 7   height          217959 non-null  int64 
 8   file_name       217959 non-null  object
dtypes: int64(5), object(4)
memory usage: 15.0+ MB


In [96]:
# convert datetime type and split into day/night time
def split_day_night_time(
	data: pd.DataFrame, day_start: str = "06:00:00", day_end: str = "18:00:00"
) -> pd.DataFrame:
	data = data.copy()
	data["datetime"] = pd.to_datetime(data["datetime"])
	data["is_day"] = data["datetime"].apply(
		lambda x: True
		if pd.Timestamp(day_start).time() <= x.time() < pd.Timestamp(day_end).time()
		else False
	)
	return data

In [121]:
def preprocess_dark_images(
	image: np.ndarray,
) -> Mat | ndarray[Any, dtype[integer[Any] | floating[Any]]]:
	img = cv2.cvtColor(image, cv2.COLOR_RGB2LUV)
	img_eq = img.copy()
	img_eq[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
	final_img = cv2.cvtColor(img_eq, cv2.COLOR_LUV2RGB)
	return final_img

In [137]:
from PIL import Image


class iWildCam2020Dataset(IterableDataset):
	def __init__(
		self,
		dataset: str,
		metadata: pd.DataFrame,
		batch_size: int = 16,
		resize_dim: Tuple[int, int] | None = None,
	):
		super().__init__()
		self.metadata = metadata

		self.dataset = dataset
		self.batch_size = batch_size
		self.resize_dim = resize_dim

	def __iter__(self):  # -> Generator[Any, Any, None]:
		for idx, image_batch in enumerate(self.dataset.iter(self.batch_size)):
			is_day = self.metadata[idx : idx + self.batch_size]["is_day"].values
			image_batch = image_batch["image"]
			imgs_ = []

			dark_idx = set(np.where(~is_day)[0].tolist())
			for i in range(len(image_batch)):
				img = np.transpose(image_batch[i].numpy())
				#print(img.shape)
				if i in dark_idx:
					img = preprocess_dark_images(img)
				#img = np.array(Image.fromarray(img).resize(self.resize_dim))
				img = cv2.resize(img, self.resize_dim, interpolation=cv2.INTER_AREA)
				imgs_.append(torch.tensor(np.transpose(img)))
			yield torch.stack(imgs_)

In [99]:
batch_size = 4

In [100]:
images_metadata = split_day_night_time(images_metadata)

In [102]:
dataset = load_dataset(
	"anngrosha/iWildCam2020", split="train", streaming=True
).with_format("torch")

Resolving data files:   0%|          | 0/190 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/190 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

In [138]:
dataset_iteratable = iWildCam2020Dataset(
	dataset=dataset,
	metadata=images_metadata,
	batch_size=batch_size,
	resize_dim=(1000, 1000),
)

for batch in dataset_iteratable:
	print(batch.shape)

torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Size([4, 3, 1000, 1000])
torch.Si

KeyboardInterrupt: 