In [None]:
# ── A. Imports & Reproducibility ────────────────────────────────────────────────
import os, copy
import csv                                                  # For result logging :contentReference[oaicite:0]{index=0}
import random                                               # For seeding :contentReference[oaicite:1]{index=1}
import numpy as np                                          # For numeric ops :contentReference[oaicite:2]{index=2}
import torch                                               # Core PyTorch :contentReference[oaicite:3]{index=3}
import torch.nn as nn                                       # Neural-net modules :contentReference[oaicite:4]{index=4}
import torch.nn.functional as F                             # Functional API :contentReference[oaicite:5]{index=5}
import torch.optim as optim                                 # Optimizers :contentReference[oaicite:6]{index=6}
from torch.optim.lr_scheduler import CosineAnnealingLR      # Scheduler :contentReference[oaicite:7]{index=7}
from torch.utils.data import DataLoader, random_split       # Data loaders & splits :contentReference[oaicite:8]{index=8}
import torchvision                                          # Datasets & transforms :contentReference[oaicite:9]{index=9}
import torchvision.transforms as T                          # Transforms :contentReference[oaicite:10]{index=10}
from torch.utils.tensorboard import SummaryWriter           # TensorBoard logging :contentReference[oaicite:11]{index=11}
import matplotlib.pyplot as plt                             # Plotting :contentReference[oaicite:12]{index=12}

In [None]:
# Seed everything for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


# ── B. Device ───────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")                             # Confirm GPU vs CPU :contentReference[oaicite:13]{index=13}



# ── C. Data Preparation ─────────────────────────────────────────────────────────
# Transforms
transform_train = T.Compose([
    T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761)),
])
transform_test = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761)),
])


Using device: cpu


In [None]:
import glob, torch, os

def latest_ckpt(dirpath, pattern=None):
    """
    If you pass pattern=None, we'll look for anything matching
    *_last_ckpt_round_*.pth and return the numerically latest file.
    """
    if pattern is None:
        pattern = "*_last_ckpt_round_*.pth"
    paths = glob.glob(os.path.join(dirpath, pattern))
    if not paths:
        return None
    # Extract the round number from each filename, assuming it ends in _<round>.pth
    def round_num(p):
        base = os.path.basename(p)
        # splits on underscores: shard_J{J}_last_ckpt_round_{rnd}.pth
        return int(base.rsplit("_",1)[1].split(".")[0])
    paths.sort(key=round_num)
    return paths[-1]


def load_checkpoint(model, optimizer, ckpt_dir, shard_key, J, resume=True):
    if not resume:
        return 1
    pat = f"{shard_key}_J{J}_last_ckpt_round_*.pth"
    ckpt_path = latest_ckpt(ckpt_dir, pat)
    if ckpt_path is None:
        return 1
    ckpt = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt['model_state'])
    optimizer.load_state_dict(ckpt['optimizer_state'])
    rng_state = ckpt['rng_state']
    if rng_state.device.type != 'cpu':
        rng_state = rng_state.cpu()
    torch.set_rng_state(rng_state)
    return ckpt['round'] + 1



def save_checkpoint(model, optimizer, round_num, ckpt_dir,
                    shard_key, J, is_best=False):
    """
    Saves:
      <shard_key>_J<J>_last_ckpt_round_<round_num>.pth
    and if is_best=True, also
      <shard_key>_J<J>_best_ckpt.pth
    """
    prefix = f"{shard_key}_J{J}"
    state = {
        'round': round_num,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'rng_state': torch.get_rng_state(),
    }
    # unique filename per config
    last_name = f"{prefix}_last_ckpt_round_{round_num}.pth"
    torch.save(state, os.path.join(ckpt_dir, last_name))
    if is_best:
        best_name = f"{prefix}_best_ckpt.pth"
        torch.save(model.state_dict(), os.path.join(ckpt_dir, best_name))
    print(f"[Checkpoint] Saved {last_name}")


In [None]:
# ── C. Data Preparation ─────────────────────────────────────────────────────────
# Transforms (as before)…

# Download full CIFAR‑100 training set
full_train = torchvision.datasets.CIFAR100(
    root='./data', train=True, download=True, transform=transform_train
)

# 1) Centralized validation split
val_size   = 5000
train_size = len(full_train) - val_size
train_dataset, val_dataset = random_split(
    full_train,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(seed)
)

