In [1]:
!module load ngc pytorch
!pip install --user torchvision



In [2]:
import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict

# ============================================================
# 1. Setup save directory
# ============================================================
save_dir = os.path.join(os.path.expanduser("~"), "FL_results_alpha_5avged")
os.makedirs(save_dir, exist_ok=True)


# ============================================================
# 2. CIFAR-10 Dataset (Non-IID Dirichlet Split, instructor’s alpha)
# ============================================================
def dirichlet_split_noniid(dataset, num_clients, alpha, seed=42):
    np.random.seed(seed)
    n_samples = len(dataset)

    # ----- Special case: alpha = 0 => IID -----
    if alpha == 0:
        all_indices = np.random.permutation(n_samples)
        splits = np.array_split(all_indices, num_clients)
        return {cid: splits[cid] for cid in range(num_clients)}

    # ----- alpha > 0 => Dirichlet-based non-IID -----
    dirichlet_conc = 1.0 / alpha
    dirichlet_conc = float(np.clip(dirichlet_conc, 1e-3, 1e3))

    if hasattr(dataset, "targets"):
        labels = np.array(dataset.targets)
    elif hasattr(dataset, "labels"):
        labels = np.array(dataset.labels)
    else:
        labels = np.array([dataset[i][1] for i in range(n_samples)])

    num_classes = int(labels.max()) + 1
    client_indices = defaultdict(list)

    for c in range(num_classes):
        class_idx = np.where(labels == c)[0]
        if len(class_idx) == 0:
            continue
        np.random.shuffle(class_idx)
        proportions = np.random.dirichlet(dirichlet_conc * np.ones(num_clients))
        split_points = (np.cumsum(proportions) * len(class_idx)).astype(int)
        class_split = np.split(class_idx, split_points[:-1])

        for client_id, idx in enumerate(class_split):
            client_indices[client_id].extend(idx.tolist())

    for client_id in range(num_clients):
        idx = np.array(client_indices[client_id], dtype=int)
        np.random.shuffle(idx)
        client_indices[client_id] = idx

    return client_indices

def get_cifar10_dirichlet_clients(data_root="./data", num_clients=10, alpha=1.0, batch_size=64, seed=42):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
    ])

    trainset = datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_train)
    client_idcs = dirichlet_split_noniid(trainset, num_clients, alpha, seed)

    client_loaders = {}
    for cid, idxs in client_idcs.items():
        subset = Subset(trainset, idxs)
        loader = DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
        client_loaders[cid] = loader

    return trainset, client_idcs, client_loaders

# ============================================================
# 3. Extract 18 leaf layers of ResNet-18
# ============================================================
def get_resnet18_layers(model):
    layers = []
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:
            layers.append((name, module))
    return layers[:18]

# ============================================================
# 4. Euclidean distance between layers
# ============================================================
def layer_distance(layer_global, layer_client):
    dist = 0.0
    for p1, p2 in zip(layer_global.parameters(), layer_client.parameters()):
        dist += torch.norm(p1.data - p2.data).item()
    return dist


# ============================================================
# 5. Local training with BatchNorm-safe handling
# ============================================================
def train_local(model, loader, epochs=2, lr=0.01, min_batch_for_bn=2):
    """
    model: PyTorch model
    loader: DataLoader for local client
    epochs: number of local epochs
    lr: learning rate
    min_batch_for_bn: minimum batch size required for BatchNorm
    """
    model = copy.deepcopy(model).cuda()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    
    for _ in range(epochs):
        for x, y in loader:
            x, y = x.cuda(), y.cuda()
            
            optimizer.zero_grad()
            
            # Check batch size for BatchNorm layers
            # If batch is too small, temporarily switch to eval mode
            if x.size(0) < min_batch_for_bn:
                model.eval()
                with torch.set_grad_enabled(True):
                    outputs = model(x)
                    loss = criterion(outputs, y)
                model.train()
            else:
                outputs = model(x)
                loss = criterion(outputs, y)
            
            loss.backward()
            optimizer.step()
    
    return model


# ============================================================
# 6. Load CIFAR-10 clients
# ============================================================
alpha = 5  # you can also use 5
num_clients = 10
train_dataset, client_idcs, client_loaders = get_cifar10_dirichlet_clients(
    data_root="./data", num_clients=num_clients, alpha=alpha, batch_size=64
)

# ============================================================
# 7. Initialize global and client models
# ============================================================
global_model = models.resnet18(num_classes=10).cuda()
client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]
res_layers = get_resnet18_layers(global_model)
num_layers = len(res_layers)
print(f"Using {num_layers} layers from ResNet-18")


# ============================================================
# 8. Federated training loop
# ============================================================

records = []
global_rounds = 100
local_epochs = 1

