In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from pathlib import Path

import flwr
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR

from fl_g13.architectures import BaseDino
from fl_g13.editing import SparseSGDM
from fl_g13.fl_pytorch.editing.centralized_mask import get_centralized_mask
from fl_g13.modeling import load_or_create

[32m2025-06-03 21:26:44.371[0m | [1mINFO    [0m | [36mfl_g13.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: C:\Users\ADMIN\Desktop\BACKUP\study\Italy\polito\classes\20242\deep learning\project\source_code\fl-g13[0m


# Login wandb

In [None]:
!pip install wandb --quiet

In [6]:
## read .env file
import dotenv

dotenv.load_dotenv()


True

In [7]:
import wandb

# login by key in .env file
WANDB_API_KEY = dotenv.dotenv_values()["WANDB_API_KEY"]
wandb.login(key=WANDB_API_KEY)

  return LooseVersion(v) >= LooseVersion(check)
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\ADMIN\_netrc
wandb: Currently logged in as: thanhnv-it23 (stefano-gamba-social-politecnico-di-torino) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


True

## Build module local

Build module local such that ClientApp can use it

In [8]:
!pip install -e .. --quiet

Obtaining file:///C:/Users/ADMIN/Desktop/BACKUP/study/Italy/polito/classes/20242/deep%20learning/project/source_code/fl-g13









  Installing build dependencies: started




  Installing build dependencies: finished with status 'done'









  Checking if build backend supports build_editable: started








[notice] A new release of pip is available: 25.0.1 -> 25.1.1


  Checking if build backend supports build_editable: finished with status 'done'

[notice] To update, run: python.exe -m pip install --upgrade pip


  Getting requirements to build editable: started





  Getting requirements to build editable: finished with status 'done'


### Download missing module for clients

Dino model,that is serialized and sent to client by server, require some modules that have to download from source code of dino model


In [9]:
import os
import urllib.request


def download_if_not_exists(file_path: str, file_url: str):
    """
    Checks if a file exists at the given path. If it does not, downloads it from the specified URL.

    Parameters:
    - file_path (str): The local path to check and save the file.
    - file_url (str): The URL from which to download the file.
    """
    if not os.path.exists(file_path):
        print(f"'{file_path}' not found. Downloading from {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
            print("Download complete.")
        except Exception as e:
            print(f"Failed to download file: {e}")
    else:
        print(f"'{file_path}' already exists.")

In [10]:
download_if_not_exists("vision_transformer.py",
                       "https://raw.githubusercontent.com/facebookresearch/dino/refs/heads/main/vision_transformer.py")
download_if_not_exists("utils.py",
                       "https://raw.githubusercontent.com/facebookresearch/dino/refs/heads/main/utils.py")


'vision_transformer.py' already exists.
'utils.py' already exists.


# FL

## Configs

## Model config

In [15]:
# -------------------------
# Debug Mode Toggle
# -------------------------
DEBUG = False

# -------------------------
# Device Configuration
# -------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

# -------------------------
# Paths and Checkpoints
# -------------------------
current_path = Path.cwd()
model_save_path = current_path / "../models/fl_dino_baseline/iid"
checkpoint_dir = model_save_path.resolve()
os.makedirs(checkpoint_dir, exist_ok=True)

# -------------------------
# Model Hyper-parameters
# -------------------------
head_layers = 3
head_hidden_size = 512
dropout_rate = 0.0
unfreeze_blocks = 12

# -------------------------
# Training Hyper-parameters
# -------------------------
batch_size = 128
lr = 1e-3
momentum = 0.9
weight_decay = 1e-5
T_max = 8
eta_min = 1e-5

# -------------------------
# Federated Learning Config
# -------------------------
K = 100
C = 0.1
J = 4
num_rounds = 30
partition_type = 'iid'
num_shards_per_partition = 10  # Only for shard partitioning

# -------------------------
# Server App Config
# -------------------------
save_every = 1
fraction_fit = C
fraction_evaluate = 0.1
min_fit_clients = 10
min_evaluate_clients = 5
min_available_clients = 10
device = 'cuda'

# -------------------------
# Wandb Configuration
# -------------------------
use_wandb = True
wandb_config = {
    'name': 'FL_Dino_Baseline_iid',
    'project_name': "FL_test_chart",
    'fraction_fit': fraction_fit,
    'lr': lr,
    'momentum': momentum,
    'partition_type': partition_type,
    'K': K,
    'C': C,
    'J': J,
}

# -------------------------
# Model Editing Config
# -------------------------
model_editing = True
mask_type = 'global'
sparsity = 0.2
mask = None

# -------------------------
# Simulation Config
# -------------------------
NUM_CLIENTS = 100
MAX_PARALLEL_CLIENTS = 10

# -------------------------
# Debug Mode Overrides
# -------------------------
if DEBUG:
    use_wandb = False
    num_rounds = 2
    J = 4

## Define model , optimizer and loss function

In [16]:

# -------------------------
# Load/Create Model
# -------------------------
model, start_epoch = load_or_create(
    path=checkpoint_dir,
    model_class=BaseDino,
    model_config=None,
    optimizer=None,
    scheduler=None,
    device=device,
    verbose=True,
)

model.unfreeze_blocks(unfreeze_blocks)
model.to(DEVICE)

# -------------------------
# Optimizer, Scheduler, Loss
# -------------------------

# Create dummy mask for SparseSGDM (required after model is on device)
init_mask = [torch.ones_like(p, device=p.device) for p in model.parameters()]

optimizer = SparseSGDM(
    model.parameters(),
    mask=init_mask,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay
)

criterion = torch.nn.CrossEntropyLoss()

scheduler = CosineAnnealingLR(
    optimizer=optimizer,
    T_max=T_max,
    eta_min=eta_min
)

🔍 Loading checkpoint from C:\Users\ADMIN\Desktop\BACKUP\study\Italy\polito\classes\20242\deep learning\project\source_code\fl-g13\models\fl_dino_baseline\iid\fl_fl_baseline_BaseDino_epoch_200_iid.pth
📦 Model class in checkpoint: BaseDino
🔧 Model configuration: {'variant': 'dino_vits16', 'dropout_rate': 0.0, 'head_hidden_size': 512, 'head_layers': 3, 'num_classes': 100, 'unfreeze_blocks': 0, 'activation_fn': 'GELU', 'pretrained': True}


Using cache found in C:\Users\ADMIN/.cache\torch\hub\facebookresearch_dino_main
Using cache found in C:\Users\ADMIN/.cache\torch\hub\facebookresearch_dino_main


➡️ Moved model to device: cuda
✅ Loaded checkpoint from C:\Users\ADMIN\Desktop\BACKUP\study\Italy\polito\classes\20242\deep learning\project\source_code\fl-g13\models\fl_dino_baseline\iid\fl_fl_baseline_BaseDino_epoch_200_iid.pth, resuming at epoch 201


## Compute stats for mask

In [17]:
from typing import Dict, Any


def compute_mask_stats(mask_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    """
    Computes various statistics for a mask represented as a dictionary
    mapping parameter names to mask tensors.

    Args:
        mask_dict: A dictionary where keys are parameter names (str)
                   and values are mask tensors (torch.Tensor) containing 0s and 1s.

    Returns:
        A dictionary containing overall and layer-wise mask statistics.
    """
    stats = {}

    # --- Overall Statistics ---
    total_elements = 0
    kept_elements_overall = 0  # Elements with value 1
    masked_elements_overall = 0  # Elements with value 0

    for name, mask_tensor in mask_dict.items():
        num_elements = mask_tensor.numel()
        kept_in_layer = torch.sum(mask_tensor == 1).item()
        masked_in_layer = num_elements - kept_in_layer

        total_elements += num_elements
        kept_elements_overall += kept_in_layer
        masked_elements_overall += masked_in_layer

        # --- Layer-wise Statistics ---
        layer_stats = {
            'num_elements': num_elements,
            'kept_elements': kept_in_layer,
            'masked_elements': masked_in_layer,
            'density': kept_in_layer / num_elements if num_elements > 0 else 0.0,
            'sparsity': masked_in_layer / num_elements if num_elements > 0 else 0.0
        }
        stats[name] = layer_stats

    # --- Add Overall Statistics to the result dictionary ---
    stats['overall'] = {
        'total_elements': total_elements,
        'kept_elements': kept_elements_overall,
        'masked_elements': masked_elements_overall,
        'density': kept_elements_overall / total_elements if total_elements > 0 else 0.0,
        'sparsity': masked_elements_overall / total_elements if total_elements > 0 else 0.0
    }

    return stats


def print_mask_stats(stats: Dict[str, Any], layer=False):
    """
    Prints the mask statistics in a readable format.

    Args:
        stats: The dictionary returned by compute_mask_stats.
    """
    if 'overall' not in stats:
        print("Invalid stats dictionary format.")
        return

    overall_stats = stats['overall']
    print("--- Overall Mask Statistics ---")
    print(f"Total Elements: {overall_stats['total_elements']}")
    print(f"Kept Elements (1s): {overall_stats['kept_elements']}")
    print(f"Masked Elements (0s): {overall_stats['masked_elements']}")
    print(f"Overall Density: {overall_stats['density']:.4f}")
    print(f"Overall Sparsity: {overall_stats['sparsity']:.4f}")
    print("-" * 30)

    if not layer:
        return

    print("--- Layer-wise Mask Statistics ---")
    # Sort layer names for consistent output
    layer_names = sorted([name for name in stats if name != 'overall'])
    for name in layer_names:
        layer_stats = stats[name]
        print(f"Layer: {name}")
        print(f"  Num Elements: {layer_stats['num_elements']}")
        print(f"  Kept Elements: {layer_stats['kept_elements']}")
        print(f"  Masked Elements: {layer_stats['masked_elements']}")
        print(f"  Density: {layer_stats['density']:.4f}")
        print(f"  Sparsity: {layer_stats['sparsity']:.4f}")
        print("-" * 20)

## Calculate the centralized mask

In [18]:
# ----------------------------------------
# Client Dataset Configuration
# ----------------------------------------
client_partition_type = 'iid'              # 'iid' or 'shard' for non-iid
client_num_partitions = 100                # Number of clients
client_num_shards_per_partition = 10       # Used only for 'shard'
client_batch_size = 16
client_train_test_split_ratio = 0.2
client_dataset = "cifar100"
client_seed = 42
client_return_dataset = False

# ----------------------------------------
# Mask Generation Configuration
# ----------------------------------------
mask_model = model                         # Model to use for mask generation
mask_sparsity = 0.8                        # Proportion of weights to prune
mask_type = 'global'                       # 'global' or 'local'
mask_rounds = 1                            # When to compute the mask
mask_func = None                           # Custom mask function (if any)

# ----------------------------------------
# Aggregation Configuration
# ----------------------------------------
agg_strategy = 'union'                     # 'union', 'intersection'.
agg_func = None                            # Custom aggregation logic (if applicable)

# ----------------------------------------
# Debug Mode Overrides
# ----------------------------------------
if DEBUG:
    client_num_partitions = 10
    client_batch_size = 128
    client_train_test_split_ratio = 0.9

In [None]:
centralized_mask = get_centralized_mask(
    client_partition_type=client_partition_type,
    client_num_partitions=client_num_partitions,
    client_num_shards_per_partition=client_num_shards_per_partition,
    client_batch_size=client_batch_size,
    client_dataset=client_dataset,
    client_seed=client_seed,
    client_return_dataset=client_return_dataset,
    mask_model=mask_model,
    mask_sparsity=mask_sparsity,
    mask_type=mask_type,
    mask_rounds=mask_rounds,
    mask_func=mask_func,
    agg_strategy=agg_strategy,
    agg_func=agg_func
)


In [None]:

agg_strategy = 'intersection'
centralized_mask_intersection = get_centralized_mask(
    client_partition_type=client_partition_type,
    client_num_partitions=client_num_partitions,
    client_num_shards_per_partition=client_num_shards_per_partition,
    client_batch_size=client_batch_size,
    client_dataset=client_dataset,
    client_seed=client_seed,
    client_return_dataset=client_return_dataset,
    mask_model=mask_model,
    mask_sparsity=mask_sparsity,
    mask_type=mask_type,
    mask_rounds=mask_rounds,
    mask_func=mask_func,
    agg_strategy=agg_strategy,
    agg_func=agg_func
)

In [21]:
compute_mask_stats(centralized_mask[1])

{'backbone.blocks.0.norm1.weight': {'num_elements': 384,
  'kept_elements': 301,
  'masked_elements': 83,
  'density': 0.7838541666666666,
  'sparsity': 0.21614583333333334},
 'backbone.blocks.0.norm1.bias': {'num_elements': 384,
  'kept_elements': 48,
  'masked_elements': 336,
  'density': 0.125,
  'sparsity': 0.875},
 'backbone.blocks.0.attn.qkv.weight': {'num_elements': 442368,
  'kept_elements': 413109,
  'masked_elements': 29259,
  'density': 0.9338582356770834,
  'sparsity': 0.06614176432291667},
 'backbone.blocks.0.attn.qkv.bias': {'num_elements': 1152,
  'kept_elements': 738,
  'masked_elements': 414,
  'density': 0.640625,
  'sparsity': 0.359375},
 'backbone.blocks.0.attn.proj.weight': {'num_elements': 147456,
  'kept_elements': 92663,
  'masked_elements': 54793,
  'density': 0.6284111870659722,
  'sparsity': 0.3715888129340278},
 'backbone.blocks.0.attn.proj.bias': {'num_elements': 384,
  'kept_elements': 0,
  'masked_elements': 384,
  'density': 0.0,
  'sparsity': 1.0},
 'ba

In [22]:
compute_mask_stats(centralized_mask_intersection[1])

{'backbone.blocks.0.norm1.weight': {'num_elements': 384,
  'kept_elements': 215,
  'masked_elements': 169,
  'density': 0.5598958333333334,
  'sparsity': 0.4401041666666667},
 'backbone.blocks.0.norm1.bias': {'num_elements': 384,
  'kept_elements': 43,
  'masked_elements': 341,
  'density': 0.11197916666666667,
  'sparsity': 0.8880208333333334},
 'backbone.blocks.0.attn.qkv.weight': {'num_elements': 442368,
  'kept_elements': 357640,
  'masked_elements': 84728,
  'density': 0.8084671585648148,
  'sparsity': 0.19153284143518517},
 'backbone.blocks.0.attn.qkv.bias': {'num_elements': 1152,
  'kept_elements': 688,
  'masked_elements': 464,
  'density': 0.5972222222222222,
  'sparsity': 0.4027777777777778},
 'backbone.blocks.0.attn.proj.weight': {'num_elements': 147456,
  'kept_elements': 55718,
  'masked_elements': 91738,
  'density': 0.3778618706597222,
  'sparsity': 0.6221381293402778},
 'backbone.blocks.0.attn.proj.bias': {'num_elements': 384,
  'kept_elements': 0,
  'masked_elements': 