In [14]:
%matplotlib inline
import json

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset

from typing import Any, Tuple, List

from cv2 import Mat
from numpy import dtype, floating, integer, ndarray

from tqdm.autonotebook import tqdm

plt.rcParams["figure.figsize"] = (16, 10)  # (w, h)

In [2]:
with open("../data/iwildcam2020_train_annotations.json") as f:
	data = json.load(f)


annotations = pd.DataFrame.from_dict(data["annotations"])
images_metadata = pd.DataFrame.from_dict(data["images"])
categories = pd.DataFrame.from_dict(data["categories"])

In [3]:
images_metadata.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 217959 entries, 0 to 217958
Data columns (total 9 columns):
 #   Column          Non-Null Count   Dtype 
---  ------          --------------   ----- 
 0   seq_num_frames  217959 non-null  int64 
 1   location        217959 non-null  int64 
 2   datetime        217959 non-null  object
 3   id              217959 non-null  object
 4   frame_num       217959 non-null  int64 
 5   seq_id          217959 non-null  object
 6   width           217959 non-null  int64 
 7   height          217959 non-null  int64 
 8   file_name       217959 non-null  object
dtypes: int64(5), object(4)
memory usage: 15.0+ MB


In [4]:
# convert datetime type and split into day/night time
def split_day_night_time(
	data: pd.DataFrame, day_start: str = "06:00:00", day_end: str = "18:00:00"
) -> pd.DataFrame:
	data = data.copy()
	data["datetime"] = pd.to_datetime(data["datetime"])
	data["is_day"] = data["datetime"].apply(
		lambda x: True
		if pd.Timestamp(day_start).time() <= x.time() < pd.Timestamp(day_end).time()
		else False
	)
	return data

