# Image Classification with PyTorch & TIMM

In [None]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,models,transforms
from torch.utils.data import DataLoader,Dataset
import copy 
import time
import pandas as pd
import json,os
from PIL import Image
import webdataset as wds
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import amp
from collections import defaultdict
import plotly.express as px
import cv2
import matplotlib.pyplot as plt
import numpy as np
import timm
from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score
import random
from timm.loss import SoftTargetCrossEntropy

## 📂 Load Training Data from JSON 📝

In [None]:
with open(r"C:\Users\thaku\jupyter_notebook _datasets\Wildlife_dataset\Dataset\train_mini.json",'r') as f:
    json_train_data = json.load(f)

## 🔍 Inspect Loaded JSON Data 📊

In [None]:
json_train_data

In [None]:
json_train_data["images"]

In [None]:
json_train_data.keys()

In [None]:
json_train_data["categories"]

In [None]:
json_train_data["annotations"]

In [None]:
categories_df = pd.DataFrame(json_train_data["categories"])

In [None]:
categories_df

In [None]:
categories_df = categories_df.iloc[:301,:]

## 🗂️ Convert JSON Data to DataFrames 📊

In [None]:
images_df = pd.DataFrame(json_train_data["images"])
annotations_df = pd.DataFrame(json_train_data["annotations"])


In [None]:
images_df

In [None]:
annotations_df

## 🔗 Merge Categories, Annotations, and Images into Final DataFrame 🖼️

In [None]:
category_annotation_df = categories_df.merge(annotations_df,left_on="id",right_on="category_id",suffixes=["_cat","_ann"])

In [None]:
category_annotation_df.head()

In [None]:
final_df =category_annotation_df.merge(images_df,left_on="image_id",right_on="id",suffixes=("_cat_ann","_img"))

In [None]:
final_df.sample(5)

In [None]:
final_df.columns

In [None]:
pd.set_option("display.max_columns", None)


In [None]:
final_df = final_df.drop(columns=["id_cat","date","rights_holder","license","id_ann","id"])

In [None]:
final_df.head(5)

In [None]:
final_df["width"].median()

In [None]:
final_df["common_name"].unique()

In [None]:
final_df["height"].median()

In [None]:
final_df.head()

In [None]:
final_df["file_name"][0]

In [None]:
new_df_train

In [None]:
new_df_train=new_df_train.rename(columns={"file_name": "img_path", "category_id": "class"})

In [None]:
pd.set_option("display.max_colwidth",None)

In [None]:
new_df_train

In [None]:
final_df.describe()

## 🌍 Visualize Species Distribution on World Map 🐛🕷️

In [None]:
species_name =['Common Earthworm', 'Mediterranean Fanworm', 'Serpula columbiana',
       'Blue Tube Worm', 'Giant House Spider', 'California Turret Spider',
       'Oak Spider', 'Gorse Orbweaver']
subset=final_df[final_df["common_name"].isin(species_name)]
fig = px.scatter_geo(subset,
                     lat="latitude", lon="longitude",
                     scope="world",
                     title=f"Locations of {species_name}",
                     opacity=0.7,
                     color="common_name")
fig.show()

In [None]:
import plotly.graph_objects as go

# Get unique species
species_list = final_df["common_name"].unique()

# Create empty figure
fig = go.Figure()

# Add one scatter trace per species
for sp in species_list:
    subset = final_df[final_df["common_name"] == sp]
    fig.add_trace(
        go.Scattergeo(
            lon=subset["longitude"],
            lat=subset["latitude"],
            text=subset["common_name"],
            mode="markers",
            marker=dict(size=6),
            name=sp,
            visible=False  # initially hidden
        )
    )

# Make the first species visible by default
fig.data[0].visible = True

# Dropdown menu: one button per species
buttons = []
for i, sp in enumerate(species_list):
    visible = [False] * len(species_list)
    visible[i] = True  # only this species visible
    buttons.append(
        dict(
            label=sp,
            method="update",
            args=[{"visible": visible}, {"title": f"Locations of {sp}"}]
        )
    )

# Add dropdown
fig.update_layout(
    updatemenus=[dict(active=0, buttons=buttons, x=1.05, y=1.15)],
    title="Species Distribution Map",
    geo=dict(scope="world")
)

fig.show()