# ── C.1 Build validation loader ───────────────────────────────
BS_VAL = 256
val_loader = DataLoader(
    val_dataset,
    batch_size=BS_VAL,
    shuffle=False,
    num_workers=2
)

# ── C.2 Non-IID Sharding Helper ────────────────────────────────────────────────

from collections import defaultdict
from torch.utils.data import Subset

def create_labelwise_shards(dataset, K, Nc, seed=42):
    # 1) Group indices by label
    label2idx = defaultdict(list)
    for idx, (_, lbl) in enumerate(dataset):
        label2idx[lbl].append(idx)

    # 2) Shuffle each label’s pool
    rng = random.Random(seed)
    for lbl in label2idx:
        rng.shuffle(label2idx[lbl])

    # 3) Prepare an iterator per label
    pointers = {lbl: 0 for lbl in label2idx}

    # 4) Build shards
    samples_per_client = len(dataset) // K
    shards_idx = []
    labels_cycle = list(label2idx.keys())

    for client_id in range(K):
        client_idxs = []
        # Rotate start point so clients don’t always pick the same first label
        rng.shuffle(labels_cycle)
        for lbl in labels_cycle:
            if len(client_idxs) >= samples_per_client:
                break
            # How many to take from this label
            needed = samples_per_client - len(client_idxs)
            available = len(label2idx[lbl]) - pointers[lbl]
            take = min(needed, available)
            if take > 0:
                start = pointers[lbl]
                end   = start + take
                client_idxs.extend(label2idx[lbl][start:end])
                pointers[lbl] += take
        # If we still haven’t reached samples_per_client (rare), fill randomly
        if len(client_idxs) < samples_per_client:
            all_remaining = [i for lbl in label2idx
                                 for i in label2idx[lbl][pointers[lbl]:]]
            client_idxs.extend(rng.sample(all_remaining,
                                          samples_per_client - len(client_idxs)))
        shards_idx.append(client_idxs)

    return [Subset(dataset, idxs) for idxs in shards_idx]



# ── C.3 Build All Shardings ────────────────────────────────────────────────────
K       = 100
base  = train_size // K

sizes   = [base] * (K - 1) + [train_size - base * (K - 1)]
iid_shards = random_split(
    train_dataset, sizes,
    generator=torch.Generator().manual_seed(seed)
)

# Non-IID for Nc in {1,5,10,50}
Nc_list      = [1, 5, 10, 50]
shardings    = {'iid': iid_shards}
for Nc in Nc_list:
    shardings[f'non_iid_{Nc}'] = create_labelwise_shards(
        train_dataset, K=K, Nc=Nc, seed=seed
    )

# Now `shardings` is a dict mapping:
#   'iid'          → list of 100 Subsets (IID)
#   'non_iid_1'    → 100 shards each with 1 label
#   'non_iid_5'    → 100 shards each with 5 labels
#   etc.

# ── C.4 Per-Client DataLoaders ────────────────────────────────────────────────
# (You can build these on-the-fly inside your training loops;
#  or precompute for each sharding if memory allows.)




100%|██████████| 169M/169M [00:06<00:00, 24.3MB/s]


In [None]:
# ── D. Model Definition ─────────────────────────────────────────────────────────
class LELeNetCIFAR(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(2)
        self.fc1   = nn.Linear(64*8*8, 384)
        self.fc2   = nn.Linear(384, 192)
        self.fc3   = nn.Linear(192, 100)
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x)); x = F.relu(self.fc2(x))
        return self.fc3(x)

In [None]:
# ── E. Utilities: Train/Eval & Checkpointing ────────────────────────────────────
def train_one_epoch(model, optimizer, criterion, loader):
    model.train()
    running_loss = correct = total = 0
    for imgs, lbls in loader:
        imgs, lbls = imgs.to(device), lbls.to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, lbls)
        loss.backward(); optimizer.step()
        running_loss += loss.item()*imgs.size(0)
        correct += out.argmax(1).eq(lbls).sum().item()
        total += lbls.size(0)
    return running_loss/total, correct/total

def eval_model(model, criterion, loader):
    model.eval()
    running_loss = correct = total = 0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            out = model(imgs); loss = criterion(out, lbls)
            running_loss += loss.item()*imgs.size(0)
            correct += out.argmax(1).eq(lbls).sum().item()
            total += lbls.size(0)
    return running_loss/total, correct/total


