In [1]:
%matplotlib inline
import json
from typing import Any, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from cv2 import Mat
from datasets import load_dataset
from numpy import dtype, floating, integer, ndarray
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler, Subset
from tqdm.autonotebook import tqdm
import torch.nn.functional as F

import pandas as pd

import torchvision
from torchvision import transforms

plt.rcParams["figure.figsize"] = (16, 10)  # (w, h)

In [2]:
from datasets import load_dataset
from transformers import AutoFeatureExtractor

# Load Hugging Face dataset
feature_extractor = AutoFeatureExtractor.from_pretrained(
    "microsoft/swinv2-base-patch4-window16-256"
)



In [3]:
with open("../data/iwildcam2020_train_annotations.json") as f:
    data = json.load(f)


annotations = pd.DataFrame.from_dict(data["annotations"])
images_metadata = pd.DataFrame.from_dict(data["images"])
categories = pd.DataFrame.from_dict(data["categories"])

In [4]:
# convert datetime type and split into day/night time
def split_day_night_time(
    data: pd.DataFrame, day_start: str = "06:00:00", day_end: str = "18:00:00"
) -> pd.DataFrame:
    data = data.copy()
    data["datetime"] = pd.to_datetime(data["datetime"])
    data["is_day"] = data["datetime"].apply(
        lambda x: True
        if pd.Timestamp(day_start).time() <= x.time() < pd.Timestamp(day_end).time()
        else False
    )
    return data


