# Load CalTech101 dataset

In [1]:
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import wandb
import datetime
from tqdm.auto import tqdm

In [2]:
from typing import Any

_batch_sizes = {
    "resnet18": 64,
    "resnet34": 64,
    "resnet50": 32,
    "resnet101": 16,
}

class Config:
    seed = 42
    lr = 0.001
    epochs = 10
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = "resnet101" # resnet18, resnet34, resnet50, resnet101, resnet152
    batch_size = _batch_sizes[model]
    suffix = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    save_dir = f"checkpoints/{model}-{suffix}"


In [3]:
# Grayscale to RGB transform
class GrayscaleToRGB(object):
    """From https://www.kaggle.com/code/cafalena/caltech101-pytorch-deep-learning"""
    def __call__(self, img):
        if img.mode == 'L':
            img = img.convert("RGB")
        return img

transform = transforms.Compose(
    [transforms.Resize((224, 224)),
    GrayscaleToRGB(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    )


# Load Caltech101 dataset
def load_data():
    dataset = datasets.Caltech101(root='caltech_data', download=False, transform=transform)

    return dataset

# Split dataset into training and testing
def split_data(dataset):
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(Config.seed))
    return train_dataset, test_dataset


# Train with early stopping and adamw optimizer with learning rate scheduler and log to wandb
def train(model: Any, train_loader: Any, test_loader: Any, config: Config, logging: bool = True):
    model.to(config.device)
    optimizer = optim.AdamW(model.parameters(), lr=config.lr)
    if config.model != "resnet101":
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    else:
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    criterion = nn.CrossEntropyLoss()

    # Main training loop
    best_acc = 0
    pbar = tqdm(total=config.epochs, desc="Epochs")
    for epoch in range(config.epochs):
        model.train()
        train_correct = 0
        for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader), desc="Training", leave=False):
            data, target = data.to(config.device), target.to(config.device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            # print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
            #         f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

            # Log loss
            if logging:
                wandb.log({"train_loss": loss.item(), "epoch": epoch})
            with torch.no_grad():
                pred = output.argmax(dim=1, keepdim=True)
                train_correct += pred.eq(target.view_as(pred)).sum().item()
            
            # Update progress bar
            pbar.set_postfix({"train_loss": loss.item()})
        
        # Update progress bar
        pbar.update(1)
        
        # Log training accuracy and test accuracy
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in tqdm(test_loader, desc="Testing", leave=False):
                data, target = data.to(config.device), target.to(config.device)
                output = model(data)
                test_loss += criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        test_acc = 100. * correct / len(test_loader.dataset)
        train_acc = 100. * train_correct / len(train_loader.dataset)
        print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
                f"({test_acc:.0f}%)\n")
        print(f"Train set: Accuracy: {train_correct}/{len(train_loader.dataset)} ({train_acc:.0f}%)\n")
        if logging:
            wandb.log({"test_loss": test_loss, "test_acc": test_acc, "train_acc": train_acc, "epoch": epoch, "lr": scheduler.get_last_lr()[0]})

        # Update progress bar
        pbar.set_postfix({"test_loss": test_loss, "test_acc": test_acc, "train_acc": train_acc})
        
        # Save model if test accuracy is better
        if test_acc > best_acc and logging:
            best_acc = test_acc
            torch.save(model.state_dict(), os.path.join(config.save_dir, "model.pth"))
            print(f"Model saved at {config.save_dir}")

        scheduler.step()

# Load pretrained ResNet model
def load_model(model_name: str):
    model = getattr(models, model_name)(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 101)
    return model

In [4]:
dataset = load_data()
train_dataset, test_dataset = split_data(dataset)
train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False)

In [5]:
len(test_dataset), len(train_dataset)

(1736, 6941)

In [5]:
model = load_model(Config.model)

# Warmup just the last layer
for param in model.parameters():
    param.requires_grad = False
for param in model.fc.parameters():
    param.requires_grad = True

num_epochs = Config.epochs
Config.epochs = 2 if Config.model == "resnet101" else 1
print("Warming up the last layer")
train(model, train_loader, test_loader, Config, logging=False)
print("Warmup done")

# Unfreeze all layers
for param in model.parameters():
    param.requires_grad = True

Config.epochs = num_epochs
print("Training the whole model")
save_dir = f"checkpoints/{Config.model}-{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}"
os.makedirs(save_dir, exist_ok=True)
Config.save_dir = save_dir

# Wandb logging
os.environ['WANDB_NOTEBOOK_NAME'] = 'interpret-resnet.ipynb'
wandb.init(
    project="interpret-mito",
    config={k: v for k, v in vars(Config).items() if not k.startswith("__")},
    tags=[Config.model, "caltech101"],
    resume=False
)
wandb.watch(model)
train(model, train_loader, test_loader, Config, logging=True)




Warming up the last layer


Epochs:   0%|          | 0/2 [00:00<?, ?it/s]

Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0245, Accuracy: 1555/1736 (90%)

Train set: Accuracy: 5270/6941 (76%)



Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0203, Accuracy: 1596/1736 (92%)

Train set: Accuracy: 6423/6941 (93%)

Warmup done
Training the whole model


[34m[1mwandb[0m: Currently logged in as: [33miamsuyogjadhav[0m ([33mpersonal-suyog[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.1388, Accuracy: 845/1736 (49%)

Train set: Accuracy: 2782/6941 (40%)

Model saved at checkpoints/resnet101-20240917-182215


Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0856, Accuracy: 1153/1736 (66%)

Train set: Accuracy: 4298/6941 (62%)

Model saved at checkpoints/resnet101-20240917-182215


Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0677, Accuracy: 1247/1736 (72%)

Train set: Accuracy: 5172/6941 (75%)

Model saved at checkpoints/resnet101-20240917-182215


Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0674, Accuracy: 1279/1736 (74%)

Train set: Accuracy: 5706/6941 (82%)

Model saved at checkpoints/resnet101-20240917-182215


Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0509, Accuracy: 1384/1736 (80%)

Train set: Accuracy: 6038/6941 (87%)

Model saved at checkpoints/resnet101-20240917-182215


Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0265, Accuracy: 1542/1736 (89%)

Train set: Accuracy: 6698/6941 (96%)

Model saved at checkpoints/resnet101-20240917-182215


Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0256, Accuracy: 1545/1736 (89%)

Train set: Accuracy: 6853/6941 (99%)

Model saved at checkpoints/resnet101-20240917-182215


Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0258, Accuracy: 1545/1736 (89%)

Train set: Accuracy: 6884/6941 (99%)



Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0262, Accuracy: 1555/1736 (90%)

Train set: Accuracy: 6919/6941 (100%)

Model saved at checkpoints/resnet101-20240917-182215


Training:   0%|          | 0/434 [00:00<?, ?it/s]

Testing:   0%|          | 0/109 [00:00<?, ?it/s]


Test set: Average loss: 0.0263, Accuracy: 1558/1736 (90%)

Train set: Accuracy: 6919/6941 (100%)

Model saved at checkpoints/resnet101-20240917-182215
