In [24]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn as nn
import torch.optim as optim

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100.0%

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw






In [5]:
test_dataset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform)

In [4]:
def split_dataset(dataset, num_clients=10, iid=True):
    client_datasets = []
    
    if iid:
        # Randomly split dataset into num_clients parts of equal size
        client_datasets = random_split(dataset, [len(dataset) // num_clients for _ in range(num_clients)])
    else:
        # Non-IID case: each client will get a certain subset of the classes
        classes_per_client = 2  # Example: each client gets 2 classes of data
        labels = dataset.targets
        class_indices = {i: torch.where(labels == i)[0] for i in range(10)}
        
        # Assign classes to each client
        for i in range(num_clients):
            client_indices = []
            for j in range(classes_per_client):
                class_id = (i * classes_per_client + j) % 10
                client_indices += class_indices[class_id].tolist()
            
            client_subset = torch.utils.data.Subset(dataset, client_indices)
            client_datasets.append(client_subset)

    return client_datasets

In [8]:
iid_clients = split_dataset(train_dataset, num_clients=10, iid=True)
non_iid_clients = split_dataset(train_dataset, num_clients=10, iid=False)

In [9]:
#improve the code for better readability
from collections import Counter

# Function to count class distribution
def class_distribution(dataset):
    labels = dataset.dataset.targets[dataset.indices]
    return Counter(labels.numpy())

# Example: check distribution for IID and Non-IID clients
for i, client_data in enumerate(iid_clients):
    print(f"Client {i} IID distribution: {class_distribution(client_data)}")

for i, client_data in enumerate(non_iid_clients):
    print(f"Client {i} Non-IID distribution: {class_distribution(client_data)}")


Client 0 IID distribution: Counter({np.int64(1): 698, np.int64(7): 644, np.int64(6): 611, np.int64(8): 600, np.int64(3): 598, np.int64(9): 590, np.int64(0): 584, np.int64(2): 567, np.int64(5): 567, np.int64(4): 541})
Client 1 IID distribution: Counter({np.int64(1): 670, np.int64(2): 629, np.int64(7): 625, np.int64(9): 622, np.int64(3): 620, np.int64(8): 620, np.int64(0): 573, np.int64(6): 564, np.int64(4): 551, np.int64(5): 526})
Client 2 IID distribution: Counter({np.int64(1): 694, np.int64(9): 612, np.int64(2): 607, np.int64(6): 606, np.int64(3): 605, np.int64(4): 603, np.int64(7): 601, np.int64(0): 595, np.int64(8): 544, np.int64(5): 533})
Client 3 IID distribution: Counter({np.int64(7): 653, np.int64(1): 647, np.int64(8): 633, np.int64(4): 608, np.int64(3): 608, np.int64(0): 592, np.int64(9): 583, np.int64(6): 574, np.int64(5): 552, np.int64(2): 550})
Client 4 IID distribution: Counter({np.int64(1): 698, np.int64(7): 642, np.int64(2): 606, np.int64(9): 598, np.int64(0): 597, np.int

In [25]:
# Simple neural network for MNIST classification
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

##### Local Training

In [26]:
def local_training(client_data, model, epochs=5, lr=0.01):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    data_loader = DataLoader(client_data, batch_size=32, shuffle=True)
    
    model.train()
    for epoch in range(epochs):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    return model.state_dict()

##### Naive FedAvg

In [27]:
def federated_averaging(global_model, client_models):
    global_state_dict = global_model.state_dict()
    
    # Initialize weights as 0 for aggregation
    for key in global_state_dict.keys():
        global_state_dict[key] = torch.zeros_like(global_state_dict[key])
    
    # Sum the models' parameters
    for client_model in client_models:
        client_state_dict = client_model.state_dict()
        for key in global_state_dict.keys():
            global_state_dict[key] += client_state_dict[key]
    
    # Average the parameters
    for key in global_state_dict.keys():
        global_state_dict[key] = global_state_dict[key] / len(client_models)
    
    global_model.load_state_dict(global_state_dict)
    return global_model


#### IID Dataset

In [28]:
global_model = SimpleNN()

client_models = [SimpleNN() for _ in range(10)]
for i, client_data in enumerate(iid_clients):
    client_models[i].load_state_dict(local_training(client_data, client_models[i]))

global_model = federated_averaging(global_model, client_models)

In [30]:
test_dataset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [31]:
def evaluate_model(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    test_loss = 0.0
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    avg_loss = test_loss / len(data_loader)
    
    return accuracy, avg_loss

In [32]:
iid_accuracy, iid_loss = evaluate_model(global_model, test_loader)
print(f"Global Model - IID Data | Accuracy: {iid_accuracy:.2f}%, Loss: {iid_loss:.4f}")

Global Model - IID Data | Accuracy: 45.21%, Loss: 1.9212


##### Non-IID data

In [33]:
# Example: Train the global model using Non-IID clients
client_models_non_iid = [SimpleNN() for _ in range(10)]
for i, client_data in enumerate(non_iid_clients):
    client_models_non_iid[i].load_state_dict(local_training(client_data, client_models_non_iid[i]))

global_model_non_iid = SimpleNN()
global_model_non_iid = federated_averaging(global_model_non_iid, client_models_non_iid)

# Evaluate the global model trained on Non-IID data
non_iid_accuracy, non_iid_loss = evaluate_model(global_model_non_iid, test_loader)
print(f"Global Model - Non-IID Data | Accuracy: {non_iid_accuracy:.2f}%, Loss: {non_iid_loss:.4f}")


Global Model - Non-IID Data | Accuracy: 16.18%, Loss: 2.2326