def preprocess_dark_images(
    image: np.ndarray,
) -> Mat | ndarray[Any, dtype[integer[Any] | floating[Any]]]:
    img = cv2.cvtColor(image, cv2.COLOR_RGB2LUV)
    img_eq = img.copy()
    img_eq[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
    final_img = cv2.cvtColor(img_eq, cv2.COLOR_LUV2RGB)
    return final_img


def crop_black_lines(image: np.ndarray) -> np.ndarray:
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        x, y, w, h = cv2.boundingRect(contours[0])
        cropped_image = image[y : y + h, x : x + w]
        return cropped_image
    else:
        return image

In [5]:
from csv import Error
from pathlib import Path
from itertools import islice
from PIL import Image, UnidentifiedImageError


class iWildCam2020Preprocessor:
    def __init__(
        self,
        dataset: str,
        metadata: pd.DataFrame,
        annotations,
        batch_size: int = 1,
        resize_dim: Tuple[int, int] | None = None,
        num_samples: int = 1000,
        save_dir: str = "./processed_images",
        overwrite: bool = False,
    ):
        self.metadata = metadata

        self.dataset = dataset
        self.resize_dim = resize_dim

        self.num_samples = num_samples

        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.overwrite = overwrite

        self.batch_size = batch_size

        self.annotations = annotations

        unique_classes = self.annotations["category_id"].unique()
        category_to_index = {
            category_id: index for index, category_id in enumerate(unique_classes)
        }
        self.annotations["mapped_category_id"] = self.annotations["category_id"].map(
            category_to_index
        )

    @staticmethod
    def is_valid(image: np.ndarray) -> bool:
        if (
            image.ndim not in [3, 4]
            or image.shape[0] == 1
            or image.shape[1] == 1
            or image.shape[2] != 3
        ):
            print(f"Skipping image with invalid shape: {image.shape}")
            return False, None

        if image.ndim == 4:
            image = np.squeeze(image, axis=-1)

        if image.ndim == 3 and image.shape[2] == 3:
            try:
                img = Image.fromarray(image.astype(np.uint8), mode="RGB")
                img.verify()
                img.load()
                img = img.convert("RGB")
                return True, img
            except (UnidentifiedImageError, IOError) as e:
                print(f"Error while processing RGB image: {e}")
                return False, None

        return False, None

    def preprocess_dataset(self):
        existing_files = list(self.save_dir.glob("image_*.pt"))
        existing_files.sort(key=lambda x: int(x.stem.split("_")[1]))

        last_processed_index = (
            int(existing_files[-1].stem.split("_")[1]) if existing_files else 0
        )
        image_iterator = self.dataset.iter(batch_size=self.batch_size)
        if last_processed_index != 0:
            image_iterator = islice(
                image_iterator, last_processed_index // self.batch_size
            )

        saved_samples = last_processed_index + 1 if last_processed_index != 0 else 0
        idx = saved_samples
        with tqdm(
            total=self.num_samples,
            initial=(last_processed_index + 1 if last_processed_index != 0 else 0),
        ) as pbar:
            while saved_samples < self.num_samples:
                try:
                    batch = next(image_iterator)
                    for i, images in enumerate(batch["image"]):
                        save_path = self.save_dir / f"image_{saved_samples}.pt"

                        if save_path.exists() and not self.overwrite:
                            pbar.update(1)
                            saved_samples += 1
                            continue

                        img_np = np.transpose(images.numpy())

                        valid, img = self.is_valid(img_np)
                        if not valid:
                            print(
                                f"Skipping invalid or corrupt image at index {idx + i}, {img_np.shape}"
                            )
                            pbar.update(0)
                            continue

                        img_np = np.array(img)

                        is_day = self.metadata.iloc[idx + i]["is_day"]
                        if not is_day:
                            img_np = preprocess_dark_images(img_np)

                        img_np = crop_black_lines(img_np)
                        img_np = cv2.resize(
                            img_np, self.resize_dim, interpolation=cv2.INTER_AREA
                        )

                        img_tensor = (
                            torch.tensor(
                                np.transpose(img_np, (2, 0, 1)), dtype=torch.float32
                            )
                            / 255.0
                        )

                        label = self.annotations.iloc[idx]["mapped_category_id"]
                        data = {
                            "image": img_tensor,
                            "label": label,
                        }

                        if not save_path.exists() or self.overwrite:
                            torch.save(data, save_path)

                        saved_samples += 1
                        pbar.update(1)

                except Exception as e:
                    print(f"Skipping record at index {idx} due to error")
                    pbar.update(0)
                idx += self.batch_size

In [6]:
class iWildCam2020Dataset(Dataset):
    def __init__(
        self,
        transform: transforms.Compose | None = None,
        save_dir: str = "./data/processed_images",
    ):
        self.save_dir = Path(save_dir)

        self.transform = transform
        self.items = list(self.save_dir.glob("image_*.pt"))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        img_path = self.save_dir / f"image_{idx}.pt"
        data = torch.load(img_path)
        
        img_tensor = data["image"]
        label = data["label"]

        if self.transform:
            img_tensor = self.transform(img_tensor)

        return img_tensor, label

In [7]:
import os
from datetime import datetime


def get_unique_model_path(base_path):
    if not os.path.exists(base_path):
        return base_path

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    unique_path = f"{base_path}_{timestamp}.pt"

    while os.path.exists(unique_path):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        unique_path = f"{base_path}_{timestamp}.pt"

    return unique_path

In [8]:
from typing import Dict
import evaluate


def init_metrics() -> Dict[str, evaluate.Metric]:
    return {
        "accuracy": evaluate.load("accuracy"),
        "precision": evaluate.load("precision", zero_division=0, average="macro"),
        "recall": evaluate.load("recall", zero_division=0, average="macro"),
        "f1": evaluate.load("f1", average="macro"),
    }


def compute_batch_metrics(metrics: Dict[str, evaluate.Metric]) -> Dict[str, float]:
    computed_metrics = {}

    computed_metrics["accuracy"] = metrics["accuracy"].compute()["accuracy"]
    computed_metrics["precision"] = metrics["precision"].compute(
        zero_division=0, average="macro"
    )["precision"]
    computed_metrics["recall"] = metrics["recall"].compute(
        zero_division=0, average="macro"
    )["recall"]
    computed_metrics["f1"] = metrics["f1"].compute(average="macro")["f1"]

    return computed_metrics

In [None]:
from torch import autocast

# comment if cuda unavaliabe
#from torch import GradScaler


def train_with_lora_and_hard_negatives(
    model,
    criterion,
    optimizer,
    train_loader,
    val_loader,
    batch_size,
    num_samples,
    device,
    num_epochs=1,
    ckpt_path="models/best.pt",
    use_mlflow=False,
    use_wandb=False,
    grad_clip_norm=None,
    scheduler=None,
    hard_negative_ratio=0.1,
    hard_negative_update_freq=1,
    use_amp=False,
):
    ckpt_path = get_unique_model_path(ckpt_path)
    best_accuracy = 0.0
    metrics = init_metrics()

    hard_negatives = []

    if use_mlflow:
        import mlflow

        mlflow.start_run()
        mlflow.log_params(
            {
                "model": model.__class__.__name__,
                "criterion": criterion.__class__.__name__,
                "optimizer": optimizer.__class__.__name__,
                "num_epochs": num_epochs,
                "batch_size": batch_size,
                "num_samples": num_samples,
                "model_path": ckpt_path,
            }
        )
    
    # comment if cuda unavaliabe
    #scaler = GradScaler() if use_amp else None

    # Training
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        current_hard_negatives = []

        if len(hard_negatives) > 0:
            hard_negative_batch = generate_hard_negative_batch(hard_negatives, batch_size, device)
            train_loader_with_hard_negatives = concatenate_batches(train_loader, hard_negative_batch, device)
        else:
            train_loader_with_hard_negatives = train_loader

        for images, labels in tqdm(train_loader_with_hard_negatives, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            with autocast(device_type="cuda", enabled=use_amp):
                outputs = model(images)
                ind_loss = torch.nn.functional.cross_entropy(outputs.logits, labels, reduction='none')
                loss = ind_loss.mean()  # Mean loss for the batch

            if use_amp:
                scaler.scale(loss).backward()
                if grad_clip_norm:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
                scaler.step(optimizer)
                scaler.update() 
            else:
                loss.backward()
                if grad_clip_norm:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
                optimizer.step()


            train_loss += loss.item()

            _, preds = torch.max(outputs.logits, dim=1)
            misclassified = preds != labels
            hard_negative_losses = ind_loss[misclassified]
            current_hard_negatives.extend(
                [(images[i], labels[i]) for i, _ in enumerate(hard_negative_losses)]
            )


        if scheduler:
            scheduler.step()

        # Update hard negatives
        if epoch % hard_negative_update_freq == 0:
            hard_negatives.extend(current_hard_negatives)
            max_negatives = int(hard_negative_ratio * len(train_loader.dataset))
            hard_negatives = hard_negatives[-max_negatives:]

        # Validation
        val_loss = 0.0
        model.eval()
        computed_metrics = {}
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc="Validation"):
                images, labels = images.to(device), labels.to(device)
                with autocast(device_type="cuda", enabled=use_amp):
                    outputs = model(images)
                    loss = criterion(outputs.logits, labels)
                val_loss += loss.item()

                preds = outputs.logits.argmax(dim=1)
                metrics["accuracy"].add_batch(predictions=preds, references=labels)
                metrics["precision"].add_batch(predictions=preds, references=labels)
                metrics["recall"].add_batch(predictions=preds, references=labels)
                metrics["f1"].add_batch(predictions=preds, references=labels)

            computed_metrics = compute_batch_metrics(metrics=metrics)

        # Log and save
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        log_data = {
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss,
            **computed_metrics,
        }

        if use_mlflow:
            mlflow.log_metrics(log_data, step=epoch)
        if use_wandb:
            wandb.log(log_data)

        print(
            f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}"
        )
        print(f"Metrics: {computed_metrics}")

        if computed_metrics["accuracy"] > best_accuracy:
            best_accuracy = computed_metrics["accuracy"]
            torch.save(model.state_dict(), ckpt_path)
            if use_mlflow:
                mlflow.pytorch.log_model(model, ckpt_path)

    if use_mlflow:
        mlflow.end_run()

