In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader

import numpy as np

import json

from fl_g13.config import RAW_DATA_DIR, PROJ_ROOT

from fl_g13.modeling import train, load, plot_metrics, get_preprocessing_pipeline

from fl_g13.architectures import BaseDino

from fl_g13.editing import SparseSGDM
from fl_g13.editing import per_class_accuracy
from fl_g13.editing import create_mask, mask_dict_to_list

[32m2025-05-20 08:55:43.044[0m | [1mINFO    [0m | [36mfl_g13.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /content/fl-g13[0m


In [4]:
train_dataset, val_dataset, test_dataset = get_preprocessing_pipeline(RAW_DATA_DIR)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

100%|██████████| 169M/169M [00:12<00:00, 13.3MB/s]


Train dataset size: 40000
Validation dataset size: 10000
Test dataset size: 10000


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

CHECKPOINT_DIR = '/content/drive/MyDrive/checkpoints'
model_name = 'arcanine'
model_checkpoint_path = f'{CHECKPOINT_DIR}/Editing/{model_name}.pth'
model_metrics_path = f'{CHECKPOINT_DIR}/Editing/{model_name}.loss_acc.json'

# Hyper-parameters
# model
head_layers=3
head_hidden_size=512
dropout_rate=0.0
unfreeze_blocks=1

# Dataloaders
BATCH_SIZE = 128

# SparseSGDM optimizer
LR = 1e-3
momentum = .9
weight_decay = 1e-5

# scheduler
T_max = 8
eta_min = 1e-5

# Empty model
# Will be replaced with the already trained model from the checkpoint
model = BaseDino(
    head_layers=head_layers,
    head_hidden_size=head_hidden_size,
    dropout_rate=dropout_rate,
    unfreeze_blocks=unfreeze_blocks
)
model.to(device)

# Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_dataloader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = False)
test_dataloader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = False)

# Create a dummy mask for SparseSGDM
mask = [torch.ones_like(p, device = p.device) for p in model.parameters()] # Must be done AFTER the model is moved to the device
# Optimizer, scheduler, and loss function
optimizer = SparseSGDM(
    model.parameters(),
    mask = mask,
    lr = LR,
    momentum = momentum,
    weight_decay = weight_decay
)
scheduler = CosineAnnealingLR(
    optimizer = optimizer,
    T_max = T_max,
    eta_min = eta_min
)
criterion = CrossEntropyLoss()

# Load the model
model, _ = load(
    path = model_checkpoint_path,
    model_class = BaseDino,
    optimizer = optimizer,
    scheduler = scheduler,
    device = device
)
model.to(device) # manually move the model to the device

print(f'\nModel {model_name} loaded from checkpoint.')

Using device: cuda


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


✅ Loaded checkpoint from /content/drive/MyDrive/checkpoints/Editing/arcanine.pth, resuming at epoch 26

Model arcanine loaded from checkpoint.


In [7]:
# Compute test accuracy
# test_loss, test_accuracy, _ = eval(test_dataloader, model, criterion)
class_acc = per_class_accuracy(test_dataloader, model)
test_accuracy = np.mean(class_acc)

# print(f'Test loss: {test_loss:.3f}')
print(f'\nTest accuracy: {100*test_accuracy:.2f}%')

Per Class Accuracy: 100%|██████████| 79/79 [00:51<00:00,  1.53batch/s]

Test accuracy: 78.62%





In [8]:
fisher_dataloader = DataLoader(train_dataset, batch_size = 16, shuffle=True)

In [9]:
density = 0.2

global_calibr_mask = create_mask(fisher_dataloader, model, density = density, mask_type = 'global', rounds = 5)
local_calibr_mask = create_mask(fisher_dataloader, model, density = density, mask_type = 'local', rounds = 5)

Computing calibrated global mask for 5 rounds with initial target density (0.20).
Round 1/5.
	Target density 0.72%
	Computing the masked fisher score


Fisher Score: 100%|██████████| 2500/2500 [02:41<00:00, 15.50batch/s]


	Updating the mask
Round 2/5.
	Target density 0.53%
	Computing the masked fisher score


Fisher Score: 100%|██████████| 2500/2500 [02:36<00:00, 15.96batch/s]


	Updating the mask
Round 3/5.
	Target density 0.38%
	Computing the masked fisher score


Fisher Score: 100%|██████████| 2500/2500 [02:37<00:00, 15.91batch/s]


	Updating the mask
Round 4/5.
	Target density 0.28%
	Computing the masked fisher score


Fisher Score: 100%|██████████| 2500/2500 [02:36<00:00, 15.96batch/s]


	Updating the mask
Round 5/5.
	Target density 0.20%
	Computing the masked fisher score


Fisher Score: 100%|██████████| 2500/2500 [02:36<00:00, 15.96batch/s]


	Updating the mask
Computing calibrated local mask for 5 rounds with initial target density (0.20).
Round 1/5.
	Target density 0.72%
	Computing the masked fisher score


Fisher Score: 100%|██████████| 2500/2500 [02:37<00:00, 15.87batch/s]


	Updating the mask
Round 2/5.
	Target density 0.53%
	Computing the masked fisher score


Fisher Score: 100%|██████████| 2500/2500 [02:37<00:00, 15.89batch/s]


	Updating the mask
Round 3/5.
	Target density 0.38%
	Computing the masked fisher score


Fisher Score: 100%|██████████| 2500/2500 [02:36<00:00, 15.96batch/s]


	Updating the mask
Round 4/5.
	Target density 0.28%
	Computing the masked fisher score


Fisher Score: 100%|██████████| 2500/2500 [02:36<00:00, 15.94batch/s]


	Updating the mask
Round 5/5.
	Target density 0.20%
	Computing the masked fisher score


Fisher Score: 100%|██████████| 2500/2500 [02:37<00:00, 15.91batch/s]

	Updating the mask





In [10]:
def fine_tune(name, train_dataloader, mask, optimizer, scheduler, criterion, epochs = 10, verbose = 1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the model
    new_model, _ = load(
        path = model_checkpoint_path,
        model_class = BaseDino,
        optimizer = optimizer,
        scheduler = scheduler,
        device = device
    )
    new_model.to(device) # manually move the model to the device

    # Create a new SparseSGDM optimizer
    new_optimizer = SparseSGDM(
        new_model.parameters(),
        mask = mask,
        lr = LR,
        momentum = momentum,
        weight_decay = weight_decay
    )

    _, _, _, _ = train(
        checkpoint_dir = CHECKPOINT_DIR,
        name = name,
        start_epoch = 1,
        num_epochs = epochs,
        save_every = epochs,
        backup_every = None,
        train_dataloader = train_dataloader,
        val_dataloader = None,
        model = new_model,
        criterion = criterion,
        optimizer = new_optimizer,
        scheduler = scheduler,
        verbose = verbose
    )

    # Compute per-class accuracy
    class_acc = per_class_accuracy(test_dataloader, new_model)

    return class_acc

In [11]:
# GLOBAL MASK
global_acc = fine_tune(
    name = f'{model_name}_ft_global_calibr',
    mask = mask_dict_to_list(model, global_calibr_mask),
    optimizer = optimizer,
    scheduler = scheduler,
    criterion = criterion,
    train_dataloader = train_dataloader
)

new_test_accuracy = np.mean(global_acc)
print(f'\nTest accuracy: {100*new_test_accuracy:.2f}% (original: {100*test_accuracy:.2f}%)')

count = sum([1 for i in range(len(global_acc)) if global_acc[i] < class_acc[i]])
print(f'Fine-tuned model is worse in {count} classes, wrt the original model')
# Save to file the per-class accuracy difference
# Create a dictionary with new_class_acc, class_acc, and class_idx
accuracy_data = {
    "class_idx": list(range(100)),
    "new_class_acc": list(global_acc),
    "class_acc": list(class_acc)
}
output_file = f"{CHECKPOINT_DIR}/Editing/{model_name}/accuracy_comparison_global_calibr.json"

# Save the dictionary to a JSON file
with open(output_file, "w") as json_file:
    json.dump(accuracy_data, json_file, indent=4)
print(f"Accuracy data saved to {output_file}")

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


✅ Loaded checkpoint from /content/drive/MyDrive/checkpoints/Editing/arcanine.pth, resuming at epoch 26
Prefix/name for the model was provided: arcanine_ft_global_calibr



Training progress: 100%|██████████| 313/313 [03:59<00:00,  1.31batch/s]


🚀 Epoch 1/10 (10.00%) Completed
	📊 Training Loss: 0.2137
	✅ Training Accuracy: 95.33%
	⏳ Elapsed Time: 239.23s | ETA: 2153.03s
	🕒 Completed At: 09:28



Training progress: 100%|██████████| 313/313 [03:57<00:00,  1.32batch/s]


🚀 Epoch 2/10 (20.00%) Completed
	📊 Training Loss: 0.1981
	✅ Training Accuracy: 95.55%
	⏳ Elapsed Time: 237.61s | ETA: 1900.92s
	🕒 Completed At: 09:32



Training progress: 100%|██████████| 313/313 [03:59<00:00,  1.31batch/s]


🚀 Epoch 3/10 (30.00%) Completed
	📊 Training Loss: 0.1806
	✅ Training Accuracy: 95.81%
	⏳ Elapsed Time: 239.20s | ETA: 1674.42s
	🕒 Completed At: 09:36



Training progress: 100%|██████████| 313/313 [03:59<00:00,  1.31batch/s]


🚀 Epoch 4/10 (40.00%) Completed
	📊 Training Loss: 0.1743
	✅ Training Accuracy: 95.84%
	⏳ Elapsed Time: 239.08s | ETA: 1434.48s
	🕒 Completed At: 09:40



Training progress: 100%|██████████| 313/313 [03:57<00:00,  1.32batch/s]


🚀 Epoch 5/10 (50.00%) Completed
	📊 Training Loss: 0.1614
	✅ Training Accuracy: 96.11%
	⏳ Elapsed Time: 237.90s | ETA: 1189.50s
	🕒 Completed At: 09:44



Training progress: 100%|██████████| 313/313 [03:57<00:00,  1.32batch/s]


🚀 Epoch 6/10 (60.00%) Completed
	📊 Training Loss: 0.1501
	✅ Training Accuracy: 96.29%
	⏳ Elapsed Time: 237.03s | ETA: 948.11s
	🕒 Completed At: 09:48



Training progress: 100%|██████████| 313/313 [03:57<00:00,  1.32batch/s]


🚀 Epoch 7/10 (70.00%) Completed
	📊 Training Loss: 0.1400
	✅ Training Accuracy: 96.35%
	⏳ Elapsed Time: 237.84s | ETA: 713.51s
	🕒 Completed At: 09:52



Training progress: 100%|██████████| 313/313 [03:57<00:00,  1.32batch/s]


🚀 Epoch 8/10 (80.00%) Completed
	📊 Training Loss: 0.1293
	✅ Training Accuracy: 96.70%
	⏳ Elapsed Time: 237.39s | ETA: 474.78s
	🕒 Completed At: 09:56



Training progress: 100%|██████████| 313/313 [03:58<00:00,  1.31batch/s]


🚀 Epoch 9/10 (90.00%) Completed
	📊 Training Loss: 0.1224
	✅ Training Accuracy: 96.85%
	⏳ Elapsed Time: 238.52s | ETA: 238.52s
	🕒 Completed At: 10:00



Training progress: 100%|██████████| 313/313 [03:57<00:00,  1.32batch/s]


🚀 Epoch 10/10 (100.00%) Completed
	📊 Training Loss: 0.1171
	✅ Training Accuracy: 96.93%
	⏳ Elapsed Time: 237.90s | ETA: 0.00s
	🕒 Completed At: 10:04

💾 Saved checkpoint at: /content/drive/MyDrive/checkpoints/BaseDino/arcanine_ft_global_calibr_BaseDino_epoch_10.pth
💾 Saved losses and accuracies (training and validation) at: /content/drive/MyDrive/checkpoints/BaseDino/arcanine_ft_global_calibr_BaseDino_epoch_10.loss_acc.json



Per Class Accuracy: 100%|██████████| 79/79 [00:53<00:00,  1.47batch/s]


Test accuracy: 78.60% (original: 78.62%)
Fine-tuned model is worse in 42 classes, wrt the original model
Accuracy data saved to /content/drive/MyDrive/checkpoints/Editing/arcanine/accuracy_comparison_global_calibr.json





In [12]:
# LOCAL MASK
local_acc = fine_tune(
    name = f'{model_name}_ft_local_calibr',
    mask = mask_dict_to_list(model, local_calibr_mask),
    optimizer = optimizer,
    scheduler = scheduler,
    criterion = criterion,
    train_dataloader = train_dataloader
)

new_test_accuracy = np.mean(local_acc)
print(f'\nTest accuracy: {100*new_test_accuracy:.2f}% (original: {100*test_accuracy:.2f}%)')

count = sum([1 for i in range(len(local_acc)) if local_acc[i] < class_acc[i]])
print(f'Fine-tuned model is worse in {count} classes, wrt the original model')
# Save to file the per-class accuracy difference
# Create a dictionary with new_class_acc, class_acc, and class_idx
accuracy_data = {
    "class_idx": list(range(100)),
    "new_class_acc": list(local_acc),
    "class_acc": list(class_acc)
}
output_file = f"{CHECKPOINT_DIR}/Editing/{model_name}/accuracy_comparison_local_calibr.json"

# Save the dictionary to a JSON file
with open(output_file, "w") as json_file:
    json.dump(accuracy_data, json_file, indent=4)
print(f"Accuracy data saved to {output_file}")

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


✅ Loaded checkpoint from /content/drive/MyDrive/checkpoints/Editing/arcanine.pth, resuming at epoch 26
Prefix/name for the model was provided: arcanine_ft_local_calibr



Training progress: 100%|██████████| 313/313 [03:57<00:00,  1.32batch/s]


🚀 Epoch 1/10 (10.00%) Completed
	📊 Training Loss: 0.2166
	✅ Training Accuracy: 95.17%
	⏳ Elapsed Time: 237.24s | ETA: 2135.13s
	🕒 Completed At: 10:09



Training progress: 100%|██████████| 313/313 [04:00<00:00,  1.30batch/s]


🚀 Epoch 2/10 (20.00%) Completed
	📊 Training Loss: 0.1920
	✅ Training Accuracy: 95.49%
	⏳ Elapsed Time: 240.42s | ETA: 1923.37s
	🕒 Completed At: 10:13



Training progress: 100%|██████████| 313/313 [03:58<00:00,  1.31batch/s]


🚀 Epoch 3/10 (30.00%) Completed
	📊 Training Loss: 0.1710
	✅ Training Accuracy: 95.89%
	⏳ Elapsed Time: 238.74s | ETA: 1671.17s
	🕒 Completed At: 10:17



Training progress: 100%|██████████| 313/313 [03:57<00:00,  1.32batch/s]


🚀 Epoch 4/10 (40.00%) Completed
	📊 Training Loss: 0.1544
	✅ Training Accuracy: 96.10%
	⏳ Elapsed Time: 237.70s | ETA: 1426.20s
	🕒 Completed At: 10:20



Training progress: 100%|██████████| 313/313 [03:59<00:00,  1.30batch/s]


🚀 Epoch 5/10 (50.00%) Completed
	📊 Training Loss: 0.1406
	✅ Training Accuracy: 96.49%
	⏳ Elapsed Time: 239.98s | ETA: 1199.91s
	🕒 Completed At: 10:24



Training progress: 100%|██████████| 313/313 [03:59<00:00,  1.31batch/s]


🚀 Epoch 6/10 (60.00%) Completed
	📊 Training Loss: 0.1305
	✅ Training Accuracy: 96.66%
	⏳ Elapsed Time: 239.54s | ETA: 958.15s
	🕒 Completed At: 10:28



Training progress: 100%|██████████| 313/313 [03:57<00:00,  1.32batch/s]


🚀 Epoch 7/10 (70.00%) Completed
	📊 Training Loss: 0.1231
	✅ Training Accuracy: 96.83%
	⏳ Elapsed Time: 237.47s | ETA: 712.41s
	🕒 Completed At: 10:32



Training progress: 100%|██████████| 313/313 [03:59<00:00,  1.31batch/s]


🚀 Epoch 8/10 (80.00%) Completed
	📊 Training Loss: 0.1118
	✅ Training Accuracy: 97.09%
	⏳ Elapsed Time: 239.72s | ETA: 479.44s
	🕒 Completed At: 10:36



Training progress: 100%|██████████| 313/313 [03:59<00:00,  1.31batch/s]


🚀 Epoch 9/10 (90.00%) Completed
	📊 Training Loss: 0.1047
	✅ Training Accuracy: 97.31%
	⏳ Elapsed Time: 239.50s | ETA: 239.50s
	🕒 Completed At: 10:40



Training progress: 100%|██████████| 313/313 [03:57<00:00,  1.32batch/s]


🚀 Epoch 10/10 (100.00%) Completed
	📊 Training Loss: 0.0997
	✅ Training Accuracy: 97.38%
	⏳ Elapsed Time: 237.80s | ETA: 0.00s
	🕒 Completed At: 10:44

💾 Saved checkpoint at: /content/drive/MyDrive/checkpoints/BaseDino/arcanine_ft_local_calibr_BaseDino_epoch_10.pth
💾 Saved losses and accuracies (training and validation) at: /content/drive/MyDrive/checkpoints/BaseDino/arcanine_ft_local_calibr_BaseDino_epoch_10.loss_acc.json



Per Class Accuracy: 100%|██████████| 79/79 [00:53<00:00,  1.47batch/s]


Test accuracy: 78.52% (original: 78.62%)
Fine-tuned model is worse in 42 classes, wrt the original model
Accuracy data saved to /content/drive/MyDrive/checkpoints/Editing/arcanine/accuracy_comparison_local_calibr.json



