In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import torch
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import transforms, models


class CUB200Dataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, apply_bg_removal=False):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform or self.default_transform()
        self.apply_bg_removal = apply_bg_removal

        # Load metadata
        self.data = self.load_metadata()

    def load_metadata(self):
        split_file = os.path.join(self.root_dir, f'{self.split}.txt')
        data = pd.read_csv(split_file, sep=' ', names=['filename', 'label'])
        data['filepath'] = data['filename'].apply(lambda x: os.path.join(self.root_dir, self.split, x))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.loc[idx, 'filepath']
        image = self.load_image(img_name)

        # Apply transformations to convert the cropped image to a tensor
        image = self.transform(image)

        label = self.data.loc[idx, 'label']
        return image, label

    @staticmethod
    def load_image(image_path):
        try:
            return Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None

    def default_transform(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])



In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import transforms, models
import torch.optim.lr_scheduler as lr_scheduler

# Define transformations 
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])



In [13]:
train_dataset = CUB200Dataset(root_dir='drive/MyDrive/COS30082_preprocessed', split='train', transform=train_transforms)
test_dataset = CUB200Dataset(root_dir='drive/MyDrive/COS30082_preprocessed', split='test', transform=test_transforms)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

class BirdClassifier(nn.Module):
    def __init__(self, model_name='efficientnet_b5', num_classes=200):
        super(BirdClassifier, self).__init__()
        
        # Choose the model based on the input argument 'model_name'
        if model_name == 'resnet50':
            self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            in_features = self.model.fc.in_features
            self.model.fc = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(in_features, num_classes)
            )
        
        elif model_name == 'resnext50_32x4d':
            self.model = models.resnext50_32x4d(weights=models.ResNeXt50_32X4D_Weights.DEFAULT)
            in_features = self.model.fc.in_features
            self.model.fc = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(in_features, num_classes)
            )
        
        elif model_name == 'efficientnet_b0':
            self.model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
            in_features = self.model.classifier[1].in_features
            self.model.classifier = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(in_features, num_classes)
            )
        
        elif model_name == 'efficientnet_b5':
            self.model = models.efficientnet_b5(weights=models.EfficientNet_B5_Weights.DEFAULT)
            in_features = self.model.classifier[1].in_features
            self.model.classifier = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(in_features, num_classes)
            )
        else:
            raise ValueError(f"Unsupported model_name: {model_name}. Choose from 'resnet50', 'resnext50_32x4d', 'efficientnet_b0', 'efficientnet_b5'.")

    def forward(self, x):
        return self.model(x)
    

# 5. Setup Model, Criterion, Optimizer, and Scheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# To use EfficientNet B5
model = BirdClassifier(model_name='efficientnet_b5', num_classes=200)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)




Downloading: "https://download.pytorch.org/models/efficientnet_b5_lukemelas-1a07897c.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b5_lukemelas-1a07897c.pth
100%|██████████| 117M/117M [00:00<00:00, 204MB/s] 


In [None]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience  # How many epochs to wait after last improvement
        self.delta = delta  # Minimum change to qualify as an improvement
        self.best_loss = None  # Best validation loss seen so far
        self.counter = 0  # Tracks how long since the last improvement
        self.early_stop = False  # Flag to indicate whether training should stop

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0  # Reset counter if validation loss improves


In [None]:
# 6. Training and Evaluation Functions
def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total_images = 0

    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total_images += labels.size(0)

    train_accuracy = (correct / total_images) * 100
    return running_loss / len(loader.dataset), train_accuracy

def evaluate(model, loader, criterion, device, num_classes=200):
    model.eval()
    running_loss = 0.0
    correct_top1 = 0
    total_images = 0

    # Track per-class accuracy
    class_correct = torch.zeros(num_classes).to(device)
    class_total = torch.zeros(num_classes).to(device)

    with torch.no_grad():
        for images, labels in tqdm(loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)

            _, predicted_top1 = torch.max(outputs, 1)
            correct_top1 += (predicted_top1 == labels).sum().item()
            total_images += labels.size(0)

            # Track per-class accuracy
            for label, prediction in zip(labels, predicted_top1):
                class_correct[label] += (prediction == label).item()
                class_total[label] += 1

    # Calculate average accuracy per class
    avg_class_accuracy = (class_correct / class_total).mean().item() * 100
    overall_accuracy = (correct_top1 / total_images) * 100

    return running_loss / len(loader.dataset), overall_accuracy, avg_class_accuracy


