In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR

from fl_g13.editing import SparseSGDM
from fl_g13.editing import create_gradiend_mask
from fl_g13.editing import fisher_scores
from fl_g13.modeling import eval

In [4]:
from torchvision import transforms
from fl_g13.fl_pytorch.datasets import load_datasets

partition_id = 1
num_partitions = 100
partition_type = 'iid'
batch_size = 128
num_shards_per_partition = 6
train_test_split_ratio = 0.2


def get_transforms():
    """Return a function that apply standard transformations to images."""

    def apply_transforms(batch):
        pytorch_transforms = transforms.Compose([
            transforms.ToTensor()
        ])
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        batch["fine_label"] = [int(lbl) for lbl in batch["fine_label"]]

        return batch

    return apply_transforms


trainloader, valloader = load_datasets(
    partition_id,
    num_partitions,
    partition_type=partition_type,
    batch_size=batch_size,
    num_shards_per_partition=num_shards_per_partition,
    train_test_split_ratio=train_test_split_ratio,
    transform=get_transforms
)

In [5]:
len(trainloader)

4

In [6]:
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

current_path = Path.cwd()
model_test_path = current_path / "../models/model_test"
model_test_path.resolve()


WindowsPath('C:/Users/ADMIN/Desktop/BACKUP/study/Italy/polito/classes/20242/deep learning/project/source_code/fl-g13/models/model_test')

In [7]:

from fl_g13.fl_pytorch.model import TinyCNN
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim import SGD
from fl_g13.modeling import load_or_create

BATCH_SIZE = 128
LR = 1e-3
checkpoint_dir = model_test_path.resolve()
model_class = TinyCNN
# Optimizer, scheduler, and loss function
model = TinyCNN()
optimizer = SGD(model.parameters(), lr=LR)
scheduler = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=8,  # First restart after 8 epochs
    T_mult=2,  # Double the interval between restarts each time
    eta_min=1e-5  # Minimum learning rate after annealing
)
criterion = CrossEntropyLoss()
# Load the model
model, start_epoch = load_or_create(
    path=checkpoint_dir,
    model_class=model_class,
    device=device,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
)
model.to(device)

# Create a dummy mask for SparseSGDM
mask = [torch.ones_like(p, device=p.device) for p in
        model.parameters()]  # Must be done AFTER the model is moved to the device
# Optimizer, scheduler, and loss function
optimizer = SparseSGDM(
    model.parameters(),
    mask=mask,
    lr=LR,
    momentum=0.9,
    weight_decay=1e-5
)
scheduler = CosineAnnealingLR(
    optimizer=optimizer,
    T_max=8,
    eta_min=1e-5
)
criterion = CrossEntropyLoss()



🔍 Loading checkpoint from C:\Users\ADMIN\Desktop\BACKUP\study\Italy\polito\classes\20242\deep learning\project\source_code\fl-g13\models\model_test\FL_TinyCNN_epoch_8.pth
📦 Model class in checkpoint: TinyCNN
✅ Loaded checkpoint from C:\Users\ADMIN\Desktop\BACKUP\study\Italy\polito\classes\20242\deep learning\project\source_code\fl-g13\models\model_test\FL_TinyCNN_epoch_8.pth, resuming at epoch 9


In [8]:
for batch_idx, (X, y) in enumerate(valloader):
    X, y = X.to(device), y.to(device)

In [9]:
## compute fisher scores

scores = fisher_scores(dataloader=valloader, model=model, loss_fn=criterion,verbose=1)
mask = create_gradiend_mask(class_score=scores, sparsity=0.2, mask_type='global')

Fisher Score: 100%|██████████| 1/1 [00:00<00:00,  5.70batch/s]


In [10]:
mask

