In [None]:
import copy
import os
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import random_split, DataLoader
import torch.nn.functional as F
import torch.nn as nn
from sklearn.metrics import classification_report, confusion_matrix

# Set environment variable for duplicate library error
os.environ['KMP_DUPLICATE_LIB_OK'] = "TRUE"

# Set hyperparameters
n_clients = 10
epochs = 10  # total epochs
local_epochs = 1  # local epochs of each user at an iteration
lr = 3e-3  # learning rate
cudaIdx = "cuda:0"  # GPU card index
device = torch.device(cudaIdx if torch.cuda.is_available() else "cpu")
num_workers = 30  # workers for dataloader


# Class to sample equal number of users in each training round
class EqualUserSampler(object):
    def __init__(self, n, num_users) -> None:
        self.i = 0
        self.selected = n
        self.num_users = num_users
        self.get_order()

    def get_order(self):
        self.users = np.arange(self.num_users)

    def get_useridx(self):
        selection = []
        for _ in range(self.selected):
            selection.append(self.users[self.i])
            self.i += 1
            if self.i >= self.num_users:
                self.get_order()
                self.i = 0
        return selection


# CNN model definition
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # self.conv1 = nn.Conv2d(1, 6, 5)  # Old convolution layer
        self.conv1 = nn.Conv2d(1, 32, 5)  # Increased filters for better feature extraction
        self.bn1 = nn.BatchNorm2d(32)  # Added batch normalization
        self.pool = nn.MaxPool2d(2, 2)
        # self.conv2 = nn.Conv2d(6, 16, 5)  # Old second convolution layer
        self.conv2 = nn.Conv2d(32, 64, 5)  # Increased filters in second convolution layer
        self.bn2 = nn.BatchNorm2d(64)  # Added batch normalization
        # Added dropout for regularization
        self.dropout = nn.Dropout(0.5)
        # self.fc1 = nn.Linear(16 * 4 * 4, 120)  # Old fully connected layer
        self.fc1 = nn.Linear(64 * 4 * 4, 120)  # Updated based on increased filters
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.relu = nn.ReLU()  # Kept ReLU activation as-is

    def forward(self, x):
        # x = self.pool(self.relu(self.conv1(x)))  # Old forward pass
        x = self.pool(self.relu(self.bn1(self.conv1(x))))  # New forward pass with batch norm
        # x = self.pool(self.relu(self.conv2(x)))  # Old second layer forward pass
        x = self.pool(self.relu(self.bn2(self.conv2(x))))  # New second layer with batch norm
        x = x.view(-1, 64 * 4 * 4)  # Adjusted dimensions for increased filters
        x = self.relu(self.fc1(x))
        x = self.dropout(self.relu(self.fc2(x)))  # Added dropout after second fully connected layer
        x = self.fc3(x)
        return x


# Data transformation with augmentation
transform = transforms.Compose([
    # transforms.ToTensor(),  # Old transformation
    transforms.RandomRotation(10),  # Random rotation for augmentation
    transforms.RandomAffine(0, translate=(0.1, 0.1)),  # Random translation for augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


# Load data (each client will load its own data in a real FL scenario)
def load_data(transform, datasets='MNIST'):
    if datasets.upper() == 'MNIST':
        train_dataset = torchvision.datasets.MNIST(
            root="./data/mnist", train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.MNIST(
            root="./data/mnist", train=False, download=True, transform=transform)
    else:
        train_dataset = torchvision.datasets.CIFAR10(
            root="./data/cifar-10-python", train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.CIFAR10(
            root="./data/cifar-10-python", train=False, download=True, transform=transform)
    return train_dataset, test_dataset


# Partition the dataset into 'n_clients' partitions
def partition_dataset(dataset, n_clients):
    split_size = len(dataset) // n_clients
    return random_split(dataset, [split_size] * n_clients)


# FedAvgServer class for managing global aggregation
class FedAvgServer:
    def __init__(self, global_parameters):
        self.global_parameters = global_parameters

    def download(self, user_idx):
        local_parameters = []
        for i in range(len(user_idx)):
            local_parameters.append(copy.deepcopy(self.global_parameters))
        return local_parameters

    def upload(self, local_parameters):
        for k, v in self.global_parameters.items():
            tmp_v = torch.zeros_like(v)
            for j in range(len(local_parameters)):
                tmp_v += local_parameters[j][k]
            tmp_v = tmp_v / len(local_parameters)  # FedAvg
            self.global_parameters[k] = tmp_v


# Client class for local training
class Client:
    def __init__(self, data_loader, user_idx):
        self.data_loader = data_loader
        self.user_idx = user_idx

    def train(self, model, learningRate, idx, global_model):
        optimizer = torch.optim.Adam(model.parameters(), lr=learningRate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # Added scheduler
        for epoch in range(epochs):
            for data, labels in self.data_loader:
                data, labels = data.to(device), labels.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = F.cross_entropy(output, labels)
                loss.backward()
                optimizer.step()
            scheduler.step()  # Step scheduler at the end of the epoch


# Activate clients
def activateClient(train_dataloaders, user_idx, server):
    local_parameters = server.download(user_idx)
    clients = [Client(train_dataloaders[user], user) for user in user_idx]
    return clients, local_parameters


# Train function for orchestrating client training
def train(train_dataloaders, user_idx, server, global_model, learningRate):
    clients, local_parameters = activateClient(train_dataloaders, user_idx, server)
    for i in range(len(user_idx)):
        model = ConvNet().to(device)
        model.load_state_dict(local_parameters[i])
        model.train()
        clients[i].train(model, learningRate, i, global_model)
        local_parameters[i] = model.state_dict()
    server.upload(local_parameters)
    global_model.load_state_dict(server.global_parameters)


# Test function with additional metrics
'''
def evaluate_model(model, test_loader, device):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    # Print evaluation metrics

    print("Classification Report:")
    print(classification_report(y_true, y_pred))
    print("Confusion Matrix:")
    print(confusion_matrix(y_true, y_pred))
    '''

def test(model, test_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def train_main(n_clients=10):
    global_model = ConvNet().to(device)
    global_parameters = global_model.state_dict()
    server = FedAvgServer(global_parameters)

    train_dataset, test_dataset = load_data(transform)
    client_datasets = partition_dataset(train_dataset, n_clients)
    client_loaders = [DataLoader(dataset, batch_size=50, shuffle=True, num_workers=num_workers)
                      for dataset in client_datasets]
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    sampler = EqualUserSampler(n_clients, n_clients)

    for epoch in range(1, epochs + 1):
        print(f'Global Epoch {epoch}/{epochs}')
        user_idx = sampler.get_useridx()

        train(client_loaders, user_idx, server, global_model, lr)

         # Evaluate global model on test dataset
        test_accuracy = test(global_model, test_loader, device)
        print(
            f'Global Model Test Accuracy after round {epoch}: {test_accuracy:.4f}')

        # Evaluate global model on test dataset
       # evaluate_model(global_model, test_loader, device)

    # Save the final global model
    torch.save(global_model.state_dict(), 'federated_model.pth')
    print("Federated learning process completed.")


if __name__ == '__main__':
    train_main()
