In [None]:
# Clone github repository
!git clone --branch clustering https://github.com/AlessandroMaini/federated-learning-project.git

In [None]:
%cd federated-learning-project

In [None]:
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import random
from data.cifar100_loader import get_federated_cifar100_dataloaders, get_clustered_cifar100_datasets
from eval import evaluate
from train import train, train_steps
from models.prepare_model import get_dino_vits16_model, freeze_backbone, unfreeze_backbone, freeze_head, unfreeze_head
from models.model_editing import mask_calculator, freeze_and_clean_client_masks
from models.federated_averaging import train_on_client, average_metrics, average_models, get_trainable_keys
from tqdm import tqdm
import copy
from collections import defaultdict

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# Number of clients
K = 100

# Number of clients per cluster
N = K // 20

# Fraction of clients
C = 0.1

# Number of local steps
J = 4

In [None]:
# Create dataloaders for the clients
train_datasets, test_loader, client_class_map, cluster_to_clients = get_clustered_cifar100_datasets(N)

criterion = nn.CrossEntropyLoss()

In [None]:
model = get_dino_vits16_model(device)

freeze_backbone(model)

In [None]:
warmup_rounds = 10
warmup_steps = 8

In [None]:
start_round = 0
num_rounds = warmup_rounds
best_test_acc = 0.0

warmup_train_loss = []
warmup_train_acc = []
warmup_test_loss = []
warmup_test_acc = []

In [None]:
# FedAvg loop
for round in range(start_round, start_round + num_rounds):
    print(f"\n--- Round {round + 1}/{start_round + num_rounds} ---")

    # Select clients
    selected_clients = random.sample(range(K), int(C * K))

    # Local training
    local_models, train_losses, train_accs = [], [], []
    for client_id in selected_clients:
        model_state, loss, acc = train_on_client(
            client_id,
            model,
            train_datasets[client_id],
            warmup_steps,
            criterion,
            lr = 0.01,
            device = device
        )
        local_models.append(model_state)
        train_losses.append(loss)
        train_accs.append(acc)

    # Weighting by dataset size
    client_sample_counts = [len(train_datasets[c]) for c in selected_clients]
    total_samples = sum(client_sample_counts)
    client_weights = [count / total_samples for count in client_sample_counts]

    # Federated averaging
    trainable_keys = get_trainable_keys(model)
    averaged_state = average_models(local_models, client_weights, trainable_keys)
    new_state = model.state_dict()
    for key in averaged_state:
        new_state[key] = averaged_state[key]
    model.load_state_dict(new_state)

    # Log average training metrics
    avg_train_loss = average_metrics(train_losses, client_weights)
    avg_train_acc = average_metrics(train_accs, client_weights)
    print(f"Avg Train Loss: {avg_train_loss:.4f}, Avg Train Accuracy: {avg_train_acc:.4f}")
    warmup_train_loss.append(avg_train_loss)
    warmup_train_acc.append(avg_train_acc)

    avg_test_loss, avg_test_acc = evaluate(model, test_loader, criterion, device)

    print(f"Avg Test Loss: {avg_test_loss:.4f}, Avg Test Accuracy: {avg_test_acc:.4f}")
    warmup_test_loss.append(avg_test_loss)
    warmup_test_acc.append(avg_test_acc)

torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f'pre_trained_clustered_model.pth'))

In [None]:
# Plot the training and test loss
plt.plot(warmup_train_loss, label='Train Loss')
plt.plot(warmup_test_loss, label='Test Loss')
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Plot the training and test accuracy
plt.plot(warmup_train_acc, label='Train Accuracy')
plt.plot(warmup_test_acc, label='Test Accuracy')
plt.xlabel('Rounds')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
# Sample per class for masking
samples_per_class = 5

In [None]:
# Load the pre-trained model
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, f'pre_trained_clustered_model.pth')))
# Unfreeze the backbone
unfreeze_backbone(model)
# Freeze the head
freeze_head(model)

sparsity = 0.10

# Compute the mask for each client
client_masks = {}
for client_id in tqdm(range(K)):
    client_masks[client_id] = mask_calculator(model, train_datasets[client_id], device, rounds=4, sparsity=sparsity,
                                              samples_per_class=samples_per_class, num_classes=5, verbose=False)

In [None]:
client_masks, frozen_state = freeze_and_clean_client_masks(model, client_masks, threshold=0.01, K=K)

