In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import random
import copy
import matplotlib.pyplot as plt

# -------------------
# 1. Simple CNN Model
# -------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5) # 28x28 -> 24x24
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5) # 12x12 -> 8x8
        self.fc1 = nn.Linear(20 * 4 * 4, 50)
        self.fc2 = nn.Linear(50, 10)
   
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))  # 24x24 -> 12x12
        x = F.relu(F.max_pool2d(self.conv2(x), 2))  # 8x8 -> 4x4
        x = x.view(-1, 20*4*4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# -----------------------------------
# 2. Data Partitioning among K clients
# -----------------------------------

def iid_partition(dataset, num_clients):
    """
    Split the dataset into IID partitions for each client
    """
    num_items = int(len(dataset)/num_clients)
    all_indices = [i for i in range(len(dataset))]
    client_dict = {}
    for i in range(num_clients):
        client_dict[i] = set(np.random.choice(all_indices, num_items, replace=False))
        all_indices = list(set(all_indices) - client_dict[i])
    return client_dict

def noniid_partition(dataset, num_clients, num_shards=200, shards_per_client=2):
    """
    Non-IID partitioning based on shard allocation.
    MNIST is sorted by labels, split into shards, and each client gets shards.
    """
    idxs = np.arange(len(dataset))
    labels = np.array(dataset.targets)
    # Sort by label
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1,:].argsort()]
    idxs = idxs_labels[0,:]
   
    shard_size = int(len(dataset)/num_shards)
    shards = [set(idxs[i*shard_size:(i+1)*shard_size]) for i in range(num_shards)]
   
    client_dict = {i: set() for i in range(num_clients)}
    shard_indices = np.arange(num_shards)
    np.random.shuffle(shard_indices)
    for i in range(num_clients):
        assigned_shards = shard_indices[i*shards_per_client:(i+1)*shards_per_client]
        for shard_id in assigned_shards:
            client_dict[i] = client_dict[i].union(shards[shard_id])
    return client_dict

# -------------------------------
# 3. Client update (local training)
# -------------------------------
def client_update(client_model, optimizer, train_loader, epochs, device):
    """
    Run local training for a client
    """
    client_model.train()
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = client_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return client_model.state_dict()

# -------------------------------
# 4. Server aggregation (FedAvg)
# -------------------------------
def fed_avg(weights, client_sizes):
    """
    Aggregate client models weighted by the number of samples per client
    """
    total_samples = sum(client_sizes)
    avg_weights = copy.deepcopy(weights[0])
    for key in avg_weights.keys():
        avg_weights[key] = torch.zeros_like(avg_weights[key])
    for client_idx, client_weight in enumerate(weights):
        for key in avg_weights.keys():
            avg_weights[key] += (client_sizes[client_idx] / total_samples) * client_weight[key]
    return avg_weights

# -------------------------------
# 5. Testing the model accuracy
# -------------------------------
def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    test_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            loss = criterion(outputs, target)
            test_loss += loss.item() * data.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return test_loss / total, correct / total

# --------------------------------------
# 6. Main Federated Training Loop
# --------------------------------------
def federated_training(
    num_clients=10, local_epochs=5, batch_size=32, lr=0.01,
    rounds=30, fraction=1.0, iid=True, device='cpu'
):
    # Load dataset
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    # Partition data
    if iid:
        client_dict = iid_partition(train_dataset, num_clients)
    else:
        # For non-IID, number of shards and shards_per_client must be set accordingly
        shards_per_client = 2
        num_shards = num_clients * shards_per_client
        client_dict = noniid_partition(train_dataset, num_clients, num_shards=num_shards, shards_per_client=shards_per_client)

    # Prepare dataloaders per client
    client_loaders = []
    client_sizes = []
    for i in range(num_clients):
        idxs = list(client_dict[i])
        client_sizes.append(len(idxs))
        subset = Subset(train_dataset, idxs)
        loader = DataLoader(subset, batch_size=batch_size, shuffle=True)
        client_loaders.append(loader)

    # Initialize global model
    global_model = SimpleCNN().to(device)
    global_weights = global_model.state_dict()

    # Tracking metrics
    global_acc = []
    global_loss = []

    for r in range(rounds):
        print(f"Round {r+1}/{rounds}")
        m = max(int(fraction * num_clients), 1)
        selected_clients = np.random.choice(range(num_clients), m, replace=False)
        local_weights = []
        local_sizes = []

        for client_idx in selected_clients:
            # Local model copy
            local_model = SimpleCNN().to(device)
            local_model.load_state_dict(global_weights)

            optimizer = optim.SGD(local_model.parameters(), lr=lr)
            local_data_loader = client_loaders[client_idx]

            w = client_update(local_model, optimizer, local_data_loader, local_epochs, device)
            local_weights.append(w)
            local_sizes.append(client_sizes[client_idx])

        # Aggregate local weights to update global model
        global_weights = fed_avg(local_weights, local_sizes)
        global_model.load_state_dict(global_weights)

        # Evaluate global model
        test_loss, test_accuracy = test(global_model, test_loader, device)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy*100:.2f}%")
        global_acc.append(test_accuracy)
        global_loss.append(test_loss)

    return global_acc, global_loss, global_model

