In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from fl_g13.config import RAW_DATA_DIR
from torchvision import datasets, transforms

from fl_g13.modeling import train
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

[32m2025-04-17 11:11:34.164[0m | [1mINFO    [0m | [36mfl_g13.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /home/massimiliano/Projects/fl-g13[0m


### Load data

In [3]:
transform = transforms.Compose([
    transforms.ToTensor()
])
cifar100_train = datasets.CIFAR100(root=RAW_DATA_DIR, train=True, download=True, transform=transform)
cifar100_test = datasets.CIFAR100(root=RAW_DATA_DIR, train=False, download=True, transform=transform)

### Train and save model

In [4]:
class TinyCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(TinyCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))     # -> [B, 16, 32, 32]
        x = F.max_pool2d(x, 2)        # -> [B, 16, 16, 16]
        x = F.relu(self.conv2(x))     # -> [B, 32, 16, 16]
        x = F.max_pool2d(x, 2)        # -> [B, 32, 8, 8]
        x = x.view(x.size(0), -1)     # -> [B, 32*8*8]
        x = self.fc1(x)               # -> [B, 100]
        return x

In [5]:
checkpoint_dir = "/home/massimiliano/Projects/fl-g13/checkpoints"

# Parameters
batch_size  = 32
start_epoch = 1
num_epochs  = 2
save_every  = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataloader = torch.utils.data.DataLoader(cifar100_train, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(cifar100_test, batch_size=batch_size, shuffle=True)

model = TinyCNN(100)
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.04)
criterion = torch.nn.CrossEntropyLoss()

In [6]:
train(
    checkpoint_dir=checkpoint_dir,
    prefix="", # Will automatically generate a name for the model
    train_dataloader=train_dataloader,
    val_dataloader=test_dataloader,
    criterion=criterion,
    start_epoch=start_epoch,
    num_epochs=num_epochs,
    save_every=save_every,
    model=model,
    optimizer=optimizer,
    scheduler=None,
    verbose=True,
)

No prefix/name for the model was provided, choosen prefix/name: frosty_metapod_93
  ↳ Batch 1/1563 | Loss: 4.5942 | Batch Acc: 0.00%
  ↳ Batch 2/1563 | Loss: 4.6098 | Batch Acc: 3.12%
  ↳ Batch 3/1563 | Loss: 4.5891 | Batch Acc: 0.00%
  ↳ Batch 4/1563 | Loss: 4.6041 | Batch Acc: 0.00%
  ↳ Batch 5/1563 | Loss: 4.6187 | Batch Acc: 0.00%
  ↳ Batch 6/1563 | Loss: 4.5993 | Batch Acc: 3.12%
  ↳ Batch 7/1563 | Loss: 4.6143 | Batch Acc: 0.00%
  ↳ Batch 8/1563 | Loss: 4.5925 | Batch Acc: 0.00%
  ↳ Batch 9/1563 | Loss: 4.5947 | Batch Acc: 3.12%
  ↳ Batch 10/1563 | Loss: 4.6255 | Batch Acc: 0.00%
  ↳ Batch 11/1563 | Loss: 4.6029 | Batch Acc: 3.12%
  ↳ Batch 12/1563 | Loss: 4.6241 | Batch Acc: 0.00%
  ↳ Batch 13/1563 | Loss: 4.6294 | Batch Acc: 0.00%
  ↳ Batch 14/1563 | Loss: 4.5799 | Batch Acc: 3.12%
  ↳ Batch 15/1563 | Loss: 4.5851 | Batch Acc: 0.00%
  ↳ Batch 16/1563 | Loss: 4.6413 | Batch Acc: 3.12%
  ↳ Batch 17/1563 | Loss: 4.5890 | Batch Acc: 0.00%
  ↳ Batch 18/1563 | Loss: 4.6231 | Batch Ac

In [7]:
train(
    checkpoint_dir=checkpoint_dir,
    prefix="TinyCNN", # Setting a name for the model
    start_epoch=start_epoch,
    num_epochs=num_epochs,
    save_every=save_every,
    train_dataloader=train_dataloader,
    val_dataloader=test_dataloader,
    model=model, # Use the same model as before (partially pre-trained)
    criterion=criterion,
    optimizer=optimizer,
    scheduler=None,
    verbose=False,
)

🚀 Epoch [1/2] Completed (50.00)
	📊 Training Loss: 3.4657
	✅ Training Accuracy: 20.12%
🔍 Validation Results:
	📉 Validation Loss: 3.4276
	🎯 Validation Accuracy: 20.64%
💾 Saved checkpoint at: /home/massimiliano/Projects/fl-g13/checkpoints/TinyCNN_epoch_1.pth
🚀 Epoch [2/2] Completed (100.00)
	📊 Training Loss: 3.3195
	✅ Training Accuracy: 22.76%
🔍 Validation Results:
	📉 Validation Loss: 3.3286
	🎯 Validation Accuracy: 22.59%
💾 Saved checkpoint at: /home/massimiliano/Projects/fl-g13/checkpoints/TinyCNN_epoch_2.pth


**Resume training**

In [8]:
from fl_g13.modeling import load

# Generate untrained objects
model2 = TinyCNN(num_classes=100)
optimizer2 = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.04)
criterion2 = torch.nn.CrossEntropyLoss()

# Load the model from the latest checkpoint
path = checkpoint_dir + "/TinyCNN_epoch_2.pth"
start_epoch = load(path=path, model=model2, optimizer=optimizer2, scheduler=None)

✅ Loaded checkpoint from /home/massimiliano/Projects/fl-g13/checkpoints/TinyCNN_epoch_2.pth, resuming at epoch 3


In [9]:
num_epochs = 4
save_every = 2

train(
    checkpoint_dir=checkpoint_dir,
    prefix="TinyCNN", # Use the same name as before to continue training!
    start_epoch=start_epoch, # Now start epoch is not 1 (will resume from where it was left)
    num_epochs=num_epochs, # This is not the number of epochs to reach, but how many to do starting from now!
    save_every=save_every,
    train_dataloader=train_dataloader,
    val_dataloader=test_dataloader,
    model=model2,
    criterion=criterion2,
    optimizer=optimizer2,
    scheduler=None,
    verbose=False,
)

🚀 Epoch [1/4] Completed (25.00)
	📊 Training Loss: 3.2382
	✅ Training Accuracy: 24.34%
🔍 Validation Results:
	📉 Validation Loss: 3.3293
	🎯 Validation Accuracy: 22.59%
🚀 Epoch [2/4] Completed (50.00)
	📊 Training Loss: 3.2379
	✅ Training Accuracy: 24.34%
🔍 Validation Results:
	📉 Validation Loss: 3.3294
	🎯 Validation Accuracy: 22.59%
💾 Saved checkpoint at: /home/massimiliano/Projects/fl-g13/checkpoints/TinyCNN_epoch_4.pth
🚀 Epoch [3/4] Completed (75.00)
	📊 Training Loss: 3.2381
	✅ Training Accuracy: 24.34%
🔍 Validation Results:
	📉 Validation Loss: 3.3295
	🎯 Validation Accuracy: 22.59%
🚀 Epoch [4/4] Completed (100.00)
	📊 Training Loss: 3.2381
	✅ Training Accuracy: 24.34%
🔍 Validation Results:
	📉 Validation Loss: 3.3287
	🎯 Validation Accuracy: 22.59%
💾 Saved checkpoint at: /home/massimiliano/Projects/fl-g13/checkpoints/TinyCNN_epoch_6.pth