def sample_clients_dirichlet(K, m, gamma, rng):
    """
    Sample m out of K client indices, with probabilities:
      • uniform if gamma == 'uniform'
      • drawn from Dirichlet([gamma]*K) otherwise.
    Returns:
      selected: list of m client indices
      p:        length-K numpy array of sampling probs (sums to 1)
    """
    if gamma == 'uniform':
        p = np.ones(K) / K
    else:
        alpha = np.ones(K) * gamma
        p     = rng.dirichlet(alpha)
    selected = rng.choice(K, size=m, replace=False, p=p)
    return selected.tolist(), p



In [None]:
# ── Configuration Summary & Utilities for FedAvg ──────────────────────────────

import os, sys, platform, time
import torch
import numpy as np
import random
from torch.utils.tensorboard import SummaryWriter



# 3) Hyperparameters
K, C = 100, 0.1
BS, BS_VAL = 128, 256
LR, WD = 0.01, 1e-4
J0, R0   = 4, 2000
budget   = J0 * R0
J_list   = [4, 8, 16]



# 1) Define and instantiate your TensorBoard writer
log_dir = f"./logs/FedAvg_lr{LR}_wd{WD}_bs{BS}"
tb_writer = SummaryWriter(log_dir=log_dir)

# 2) Summary utility
def summarize_run(cfg, client_loaders, test_loader, writer=None):
    """
    Print and log summary for a FedAvg run.
    """
    ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    print(f"\n========== FEDAVG RUN SUMMARY ({ts}) ==========")
    # Hyperparameters
    for key in ['lr','weight_decay','batch_size','K','C','J','ROUNDS']:
        print(f"    • {key}: {cfg[key]}")
    # Data info
    num_clients  = len(client_loaders)
    shard_size   = len(client_loaders[0].dataset)
    test_samples = len(test_loader.dataset)
    print(f"    • clients (K): {num_clients}, shard size: {shard_size}")
    print(f"    • test samples: {test_samples}, batch size: {cfg['batch_size']}")
    # Log to TensorBoard
    if writer:
        for key in ['lr','weight_decay','batch_size','K','C','J','ROUNDS']:
            writer.add_text(f'RunInfo/{key}', str(cfg[key]), 0)

# 3) Checkpoint utility
ckpt_dir = './checkpoints'
os.makedirs(ckpt_dir, exist_ok=True)


# ── Example Usage ────────────────────────────────────────────────────────────────

cfg = {
    'lr':           LR,
    'weight_decay': WD,
    'batch_size':   BS,
    'K':            K,
    'C':            C,
    'J0':            J0,
    'ROUNDS':       R0

}



In [None]:

# ── A. Mount Google Drive ─────────────────────────────────────────────────────
from google.colab import drive
import os

drive.mount('/content/drive')
# Point to a folder inside your Drive for persistent checkpoints
CKPT_DIR = '/content/drive/MyDrive/fl_checkpoints'
os.makedirs(CKPT_DIR, exist_ok=True)




# Set this to True to resume from the last checkpoint; False to start from scratch
RESUME = True



# 3) Hyperparameters
K, C = 100, 0.1
BS, BS_VAL = 128, 256
LR, WD = 0.01, 1e-4
J0, R0   = 4, 2000
budget   = J0 * R0
J_list   = [4, 8, 16]

# 4) Transforms & Data
transform_train = T.Compose([
    T.RandomCrop(32, 4), T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761)),
])
transform_test = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761)),
])

full_train = torchvision.datasets.CIFAR100(
    './data', train=True, download=True, transform=transform_train
)
val_size   = 5000
train_size = len(full_train) - val_size
train_dataset, val_dataset = random_split(
    full_train, [train_size, val_size],
    generator=torch.Generator().manual_seed(seed)
)
val_loader = DataLoader(val_dataset, batch_size=BS_VAL, shuffle=False, num_workers=2)

test_dataset = torchvision.datasets.CIFAR100(
    './data', train=False, download=True, transform=transform_test
)
test_loader = DataLoader(test_dataset, batch_size=BS_VAL, shuffle=False, num_workers=2)



# ── Instantiate TensorBoard writer ──────────────────────────────────────────────
from torch.utils.tensorboard import SummaryWriter
log_dir   = f"./logs/FedAvg_lr{LR}_wd{WD}_bs{BS}"
tb_writer = SummaryWriter(log_dir=log_dir)

# ── B. CSV Logging Setup ────────────────────────────────────────────────────────
import csv, os
csv_path = './fedavg_results.csv'