# ------------------------------
# 7. Centralized SGD training for comparison
# ------------------------------
def centralized_sgd_train(local_epochs=5, batch_size=32, lr=0.01, rounds=30, device='cpu'):
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    model = SimpleCNN().to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    acc = []
    loss_list = []

    for r in range(rounds):
        model.train()
        for epoch in range(local_epochs):
            for data, target in train_loader:
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                outputs = model(data)
                loss = criterion(outputs, target)
                loss.backward()
                optimizer.step()
       
        # Test
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                loss = criterion(outputs, target)
                test_loss += loss.item() * data.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        avg_loss = test_loss / total
        accuracy = correct / total
        print(f"Round {r+1}, Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy*100:.2f}%")
        acc.append(accuracy)
        loss_list.append(avg_loss)

    return acc, loss_list, model


# ------------------------------
# 8. Run Experiments and Plot Results
# ------------------------------

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Running Federated Learning experiments on device:", device)

    # Parameters
    num_clients = 10
    local_epochs = 5
    batch_size = 32
    lr = 0.01
    rounds = 30
    fraction = 1.0  # Use all clients every round

    # IID FedAvg
    print("=== IID FedAvg Training ===")
    fedacc_iid, fedloss_iid, fed_model_iid = federated_training(
        num_clients=num_clients, local_epochs=local_epochs, batch_size=batch_size,
        lr=lr, rounds=rounds, fraction=fraction, iid=True, device=device
    )

    # Non-IID FedAvg
    print("=== Non-IID FedAvg Training ===")
    fedacc_noniid, fedloss_noniid, fed_model_noniid = federated_training(
        num_clients=num_clients, local_epochs=local_epochs, batch_size=batch_size,
        lr=lr, rounds=rounds, fraction=fraction, iid=False, device=device
    )

    # Centralized SGD (L=local_epochs)
    print("=== Centralized SGD Training ===")
    sgd_acc, sgd_loss, sgd_model = centralized_sgd_train(
        local_epochs=local_epochs, batch_size=batch_size, lr=lr, rounds=rounds, device=device
    )

    # Plot accuracy
    plt.figure(figsize=(10,5))
    plt.plot(fedacc_iid, label="FedAvg IID")
    plt.plot(fedacc_noniid, label="FedAvg Non-IID")
    plt.plot(sgd_acc, label="Centralized SGD")
    plt.title("Test Accuracy over Rounds")
    plt.xlabel("Rounds")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.show()

    # Plot loss
    plt.figure(figsize=(10,5))
    plt.plot(fedloss_iid, label="FedAvg IID")
    plt.plot(fedloss_noniid, label="FedAvg Non-IID")
    plt.plot(sgd_loss, label="Centralized SGD")
    plt.title("Test Loss over Rounds")
    plt.xlabel("Rounds")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

ImportError: DLL load failed while importing _C: The specified module could not be found.

In [3]:
pip install torch torchvision

Collecting torch
  Downloading torch-2.8.0-cp311-cp311-win_amd64.whl.metadata (30 kB)
Collecting torchvision
  Downloading torchvision-0.23.0-cp311-cp311-win_amd64.whl.metadata (6.1 kB)
Downloading torch-2.8.0-cp311-cp311-win_amd64.whl (241.4 MB)
   ---------------------------------------- 0.0/241.4 MB ? eta -:--:--
   ---------------------------------------- 0.3/241.4 MB ? eta -:--:--
   ---------------------------------------- 1.0/241.4 MB 5.0 MB/s eta 0:00:48
   ---------------------------------------- 1.8/241.4 MB 4.6 MB/s eta 0:00:53
   ---------------------------------------- 2.4/241.4 MB 4.1 MB/s eta 0:00:59
   ---------------------------------------- 2.6/241.4 MB 3.2 MB/s eta 0:01:15
    --------------------------------------- 4.7/241.4 MB 4.4 MB/s eta 0:00:54
   - -------------------------------------- 6.6/241.4 MB 5.2 MB/s eta 0:00:46
   - -------------------------------------- 8.4/241.4 MB 5.7 MB/s eta 0:00:42
   - -------------------------------------- 10.0/241.4 MB 5.9 MB/

ERROR: Could not install packages due to an OSError: [WinError 5] Access is denied: 'C:\\Users\\Luv Mathur\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\lib\\asmjit.dll'
Consider using the `--user` option or check the permissions.



In [3]:
pip install --upgrade pip


Collecting pip
  Obtaining dependency information for pip from https://files.pythonhosted.org/packages/b7/3f/945ef7ab14dc4f9d7f40288d2df998d1837ee0888ec3659c813487572faa/pip-25.2-py3-none-any.whl.metadata
  Downloading pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.2-py3-none-any.whl (1.8 MB)
   ---------------------------------------- 0.0/1.8 MB ? eta -:--:--
    --------------------------------------- 0.0/1.8 MB 1.3 MB/s eta 0:00:02
   - -------------------------------------- 0.1/1.8 MB 812.7 kB/s eta 0:00:03
   -- ------------------------------------- 0.1/1.8 MB 1.1 MB/s eta 0:00:02
   ----- ---------------------------------- 0.2/1.8 MB 1.4 MB/s eta 0:00:02
   -------- ------------------------------- 0.4/1.8 MB 1.8 MB/s eta 0:00:01
   -------------- ------------------------- 0.6/1.8 MB 2.4 MB/s eta 0:00:01
   ---------------------- ----------------- 1.0/1.8 MB 3.3 MB/s eta 0:00:01
   ---------------------------- ----------- 1.2/1.8 MB 3.6 MB/s eta 0:00:01
   --------