In [14]:



# 7. Training Loop with Early Stopping and Class Accuracy
num_epochs = 10
best_acc = 0.0
early_stopper = EarlyStopping(patience=5)

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, avg_class_acc = evaluate(model, test_loader, criterion, device, num_classes=200)

    scheduler.step()

    print(f"Train Loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f} | Train Accuracy: {train_acc:.2f}% | Validation Accuracy: {val_acc:.2f}%")
    print(f"Average Accuracy per Class: {avg_class_acc:.2f}%")

    # Save the best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
        print("Model saved!")

    # Early stopping check
    early_stopper(val_loss)
    if early_stopper.early_stop:
        print("Early stopping triggered!")
        break


Epoch 1/10


100%|██████████| 302/302 [07:47<00:00,  1.55s/it]
100%|██████████| 76/76 [00:20<00:00,  3.63it/s]


Train Loss: 5.0524 | Validation Loss: 4.2771 | Train Accuracy: 6.21% | Validation Accuracy: 25.75%
Model saved!
Epoch 2/10


100%|██████████| 302/302 [09:31<00:00,  1.89s/it]
100%|██████████| 76/76 [00:20<00:00,  3.69it/s]


Train Loss: 3.5944 | Validation Loss: 2.7058 | Train Accuracy: 32.62% | Validation Accuracy: 53.24%
Model saved!
Epoch 3/10


100%|██████████| 302/302 [09:27<00:00,  1.88s/it]
100%|██████████| 76/76 [00:20<00:00,  3.64it/s]


Train Loss: 2.2969 | Validation Loss: 1.8102 | Train Accuracy: 55.79% | Validation Accuracy: 64.70%
Model saved!
Epoch 4/10


100%|██████████| 302/302 [09:24<00:00,  1.87s/it]
100%|██████████| 76/76 [00:21<00:00,  3.47it/s]


Train Loss: 1.4428 | Validation Loss: 1.3504 | Train Accuracy: 72.71% | Validation Accuracy: 71.10%
Model saved!
Epoch 5/10


100%|██████████| 302/302 [09:22<00:00,  1.86s/it]
100%|██████████| 76/76 [00:20<00:00,  3.70it/s]


Train Loss: 0.9192 | Validation Loss: 1.0778 | Train Accuracy: 83.64% | Validation Accuracy: 74.42%
Model saved!
Epoch 6/10


100%|██████████| 302/302 [09:26<00:00,  1.88s/it]
100%|██████████| 76/76 [00:21<00:00,  3.46it/s]


Train Loss: 0.6180 | Validation Loss: 1.0221 | Train Accuracy: 90.83% | Validation Accuracy: 76.16%
Model saved!
Epoch 7/10


100%|██████████| 302/302 [09:27<00:00,  1.88s/it]
100%|██████████| 76/76 [00:22<00:00,  3.40it/s]


Train Loss: 0.5617 | Validation Loss: 1.0029 | Train Accuracy: 92.32% | Validation Accuracy: 76.08%
Epoch 8/10


100%|██████████| 302/302 [09:30<00:00,  1.89s/it]
100%|██████████| 76/76 [00:31<00:00,  2.40it/s]


Train Loss: 0.5149 | Validation Loss: 0.9810 | Train Accuracy: 93.08% | Validation Accuracy: 77.41%
Model saved!
Epoch 9/10


100%|██████████| 302/302 [09:29<00:00,  1.89s/it]
100%|██████████| 76/76 [00:51<00:00,  1.46it/s]


Train Loss: 0.4828 | Validation Loss: 0.9608 | Train Accuracy: 94.08% | Validation Accuracy: 78.07%
Model saved!
Epoch 10/10


100%|██████████| 302/302 [09:33<00:00,  1.90s/it]
100%|██████████| 76/76 [00:27<00:00,  2.81it/s]


Train Loss: 0.4553 | Validation Loss: 0.9441 | Train Accuracy: 94.49% | Validation Accuracy: 78.49%
Model saved!