# ── B. CSV Logging Setup ────────────────────────────────────────────────────────
import csv, os
csv_path = './fedavg_results.csv'
if not os.path.exists(csv_path):
    with open(csv_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            'round',
            'val_loss', 'val_acc',
            'test_loss', 'test_acc'
        ])

# Before your FedAvg loop: Instantiate the global model, loss, and client loaders once


# 6) Build all shardings
base = train_size // K
sizes = [base]*(K-1) + [train_size-base*(K-1)]
iid_shards = random_split(train_dataset, sizes,
                          generator=torch.Generator().manual_seed(seed))

Nc_list   = [1, 5, 10, 50]
shardings = {'iid': iid_shards}
for Nc in Nc_list:
    shardings[f'non_iid_{Nc}'] = create_labelwise_shards(
        train_dataset, K, Nc, seed
    )

cfg = {'lr': LR, 'weight_decay': WD, 'batch_size': BS, 'K': K, 'C': C}





rng = np.random.default_rng(seed)





# ── C. FedAvg Training Loop ─────────────────────────────────────────────────────
# ── F. FedAvg Experiment Loop with Drive Checkpoints ──────────────────────────
# Assumes `shardings`, `J_list`, `budget`, `load_checkpoint`, `save_checkpoint`, etc. are defined

# ── D: Training Loop ────────────────────────────────────────────────────────────
all_results = {}

for shard_key, shards in shardings.items():
    client_loaders = [
        DataLoader(shards[i], batch_size=BS, shuffle=True, num_workers=2)
        for i in range(K)
    ]

    for J in J_list:
        ROUNDS_scaled = budget // J
        cfg.update({'J': J, 'ROUNDS': ROUNDS_scaled})
        summarize_run(cfg, client_loaders, test_loader, writer=tb_writer)

        # Initialize model & optimizer, then resume if any
        model     = LELeNetCIFAR().to(device)
        optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=WD)
        criterion = nn.CrossEntropyLoss()
        start_round = load_checkpoint(
            model, optimizer,
            ckpt_dir=CKPT_DIR,
            shard_key=shard_key, J=J,
            resume=RESUME
        )

        acc_hist = []
        print(f"[{shard_key}|J={J}] Starting from round {start_round}/{ROUNDS_scaled}")

        for rnd in range(start_round, ROUNDS_scaled + 1):
            m = max(1, int(C * K))
            selected, _ = sample_clients_dirichlet(K, m, gamma='uniform', rng=rng)

            # Local updates
            local_accuracies = []
            local_states, sizes = [], []
            for cid in selected:
                cm    = copy.deepcopy(model)
                opt_c = optim.SGD(cm.parameters(), lr=LR, momentum=0.9, weight_decay=WD)
                for _ in range(J):
                    train_loss, train_acc = train_one_epoch(cm, opt_c, criterion, client_loaders[cid])
                local_states.append(cm.state_dict())
                sizes.append(len(shards[cid]))
                local_accuracies.append(train_acc)


            # Compute local mean & std
            local_mean = np.mean(local_accuracies)
            local_std  = np.std(local_accuracies)

            # Aggregate
            total = sum(sizes)
            new_st = {
                k: sum((sizes[i] / total) * local_states[i][k] for i in range(len(sizes)))
                for k in model.state_dict()
            }
            model.load_state_dict(new_st)

            # Eval
            _, val_acc = eval_model(model, criterion, val_loader)
            acc_hist.append(val_acc)

            # 4) Evaluation
            val_loss, val_acc   = eval_model(model, criterion, val_loader)
            test_loss, test_acc = eval_model(model, criterion, test_loader)
            acc_hist.append(test_acc)

            # 5) Timing
            round_time = time.time() - round_start



            # Logging
            # 6) CSV Logging
            with open(csv_path, 'a', newline='') as f:
                csv.writer(f).writerow([
                    shard_key, J, rnd,
                    f"{round_time:.2f}",
                    f"{local_mean:.4f}", f"{local_std:.4f}",
                    f"{val_loss:.4f}", f"{val_acc:.4f}",
                    f"{test_loss:.4f}", f"{test_acc:.4f}"
                ])

            # 7) TensorBoard Logging
            tb_writer.add_scalar(f"{shard_key}/J{J}_round_time", round_time, rnd)
            tb_writer.add_scalar(f"{shard_key}/J{J}_local_acc_mean", local_mean, rnd)
            tb_writer.add_scalar(f"{shard_key}/J{J}_local_acc_std",  local_std,  rnd)
            tb_writer.add_scalar(f"{shard_key}/J{J}_val_loss",       val_loss,    rnd)
            tb_writer.add_scalar(f"{shard_key}/J{J}_val_acc",        val_acc,     rnd)
            tb_writer.add_scalar(f"{shard_key}/J{J}_test_loss",      test_loss,   rnd)
            tb_writer.add_scalar(f"{shard_key}/J{J}_test_acc",       test_acc,    rnd)

            # 8) Print per-round summary
            print(
                f"[{shard_key} | J={J}] Round {rnd}/{ROUNDS_scaled} "
                f"Time={round_time:.1f}s | "
                f"Local Acc={local_mean:.3f}±{local_std:.3f} | "
                f"Val Acc={val_acc:.3f} | Test Acc={test_acc:.3f}"
            )



            # Checkpoint every 20 rounds (or at first)
            if rnd == start_round or rnd % 20 == 0:
              # 1) Log a header with the current config and round
              print(f"[{shard_key} | J={J}] Checkpointing at round {rnd}/{ROUNDS_scaled}")

              # 2) Print the latest metrics
              print(f"    → Last Val Acc: {val_acc:.4f} | Last Val Loss: {val_loss:.4f}")
              print(f"    → Last Test Acc: {test_acc:.4f} | Last Test Loss: {test_loss:.4f}")

              # 3) Save the checkpoint
              save_checkpoint(
                  model, optimizer,
                  round_num=rnd,
                  ckpt_dir=CKPT_DIR,
                  shard_key=shard_key, J=J,
                  is_best=False
              )
              print(f"[{shard_key} | J={J}] Checkpoint saved.\n")


        all_results[(shard_key, J)] = np.array(acc_hist)



