In [22]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import torch
import torch.nn as nn
import torch.optim as optim

from modules import data_loader
from modules.networks import VarResNet
from modules.networks import Net

## Question 14
see data_loader.py for how to load the data as three tensors

In [19]:
root = "../data/mnist-varres/train"
buckets = data_loader.load_sorted_data(root)

Training loop with inner loop over the three dimensions

In [20]:
model = VarResNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

batch_size = 16
epochs = 2 

# To not slice past the end of the data
min_samples = min([t[0].size(0) for t in buckets])

print("Starting training..")

for epoch in range(epochs):
    running_loss = 0.0
    
    for i in range(0, min_samples, batch_size):
        
        # loop over the three resolutions 
        for inputs_full, labels_full in buckets:
            
            inputs = inputs_full[i : i + batch_size]
            labels = labels_full[i : i + batch_size]
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
    
    print(f"Epoch {epoch + 1}, Loss: {running_loss / (min_samples // batch_size * 3):.4f}")

print("Finished Training")

Starting training..
Epoch 1, Loss: 0.3631
Epoch 2, Loss: 0.4836
Finished Training


## Question 15
find the value of N for which both networks have roughly the same number of parameters

In [25]:
def count_params(model):
    total_params = 0

    for param in model.parameters():
        num_elements = param.numel()
        total_params += num_elements
        
    return total_params

# target_params = (slope * n) + intercept
# Rearrange:
# n = (target_params - intercept) / slope

target_params = count_params(Net())
params_at_1 = count_params(VarResNet(n_channels=1))
params_at_2 = count_params(VarResNet(n_channels=2))
slope = params_at_2 - params_at_1
intercept = params_at_1 - slope

n = (target_params - intercept) / slope
print(n)


81.123745819398
