In [None]:
import torch
from torch import nn, optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchmetrics import Precision, Recall, AveragePrecision
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

# Config
BATCH_SIZE = 32
LR = 0.001
EPOCHS = 15
NUM_CLASSES = len(ImageFolder("cropped_animals/train").classes)  # Auto-detect classes

# Augmentations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Datasets (corrected paths)
train_dataset = ImageFolder("cropped_animals/train", transform=train_transform)
val_dataset = ImageFolder("cropped_animals/val", transform=val_transform)

# Class weights for imbalance
class_counts = torch.bincount(torch.tensor(train_dataset.targets))
class_weights = (1. / class_counts.float()) * len(class_counts) / 2.0
weights = class_weights[train_dataset.targets]
sampler = torch.utils.data.WeightedRandomSampler(weights, len(weights))

# Metrics
precision = Precision(task="multiclass", num_classes=NUM_CLASSES, average='macro').cuda()
recall = Recall(task="multiclass", num_classes=NUM_CLASSES, average='macro').cuda()
map_metric = AveragePrecision(task="multiclass", num_classes=NUM_CLASSES).cuda()

def log_metrics(prefix, preds, targets):
    print(f"\n{prefix} Metrics:")
    print(f"• Precision: {precision(preds, targets):.4f}")
    print(f"• Recall: {recall(preds, targets):.4f}")
    print(f"• mAP: {map_metric(preds, targets):.4f}")

# Model
model = models.resnet50(pretrained=True)
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, NUM_CLASSES)
)
model = model.cuda()

# Loss and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights.cuda())
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)

# Training loop
for epoch in range(EPOCHS):
    # Train
    model.train()
    train_loss = 0
    all_preds, all_targets = [], []
    
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        all_preds.append(outputs.argmax(dim=1))
        all_targets.append(labels)
    
    # Train metrics
    train_preds = torch.cat(all_preds)
    train_targets = torch.cat(all_targets)
    log_metrics("Train", train_preds, train_targets)
    
    # Validation
    model.eval()
    val_loss = 0
    val_preds, val_targets = [], []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            val_loss += criterion(outputs, labels).item()
            val_preds.append(outputs.argmax(dim=1))
            val_targets.append(labels)
    
    # Val metrics
    val_preds = torch.cat(val_preds)
    val_targets = torch.cat(val_targets)
    log_metrics("Validation", val_preds, val_targets)
    
    # LR scheduling
    val_acc = (val_preds == val_targets).float().mean()
    scheduler.step(val_acc)
    
    # Epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"• Train Loss: {train_loss/len(train_loader):.4f}")
    print(f"• Val Loss: {val_loss/len(val_loader):.4f}")
    print(f"• Val Accuracy: {val_acc:.4f}")
    print(f"• LR: {optimizer.param_groups[0]['lr']:.2e}\n")

# Confusion matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(val_targets.cpu(), val_preds.cpu())
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=train_dataset.classes, yticklabels=train_dataset.classes)
plt.savefig("confusion_matrix.png")