{'conv1.weight': tensor([[[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           

In [11]:
## eval model before run editing
eval(dataloader=valloader, model=model, criterion=criterion)

Eval progress: 100%|██████████| 1/1 [00:00<00:00, 55.35batch/s]


(2.9908766746520996, 0.2, [2.9908766746520996])

In [12]:
from fl_g13.editing import mask_dict_to_list

mask_list = mask_dict_to_list(model, mask)

optimizer = SparseSGDM(
    model.parameters(),
    mask=mask_list,
    lr=LR,
    momentum=0.9,
    weight_decay=1e-5
)

In [13]:
mask_list

[tensor([[[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
  

In [14]:
from fl_g13.modeling import train
checkpoint_dir_edit = f'{checkpoint_dir}/edit'
name = "model_editing"
epochs= 10
_, _, _, _ = train(
        checkpoint_dir = checkpoint_dir_edit,
        name = name,
        start_epoch = 1,
        num_epochs = epochs,
        save_every = epochs,
        backup_every = None,
        train_dataloader = trainloader,
        val_dataloader = None,
        model = model,
        criterion = criterion,
        optimizer = optimizer,
        scheduler = None, # No scheduler needed, too few epochs
        verbose = 1
    )

Prefix/name for the model was provided: model_editing



Training progress: 100%|██████████| 4/4 [00:00<00:00, 45.30batch/s]


🚀 Epoch 1/10 (10.00%) Completed
	📊 Training Loss: 3.3045
	✅ Training Accuracy: 26.00%
	⏳ Elapsed Time: 0.09s | ETA: 0.82s
	🕒 Completed At: 10:38



Training progress: 100%|██████████| 4/4 [00:00<00:00, 61.11batch/s]


🚀 Epoch 2/10 (20.00%) Completed
	📊 Training Loss: 3.0909
	✅ Training Accuracy: 26.00%
	⏳ Elapsed Time: 0.07s | ETA: 0.54s
	🕒 Completed At: 10:38



Training progress: 100%|██████████| 4/4 [00:00<00:00, 59.38batch/s]


🚀 Epoch 3/10 (30.00%) Completed
	📊 Training Loss: 3.1011
	✅ Training Accuracy: 26.00%
	⏳ Elapsed Time: 0.07s | ETA: 0.49s
	🕒 Completed At: 10:38



Training progress: 100%|██████████| 4/4 [00:00<00:00, 58.50batch/s]


🚀 Epoch 4/10 (40.00%) Completed
	📊 Training Loss: 3.1835
	✅ Training Accuracy: 26.00%
	⏳ Elapsed Time: 0.07s | ETA: 0.42s
	🕒 Completed At: 10:38



Training progress: 100%|██████████| 4/4 [00:00<00:00, 59.45batch/s]


🚀 Epoch 5/10 (50.00%) Completed
	📊 Training Loss: 3.1030
	✅ Training Accuracy: 26.00%
	⏳ Elapsed Time: 0.07s | ETA: 0.35s
	🕒 Completed At: 10:38



Training progress: 100%|██████████| 4/4 [00:00<00:00, 64.72batch/s]


🚀 Epoch 6/10 (60.00%) Completed
	📊 Training Loss: 3.2345
	✅ Training Accuracy: 25.75%
	⏳ Elapsed Time: 0.06s | ETA: 0.26s
	🕒 Completed At: 10:38



Training progress: 100%|██████████| 4/4 [00:00<00:00, 60.61batch/s]


🚀 Epoch 7/10 (70.00%) Completed
	📊 Training Loss: 3.1307
	✅ Training Accuracy: 25.75%
	⏳ Elapsed Time: 0.07s | ETA: 0.20s
	🕒 Completed At: 10:38



Training progress: 100%|██████████| 4/4 [00:00<00:00, 65.54batch/s]


🚀 Epoch 8/10 (80.00%) Completed
	📊 Training Loss: 3.1188
	✅ Training Accuracy: 25.75%
	⏳ Elapsed Time: 0.06s | ETA: 0.13s
	🕒 Completed At: 10:38



Training progress: 100%|██████████| 4/4 [00:00<00:00, 63.04batch/s]


🚀 Epoch 9/10 (90.00%) Completed
	📊 Training Loss: 3.0932
	✅ Training Accuracy: 25.75%
	⏳ Elapsed Time: 0.07s | ETA: 0.07s
	🕒 Completed At: 10:38



Training progress: 100%|██████████| 4/4 [00:00<00:00, 63.58batch/s]

🚀 Epoch 10/10 (100.00%) Completed
	📊 Training Loss: 3.2409
	✅ Training Accuracy: 25.75%
	⏳ Elapsed Time: 0.07s | ETA: 0.00s
	🕒 Completed At: 10:38

💾 Saved checkpoint at: C:\Users\ADMIN\Desktop\BACKUP\study\Italy\polito\classes\20242\deep learning\project\source_code\fl-g13\models\model_test/edit\TinyCNN\model_editing_TinyCNN_epoch_10.pth
💾 Saved losses and accuracies (training and validation) at: C:\Users\ADMIN\Desktop\BACKUP\study\Italy\polito\classes\20242\deep learning\project\source_code\fl-g13\models\model_test/edit\TinyCNN\model_editing_TinyCNN_epoch_10.loss_acc.json






In [15]:
eval(dataloader=valloader, model=model, criterion=criterion)

Eval progress: 100%|██████████| 1/1 [00:00<00:00, 62.50batch/s]


(2.991311550140381, 0.21, [2.991311550140381])

In [16]:
mask_list

[tensor([[[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
  

In [24]:
from fl_g13.editing.masking import compress_mask_sparse

compressed = compress_mask_sparse(mask_list)
compressed

b'\x80\x04\x95\x03\x00\x01\x00\x00\x00\x00\x00]\x94(]\x94(K\x10K\x03K\x03K\x03e]\x94M~\x01a\x86\x94]\x94K\x10a]\x94\x86\x94]\x94(K K\x10K\x03K\x03e]\x94(K\nK\x11K\x1bK\x1cK\x1dK<KPKeKiKrK\xdbM\x8c\x01M\x92\x01M\xf5\x01M\xf6\x01M\x0c\x02M\x0e\x02M\x18\x02M\x1b\x02M%\x02M(\x02M\x1e\x03M6\x03ME\x03MG\x03MH\x03Mr\x03M/\x05MX\x05M^\x05M`\x05Mu\x05M\xbe\x05M\xbf\x05M\xd9\x05M\xdc\x05M\x0b\x06M4\x06MO\x06Mt\x06M\x7f\x06M\x88\x06M\xe0\x06M\xf4\x06M\xfe\x06M\x11\x07M\x15\x07M\x17\x07M\x1e\x07M \x07M#\x07M$\x07M&\x07M\'\x07M(\x07M\xb7\x07M\xf7\x07MG\x08M\x9d\x08M\xd9\x08M\xdb\x08M\xf1\tM*\nM=\nMB\nMM\nMi\nMq\nMr\nMt\nMu\nMy\nM\x89\nM\x8a\nM\x8b\nM\x98\nM\x99\nM\x9a\nM\x9e\nM\xe2\x0bM3\x0cMC\x0cM\xec\x0cMo\rMu\rM\x95\rM\x9a\rM\xca\rM\xd0\rM\xf3\rM\xf4\rM\xfe\rM\x00\x0eM\x02\x0eM\x04\x0eM\x05\x0eM"\x0eMS\x0eMT\x0eM]\x0eM`\x0eM\xb9\x0eM\xfd\x0eM\xfe\x0eM\xff\x0eM\x03\x0fM\x04\x0fM\x0b\x0fMK\x0fMx\x0fM\x98\x0fM?\x10Mf\x10Mp\x10M\x82\x10M\x9a\x10M\xb6\x10M\xc3\x10M\x82\x11M\x84\x11M\x87\x11e\x86\x94]

In [25]:
from fl_g13.editing.masking import uncompress_mask_sparse
uncompress_mask_sparse(compressed,device=device)

[tensor([[[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],
 
          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 0., 0.],
  