def generate_hard_negative_batch(hard_negatives, batch_size, device):
    if len(hard_negatives) < batch_size:
        return DataLoader(hard_negatives, batch_size=batch_size, shuffle=True)
    
    indices = torch.randint(0, len(hard_negatives), (batch_size,))
    batch = [hard_negatives[i] for i in indices]
    images, labels = zip(*batch)
    return torch.stack(images).to(device), torch.stack(labels).to(device)

def concatenate_batches(train_loader, hard_negative_batch, device):
    train_images, train_labels = next(iter(train_loader))
    train_images, train_labels = train_images.to(device), train_labels.to(device)

    images, labels = hard_negative_batch
    combined_images = torch.cat((train_images, images), dim=0)
    combined_labels = torch.cat((train_labels, labels), dim=0)
    combined_loader = torch.utils.data.DataLoader(
        list(zip(combined_images, combined_labels)), batch_size=train_loader.batch_size, shuffle=True
    )
    return combined_loader

In [14]:
dataset = load_dataset(
    "anngrosha/iWildCam2020", split="train", streaming=True
).with_format("torch")

Resolving data files:   0%|          | 0/190 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/190 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

In [15]:
images_metadata = split_day_night_time(images_metadata)

In [16]:
batch_size = 2
img_size = 224
resize_dim = (img_size, img_size)
num_classes = len(annotations["category_id"].unique())
print(num_classes)

num_samples = 20_000
val_ratio = 0.2

train_size = int(num_samples * (1 - val_ratio))
val_size = int(num_samples * val_ratio)

num_epochs = 5

mean_std_samples = num_samples - int(num_samples * val_ratio)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "mps"

save_dir = "./data/processed_images/"

216


