In [1]:
# imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, models, transforms
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import confusion_matrix
import seaborn as sns
from google.colab import drive, files
import matplotlib.pyplot as plt
import numpy as np
import os
import time
import tarfile
import warnings

# ignore warnings
warnings.filterwarnings("ignore")

# saving the trained model
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
NUM_CLASSES = 257  # 256 Categories + 1 Background Class
EPOCHS = 10
LEARNING_RATE = 1e-4
SAVE_PATH = "/content/drive/MyDrive/caltech256_resnet_optimal.pth"
DATA_DIR = "/content/local_caltech256"

print(DEVICE)

Mounted at /content/drive
cuda


In [2]:
# This function is responsible for entire data pipeline eg, downloading, extracting.
def setup_data():

    # if dataset not in local device, download it
    if not os.path.exists(DATA_DIR):
        print("Downloading the dataset(caltech 256)")
        os.system("wget -q https://data.caltech.edu/records/nyy15-4j048/files/256_ObjectCategories.tar -O caltech256.tar")

        print("Extracting, please wait!")
        with tarfile.open('caltech256.tar') as tar_file:
            tar_file.extractall(path='/content')
        if os.path.exists("/content/256_ObjectCategories"):
            os.rename("/content/256_ObjectCategories", DATA_DIR)

        if os.path.exists('caltech256.tar'):
            os.remove('caltech256.tar')

    # transformer defining
    train_transform = transforms.Compose([transforms.Resize((384, 384)),
                                          transforms.TrivialAugmentWide(),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    val_transform = transforms.Compose([transforms.Resize((384, 384)),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # loading the dataset
    full_dataset = datasets.ImageFolder(root=DATA_DIR)

    torch.manual_seed(42) # train and val split same
    train_length = int(0.8 * len(full_dataset))
    val_length = len(full_dataset) - train_length
    train_set, val_set = random_split(full_dataset, [train_length, val_length])

    # assigning specific transforms to the subsets
    train_set.dataset.transform = train_transform
    val_set.dataset.transform = val_transform
    dataloaders_dict = {'train': DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True),
                        'val': DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)}

    return dataloaders_dict, full_dataset.classes

In [3]:
# loading the ResNet50 model pre-trained on ImageNet.
def get_resnet_model():

    # Loading V2 weights
    weights = models.ResNet50_Weights.IMAGENET1K_V2
    model = models.resnet50(weights=weights)

    # Freezing the first few blocks
    for parameter_name, parameter in model.named_parameters():
        if "layer1" in parameter_name or "conv1" in parameter_name or "bn1" in parameter_name:
            parameter.requires_grad = False

    # Replace the final classification layer
    input_features = model.fc.in_features
    model.fc = nn.Linear(input_features, NUM_CLASSES)

    return model.to(DEVICE)

In [4]:
# this is teh main training loop and it will returns the model and the training history for the plotting.
def train_model(model, loaders):

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    scaler = GradScaler()

    best_accuracy = 0.0
    accuracy_history = []
    print(f"Training started for {EPOCHS} epochs.")

    for epoch_index in range(EPOCHS):
        start_time = time.time()
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            running_loss = 0.0
            correct_predictions = 0

            for input_images, labels in loaders[phase]:
                input_images = input_images.to(DEVICE)
                labels = labels.to(DEVICE)
                optimizer.zero_grad()

                # autocast enables mixed precision (FP16)
                with autocast():
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(input_images)
                        loss = criterion(outputs, labels)
                        max_probs, predicted_classes = torch.max(outputs, 1)

                        if phase == 'train':
                            scaler.scale(loss).backward()
                            scaler.step(optimizer)
                            scaler.update()
                running_loss += loss.item() * input_images.size(0)
                correct_predictions += torch.sum(predicted_classes == labels.data)

            if phase == 'train':
                scheduler.step()

            # Logging
            dataset_size = len(loaders[phase].dataset)
            epoch_loss = running_loss / dataset_size
            epoch_accuracy = correct_predictions.double() / dataset_size

            if phase == 'val':
                elapsed_time = (time.time() - start_time) / 60
                print(f"Epoch {epoch_index+1}/{EPOCHS} | Loss: {epoch_loss:.4f} | Acc: {epoch_accuracy*100:.2f}% | Time: {elapsed_time:.1f} min")

                # Store accuracy for plot
                accuracy_history.append(epoch_accuracy.item()*100)

                # Save the best model found so far
                if epoch_accuracy > best_accuracy:
                    best_accuracy = epoch_accuracy
                    torch.save(model.state_dict(), 'temp_best.pth')

    print(f"Training finished. Best Validation Accuracy: {best_accuracy*100:.2f}%")

    # Reload the absolute best weights before returning
    model.load_state_dict(torch.load('temp_best.pth'))
    return model, accuracy_history

In [5]:
# Calculates accuracy on the dataset without training.
def evaluate_accuracy(model, loader):

    model.eval()
    correct_count = 0
    total_count = 0

    print("Calculating current accuracy...")
    with torch.no_grad():
        for input_images, labels in loader:
            input_images = input_images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = model(input_images)
            max_scores, predicted_classes = torch.max(outputs.data, 1)
            total_count += labels.size(0)
            correct_count += (predicted_classes == labels).sum().item()

    accuracy = 100 * correct_count / total_count
    print(f"Current Model Accuracy: {accuracy:.2f}%")

In [6]:
# finds the top 5 most confused pairs of classes.
def show_confusion_errors(model, loader):

    model.eval()
    all_predictions = []
    all_true_labels = []

    print("\nAnalyzing confusion matrix...")
    with torch.no_grad():
        for input_images, labels in loader:
            input_images = input_images.to(DEVICE)
            outputs = model(input_images)
            max_scores, predicted_classes = torch.max(outputs, 1)

            all_predictions.extend(predicted_classes.cpu().numpy())
            all_true_labels.extend(labels.numpy())

    # confusion matrix
    confusion_mat = confusion_matrix(all_true_labels, all_predictions)
    np.fill_diagonal(confusion_mat, 0) # Ignore correct predictions

    # Sort by error count
    top_error_indices = np.argsort(confusion_mat.flatten())[-5:]
    class_names = loader.dataset.dataset.classes

    print("Top 5 Confused Pairs:")
    for flat_index in top_error_indices:
        true_class_index, predicted_class_index = divmod(flat_index, NUM_CLASSES)
        error_count = confusion_mat[true_class_index, predicted_class_index]
        print(f"True: {class_names[true_class_index]} | Pred: {class_names[predicted_class_index]} | Count: {error_count}")

In [7]:
if __name__ == "__main__":
    # preparing the data
    dataloaders_dict, all_class_names = setup_data()

    # initialize the model
    resnet_model = get_resnet_model()

    training_history = []

    # check if we already have a saved model
    if os.path.exists(SAVE_PATH):
        print(f"Found saved model at {SAVE_PATH}")
        print("Loading weights!!!")
        resnet_model.load_state_dict(torch.load(SAVE_PATH, map_location=DEVICE))

        # if model is loaded, just check accuracy
        evaluate_accuracy(resnet_model, dataloaders_dict['val'])
        trained_model = resnet_model

    else:
        print("No saved model found. Training from scratch...")
        trained_model, training_history = train_model(resnet_model, dataloaders_dict)

        # save to Google Drive
        torch.save(trained_model.state_dict(), SAVE_PATH)
        print(f"Model saved to {SAVE_PATH}")

    # run the analysis
    show_confusion_errors(trained_model, dataloaders_dict['val'])



Downloading the dataset(caltech 256)
Extracting, please wait!
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 188MB/s]


Found saved model at /content/drive/MyDrive/caltech256_resnet_optimal.pth
Loading weights!!!
Calculating current accuracy...
Current Model Accuracy: 90.33%

Analyzing confusion matrix...
Top 5 Confused Pairs:
True: 232.t-shirt | Pred: 159.people | Count: 4
True: 026.cake | Pred: 159.people | Count: 4
True: 191.sneaker | Pred: 255.tennis-shoes | Count: 4
True: 069.fighter-jet | Pred: 251.airplanes-101 | Count: 6
True: 255.tennis-shoes | Pred: 191.sneaker | Count: 8