# Save the client masks to a single file
torch.save(client_masks, os.path.join(CHECKPOINT_DIR, f'client_masks_cluster.pth'))
# Save the frozen state of the model
torch.save(frozen_state, os.path.join(CHECKPOINT_DIR, f'frozen_state_cluster_{int(sparsity * 100)}.pth'))

In [None]:
# Load the client masks from the file
client_masks = torch.load(os.path.join(CHECKPOINT_DIR, f'client_masks_cluster.pth'))
# Load frozen state from file
frozen_state = torch.load(os.path.join(CHECKPOINT_DIR, f'frozen_state_cluster_{int(sparsity * 100)}.pth'))

In [None]:
# Apply frozen state
for name, param in model.named_parameters():
    if name in frozen_state:
        param.requires_grad = False

In [None]:
def compute_cluster_majority_masks(cluster_to_clients, client_masks):
    cluster_masks = {}

    for cluster_id, client_ids in cluster_to_clients.items():
        num_clients = len(client_ids)

        # Use the structure from the first client in the cluster
        param_names = client_masks[client_ids[0]].keys()
        vote_counts = {
            name: torch.zeros_like(client_masks[client_ids[0]][name], dtype=torch.int32)
            for name in param_names
        }

        # Count how many clients keep each parameter
        for client_id in client_ids:
            client_mask = client_masks[client_id]
            for name in param_names:
                vote_counts[name] += client_mask[name].int()

        # Threshold: keep if kept by majority of clients
        threshold = (num_clients // 2) + 1

        # Compute final binary mask (torch.bool)
        majority_mask = {
            name: (count_tensor >= threshold).to(dtype=torch.bool)
            for name, count_tensor in vote_counts.items()
        }

        cluster_masks[cluster_id] = majority_mask

    return cluster_masks


In [None]:
cluster_masks = compute_cluster_majority_masks(cluster_to_clients, client_masks)

# Save clusters masks
torch.save(cluster_masks, os.path.join(CHECKPOINT_DIR, f'cluster_masks.pth'))

In [None]:
sparsity = 0.1

# Load cluster masks
cluster_masks = torch.load(os.path.join(CHECKPOINT_DIR, f'cluster_masks.pth'), weights_only=False)
# Load frozen state from file
frozen_state = torch.load(os.path.join(CHECKPOINT_DIR, f'frozen_state_cluster_{int(sparsity * 100)}.pth'))

In [None]:
base_model = get_dino_vits16_model(device)
base_model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, f'pre_trained_clustered_model.pth')))
freeze_head(base_model)
# Apply frozen state
for name, param in base_model.named_parameters():
    if name in frozen_state:
        param.requires_grad = False

C = 0.2
R = 100
P = 10

In [None]:
# --- Phase 1: FedAvg per cluster ---
cluster_models = {}  # cluster_id â†’ model after R rounds
for cluster_id, client_ids in cluster_to_clients.items():
    print(f"\n=== Federated Training in Cluster {cluster_id} ===")

    # Initialize cluster model
    model = copy.deepcopy(base_model).to(device)
    K = len(client_ids)

    cluster_mask = cluster_masks[cluster_id]

    for round in range(R):
        print(f"\n  [Cluster {cluster_id}] Round {round+1}/{R}")

        # Select clients from this cluster
        selected_clients = random.sample(client_ids, max(1, int(C * K)))

        # Local training
        local_models, train_losses, train_accs = [], [], []
        for client_id in selected_clients:
            state, loss, acc = train_on_client(
                client_id,
                model,
                train_datasets[client_id],
                J,
                criterion,
                lr=.01,
                device=device,
                mask=cluster_mask
            )
            local_models.append(state)
            train_losses.append(loss)
            train_accs.append(acc)

        # Weighted average
        sample_counts = [len(train_datasets[c]) for c in selected_clients]
        total_samples = sum(sample_counts)
        weights = [n / total_samples for n in sample_counts]

        keys = get_trainable_keys(model)
        averaged_state = average_models(local_models, weights, keys)

        # Update cluster model
        model.load_state_dict({**model.state_dict(), **averaged_state})

        # Optional: logging
        avg_loss = average_metrics(train_losses, weights)
        avg_acc = average_metrics(train_accs, weights)
        print(f"    Avg Train Loss: {avg_loss:.4f}, Avg Train Acc: {avg_acc:.4f}")

    cluster_models[cluster_id] = model

# --- Phase 2: Merge cluster models into global model ---
print("\n=== Merging Cluster Models into Final Global Model ===")

global_model = copy.deepcopy(base_model)
global_state = global_model.state_dict()

