In [4]:
%matplotlib inline
import json

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from datasets import load_dataset

import torch.optim as optim
from torch.utils.data import DataLoader, IterableDataset
# from torchvision import transforms
# from torchvision.models import resnet18
from sklearn.metrics import accuracy_score

from typing import Any, Tuple, List

from cv2 import Mat
from numpy import dtype, floating, integer, ndarray

plt.rcParams["figure.figsize"] = (16, 10)  # (w, h)

In [5]:
with open("../data/iwildcam2020_train_annotations.json") as f:
	data = json.load(f)


annotations = pd.DataFrame.from_dict(data["annotations"])
images_metadata = pd.DataFrame.from_dict(data["images"])
categories = pd.DataFrame.from_dict(data["categories"])

In [6]:
images_metadata.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 217959 entries, 0 to 217958
Data columns (total 9 columns):
 #   Column          Non-Null Count   Dtype 
---  ------          --------------   ----- 
 0   seq_num_frames  217959 non-null  int64 
 1   location        217959 non-null  int64 
 2   datetime        217959 non-null  object
 3   id              217959 non-null  object
 4   frame_num       217959 non-null  int64 
 5   seq_id          217959 non-null  object
 6   width           217959 non-null  int64 
 7   height          217959 non-null  int64 
 8   file_name       217959 non-null  object
dtypes: int64(5), object(4)
memory usage: 15.0+ MB


In [7]:
# convert datetime type and split into day/night time
def split_day_night_time(
	data: pd.DataFrame, day_start: str = "06:00:00", day_end: str = "18:00:00"
) -> pd.DataFrame:
	data = data.copy()
	data["datetime"] = pd.to_datetime(data["datetime"])
	data["is_day"] = data["datetime"].apply(
		lambda x: True
		if pd.Timestamp(day_start).time() <= x.time() < pd.Timestamp(day_end).time()
		else False
	)
	return data

In [8]:
def preprocess_dark_images(
	image: np.ndarray,
) -> Mat | ndarray[Any, dtype[integer[Any] | floating[Any]]]:
	img = cv2.cvtColor(image, cv2.COLOR_RGB2LUV)
	img_eq = img.copy()
	img_eq[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
	final_img = cv2.cvtColor(img_eq, cv2.COLOR_LUV2RGB)
	return final_img

In [9]:
from PIL import Image


class iWildCam2020Dataset(IterableDataset):
	def __init__(
		self,
		dataset: str,
		metadata: pd.DataFrame,
		batch_size: int = 16,
		resize_dim: Tuple[int, int] | None = None,
	):
		super().__init__()
		self.metadata = metadata

		self.dataset = dataset
		self.batch_size = batch_size
		self.resize_dim = resize_dim

	def __iter__(self):  # -> Generator[Any, Any, None]:
		for idx, image_batch in enumerate(self.dataset.iter(self.batch_size)):
			is_day = self.metadata[idx : idx + self.batch_size]["is_day"].values
			image_batch = image_batch["image"]
			imgs_ = []
			
			dark_idx = set(np.where(~is_day)[0].tolist())
			for i in range(len(image_batch)):
				img = np.transpose(image_batch[i].numpy())
				if i in dark_idx:
					img = preprocess_dark_images(img)
				img = cv2.resize(img, self.resize_dim, interpolation=cv2.INTER_AREA)
				# imgs_.append(torch.tensor(np.transpose(img)))
				imgs_.append(torch.tensor(np.transpose(img, (2, 0, 1)), dtype=torch.float32) / 255.0)
			yield torch.stack(imgs_)

In [10]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 56 * 56, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 32 * 56 * 56)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [11]:
from tqdm.autonotebook import tqdm

def train(
    model, 
    criterion, 
    optimizer, 
    dataset_iterable, 
    batch_size,
    num_epochs=1,
    train_batches=5,
    val_batches=2,
    ckpt_path="best.pt",
):
    best = 0.0
    epoch_items = (train_batches + val_batches) * batch_size

    for epoch in range(num_epochs):

        train_loop = tqdm(
            enumerate(dataset_iterable, 0),
            total=train_batches,
            desc=f"Epoch {epoch}: train",
        )

        model.train()
        train_loss = 0.0

        for i, batch in train_loop:
            images = batch
            labels = torch.tensor(annotations['category_id'][epoch * epoch_items + batch_size * i : min(epoch * epoch_items + batch_size * (i + 1), len(annotations['category_id']))].values)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_loop.set_postfix({"loss": loss.item()})

            if i + 1 > train_batches:
                break

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {train_loss / (batch_size * train_batches):.4f}')


        correct = 0
        total = 0
        with torch.no_grad():
            model.eval()

            val_loop = tqdm(
                enumerate(dataset_iterable, 0),
                total=val_batches,
                desc=f"Val",
            )
            
            for i, batch in val_loop:
                images = batch
                labels = torch.tensor(annotations['category_id'][epoch * epoch_items + batch_size * train_batches + batch_size * i : min(epoch * epoch_items + batch_size * train_batches + batch_size * (i + 1), len(annotations['category_id']))].values)

                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_loop.set_postfix({"acc": correct / total})

                if i + 1 > val_batches:
                    break
                

            if correct / total > best:
                torch.save(model.state_dict(), ckpt_path)
                best = correct / total


In [12]:
dataset = load_dataset(
	"anngrosha/iWildCam2020", split="train", streaming=True
).with_format("torch")


Resolving data files:   0%|          | 0/190 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/190 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

In [13]:
images_metadata = split_day_night_time(images_metadata)

In [24]:
batch_size = 16
img_size = 224
num_classes = max(annotations['category_id'])

In [25]:
dataset_iterable = iWildCam2020Dataset(
	dataset=dataset,
	metadata=images_metadata,
	batch_size=batch_size,
	resize_dim=(img_size, img_size),
)

In [26]:
model = SimpleCNN(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [27]:
train(
    model, 
    criterion, 
    optimizer, 
    dataset_iterable, 
    batch_size,
    num_epochs=2,
    train_batches=5,
    val_batches=2,
)

Epoch 0: train:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch [1/2], Loss: 0.4057


Val:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 1: train:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch [2/2], Loss: 0.5088


Val:   0%|          | 0/2 [00:00<?, ?it/s]