## 📂 Load & Visualize Validation Data 🌍🐛🕷️

In [None]:
with open(r"C:\Users\thaku\jupyter_notebook _datasets\Wildlife_dataset\Dataset\val.json",'r') as f:
    json_val_data = json.load(f)

In [None]:
val_categories_df = pd.DataFrame(json_train_data["categories"])

In [None]:
val_categories_df

In [None]:
val_categories_df = val_categories_df.iloc[:301,:]

In [None]:
val_images_df = pd.DataFrame(json_val_data["images"])
val_annotations_df = pd.DataFrame(json_val_data["annotations"])


In [None]:
val_category_annotation_df = val_categories_df.merge(val_annotations_df,left_on="id",right_on="category_id",suffixes=["_cat","_ann"])

In [None]:
val_final_df =val_category_annotation_df.merge(val_images_df,left_on="image_id",right_on="id",suffixes=("_cat_ann","_img"))

In [None]:
val_final_df = val_final_df.drop(columns=["id_cat","date","rights_holder","license","id_ann","id"])

In [None]:
val_final_df

In [None]:
new_df_val =  val_final_df[["file_name","category_id","latitude","longitude"]]

In [None]:
new_df_val

In [None]:
new_df_val=new_df_val.rename(columns={"file_name": "img_path", "category_id": "class"})
new_df_val

In [None]:
new_df_val.to_csv("new_df_val.csv", index=False)
new_df_train.to_csv("new_df_train.csv",index=False)

In [None]:
import plotly.graph_objects as go

# Get unique species
species_list = val_final_df["common_name"].unique()

# Create empty figure
fig = go.Figure()

# Add one scatter trace per species
for sp in species_list:
    subset = val_final_df[val_final_df["common_name"] == sp]
    fig.add_trace(
        go.Scattergeo(
            lon=subset["longitude"],
            lat=subset["latitude"],
            text=subset["common_name"],
            mode="markers",
            marker=dict(size=6),
            name=sp,
            visible=False  # initially hidden
        )
    )

# Make the first species visible by default
fig.data[0].visible = True

# Dropdown menu: one button per species
buttons = []
for i, sp in enumerate(species_list):
    visible = [False] * len(species_list)
    visible[i] = True  # only this species visible
    buttons.append(
        dict(
            label=sp,
            method="update",
            args=[{"visible": visible}, {"title": f"Locations of {sp}"}]
        )
    )

# Add dropdown
fig.update_layout(
    updatemenus=[dict(active=0, buttons=buttons, x=1.05, y=1.15)],
    title="Species Distribution Map",
    geo=dict(scope="world")
)

fig.show()


## 🎨 Probabilistic Background Blur Data Augmentation 🖼️


In [None]:
class ProbBackgroundBlur:
    def __init__(self,prob=0.3,min_kernel=7,max_kernel=31):
        self.prob=prob
        self.min_kernel=min_kernel
        self.max_kernel=max_kernel
    def __call__(self,img):
        if random.random()>self.prob:
            return img
        if isinstance(img,Image.Image):
            img=np.array(img)

        kernel_size =random.choice(range(self.min_kernel,self.max_kernel+1,2))
        blurred = cv2.GaussianBlur(img,(kernel_size,kernel_size),0)
        try:
            saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
            success,saliencyMap =saliency.computeSaliency(img)
            threshold =np.mean(saliencyMap)
            mask = (saliencyMap>threshold).astype(np.uint8)
            if np.mean(mask)>0.5:
                mask=1-mask
        except:
            gray = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
            _,mask =cv2.threshold(gray,120,1,cv2.THRESH_BINARY_INV)
        mask_3d = np.repeat(mask[:,:,np.newaxis],3,axis=2)
        result = np.where(mask_3d==1,img,blurred)
        return Image.fromarray(result.astype(np.uint8))

## 🔀 Mixup & CutMix Data Augmentation for Images 🖼️