MessageError: Error: credential propagation was unsuccessful

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# 1) Load results
df = pd.read_csv('./fedavg_results.csv', names=['shard','J','round','val_acc'], header=0)

# 2) Pivot final validation accuracy per (shard,J)
final = df.groupby(['shard','J']).last().reset_index()
pivot = final.pivot(index='shard', columns='J', values='val_acc')
print("Final Validation Accuracy:")
display(pivot)

# 3) Plot convergence curves for each (shard, J)
plt.figure(figsize=(10,6))
for (shard, J), group in df.groupby(['shard','J']):
    plt.plot(group['round'], group['val_acc'], label=f"{shard}, J={J}")
plt.xlabel('Federated Round')
plt.ylabel('Validation Accuracy')
plt.title('FedAvg Convergence: IID vs Non-IID, varying J')
plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')
plt.tight_layout()
plt.show()

# 4) (Optional) Rolling average smoothing
plt.figure(figsize=(10,6))
for (shard, J), group in df.groupby(['shard','J']):
    sm = group['val_acc'].rolling(window=50, min_periods=1).mean()
    plt.plot(group['round'], sm, label=f"{shard}, J={J}")
plt.xlabel('Federated Round')
plt.ylabel('Smoothed Validation Acc')
plt.title('Smoothed Convergence Curves')
plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')
plt.tight_layout()
plt.show()

# 5) Plot data‐distribution boxplots (to verify your non-IID splits)
from collections import Counter

for key, shards in shardings.items():
    # count labels per client
    counts = [Counter([full_train[i][1] for i in subset]) for subset in shards]
    # convert to list of lists for boxplot
    per_class_counts = []
    for lbl in range(100):
        per_class_counts.append([c[lbl] for c in counts])
    plt.figure(figsize=(8,4))
    plt.boxplot(per_class_counts, whis=(5,95), showfliers=False)
    plt.title(f"Label counts per client — {key}")
    plt.xlabel("Class label")
    plt.ylabel("Examples per client")
    plt.tight_layout()
    plt.show()


In [None]:
#ATTENZIONE!!!!
#OCCHIO A RUNNAREEEE



#utility function that deletes all checkpoints from the checkpoint folder

def clear_checkpoints(ckpt_dir):
    """
    Remove all checkpoint files in the specified dire ctory.
    """
    removed = 0
    for fname in os.listdir(ckpt_dir):
        path = os.path.join(ckpt_dir, fname)
        if os.path.isfile(path):
            os.remove(path)
            removed += 1
    print(f"[Checkpoint] Cleared {removed} files from {ckpt_dir}")


clear_checkpoints(CKPT_DIR)