for r in range(global_rounds):
    print(f"Global Round {r+1}/{global_rounds}")
    updated_clients = []

    for i in range(num_clients):
        # Train local model
        updated_model = train_local(client_models[i], client_loaders[i], epochs=local_epochs)

        # Compute layer-wise distances BEFORE FedAvg
        for layer_id, (layer_name, _) in enumerate(res_layers):
            global_layer = dict(global_model.named_modules())[layer_name]
            client_layer = dict(updated_model.named_modules())[layer_name]
            dist = layer_distance(global_layer, client_layer)
            records.append([r, layer_id, i, dist])

        updated_clients.append(updated_model)

    # Perform FedAvg aggregation
    with torch.no_grad():
        for p_global, *p_clients in zip(global_model.parameters(), *[m.parameters() for m in updated_clients]):
            stacked = torch.stack([p.data for p in p_clients], dim=0)
            p_global.data.copy_(torch.mean(stacked, dim=0))

    # Update all client models to the new global model
    client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]


# ============================================================
# 9. Save CSV
# ============================================================
df = pd.DataFrame(records, columns=["round", "layer_id", "client_id", "distance"])
csv_path = f"{save_dir}/raw_distances.csv"
df.to_csv(csv_path, index=False)
print("Saved all raw distances to:", csv_path)

# ============================================================
# 10. Generate 18 plots
# ============================================================
for layer_id in range(num_layers):
    plt.figure(figsize=(7,5))
    for c in range(num_clients):
        d = df[(df.layer_id == layer_id) & (df.client_id == c)]
        plt.plot(d["round"], d["distance"], label=f"Client {c}")
    plt.title(f"Layer {layer_id} Euclidean Distance")
    plt.xlabel("Global Round")
    plt.ylabel("Distance")
    plt.grid(True)
    plt.legend()
    plot_path = f"{save_dir}/layer_{layer_id}.png"
    plt.savefig(plot_path, dpi=300)
    plt.close()

print("All plots saved in:", save_dir)


Using 18 layers from ResNet-18
Global Round 1/100
Global Round 2/100
Global Round 3/100
Global Round 4/100
Global Round 5/100
Global Round 6/100
Global Round 7/100
Global Round 8/100
Global Round 9/100
Global Round 10/100
Global Round 11/100
Global Round 12/100
Global Round 13/100
Global Round 14/100
Global Round 15/100
Global Round 16/100
Global Round 17/100
Global Round 18/100
Global Round 19/100
Global Round 20/100
Global Round 21/100
Global Round 22/100
Global Round 23/100
Global Round 24/100
Global Round 25/100
Global Round 26/100
Global Round 27/100
Global Round 28/100
Global Round 29/100
Global Round 30/100
Global Round 31/100
Global Round 32/100
Global Round 33/100
Global Round 34/100
Global Round 35/100
Global Round 36/100
Global Round 37/100
Global Round 38/100
Global Round 39/100
Global Round 40/100
Global Round 41/100
Global Round 42/100
Global Round 43/100
Global Round 44/100
Global Round 45/100
Global Round 46/100
Global Round 47/100
Global Round 48/100
Global Round 49/10

In [3]:
# labels = np.array(train_dataset.targets)

# for client_id, indices in client_idcs.items():
#     client_labels = labels[indices]
#     unique, counts = np.unique(client_labels, return_counts=True)
#     print(f"Client {client_id}: {dict(zip(unique, counts))}")

Client 0: {0: 1, 1: 2, 2: 12, 3: 381, 4: 87, 5: 282, 6: 925, 7: 700, 9: 1597}
Client 1: {0: 187, 1: 34, 2: 350, 3: 204, 4: 3949, 6: 155, 7: 11, 8: 626, 9: 1}
Client 2: {0: 1, 1: 1168, 2: 120, 3: 143, 4: 781, 6: 152, 8: 1}
Client 3: {0: 1540, 1: 414, 2: 1, 6: 21, 8: 1, 9: 293}
Client 4: {0: 3194, 2: 6, 5: 73, 6: 2250, 8: 27, 9: 173}
Client 5: {0: 10, 1: 170, 2: 2977, 3: 17, 4: 2, 5: 3466, 8: 3128}
Client 6: {0: 47, 2: 1223, 4: 6, 5: 37, 6: 1485, 7: 4234}
Client 7: {0: 6, 1: 3083, 2: 12, 3: 4225, 4: 10, 5: 913, 7: 51, 9: 40}
Client 8: {0: 7, 1: 111, 2: 296, 3: 29, 4: 164, 5: 166, 6: 7, 7: 3, 8: 993, 9: 2258}
Client 9: {0: 7, 1: 18, 2: 3, 3: 1, 4: 1, 5: 63, 6: 5, 7: 1, 8: 224, 9: 638}


Mean plots saved in: /home/mkarunar/FL_results_alpha_5avged/mean_plots


In [5]:
# ============================================================
# 10A. Plot ALL layers in ONE plot (excluding client 8)
# ============================================================

EXCLUDED_CLIENT = 8

plt.figure(figsize=(10, 6))

