# Federated Learning on CIFAR-10: Data Exploration and Results Analysis

This notebook provides a comprehensive overview of the data and results for federated learning experiments (FedAvg and FedProx) on CIFAR-10.

In [None]:
# Import Required Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torchvision import datasets, transforms
from collections import Counter
display_available = True
try:
    from IPython.display import display
except ImportError:
    display_available = False

## 1. Dataset Overview and Visualization

Explore the CIFAR-10 dataset: class distribution and sample images.

In [None]:
# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = datasets.CIFAR10(root="../data", train=True, download=True, transform=transform)
classes = trainset.classes
labels = np.array(trainset.targets)

# Plot class distribution
plt.figure(figsize=(8,4))
sns.countplot(x=labels)
plt.title("CIFAR-10 Class Distribution (Train)")
plt.xticks(ticks=range(10), labels=classes, rotation=45)
plt.show()

# Show sample images per class
fig, axes = plt.subplots(2, 5, figsize=(12,5))
for i, cls in enumerate(classes):
    idx = np.where(labels == i)[0][0]  # Get the first index for class i
    img = trainset.data[idx]
    axes[i//5, i%5].imshow(img)
    axes[i//5, i%5].set_title(cls)
    axes[i//5, i%5].axis('off')
plt.suptitle("Sample Images per Class")
plt.show()

## 2. Data Partitioning: IID vs Non-IID

Visualize how data is split among clients for both IID and non-IID settings. Show class distribution for each client.

In [None]:
# Example: Non-IID partitioning (2 classes per client)
from collections import defaultdict
num_clients = 10
classes_per_client = 2
np.random.seed(42)
labels = np.array(trainset.targets)
class_indices = [np.where(labels == i)[0] for i in range(10)]
client_indices = [[] for _ in range(num_clients)]
all_classes = np.arange(10)
for client in range(num_clients):
    chosen_classes = np.random.choice(all_classes, classes_per_client, replace=False)
    for cls in chosen_classes:
        idxs = np.random.choice(class_indices[cls], len(class_indices[cls]) // num_clients, replace=False)
        client_indices[client].extend(idxs)
        class_indices[cls] = np.setdiff1d(class_indices[cls], idxs)

# Plot class distribution per client
fig, axes = plt.subplots(2, 5, figsize=(16,6), sharey=True)
for client in range(num_clients):
    client_labels = labels[client_indices[client]]
    counts = [np.sum(client_labels == i) for i in range(10)]
    ax = axes[client//5, client%5]
    ax.bar(range(10), counts)
    ax.set_xticks(range(10))
    ax.set_xticklabels(classes, rotation=45, fontsize=8)
    ax.set_title(f"Client {client}")
plt.suptitle("Class Distribution per Client (Non-IID)")
plt.tight_layout()
plt.show()

## 3. Data Integrity Checks and Summary

Check for data leakage, print number of samples per client, and show summary statistics.

In [None]:
# Check for data leakage (overlap between clients)
all_indices = np.concatenate(client_indices)
unique_indices = np.unique(all_indices)
print(f"Total samples assigned: {len(all_indices)}")
print(f"Unique samples assigned: {len(unique_indices)}")
if len(all_indices) == len(unique_indices):
    print("No data leakage detected (no overlap between clients).")
else:
    print("Warning: Data leakage detected!")

# Print number of samples per client
for client in range(num_clients):
    print(f"Client {client}: {len(client_indices[client])} samples")

## 4. Training and Test Curves: FedAvg vs FedProx

Plot accuracy and loss vs. communication rounds for both algorithms.

In [None]:
# Load experiment logs (assume CSVs saved by each experiment)
fedavg_log = pd.read_csv("../results/fedavg_metrics.csv")
fedprox_log = pd.read_csv("../results/fedprox_metrics.csv")

# Plot accuracy vs. communication rounds
plt.figure(figsize=(8,5))
plt.plot(fedavg_log["accuracy"], label="FedAvg")
plt.plot(fedprox_log["accuracy"], label="FedProx")
plt.title("Test Accuracy vs Communication Rounds")
plt.xlabel("Round")
plt.ylabel("Test Accuracy")
plt.legend()
plt.grid(True)
plt.show()

# Plot loss vs. communication rounds
plt.figure(figsize=(8,5))
plt.plot(fedavg_log["loss"], label="FedAvg")
plt.plot(fedprox_log["loss"], label="FedProx")
plt.title("Test Loss vs Communication Rounds")
plt.xlabel("Round")
plt.ylabel("Test Loss")
plt.legend()
plt.grid(True)
plt.show()

## 5. Per-Client Performance

Analyze and visualize accuracy/loss per client to show heterogeneity effects.

In [None]:
# Plot per-client accuracy (if available)
if "client_accuracies" in fedavg_log.columns and "client_accuracies" in fedprox_log.columns:
    fedavg_client_acc = fedavg_log["client_accuracies"].apply(eval).tolist()  # list of lists
    fedprox_client_acc = fedprox_log["client_accuracies"].apply(eval).tolist()
    rounds = len(fedavg_client_acc)
    num_clients = len(fedavg_client_acc[0])
    plt.figure(figsize=(10,6))
    for client in range(num_clients):
        plt.plot([fedavg_client_acc[r][client] for r in range(rounds)], label=f"FedAvg Client {client}", alpha=0.5)
    plt.title("FedAvg: Per-Client Accuracy vs Rounds")
    plt.xlabel("Round")
    plt.ylabel("Accuracy")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    plt.show()
    plt.figure(figsize=(10,6))
    for client in range(num_clients):
        plt.plot([fedprox_client_acc[r][client] for r in range(rounds)], label=f"FedProx Client {client}", alpha=0.5)
    plt.title("FedProx: Per-Client Accuracy vs Rounds")
    plt.xlabel("Round")
    plt.ylabel("Accuracy")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    plt.show()
else:
    print("Per-client accuracy not available in logs.")

## 6. Confusion Matrix and Misclassified Images

Visualize confusion matrix and show examples of misclassified images for the final global model.

In [None]:
# Example: Compute and plot confusion matrix for final FedAvg model
from sklearn.metrics import confusion_matrix
import itertools

def plot_confusion_matrix(cm, classes, title='Confusion matrix'):
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.title(title)
    plt.show()

# Assume y_true and y_pred are available from test set evaluation
y_true = fedavg_log.get('y_true', None)
y_pred = fedavg_log.get('y_pred', None)
if y_true is not None and y_pred is not None:
    cm = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm, classes)
else:
    print("Confusion matrix data not available in logs.")

# Show misclassified images (if available)
# Assume misclassified = list of (img, true_label, pred_label)
misclassified = fedavg_log.get('misclassified', None)
if misclassified is not None:
    n = min(10, len(misclassified))
    fig, axes = plt.subplots(1, n, figsize=(15,3))
    for i in range(n):
        img, true_label, pred_label = misclassified[i]
        axes[i].imshow(img)
        axes[i].set_title(f"T:{classes[true_label]}\nP:{classes[pred_label]}")
        axes[i].axis('off')
    plt.suptitle("Examples of Misclassified Images (FedAvg)")
    plt.show()
else:
    print("Misclassified images not available in logs.")

## 7. Experiment Configuration and Reproducibility

Display all experiment parameters and random seeds used for each run.

In [None]:
# Display experiment configuration
import yaml
with open("../config/fedavg_config.yaml") as f:
    fedavg_cfg = yaml.safe_load(f)
with open("../config/fedprox_config.yaml") as f:
    fedprox_cfg = yaml.safe_load(f)
print("FedAvg Config:")
print(fedavg_cfg)
print("\nFedProx Config:")
print(fedprox_cfg)

## 8. Summary Table of Results

Tabulate key results (final accuracy, rounds to reach target accuracy, etc.) for all algorithms/settings.

In [None]:
# Create summary table of results
summary = {
    "Algorithm": ["FedAvg", "FedProx"],
    "Final Accuracy": [fedavg_log["accuracy"].iloc[-1], fedprox_log["accuracy"].iloc[-1]],
    "Best Accuracy": [fedavg_log["accuracy"].max(), fedprox_log["accuracy"].max()],
    "Rounds to 70% Acc": [next((i for i, acc in enumerate(fedavg_log["accuracy"]) if acc >= 0.7), None),
                          next((i for i, acc in enumerate(fedprox_log["accuracy"]) if acc >= 0.7), None)]
}
df_summary = pd.DataFrame(summary)
display(df_summary) if display_available else print(df_summary)