In [5]:
def preprocess_dark_images(
	image: np.ndarray,
) -> Mat | ndarray[Any, dtype[integer[Any] | floating[Any]]]:
	img = cv2.cvtColor(image, cv2.COLOR_RGB2LUV)
	img_eq = img.copy()
	img_eq[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
	final_img = cv2.cvtColor(img_eq, cv2.COLOR_LUV2RGB)
	return final_img

In [6]:
from pathlib import Path

class iWildCam2020Dataset(IterableDataset):
    def __init__(
        self,
        dataset: str,
        metadata: pd.DataFrame,
        batch_size: int = 16,
        resize_dim: Tuple[int, int] | None = None,
        num_samples: int = 1000,
        mean: np.ndarray | None = None,
        std: np.ndarray | None = None,
        save_dir: str | None = None,
        overwrite: bool = False,
        split: str = "train",
        val_ratio: float = 0.2,
    ):
        super().__init__()
        self.metadata = metadata

        self.split = split
        self.val_ratio = val_ratio
        self.train_size = int((1 - val_ratio) * num_samples)
        self.val_size = num_samples - self.train_size

        self.dataset = dataset
        self.batch_size = batch_size
        self.resize_dim = resize_dim

        self.num_samples = num_samples
        if self.split == "train":
            self.num_batches = (self.train_size + batch_size - 1) // batch_size
        else:
            self.num_batches= (self.val_size + batch_size - 1) // batch_size

        self.mean = torch.tensor(mean if mean is not None else [0.0, 0.0, 0.0]).view(
            3, 1, 1
        )
        self.std = torch.tensor(std if std is not None else [1.0, 1.0, 1.0]).view(
            3, 1, 1
        )

        self.save_dir = Path(save_dir) if save_dir else None
        if self.save_dir:
            self.save_dir.mkdir(parents=True, exist_ok=True)
        self.overwrite = overwrite

    def save_image(self, img_tensor: torch.Tensor, idx: int):
        if self.save_dir:
            save_path = self.save_dir / f"image_{idx}.pt"
            torch.save(img_tensor, save_path)
    
    def load_image(self, idx: int) -> torch.Tensor | None:
        if self.save_dir:
            save_path = self.save_dir / f"image_{idx}.pt"
            if save_path.exists():
                return torch.load(save_path, weights_only=True)
        return None
    
    def __len__(self):
        return self.num_batches

    def __iter__(self):
        if self.split == "train":
            start_idx, end_idx = 0, self.train_size
        else:
            start_idx, end_idx = self.train_size, self.num_samples
        
        for idx, image_batch in enumerate(self.dataset.iter(self.batch_size)):
            # to get consistent part of dataset + val / train split
            batch_start = idx * self.batch_size
            if batch_start >= end_idx:
                break
            if batch_start < start_idx:
                continue
            
            is_day = self.metadata[idx * self.batch_size : (idx + 1) * self.batch_size][
                "is_day"
            ].values
            image_batch = image_batch["image"]
            imgs_ = []

            dark_idx = set(np.where(~is_day)[0].tolist())
            for i in range(len(image_batch)):
                img_tensor = self.load_image(idx * self.batch_size + i)
                if img_tensor is None:
                    img = np.transpose(image_batch[i].numpy())
                    if i in dark_idx:
                        img = preprocess_dark_images(img)
                    img = cv2.resize(img, self.resize_dim, interpolation=cv2.INTER_AREA)
                    img_tensor = (
                        torch.tensor(np.transpose(img, (2, 0, 1)), dtype=torch.float32)
                        / 255.0
                    )

                    if self.save_dir:
                        self.save_image(img_tensor, idx * self.batch_size + i)

                imgs_.append(img_tensor)
            yield torch.stack(imgs_)

In [7]:
def calculate_mean_std(dataset, batch_size=32, resize_dim=(224, 224), num_samples=1000):
    means = []
    stds = []
    for idx, image_batch in tqdm(enumerate(dataset.iter(batch_size)), total = ((num_samples + batch_size - 1) // batch_size)):
        if idx * batch_size >= num_samples:
            break

        imgs_ = []
        for image in image_batch["image"]:
            img = np.transpose(image.numpy(), (1, 2, 0))
            img = cv2.resize(img, resize_dim, interpolation=cv2.INTER_AREA)
            img = img / 255.0
            imgs_.append(img)

        imgs_array = np.stack(imgs_)
        means.append(imgs_array.mean(axis=(0, 1, 2)))
        stds.append(imgs_array.std(axis=(0, 1, 2)))

    mean = np.mean(means, axis=0)
    std = np.mean(stds, axis=0)
    return mean, std

In [9]:
images_metadata = split_day_night_time(images_metadata)

In [10]:
dataset = load_dataset(
	"anngrosha/iWildCam2020", split="train", streaming=True
).with_format("torch")

Resolving data files:   0%|          | 0/190 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/190 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

In [11]:
batch_size = 4
img_size = 528
resize_dim = (img_size, img_size)

num_samples = 200
val_ratio = 0.2

In [15]:
mean, std = calculate_mean_std(
    dataset, batch_size=batch_size, resize_dim=resize_dim, num_samples=num_samples
)
mean, std

  0%|          | 0/50 [00:00<?, ?it/s]

(array([0.39645247, 0.40370792, 0.3785379 ]),
 array([0.25464274, 0.25719384, 0.26200439]))

In [16]:
train_dataset = iWildCam2020Dataset(
    dataset=dataset,
    metadata=images_metadata,
    batch_size=batch_size,
    resize_dim=resize_dim,
    num_samples=num_samples,
    mean=mean,
    std=std,
    save_dir="./data/train",
    split="train",
    val_ratio=val_ratio
)

val_dataset = iWildCam2020Dataset(
    dataset=dataset,
    metadata=images_metadata,
    batch_size=batch_size,
    resize_dim=resize_dim,
    num_samples=num_samples,
    mean=mean,
    std=std,
    save_dir="./data/val",
    split="val",
    val_ratio=val_ratio
)


train_loader = DataLoader(train_dataset, batch_size=None)
val_loader = DataLoader(val_dataset, batch_size=None)

for batch in train_loader:
	print(batch.shape)

for batch in val_loader:
	print(batch.shape)

torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 528, 528])
torch.Size([4, 3, 528, 528])
torch.Size([4, 3, 528, 528])
torch.Size([4, 3, 528, 528])
torch.Size([4, 3, 528, 528])
torch.Size([4, 3, 528, 528])
torch.Size([4, 3, 528, 528])
torch.Size([4, 3, 528, 528])
torch.Size([4, 3, 528, 528])
torch.Size([4, 3, 528, 528])
torch.Size([4,