In [1]:
from src.collectors import DataCollector

COLLECTOR = DataCollector(metric_kwargs={'metric': 'minkowski', 'p': 1})
NUM_EPOCHS = 5

In [2]:
# After https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
# and https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#finetuning-the-convnet

from tqdm import tqdm

import torch


def train(model, loaders, criterion, optimizer, scheduler, num_epochs=5, device=torch.device('cpu')):
    model.to(device)
    
    losses = {}
    losses['train'], losses['valid'] = [], []
    accs = {}
    accs['train'], accs['valid'] = [], []

    for epoch in range(num_epochs):
        COLLECTOR.current_epoch = epoch

        train_loss, train_acc = one_epoch('train', model, loaders['train'], criterion, scheduler, optimizer, epoch, device)
        valid_loss, valid_acc = one_epoch('valid', model, loaders['valid'], criterion, scheduler, optimizer, epoch, device)

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_acc}")
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {valid_loss:.4f}, Accuracy: {valid_acc}")
        
        losses['train'].append(train_loss)
        losses['valid'].append(valid_loss)
        accs['train'].append(train_acc)
        accs['valid'].append(valid_acc)
        
    return losses, accs

def one_epoch(phase, model, loader, criterion, scheduler, optimizer, epoch_idx, device):
    if phase == 'train':
        model.train()
    elif phase == 'valid':
        model.eval()
    else:
        raise ValueError(
            f"Phase {phase} is not a proper learning phase (use 'train' or 'valid')!")

    running_loss = 0.0
    running_corrects = 0
    dataset_size = 0

    for batch, (inputs, labels) in enumerate(tqdm(loader)):
        COLLECTOR.current_batch = batch
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        with torch.set_grad_enabled(phase == 'train'):
            outputs = model(inputs)
            _, predictions = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            if phase == 'train':
                loss.backward()
                optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(predictions == labels.data).cpu().numpy()
        dataset_size += inputs.size(0)

        if phase == 'train' and scheduler is not None:
            scheduler.step()

    epoch_loss = running_loss / dataset_size
    epoch_acc = running_corrects / dataset_size

    return epoch_loss, epoch_acc

In [3]:
from src.modules.blocks import BasicBlock, Bottleneck, PlainBasicBlock, PlainBottleneck
from src.modules.models import ResNet

def resnet(version, in_channels, num_classes, plain=False):
    match version:
        case 18:
            layers = [2, 2, 2, 2]
            block = BasicBlock if not plain else PlainBasicBlock
        case 26:
            layers = [2, 2, 2, 2]
            block = Bottleneck if not plain else PlainBottleneck
        case 34:
            layers = [3, 4, 6, 3]
            block = BasicBlock if not plain else PlainBasicBlock
        case 50:
            layers = [3, 4, 6, 3]
            block = Bottleneck if not plain else PlainBottleneck
        case 68:
            layers = [3, 4, 23, 3]
            block = BasicBlock if not plain else PlainBasicBlock
        case 101:
            layers = [3, 4, 23, 3]
            block = Bottleneck if not plain else PlainBottleneck        
        case _:
            raise ValueError(f'Version {version} is not a valid ResNet size.')
        
    if not plain:
        name = f"ResNet-{version}"
    else:
        name = f"ResNet-{version} [Plain]"
        
    model = ResNet(block=block, in_channels=in_channels, layers=layers, num_classes=num_classes)
    
    return name, model

In [4]:
import os
import matplotlib.pyplot as plt
import matplotlib.colorbar as colorbar
import numpy as np

import torch.nn as nn
import torch.optim as optim

from src.datasets import cifar100, cifar10, fashion_mnist, mnist

plt.style.use('ggplot')

loaders = {}
    
for dataset in [cifar10, cifar100, fashion_mnist, mnist]:
    train_loader, valid_loader = dataset()
    loaders["train"] = train_loader
    loaders["valid"] = valid_loader
    
    if dataset.__name__ in ['cifar100']:
        num_classes = 100
    else:
        num_classes = 10
        
    if dataset.__name__ in ['cifar10', 'cifar100']:
        in_channels = 3
    else:
        in_channels = 1
        
    for version in [18, 26, 34, 50, 68, 101]:
        for plain in [True, False]:
            
            model_name, model = resnet(version, in_channels, num_classes, plain)
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(model.parameters())
            scheduler = None # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
            device = torch.device("mps")
            
            DataCollector.reset_instances()
            COLLECTOR = DataCollector(metric_kwargs={'metric': 'minkowski', 'p': 1})
            
            losses, accs = train(model, loaders, criterion, optimizer, scheduler, NUM_EPOCHS, device)
            
            experiment_path = os.path.join('experiments', dataset.__name__, model_name)
            COLLECTOR.save_state(experiment_path)
            
            fig, axes = plt.subplots(2, 2, figsize=(16, 16))
            fig.suptitle(f"Dataset: {dataset.__name__}, Model: {model_name}")

            axes[0][0].set_title("Init Differences")
            for depth, differences in COLLECTOR.init_differences.items():
                axes[0][0].plot(differences, label=f"depth = {depth}", color=plt.cm.coolwarm((depth-1) / model.max_depth))
            axes[0][0].set_xlabel('Batch number')
            axes[0][0].set_ylabel('Distance to initial state')
            

            axes[0][1].set_title("Next Differences")
            for depth, differences in COLLECTOR.next_differences.items():
                axes[0][1].plot(differences, label=f"depth = {depth}", color=plt.cm.coolwarm((depth-1) / model.max_depth))
            axes[0][1].set_yscale('log')
            axes[0][1].set_xlabel('Batch number')
            axes[0][1].set_ylabel('Distance to previous state')
            
            axes[1][0].set_title("Losses")
            axes[1][0].plot(range(1, len(losses['train']) + 1), losses['train'], label='train')
            axes[1][0].plot(range(1, len(losses['valid']) + 1), losses['valid'], label='valid')
            axes[1][0].set_xlabel('Epoch number')
            axes[1][0].set_ylabel('Error')
            axes[1][0].legend()
            
            axes[1][1].set_title("Accuracies")
            axes[1][1].plot(range(1, len(accs['train']) + 1), accs['train'], label='train')
            axes[1][1].plot(range(1, len(accs['valid']) + 1), accs['valid'], label='valid')
            axes[1][0].set_xlabel('Epoch number')
            axes[1][0].set_ylabel('Accuracy')
            axes[1][1].legend()
            
            cbar_ax = fig.add_axes([0.95, 0.53, 0.02, 0.35])  # [left, bottom, width, height]
            cbar = colorbar.ColorbarBase(cbar_ax, cmap=plt.cm.coolwarm, orientation='vertical')
            cbar.set_ticks((np.linspace(1, model.max_depth, model.max_depth) - 1) / (model.max_depth - 1))
            cbar.set_ticklabels(np.linspace(1, model.max_depth, model.max_depth))
            cbar.set_label('Depth')
            
            fig.savefig(os.path.join(experiment_path, 'charts.png'))
            plt.show()

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 49/49 [00:21<00:00,  2.31it/s]
100%|██████████| 10/10 [00:01<00:00,  7.90it/s]


Epoch [1/5], Loss: 1.6368, Accuracy: 0.4045
Epoch [1/5], Loss: 1.6503, Accuracy: 0.4399


 33%|███▎      | 16/49 [00:06<00:14,  2.34it/s]