# Flatten all models and compute total sample count
all_states, all_weights = [], []
for cluster_id, model in cluster_models.items():
    state = model.state_dict()
    client_ids = cluster_to_clients[cluster_id]
    sample_count = sum(len(train_datasets[c]) for c in client_ids)
    all_states.append(state)
    all_weights.append(sample_count)
total = sum(all_weights)
weights = [n / total for n in all_weights]

keys = get_trainable_keys(global_model)
merged_state = average_models(all_states, weights, keys)

# Load into global model
for key in merged_state:
    global_state[key] = merged_state[key]
global_model.load_state_dict(global_state)

# Save global model
torch.save(global_model.state_dict(), os.path.join(CHECKPOINT_DIR, f'global_clustered_model.pth'))

In [None]:
# Test the global model on the test dataset
avg_test_loss, avg_test_acc = evaluate(global_model, test_loader, criterion, device)
print(f"Avg Test Loss: {avg_test_loss:.4f}, Avg Test Accuracy: {avg_test_acc:.4f}")

In [None]:
clust_losses = []
clust_accs = []

In [None]:
rounds_per_phase = R // P
cluster_models = {cluster_id: copy.deepcopy(base_model).to(device) for cluster_id in cluster_to_clients}
cluster_sample_counts = {cid: sum(len(train_datasets[c]) for c in client_ids)
                          for cid, client_ids in cluster_to_clients.items()}

for phase in range(P):
    print(f"\n=== Incremental Phase {phase + 1}/{P} ===")

    # --- Phase: cluster-local training for rounds_per_phase ---
    for cluster_id, client_ids in cluster_to_clients.items():
        print(f"\n--- Federated Training in Cluster {cluster_id} (Phase {phase+1}) ---")

        model = cluster_models[cluster_id]
        K = len(client_ids)
        cluster_mask = cluster_masks[cluster_id]

        for r in range(rounds_per_phase):
            round_idx = phase * rounds_per_phase + r
            print(f"\n  [Cluster {cluster_id}] Round {round_idx + 1}/{R}")

            # Select clients from this cluster
            selected_clients = random.sample(client_ids, max(1, int(C * K)))

            # Local training
            local_models, train_losses, train_accs = [], [], []
            for client_id in selected_clients:
                state, loss, acc = train_on_client(
                    client_id,
                    model,
                    train_datasets[client_id],
                    J,
                    criterion,
                    lr=0.01,
                    device=device,
                    mask=cluster_mask
                )
                local_models.append(state)
                train_losses.append(loss)
                train_accs.append(acc)

            # Weighted average
            sample_counts = [len(train_datasets[c]) for c in selected_clients]
            total_samples = sum(sample_counts)
            weights = [n / total_samples for n in sample_counts]

            keys = get_trainable_keys(model)
            averaged_state = average_models(local_models, weights, keys)

            # Update cluster model
            model.load_state_dict({**model.state_dict(), **averaged_state})

            # Logging
            avg_loss = average_metrics(train_losses, weights)
            avg_acc = average_metrics(train_accs, weights)
            print(f"    Avg Train Loss: {avg_loss:.4f}, Avg Train Acc: {avg_acc:.4f}")

        cluster_models[cluster_id] = model  # store updated model

    # --- Merge all cluster models into new global model ---
    print(f"\n=== Merging Cluster Models After Phase {phase + 1} ===")

    global_model = copy.deepcopy(base_model)
    global_state = global_model.state_dict()

    # Collect cluster states and compute weights
    all_states = []
    all_weights = []
    for cid, model in cluster_models.items():
        all_states.append(model.state_dict())
        all_weights.append(cluster_sample_counts[cid])
    total = sum(all_weights)
    weights = [w / total for w in all_weights]

    keys = get_trainable_keys(global_model)
    merged_state = average_models(all_states, weights, keys)

    for key in merged_state:
        global_state[key] = merged_state[key]
    global_model.load_state_dict(global_state)

    # --- Re-initialize all cluster models with merged global model ---
    for cid in cluster_models:
        cluster_models[cid].load_state_dict(global_model.state_dict())

    # Save intermediate global model
    torch.save(global_model.state_dict(),
                os.path.join(CHECKPOINT_DIR, f'global_model_phase{phase + 1}.pth'))
    
    # Test the global model on the test dataset
    avg_test_loss, avg_test_acc = evaluate(global_model, test_loader, criterion, device)
    print(f"Avg Test Loss: {avg_test_loss:.4f}, Avg Test Accuracy: {avg_test_acc:.4f}")
    clust_losses.append(avg_test_loss)
    clust_accs.append(avg_test_acc)