In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms.v2 as transforms 
from torch.utils.data import random_split
import random 

from modules import data_loader
from modules.networks import VarResNet
from modules.networks import Net


## Question 14
see data_loader.py for how to load the data as three tensors

In [2]:
root = "../data/mnist-varres/train"
buckets = data_loader.load_sorted_data(root)

In [3]:
#split into training and validation 
train_buckets = []
val_buckets = []
torch.manual_seed(42)

for inputs, labels in buckets:

    n_total = len(inputs)
    n_val = int(n_total * 0.2)
    n_train = n_total - n_val
    

    indices = torch.randperm(n_total)
    # Slice using indices
    train_x = inputs[indices[:n_train]]
    train_y = labels[indices[:n_train]]
    
    val_x = inputs[indices[n_train:]]
    val_y = labels[indices[n_train:]]
    
    train_buckets.append((train_x, train_y))
    val_buckets.append((val_x, val_y))

Training loop with inner loop over the three dimensions

In [None]:
model = VarResNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

epochs = 2
batch_size = 16

# Calculate how many full batches fit in the smallest bucket
min_samples = min(len(x) for x, y in train_buckets)
n_batches = min_samples // batch_size 

print(f"Training on {n_batches} batches per resolution..")

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    
    # Generate shuffled indices for all buckets
    bucket_shuffles = [torch.randperm(len(x)) for x, y in train_buckets]
    
    for i in range(n_batches):
        
        # Calculate slice range for this step
        start = i * batch_size
        end = start + batch_size
        
        # loop over the three resolutions (32->48->64)
        for res_idx, (inputs_full, labels_full) in enumerate(train_buckets):
            
            batch_indices = bucket_shuffles[res_idx][start:end]
            inputs = inputs_full[batch_indices]
            labels = labels_full[batch_indices]
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

    # Average loss: total loss / (batches * 3 resolutions)
    avg_loss = running_loss / (n_batches * 3)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

print("Finished Training")

Training on 995 batches per resolution...
Epoch 1, Loss: 0.2627
Epoch 2, Loss: 0.0825
Finished Training


## Question 15
find the value of N for which both networks have roughly the same number of parameters

In [5]:
def count_params(model):
    total_params = 0

    for param in model.parameters():
        num_elements = param.numel()
        total_params += num_elements
        
    return total_params

# target_params = (slope * n) + intercept
# Rearrange:
# n = (target_params - intercept) / slope

target_params = count_params(Net())
params_at_1 = count_params(VarResNet(n_channels=1))
params_at_2 = count_params(VarResNet(n_channels=2))
slope = params_at_2 - params_at_1
intercept = params_at_1 - slope

n = (target_params - intercept) / slope
print(n)

# So, rounded:
optimal_n=round(n)

# check
print(f"Fixed Model: {count_params(Net())} params")
print(f"VarRes Model: {count_params(VarResNet())} params")


81.123745819398
Fixed Model: 29066 params
VarRes Model: 29029 params


## Question 16
Compare the validation performance of global max pooling to that of global mean pooling. Report your findings, and choose a global pooling variant.

In [None]:
def run_training(pooling_type, epochs=3, lr=0.001):
    print(f"\nTraining with Global {pooling_type.upper()} Pooling")
    
    model = VarResNet(n_channels=81, pooling=pooling_type)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    batch_size = 16
    
    # Calculate batch count based on the smallest resolution bucket
    min_samples = min(len(b[0]) for b in train_buckets)
    n_batches = min_samples // batch_size 

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        
        # Shuffle indices for each bucket at the start of every epoch
        bucket_shuffles = [torch.randperm(len(b[0])) for b in train_buckets]
        
        for i in range(n_batches):
            start = i * batch_size
            end = start + batch_size
            
            for bucket_idx, (inputs_full, labels_full) in enumerate(train_buckets):
                indices = bucket_shuffles[bucket_idx][start:end]
                inputs = inputs_full[indices]
                labels = labels_full[indices]
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
        
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_buckets:
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        # Calculate stats
        val_acc = 100 * correct / total
        avg_loss = running_loss / (n_batches * 3)
        
        print(f"Epoch {epoch+1}: Train Loss {avg_loss:.4f}, Val Acc {val_acc:.2f}%")
        
    return val_acc

In [7]:
acc_max = run_training('max')
acc_mean = run_training('mean')


Training with Global MAX Pooling
Epoch 1: Train Loss 0.2846, Val Acc 97.16%
Epoch 2: Train Loss 0.0810, Val Acc 98.20%
Epoch 3: Train Loss 0.0580, Val Acc 97.67%

Training with Global MEAN Pooling
Epoch 1: Train Loss 0.8135, Val Acc 92.92%
Epoch 2: Train Loss 0.2767, Val Acc 94.65%
Epoch 3: Train Loss 0.1996, Val Acc 94.16%


## Question 17
comparing fixed resolution network with variable resolution

In [12]:
### Fixed model 

root = "../data/mnist-varres"
transform_fixed = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((28,28)),
    transforms.ToTensor()
])

full_train_fixed = torchvision.datasets.ImageFolder(root=root + "/train", transform=transform_fixed)

# Split train/val
train_size = int(0.8 * len(full_train_fixed))
val_size = len(full_train_fixed) - train_size
train_set_fixed, val_set_fixed = random_split(full_train_fixed, [train_size, val_size], generator=torch.Generator().manual_seed(42))

