In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, transform=transform, download=True
)
valid_dataset = torchvision.datasets.MNIST(
    root="./data", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=256, shuffle=False)

loaders = {}
loaders['train'] = train_loader
loaders['valid'] = valid_loader

In [None]:
# After https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
# and https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#finetuning-the-convnet

from tqdm import tqdm

from src.collectors import DataCollector

def train(model, loaders, criterion, optimizer, scheduler, num_epochs = 5, device = torch.device('cpu')):
    model.to(device)
    
    dc = DataCollector()
    
    for epoch in range(num_epochs):
        dc.current_epoch = epoch
        
        train_loss, train_acc = one_epoch('train', model, loaders['train'], criterion, scheduler, optimizer, epoch, device)
        valid_loss, valid_acc = one_epoch('valid', model, loaders['valid'], criterion, scheduler, optimizer, epoch, device)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_acc}")
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {valid_loss:.4f}, Accuracy: {valid_acc}")
        
def one_epoch(phase, model, loader, criterion, scheduler, optimizer, epoch_idx, device):
    if phase == 'train':
        model.train()
    elif phase == 'valid':
        model.eval()
    else:
        raise ValueError(f"Phase {phase} is not a proper learning phase (use 'train' or 'valid')!")
        
    running_loss = 0.0
    running_corrects = 0
    dataset_size = 0
    
    dc = DataCollector()
    
    for batch, (inputs, labels) in enumerate(tqdm(loader)):
        dc.current_batch = batch
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(phase == 'train'):
            outputs = model(inputs)
            _, predictions = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            
            if phase == 'train':
                loss.backward()
                optimizer.step()
                
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(predictions == labels.data)
        dataset_size += inputs.size(0)
        
        if phase == 'train' and scheduler is not None:
            scheduler.step()
            
    epoch_loss = running_loss / dataset_size
    epoch_acc = running_corrects / dataset_size
    
    return epoch_loss, epoch_acc

In [None]:
from torchinfo import summary

from src.modules.blocks import BasicBlock
from src.modules.models import ResNet

model = ResNet(block=BasicBlock, in_channels=1, layers=[3, 4, 6, 3], num_classes=10)
    
summary(model)

In [None]:
import torch.nn as nn
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
scheduler = None
device = torch.device('mps')

train(model, loaders, criterion, optimizer, scheduler, 1, device)

In [None]:
dc.data[0, 0, 'BasicBlock', 1]

In [None]:
dc = DataCollector()
dc.data[0, 0, 'BasicBlock', 1]['init']['conv1.weight']

In [None]:
def parameters_distance(pre: dict, post: dict):
    # for name, _ in pre:
    #     print('pre', name)
    # for name, _ in post:
    #     print('post', name)
    distance = 0.0
    for name, pre_params in pre.items():
        if name not in post:
            raise ValueError(f"Parameter '{name}' is missing in the 'post' parameters.")
        post_params = post[name]
        if pre_params.shape != post_params.shape:
            raise ValueError(f"Parameter '{name}' has different shapes in 'pre' and 'post'.")
        # Calculate the squared Euclidean distance between parameters
        distance += torch.sum((pre_params - post_params) ** 2).item()
    return distance ** 0.5  # Return the square root of the sum of squared distances

dc = DataCollector()

for (epoch, batch, block, depth) in dc.data.keys():
    pre = dc.data[epoch, batch, block, depth]['pre']
    post = dc.data[epoch, batch, block, depth]['post']
    
    print(pre['conv1.weight'])
    
    # print(epoch, batch, block, depth, parameters_distance(pre, post))