In [None]:
class MixupCutmix:
    def __init__(self,mixup_alpha=0.2,cutmix_alpha=1.2,prob=0.8,switch_prob=0.4,num_classes=301):
        self.mixup_alpha = mixup_alpha
        self.cutmix_alpha = cutmix_alpha
        self.prob = prob
        self.switch_prob = switch_prob
        self.num_classes = num_classes
        self.applied = False
    def _one_hot(self,label):
        return torch.nn.functional.one_hot(label,num_classes=self.num_classes).float()
    def _sample_beta(self,alpha):
        return np.random.beta(alpha,alpha) if alpha>0 else 1.0
    def __call__(self, x,y):
        self.applied = False
        if np.random.random()>self.prob:
            return x,self._one_hot(y)
        self.applied = True
        B,C,H,W = x.size()
        shuffled_idx = torch.randperm(B)    
        y_shuffled = self._one_hot(y[shuffled_idx])
        if np.random.rand()<self.switch_prob:
            lam = self._sample_beta(self.cutmix_alpha)
            rx = np.random.randint(W)
            ry = np.random.randint(H)
            rw = max(int(W*np.sqrt(1-lam)),1)
            rh = max(int(H*np.sqrt(1-lam)),1)
            x1=  np.clip(rx-rw//2,0,W)
            x2 = np.clip(rx + rw // 2, 0, W)
            y1 = np.clip(ry - rh // 2, 0, H)
            y2 = np.clip(ry + rh // 2, 0, H)
            x[:, :, y1:y2, x1:x2] = x[shuffled_idx, :, y1:y2, x1:x2]
            lam = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
        else:
            lam = self._sample_beta(self.mixup_alpha)
            x = lam * x + (1 - lam) * x[shuffled_idx]
        mixed_y = lam * self._one_hot(y) + (1 - lam) * y_shuffled
        return x,mixed_y


## 🛠️ Define Training and Validation Transform Pipelines ✨


In [None]:
def get_transform_pipeline(blur_prob=0.2):
    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        ProbBackgroundBlur(prob=blur_prob,min_kernel=7,max_kernel=31),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                             std=[0.229,0.224,0.225])
    ])
    return transform

def get_val_transform():
    return transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                             std=[0.229,0.224,0.225])
    ])

## 📦 Prepare WebDataset & DataLoaders for Training and Validation 🖼️


In [None]:
train_transform = get_transform_pipeline(blur_prob=0.2)
val_transform = get_val_transform()

def has_all_fields(sample):
    return all(k in sample for k in ["jpg", "cls", "lat", "lon"])

train_shards = "file:C:/Users/thaku/jupyter_notebook _datasets/Wildlife_dataset/shards_train_mini_300/shard-{00000..00003}.tar"
val_shards   = "file:C:/Users/thaku/jupyter_notebook _datasets/Wildlife_dataset/shards_val_mini_300/shard-00000.tar"
def decode_lat(x): 
    return torch.tensor(float(x), dtype=torch.float32)

def decode_lon(x): 
    return torch.tensor(float(x), dtype=torch.float32)
train_dataset = (
    wds.WebDataset(train_shards, handler=wds.warn_and_continue)
    .decode("pil")                                  # decode jpg -> PIL
    .select(has_all_fields)                         # keep only valid samples
    .to_tuple("jpg", "cls", "lat", "lon")           # load all four
    .map_tuple(train_transform, int, decode_lat, decode_lon)  # apply transforms + type conversions
    .shuffle(1000)
)

val_dataset = (
    wds.WebDataset(val_shards, handler=wds.warn_and_continue)
    .decode("pil")                                  # decode jpg -> PIL
    .select(has_all_fields)                         # keep only valid samples
    .to_tuple("jpg", "cls", "lat", "lon")           # load all four
    .map_tuple(val_transform, int, decode_lat, decode_lon)  # apply transforms + type conversions
)

# -------------------------------
# 🔹 Step 4: DataLoader
# -------------------------------
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, num_workers=0
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=128, num_workers=0, shuffle=False
) 

## ⏱️ Benchmark DataLoader Performance ⚡


In [None]:
import time

# warmup
for _ in range(2):
    for _ in train_dataloader:
        break

# benchmark
start = time.time()
for i, (images, labels, lats, lons) in enumerate(train_dataloader):
    print(f"Batch {i} -> {images.shape}, {labels.shape}")
    break  # only load first batch
end = time.time()

print(f"Time to load 1 batch = {end - start:.3f} seconds")


## 🖼️🌍 Multi-Modal EfficientNet Model with Image and Location Features 🔗