for layer_id in range(num_layers):
    # Filter out client 8
    d = df[(df.layer_id == layer_id) & (df.client_id != EXCLUDED_CLIENT)]
    
    # Compute mean across remaining clients
    mean_d = d.groupby("round")["distance"].mean()

    plt.plot(mean_d.index, mean_d.values, label=f"Layer {layer_id}")

plt.title("Mean Euclidean Distance Across Global Rounds (α = 5)")
plt.xlabel("Global Round")
plt.ylabel("Mean Euclidean Distance")
plt.grid(True)
plt.legend(ncol=2, fontsize=8)

combined_plot_path = f"{save_dir}/all_layers_one_plot_no_client8.png"
plt.savefig(combined_plot_path, dpi=300)
plt.close()

print("Saved combined plot (client 8 excluded) to:", combined_plot_path)
# alpha - 5

Saved combined plot (client 8 excluded) to: /home/mkarunar/FL_results_alpha_5avged/all_layers_one_plot_no_client8.png


In [6]:
labels = np.array(train_dataset.targets)

for client_id, indices in client_idcs.items():
    client_labels = labels[indices]
    unique, counts = np.unique(client_labels, return_counts=True)
    print(f"Client {client_id}: {dict(zip(unique, counts))}")

Client 0: {0: 1, 1: 2, 2: 12, 3: 381, 4: 87, 5: 282, 6: 925, 7: 700, 9: 1597}
Client 1: {0: 187, 1: 34, 2: 350, 3: 204, 4: 3949, 6: 155, 7: 11, 8: 626, 9: 1}
Client 2: {0: 1, 1: 1168, 2: 120, 3: 143, 4: 781, 6: 152, 8: 1}
Client 3: {0: 1540, 1: 414, 2: 1, 6: 21, 8: 1, 9: 293}
Client 4: {0: 3194, 2: 6, 5: 73, 6: 2250, 8: 27, 9: 173}
Client 5: {0: 10, 1: 170, 2: 2977, 3: 17, 4: 2, 5: 3466, 8: 3128}
Client 6: {0: 47, 2: 1223, 4: 6, 5: 37, 6: 1485, 7: 4234}
Client 7: {0: 6, 1: 3083, 2: 12, 3: 4225, 4: 10, 5: 913, 7: 51, 9: 40}
Client 8: {0: 7, 1: 111, 2: 296, 3: 29, 4: 164, 5: 166, 6: 7, 7: 3, 8: 993, 9: 2258}
Client 9: {0: 7, 1: 18, 2: 3, 3: 1, 4: 1, 5: 63, 6: 5, 7: 1, 8: 224, 9: 638}


In [15]:
# ============================================================
# 11. Generate averaged plots (ALL clients, ALL layers in ONE plot)
# ============================================================

mean_plot_dir = os.path.join(save_dir, "mean_plots_with all clients")
os.makedirs(mean_plot_dir, exist_ok=True)

plt.figure(figsize=(9, 6))

for layer_id in range(num_layers):

    # Filter rows for this layer only
    d_layer = df[df.layer_id == layer_id]

    # Mean distance over clients per round
    mean_curve = d_layer.groupby("round")["distance"].mean()

    # Plot this layer’s mean curve
    plt.plot(mean_curve.index, mean_curve.values, linewidth=2, label=f"Layer {layer_id}")

plt.title("Mean Euclidean Distance Across Global Rounds (α = 5)")
plt.xlabel("Global Round")
plt.ylabel("Mean Euclidean Distance")
plt.grid(True)
plt.legend()

plot_path = f"{mean_plot_dir}/mean_all_layers.png"
plt.savefig(plot_path, dpi=300)
plt.close()

print("Plot saved at:", plot_path)
print("Mean plots saved in:", mean_plot_dir)



Plot saved at: /home/mkarunar/FL_results_alpha_5avged/mean_plots_with all clients/mean_all_layers.png
Mean plots saved in: /home/mkarunar/FL_results_alpha_5avged/mean_plots_with all clients


In [17]:
plt.figure(figsize=(9, 6))

for layer_id in range(num_layers):

    # Filter rows for this layer only
    d_layer = df[df.layer_id == layer_id]

    # Mean distance over clients per round
    mean_curve = d_layer.groupby("round")["distance"].mean()

    # Plot this layer’s mean curve
    plt.plot(mean_curve.index, mean_curve.values, linewidth=2, label=f"Layer {layer_id}")

plt.title("Mean Euclidean Distance Across Global Rounds (α = 5)")
plt.xlabel("Global Round")
plt.ylabel("Mean Euclidean Distance")
plt.grid(True)

# Right-aligned legend with white background
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), facecolor="white", framealpha=1)

plot_path = f"{mean_plot_dir}/mean_all_layers.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')  # bbox_inches ensures legend isn't cut off
plt.close()

print("Plot saved at:", plot_path)
print("Mean plots saved in:", mean_plot_dir)


Plot saved at: /home/mkarunar/FL_results_alpha_5avged/mean_plots_with all clients/mean_all_layers.png
Mean plots saved in: /home/mkarunar/FL_results_alpha_5avged/mean_plots_with all clients
