In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from fl_g13.config import RAW_DATA_DIR
from torchvision import datasets, transforms

from fl_g13.modeling import train, train_one_epoch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

[32m2025-04-16 15:34:19.456[0m | [1mINFO    [0m | [36mfl_g13.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /home/massimiliano/Projects/fl-g13[0m


### Load data

In [3]:
transform = transforms.Compose([
    transforms.ToTensor()
])
cifar100_train = datasets.CIFAR100(root=RAW_DATA_DIR, train=True, download=True, transform=transform)
cifar100_test = datasets.CIFAR100(root=RAW_DATA_DIR, train=False, download=True, transform=transform)

### Train and save model

In [4]:
class TinyCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(TinyCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))     # -> [B, 16, 32, 32]
        x = F.max_pool2d(x, 2)        # -> [B, 16, 16, 16]
        x = F.relu(self.conv2(x))     # -> [B, 32, 16, 16]
        x = F.max_pool2d(x, 2)        # -> [B, 32, 8, 8]
        x = x.view(x.size(0), -1)     # -> [B, 32*8*8]
        x = self.fc1(x)               # -> [B, 100]
        return x

In [5]:
checkpoint_dir = "/home/massimiliano/Projects/fl-g13/checkpoints"

# Parameters
batch_size  = 32
start_epoch = 1
num_epochs  = 2
save_every  = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataloader = torch.utils.data.DataLoader(cifar100_train, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(cifar100_test, batch_size=batch_size, shuffle=True)

model = TinyCNN(100)
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.04)

loss_fn = torch.nn.CrossEntropyLoss()

In [6]:
train(checkpoint_dir, train_dataloader, loss_fn, start_epoch, num_epochs, save_every, model, optimizer, scheduler=None, prefix=None, verbose=False)

Training Loss: 4.0770, Training Accuracy: 9.37%
📘 Epoch [1/2] - Avg Loss: 4.0770, Accuracy: 9.37%
💾 Saved checkpoint at: /home/massimiliano/Projects/fl-g13/checkpoints/zesty_koala_60_epoch_1.pth
Training Loss: 3.6234, Training Accuracy: 17.13%
📘 Epoch [2/2] - Avg Loss: 3.6234, Accuracy: 17.13%
💾 Saved checkpoint at: /home/massimiliano/Projects/fl-g13/checkpoints/zesty_koala_60_epoch_2.pth


In [7]:
train(checkpoint_dir, train_dataloader, loss_fn, start_epoch, num_epochs, save_every, model, optimizer, scheduler=None, prefix="TinyCNN", verbose=False)

Training Loss: 3.4353, Training Accuracy: 20.79%
📘 Epoch [1/2] - Avg Loss: 3.4353, Accuracy: 20.79%
💾 Saved checkpoint at: /home/massimiliano/Projects/fl-g13/checkpoints/TinyCNN_epoch_1.pth
Training Loss: 3.2979, Training Accuracy: 23.13%
📘 Epoch [2/2] - Avg Loss: 3.2979, Accuracy: 23.13%
💾 Saved checkpoint at: /home/massimiliano/Projects/fl-g13/checkpoints/TinyCNN_epoch_2.pth


**Resume training**

In [9]:
from fl_g13.modeling import load

# Load the model from the latest checkpoint
model2 = TinyCNN(num_classes=100)
optimizer2 = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.04)
loss_fn2 = torch.nn.CrossEntropyLoss()

start_epoch = load(checkpoint_dir, model=model2, optimizer=optimizer2, filename="TinyCNN_epoch_2.pth")

✅ Loaded checkpoint from /home/massimiliano/Projects/fl-g13/checkpoints/TinyCNN_epoch_2.pth, resuming at epoch 3


In [10]:
num_epochs = 4
save_every = 2

train(checkpoint_dir, train_dataloader, loss_fn2, start_epoch, num_epochs, save_every, model2, optimizer2, scheduler=None, prefix="TinyCNN", verbose=False)

Training Loss: 3.2033, Training Accuracy: 25.30%
📘 Epoch [1/4] - Avg Loss: 3.2033, Accuracy: 25.30%
Training Loss: 3.2034, Training Accuracy: 25.30%
📘 Epoch [2/4] - Avg Loss: 3.2034, Accuracy: 25.30%
💾 Saved checkpoint at: /home/massimiliano/Projects/fl-g13/checkpoints/TinyCNN_epoch_4.pth
Training Loss: 3.2034, Training Accuracy: 25.30%
📘 Epoch [3/4] - Avg Loss: 3.2034, Accuracy: 25.30%
Training Loss: 3.2033, Training Accuracy: 25.30%
📘 Epoch [4/4] - Avg Loss: 3.2033, Accuracy: 25.30%
💾 Saved checkpoint at: /home/massimiliano/Projects/fl-g13/checkpoints/TinyCNN_epoch_6.pth