In [None]:
class MultiModalEfficientNet(nn.Module):
    def __init__(self,num_classes,loc_feat_dim=64,backbone="tf_efficientnetv2_s"):
        super().__init__()
        self.backbone= timm.create_model(backbone,pretrained=True,num_classes=0)
        feat_dim=self.backbone.num_features
        for param in self.backbone.parameters():
            param.requires_grad=False
        
        self.loc_mlp=nn.Sequential(
            nn.Linear(2,loc_feat_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(loc_feat_dim,loc_feat_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
        )

        self.classifier = nn.Linear(feat_dim+loc_feat_dim,num_classes)
    def forward(self,img,lat,lon):
        img_feat=self.backbone(img)
        img_feat = nn.functional.dropout(img_feat,p=0.1,training=self.training)
        loc_input =torch.stack([lat,lon],dim=1)
        loc_feat = self.loc_mlp(loc_input)

        fused = torch.cat([img_feat,loc_feat],dim=1)
        return self.classifier(fused)
    def unfreeze_last_block(self):
        last_block = self.backbone.blocks[-1]
        for param in last_block.parameters():
            param.requires_grad=True

## 🚀 Multi-Modal Model Training with Mixed Precision, MixUp/CutMix, and Early Stopping 🖼️🌍


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---------------- Model ----------------
model = MultiModalEfficientNet(num_classes=301, backbone="tf_efficientnetv2_s").to(device)

# ---------------- Checkpoint ----------------
checkpoint_path = "best_checkpoint_multimodal.pth"

start_epoch = 0
best_val_loss = float("inf")
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'],strict=False)
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint['best_val_loss']
    print(f"✅ Loaded checkpoint from epoch {checkpoint['epoch']}, val_loss: {best_val_loss:.4f}")
csv_file = "training_log_3_multimodal.csv"
if os.path.exists(csv_file):
    os.remove(csv_file)

# Move model to device
model = model.to(device)
unfrozen = False
unfreeze_epoch = 5 
# Loss & optimizer
criterion = SoftTargetCrossEntropy()
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-5, weight_decay=5e-5
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=3
)

# AMP scaler
scaler = torch.amp.GradScaler("cuda")

# Early stopping
# best_val_loss = float("inf")
best_model_wts = copy.deepcopy(model.state_dict())
patience = 80
counter = 0
num_epochs = 300

# Dict for logging
epoc_data = defaultdict(list)
mixcut = MixupCutmix(num_classes=301)
try:
    for epoch in range(start_epoch,num_epochs):
        # ---------------- Training ----------------
        model.train()
        running_loss, correct_top1, correct_top5, total, num_batches = 0.0, 0, 0, 0, 0

        for batch_idx, (images, labels, lats, lons) in enumerate(train_dataloader):   # MULTIMODAL BATCH
            
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            images, labels = mixcut(images, labels) 
            lats, lons = lats.to(device, non_blocking=True), lons.to(device, non_blocking=True)
            
            optimizer.zero_grad()

            with torch.autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(images, lats, lons)           # MULTIMODAL FORWARD
                if labels.dim() == 1:  
                    labels = torch.nn.functional.one_hot(labels, num_classes=outputs.size(1)).float()
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            # ---------------- Debug logging for first batch ----------------
            if batch_idx % 10 == 0 or epoch == 0:  # log every 10 batches and first epoch
                print(f"\n🔎 Epoch {epoch+1}, Batch {batch_idx+1}")
                print("MixUp/CutMix applied:", mixcut.applied)
                sample_logits = outputs[0].detach().cpu()
                sample_probs = nn.functional.softmax(sample_logits, dim=0)
                top5_prob, top5_cls = sample_probs.topk(5)
                print("Top-5 predicted classes:", top5_cls.tolist())
                print("Top-5 probabilities:", top5_prob.tolist())
                if labels.dim() > 1:
                    print("Sample soft labels (first 10 classes):", labels[0][:10].tolist())
                else:
                    print("Sample hard label:", labels[0].item())

            # Metrics
            running_loss += loss.item()
            # Convert soft targets to hard labels for accuracy
            if labels.dim() > 1:  
                true_labels = labels.argmax(dim=1)
            else:
                true_labels = labels

            # Top-1
            _, preds = outputs.max(1)
            correct_top1 += (preds == true_labels).sum().item()

            # Top-5
            _, top5_preds = outputs.topk(5, dim=1)
            correct_top5 += (top5_preds == true_labels.view(-1, 1)).sum().item()

            total += true_labels.size(0)
            num_batches += 1



        train_loss = running_loss / num_batches
        train_acc = 100 * correct_top1 / total
        train_top5 = 100 * correct_top5 / total

        # ---------------- Validation ----------------
        model.eval()
        val_loss, correct_top1, correct_top5, total, num_batches = 0.0, 0, 0, 0, 0
        all_preds = []
        all_true = []

        with torch.inference_mode():
            for images, labels, lats, lons in val_dataloader:   # MULTIMODAL BATCH
                images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
                lats, lons = lats.to(device, non_blocking=True), lons.to(device, non_blocking=True)
        
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    outputs = model(images, lats, lons)        # MULTIMODAL FORWARD
                    if labels.dim() == 1:
                        labels = torch.nn.functional.one_hot(labels, num_classes=outputs.size(1)).float()
                    loss = criterion(outputs, labels)

                val_loss += loss.item()
                if labels.dim() > 1:  
                    true_labels = labels.argmax(dim=1)
                else:
                    true_labels = labels

                # Top-1
                _, preds = outputs.max(1)
                all_preds.extend(preds.cpu().numpy())
                all_true.extend(true_labels.cpu().numpy())
                correct_top1 += (preds == true_labels).sum().item()

                # Top-5
                _, top5_preds = outputs.topk(5, dim=1)
                correct_top5 += (top5_preds == true_labels.view(-1, 1)).sum().item()

                total += true_labels.size(0)
                num_batches += 1

        val_loss /= num_batches
        val_acc = 100 * correct_top1 / total
        val_top5 = 100 * correct_top5 / total

        # ---------------- Scheduler ----------------
        scheduler.step(val_loss)
        for param_group in optimizer.param_groups:
            print(f"Current LR: {param_group['lr']}")

        # ---------------- Logging ----------------
        f1 = f1_score(all_true, all_preds, average='macro',zero_division=0)
        precision = precision_score(all_true, all_preds, average='macro', zero_division=0)
        recall = recall_score(all_true, all_preds, average='macro', zero_division=0)

# Balanced accuracy
        balanced_acc = balanced_accuracy_score(all_true, all_preds)

        print(f"Val F1-score: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, Balanced Acc: {balanced_acc:.4f}")




        print(f"Epoch [{epoch+1}/{num_epochs}] "
            f"Train Loss: {train_loss:.4f}, Top-1: {train_acc:.2f}%, Top-5: {train_top5:.2f}% | "
            f"Val Loss: {val_loss:.4f}, Top-1: {val_acc:.2f}%, Top-5: {val_top5:.2f}%")

        epoc_data["Epoch"].append(epoch+1)
        epoc_data["Train Loss"].append(train_loss)
        epoc_data["Val Loss"].append(val_loss)
        epoc_data["Train Top-1 (%)"].append(train_acc)
        epoc_data["Val Top-1 (%)"].append(val_acc)
        epoc_data["Train Top-5 (%)"].append(train_top5)
        epoc_data["Val Top-5 (%)"].append(val_top5)
        epoc_data["Val F1-score"].append(f1)
        epoc_data["Val Precision"].append(precision)
        epoc_data["Val Recall"].append(recall)
        epoc_data["Val Balanced Acc"].append(balanced_acc)
        pd.DataFrame(epoc_data).to_csv(csv_file, index=False)

        # ---------------- Early stopping ----------------
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "best_val_loss": best_val_loss
            }, "best_checkpoint_multimodal_3.pth")
            print(f"✅ Model improved & saved at epoch {epoch+1}")
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("⏹️ Early stopping triggered")
                break
        if not unfrozen and (epoch + 1) >= 94:
            print(f"🔓 Unfreezing last MBConv block at epoch {epoch+1}...")
            model.unfreeze_last_block()
            optimizer = torch.optim.AdamW(
                filter(lambda p: p.requires_grad, model.parameters()), 
                lr=1e-6, weight_decay=5e-5  # smaller LR for fine-tuning CNN
            )
            unfrozen = True
except KeyboardInterrupt:
    print("\n⏹️ Training interrupted by user!")
finally:

# Load best weights
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), "best_model_only_multimodal_3.pth")