In [69]:
import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)              
    np.random.seed(seed)                  
    torch.manual_seed(seed)               
    torch.cuda.manual_seed(seed)           
    torch.cuda.manual_seed_all(seed)       
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)


In [70]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def load_cifar10():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = datasets.CIFAR10(root = './data', train = True , download=True , transform=transform)

    trainloader = DataLoader(trainset , batch_size = 64,shuffle = True)

    return trainset,trainloader

trainset,trainloader = load_cifar10()

In [71]:
def load_cifar10_test(batch_size: int = 64):
    """Load CIFAR-10 test dataset."""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    ])
    
    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
    
    return testset, testloader

In [72]:
import numpy as np
from torch.utils.data import Subset

def Dirichlet_partition(dataset , num_clients = 5 , alpha = 0.5):
    """
    Partition the dataset indices into `num_clients` parts using Dirichlet distribution per class.

    Returns a list of lists, where each inner list contains indices for one client.
    """
    num_classes = len(set([label for _, label in dataset]))
    class_indices = [[] for _ in range(num_classes)]

    for idx , (_,label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    client_indices = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        np.random.shuffle(class_indices[c])
        proportions = np.random.dirichlet([alpha]*num_clients)
        proportions = (np.cumsum(proportions)*len(class_indices[c])).astype(int)[:-1]
        splits = np.split(class_indices[c],proportions)
        for i,split in enumerate(splits):
            client_indices[i].extend(split.tolist())

    for i in range(num_clients):
        np.random.shuffle(client_indices[i])
    
    return client_indices

num_clients = 3
alpha = 0.5

client_idx_lists = Dirichlet_partition(trainset,num_clients,alpha)

def create_client_datasets(trainset: datasets.CIFAR10, client_indices):
    """Create client datasets from indices."""
    return [Subset(trainset, idx_list) for idx_list in client_indices]

client_datasets = create_client_datasets(trainset,client_idx_lists)

In [73]:
import torch.nn as nn

class model(nn.Module):
    def __init__(self, output_dim=10):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),  # output size: 16x16

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # output size: 8x8

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)   # output size: 4x4
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, output_dim)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x


In [74]:
import copy
def client_update(model,dataloader,optimizer,loss_fn,epochs = 1):
    model.train()
    model.to(device)
    for _ in range(epochs):
        for x,y in dataloader:
            x,y =x.to(device),y.to(device)
            optimizer.zero_grad()
            output = model(x)
            loss = loss_fn(output,y)
            loss.backward()
            optimizer.step()
    return model.state_dict()

def aggregate_models(client_models):
    avg_model = copy.deepcopy(client_models[0])
    for key in avg_model.keys():
        for i in range(1,len(client_models)):
            avg_model[key] += client_models[i][key]
        avg_model[key] = torch.div(avg_model[key], len(client_models))

    return avg_model

In [75]:
def evaluate(model,dataloader):
    model.eval()
    model.to(device)
    correct,total = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            output = model(x)
            pred = output.argmax(dim=1)
            correct+= (pred==y).sum().item()
            total+= y.size(0)
    return 100.0*correct/total

testset,testloader = load_cifar10_test()


In [76]:
import random

def train_fedavg(num_rounds,fraction,lr,momentum,epochs,output_dim,num_clients):

    global_model = model(output_dim).to(device)
    best_acc = 0
    best_model_state = None
    for round in range(num_rounds):
        print(f"\n--- Round {round + 1} ---")

        selected_clients = random.sample(range(num_clients),int(fraction*num_clients))

        client_models=[]
        for client_id in selected_clients:
            local_model = copy.deepcopy(global_model)
            optimizer = torch.optim.SGD(local_model.parameters(), lr, momentum)
            dataloader = DataLoader(client_datasets[client_id],batch_size=64,shuffle=True)
            loss_fn=torch.nn.CrossEntropyLoss()

            updated_weights = client_update(local_model,dataloader,optimizer,loss_fn,epochs)
            client_models.append(updated_weights)

        new_global_weights = aggregate_models(client_models)
        global_model.load_state_dict(new_global_weights)

        test_acc = evaluate(global_model, testloader)
        print(f"Test Accuracy: {test_acc:.2f}%")
        if test_acc > best_acc:
            best_acc = test_acc
            best_model_state = copy.deepcopy(global_model.state_dict())

    torch.save(best_model_state, "best_global_model.pth")
    print(f"Saved the best accuracy model. Best Accuracy : {best_acc:.2f}%")
    return best_acc , best_model_state


In [77]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dim =10
num_rounds =10
fraction = 0.6
lr = 0.03
momentum = 0.9
epochs = 7

best_acc , best_model_state = train_fedavg(num_rounds,fraction,lr,momentum,epochs,output_dim,num_clients)


--- Round 1 ---
Test Accuracy: 56.12%

--- Round 2 ---
Test Accuracy: 56.18%

--- Round 3 ---
Test Accuracy: 54.92%

--- Round 4 ---
Test Accuracy: 60.86%

--- Round 5 ---
Test Accuracy: 74.59%

--- Round 6 ---
Test Accuracy: 63.49%

--- Round 7 ---
Test Accuracy: 64.53%

--- Round 8 ---
Test Accuracy: 64.81%

--- Round 9 ---
Test Accuracy: 69.23%

--- Round 10 ---
Test Accuracy: 66.64%
Saved the best accuracy model. Best Accuracy : 74.59%