In [17]:
dataset_preprocessor = iWildCam2020Preprocessor(
    dataset=dataset,
    metadata=images_metadata,
    resize_dim=resize_dim,
    batch_size=100,
    num_samples=num_samples,
    save_dir=save_dir,
    overwrite=False,
    annotations=annotations,
)
dataset_preprocessor.preprocess_dataset()

100%|##########| 20000/20000 [00:00<?, ?it/s]

In [18]:
def calculate_mean_std(
    resize_dim=(224, 224),
    num_samples=1000,
    device="cpu",
    save_dir="./processed_images",
    save_files=True,
):
    total_pixels = 0
    sum_mean = torch.zeros(3, dtype=torch.float32, device=device)
    sum_std = torch.zeros(3, dtype=torch.float32, device=device)

    image_files = list(Path(save_dir).glob("image_*.pt"))
    image_files.sort(key=lambda x: int(x.stem.split("_")[1]))

    image_files = image_files[:num_samples]

    with tqdm(total=len(image_files)) as pbar:
        for idx, image_file in enumerate(image_files):
            img_tensor = torch.load(image_file)["image"].to(device)

            img_tensor = torch.nn.functional.interpolate(
                img_tensor.unsqueeze(0),
                size=resize_dim,
                mode="bilinear",
                align_corners=False,
            ).squeeze(0)

            img_tensor = img_tensor / 255.0  # Normalize to [0, 1]
            sum_mean += img_tensor.mean(dim=(1, 2))
            sum_std += img_tensor.std(dim=(1, 2))
            total_pixels += img_tensor.numel()

            pbar.update(1)

    mean = sum_mean / total_pixels
    std = sum_std / total_pixels

    if save_files:
        mean_file = Path(save_dir) / f"mean_top_{num_samples}.pt"
        std_file = Path(save_dir) / f"std_top_{num_samples}.pt"
        torch.save(mean, mean_file)
        torch.save(std, std_file)

    return mean, std


def get_mean_std_from_files(
    save_dir="./processed_images", num_samples=1000, device="cpu"
):
    mean_file = Path(save_dir) / f"mean_top_{num_samples}.pt"
    std_file = Path(save_dir) / f"std_top_{num_samples}.pt"

    if mean_file.exists() and std_file.exists():
        mean = torch.load(mean_file)
        std = torch.load(std_file)
        return mean, std
    else:
        return calculate_mean_std(
            num_samples=num_samples, device=device, save_dir=save_dir, save_files=True
        )


mean, std = get_mean_std_from_files(save_dir, mean_std_samples, device=device)
mean, std = mean.to("cpu"), std.to("cpu")
mean, std

(tensor([2.9317e-09, 2.8943e-09, 2.8340e-09]),
 tensor([1.3166e-09, 1.3476e-09, 1.4205e-09]))

In [28]:
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop(size=(224, 224)),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
        transforms.RandomApply([transforms.GaussianBlur(3)], p=0.1),
        transforms.Normalize(mean=mean, std=std),
    ]
)

dataset = iWildCam2020Dataset(
    save_dir="./data/processed_images", transform=transform
)


train_idx = list(range(train_size))
val_idx = list(range(train_size, num_samples))

train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size
)

In [20]:
from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained(
    "microsoft/swinv2-tiny-patch4-window16-256"
)
model.classifier = nn.Linear(768, num_classes)

In [21]:
import peft
from peft import get_peft_model, LoraConfig


lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.2,
    target_modules=["query", "value", "key"],
    modules_to_save=["classifier"],
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 378,072 || all params: 28,122,330 || trainable%: 1.3444


In [22]:
model.to(device)


criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

In [33]:
train_with_lora_and_hard_negatives(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    batch_size=batch_size,
    num_samples=num_samples,
    device=device,
    num_epochs=10,
    ckpt_path="models/best-lora.pt",
    grad_clip_norm=1.0,
    scheduler=scheduler,
    hard_negative_ratio=0.1,
    hard_negative_update_freq=1,
    use_amp=False
)

Epoch 1/10:   0%|          | 0/8000 [00:00<?, ?it/s]

Validation:   0%|          | 0/2000 [00:00<?, ?it/s]

Epoch [1/10], Train Loss: 0.0003, Val Loss: 0.0016
Metrics: {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0}


Epoch 2/10:   0%|          | 0/2 [00:00<?, ?it/s]

Validation:   0%|          | 0/2000 [00:00<?, ?it/s]

Epoch [2/10], Train Loss: 0.0003, Val Loss: 0.0016
Metrics: {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0}


Epoch 3/10:   0%|          | 0/2 [00:00<?, ?it/s]

Validation:   0%|          | 0/2000 [00:00<?, ?it/s]

KeyboardInterrupt: 