## Import Libraries

In [None]:
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from collections import defaultdict
from copy import deepcopy

## Parameters

In [None]:
num_rounds = 20
num_clients = 100
client_fraction = 0.01
local_epochs = 1
learning_rate = 0.001

## Load MNIST dataset

In [None]:
def load_dataset_and_loaders():
    # Define transformations to normalize the data
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    
    # Load the MNIST dataset
    train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    # Number of users
    num_users = num_clients
    
    # Size of data per user
    samples_per_user = len(train_data) // num_users 
    
    # Create a dictionary to store the indices for each user
    user_data = {i: [] for i in range(num_users)}
    
    # Shuffle the indices of the dataset to create a random split
    indices = np.arange(len(train_data))
    np.random.seed(10)
    np.random.shuffle(indices)
    
    for i in range(num_users):
        user_data[i] = indices[i * samples_per_user: (i + 1) * samples_per_user]
    
    # Create a DataLoader for each user
    train_loaders = []
    batch_size = 32
    
    for i in range(num_users):
        # Create a subset of the dataset for each user
        user_subset = Subset(train_data, user_data[i])
        # Create a DataLoader for each subset
        user_loader = DataLoader(user_subset, batch_size=batch_size, shuffle=True)
        train_loaders.append(user_loader)
    
    # Test DataLoader for evaluating the global model
    test_loader = DataLoader(test_data, batch_size=1000, shuffle=False)

    return train_loaders, test_loader, train_data, user_data

## Check Non-IID

In [None]:
# Plot class distributions for each user
def plot_class_distribution(user_class_counts):
    num_users = len(user_class_counts)
    fig, axs = plt.subplots(num_users // 4, 4, figsize=(20, 10))  # Adjust subplot layout for readability

    for i, (user, counts) in enumerate(user_class_counts.items()):
        ax = axs[i // 4, i % 4]
        ax.bar(range(10), counts)
        ax.set_title(f'User {user}')
        ax.set_xlabel('Class')
        ax.set_ylabel('Number of Samples')
        ax.set_xticks(range(10))

    plt.tight_layout()
    plt.show()

In [None]:
# Function to calculate class distribution for each user
def calculate_and_plot_class_distribution(data, user_data):
    user_class_counts = defaultdict(lambda: np.zeros(10, dtype=int))  # Initialize counts for 10 classes

    # Iterate over each user
    for user, indices in user_data.items():
        # Count the classes for each user's subset of indices
        for idx in indices:
            label = data.targets[idx].item()  # Get the label of the sample
            user_class_counts[user][label] += 1

    # Display the class distribution for each user
    for user, counts in user_class_counts.items():
        print(f"User {user}: {counts}")

    plot_class_distribution(user_class_counts)

## Define Our Deep Neural Network

In [None]:
# Define the CNN Model with 2 convolutional layers and 2 fully-connected layers
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Instantiate model
model = SimpleCNN()


## Define Functions for Local Training and Model Evaluation

In [None]:
def train_local(model, train_loader, epochs=1, lr=0.01):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    return model.state_dict()


In [None]:
# Function to test the global model on test data
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100 * correct / total
    return accuracy

## Implement FedAvg

In [None]:
# Function to average weights from selected clients
def average_weights(selected_models):

    return avg_state_dict

In [None]:
# Federated training function with client fraction C and test accuracy measurement
def federated_training(num_rounds, num_clients, client_fraction, local_epochs, train_loaders, test_loader, lr=0.001):
    
    return global_model

In [None]:
if __name__ == "__main__":
    train_loaders, test_loader, train_data, user_data = load_dataset_and_loaders()
    calculate_and_plot_class_distribution(train_data, user_data)
    global_model = federated_training(num_rounds, num_clients, client_fraction, local_epochs, train_loaders, test_loader, lr=learning_rate)
    