# Loaders
batch_size = 16
train_loader_fixed = torch.utils.data.DataLoader(train_set_fixed, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader_fixed   = torch.utils.data.DataLoader(val_set_fixed,   batch_size=batch_size, shuffle=False, num_workers=2)

# Load test set (for final comparison)
test_set_fixed = torchvision.datasets.ImageFolder(root=root + "/test", transform=transform_fixed)
test_loader_fixed = torch.utils.data.DataLoader(test_set_fixed, batch_size=batch_size, shuffle=False, num_workers=2)

model_fixed = Net() 



In [13]:
## var model

all_buckets = data_loader.load_sorted_data(data_root=root + "/train")

train_buckets = []
val_buckets = []
torch.manual_seed(42)

for inputs, labels in all_buckets:
    n_total = len(inputs)
    n_val = int(n_total * 0.2)
    n_train = n_total - n_val
    
    indices = torch.randperm(n_total)
    
    train_buckets.append((inputs[indices[:n_train]], labels[indices[:n_train]]))
    val_buckets.append((inputs[indices[n_train:]], labels[indices[n_train:]]))

# Load test buckets (for final comparison)
test_buckets_var = data_loader.load_sorted_data(data_root=root + "/test")

model_var = VarResNet(n_channels=81, pooling='max')

In [24]:
def train_fixed_res(model, train_loader, val_loader, epochs=3, lr=0.001):
    print(f"\nTraining FixedResNet")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr) 

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total_epoch_loss = 0.0
        
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            total_epoch_loss += loss.item() 

            # print statistics
            if i % 200 == 199:    
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}')
                running_loss = 0.0
            
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_acc = 100 * correct / total

        avg_loss = total_epoch_loss / len(train_loader)
        print(f"Epoch {epoch + 1}: Train Loss {avg_loss:.4f}, Val Acc {val_acc:.2f}%")

In [None]:
def train_var_res(model, train_data, val_data, epochs=3, lr=0.001, batch_size=16):
    print(f"\nTraining VarResNet")

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Find smallest bucket size to prevent index errors
    min_samples = min(len(b[0]) for b in train_data)
    n_batches = min_samples // batch_size
    num_resolutions = len(train_data) # will be 3 (32x32, 48x48, 64x64)
    
    # for printing
    log_interval = 100

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0      # For printing (resets)
        total_epoch_loss = 0.0  # For final epoch average (never resets)
        
        # Shuffle indices
        bucket_shuffles = [torch.randperm(len(b[0])) for b in train_data]
        
        for i in range(n_batches):
            start = i * batch_size
            end = start + batch_size
            
            # loop through resolutions 
            for bucket_idx, (inputs_full, labels_full) in enumerate(train_data):
                
                indices = bucket_shuffles[bucket_idx][start:end]
                inputs = inputs_full[indices]
                labels = labels_full[indices]
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                total_epoch_loss += loss.item()

            # for print statements
            if (i + 1) % log_interval == 0:
                # Average over: log_interval * number of resolutions
                avg_running = running_loss / (log_interval * num_resolutions)
                print(f'[{epoch + 1}, {i + 1:5d} bucket-sets] loss: {avg_running:.3f}')
                running_loss = 0.0

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_data:
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        
        total_updates = n_batches * num_resolutions
        avg_loss = total_epoch_loss / total_updates
        
        print(f"Epoch {epoch+1}: Train Loss {avg_loss:.4f}, Val Acc {val_acc:.2f}%")

In [28]:
train_fixed_res(model_fixed, train_loader_fixed, val_loader_fixed, epochs=5, lr=0.0003)
train_var_res(model_var, train_buckets, val_buckets, epochs=5, lr=0.0003, batch_size=16)


Training FixedResNet
[1,   200] loss: 0.145
[1,   400] loss: 0.142
[1,   600] loss: 0.136
[1,   800] loss: 0.141
[1,  1000] loss: 0.142
[1,  1200] loss: 0.123
[1,  1400] loss: 0.145
[1,  1600] loss: 0.149
[1,  1800] loss: 0.127
[1,  2000] loss: 0.138
[1,  2200] loss: 0.141
[1,  2400] loss: 0.156
[1,  2600] loss: 0.139
[1,  2800] loss: 0.135
[1,  3000] loss: 0.161
Epoch 1: Train Loss 0.1413, Val Acc 95.28%
[2,   200] loss: 0.112
[2,   400] loss: 0.127
[2,   600] loss: 0.138
[2,   800] loss: 0.119
[2,  1000] loss: 0.116
[2,  1200] loss: 0.130
[2,  1400] loss: 0.142
[2,  1600] loss: 0.109
[2,  1800] loss: 0.114
[2,  2000] loss: 0.150
[2,  2200] loss: 0.135
[2,  2400] loss: 0.121
[2,  2600] loss: 0.121
[2,  2800] loss: 0.145
[2,  3000] loss: 0.111
Epoch 2: Train Loss 0.1261, Val Acc 95.41%
[3,   200] loss: 0.101
[3,   400] loss: 0.112
[3,   600] loss: 0.106
[3,   800] loss: 0.104
[3,  1000] loss: 0.119
[3,  1200] loss: 0.118
[3,  1400] loss: 0.099
[3,  1600] loss: 0.127
[3,  1800] loss: 0

In [29]:
# Test fixed
correct = 0
total = 0
model_fixed.eval()
with torch.no_grad():
    for inputs, labels in test_loader_fixed:
        outputs = model_fixed(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"Fixed Resolution Test Accuracy: {100 * correct / total:.2f}%")

# Test var
correct = 0
total = 0
model_var.eval()
with torch.no_grad():
    for inputs, labels in test_buckets_var:
        outputs = model_var(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"Variable Resolution Test Accuracy: {100 * correct / total:.2f}%")

Fixed Resolution Test Accuracy: 96.03%
Variable Resolution Test Accuracy: 98.53%
