In [None]:
## Config

import os
import random

train_root = "CUB_formatted/train"
val_root = "CUB_formatted/val"
roots = [train_root, val_root]
num_classes = 200

if num_classes != 200:
    all_classes = sorted(os.listdir(train_root))
    selected_classes = all_classes[:num_classes]
    #selected_classes = random.sample(all_classes, num_classes)
    print(selected_classes)

In [None]:
## Torchvision data formatting -- Thins training and validation datasets down to the selected classes and reindexes them

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch

# ImageNet transforms--expected by ResNet18
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

filtered_datasets = []

for root in roots:
    full_dataset = datasets.ImageFolder(root, transform=transform)

    # Skip filtering if using whole dataset
    if num_classes == 200:
        filtered_datasets.append(full_dataset)
    else:
        # Get indices of selected classes
        selected_classes_idx = []
        for c in selected_classes:
            class_idx = full_dataset.class_to_idx[c]
            selected_classes_idx.append(class_idx)

        # Create a map of new labels from 0 to num_classes-1
        label_remap = {}
        for new, orig in enumerate(sorted(selected_classes_idx)):
            label_remap[orig] = new

        filtered_data = []

        # Filter out images from the selected classes and relabel them
        for img, label in full_dataset:
            if label in selected_classes_idx:
                new_label = label_remap[label]
                filtered_data.append((img, new_label))

        filtered_datasets.append(filtered_data) 

train_data, val_data = filtered_datasets
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=True)

In [None]:
## Torchvision model construction -- WARNING: WILL RESET CURRENT MODEL

import torchvision.models as models
from torchvision.models import ResNet18_Weights
import torch.nn as nn
import torch.optim as optim

model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
total_epochs = 0

model.fc = nn.Linear(model.fc.in_features, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\johnw/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100.0%


In [14]:
## Torchvision training -- Can be run multiple times on the same model
import time

start = time.time()
num_epochs = 1

for epoch in range(num_epochs):
    model.train()
    total_epochs += 1
    
    total_loss = 0
    
    # Train model and calculate loss
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    ## Get accuracy of epoch
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    epoch_acc = 100 * correct / total

    # Print epoch metrics
    elapsed = time.time() - start
    mins, secs = divmod(int(elapsed), 60)
    print(f"Epoch {total_epochs}, Loss: {total_loss:.4f}, Accuracy: {epoch_acc:.2f}, time: {mins}m {secs:.2f}s")

Epoch 7, Loss: 62.0881, Accuracy: 51.56, time: 9m 58.00s


In [16]:
## Validation -- Just evaluates the current state of the model

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
epoch_acc = 100 * correct / total

print(f"Epoch {total_epochs}, Loss: {total_loss:.4f}, Accuracy: {epoch_acc:.2f}")

Epoch 7, Loss: 62.0881, Accuracy: 51.56


In [None]:
## TODO: Compare with Tensorflow implementation