In [127]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

In [128]:
# Define the neural network architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64*8*8, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2, 2)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool(x)
        x = x.view(-1, 64*8*8)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [129]:
# Function to create heterogenous data loaders
def create_heterogenous_data_loaders(dataset, num_users):
    heterogenous_loaders = []
    num_samples_per_user = len(dataset) // num_users
    for i in range(num_users):
        user_indices = list(range(i * num_samples_per_user, (i + 1) * num_samples_per_user))
        user_loader = DataLoader(Subset(dataset, user_indices), batch_size=64, shuffle=True)
        heterogenous_loaders.append(user_loader)
    return heterogenous_loaders


In [130]:
class FedLearning:
    def __init__(self, num_users, num_epochs, algorithm='FedAvg', device='cuda'):
        self.num_users = num_users
        self.num_epochs = num_epochs
        self.algorithm = algorithm
        self.local_criteria = [nn.CrossEntropyLoss() for _ in range(num_users)]
        self.criterion = nn.CrossEntropyLoss()
        self.local_optimizers = [None for _ in range(num_users)]
        self.optimizer = None
        self.device = device

    def train_local_model_avg(self, model, train_loader, optimizer, criterion):
        model.train()
        model.to(self.device)
        for _ in range(2):
            for data, target in train_loader:
                data, target = data.to(self.device), target.to(self.device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()

    def train_local_model_adam(self, model, train_loader, optimizer, criterion):
        model.train()
        model.to(self.device)
        avg_grad = {}
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(self.device), target.to(self.device)
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
        avg_grad = {name: param.grad for name, param in model.named_parameters()}
        optimizer.zero_grad()
        optimizer.step()        
        return avg_grad
    
    def train_local_model_prox(self, model, train_loader, optimizer, criterion, global_model, mu=0.01):
        model.train()
        model.to(self.device)
        for data, target in train_loader:
            data, target = data.to(self.device), target.to(self.device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss += mu * sum([torch.norm(param - global_param) for param, global_param in zip(model.parameters(), global_model.parameters())])
            loss.backward()
            optimizer.step()
    
    def fed_avg(self, global_model, local_models):
        global_model.train()
        global_model.to(self.device)
        global_dict = global_model.state_dict()
        for k in global_dict.keys():
            global_dict[k] = torch.stack([local_models[i].state_dict()[k].float() for i in range(self.num_users)], 0).mean(0)
        global_model.load_state_dict(global_dict)
        for local_model in local_models:
            local_model.load_state_dict(global_dict)
        return global_model

    def fed_adam(self, global_model, local_models, avg_grads):
        global_model.train()
        global_model.to(self.device)
        for name, param in global_model.named_parameters():
            param.grad = torch.stack([avg_grads[i][name] for i in range(self.num_users)], 0).mean(0)
        
        self.optimizer.step()
        for local_model in local_models:
            local_model.load_state_dict(global_model.state_dict())
        return global_model
    
    def fed_prox(self, global_model, local_models, mu=0.01):
        global_model.train()
        global_model.to(self.device)
        for k in global_model.state_dict().keys():
            global_model.state_dict()[k] = torch.stack([local_models[i].state_dict()[k].float() for i in range(self.num_users)], 0).mean(0) - mu * sum([torch.norm(local_models[i].state_dict()[k] - global_model.state_dict()[k]) for i in range(self.num_users)])
        for local_model in local_models:
            local_model.load_state_dict(global_model.state_dict())
        return global_model
        
    def evaluate(self, model, val_loader):
        model.eval()
        running_loss = 0.0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = model(data)
                loss = self.criterion(output, target)
                running_loss += loss.item() * data.size(0)
        model.train()
        return running_loss / len(val_loader.dataset)                    
            
    def train(self, local_models, global_model, train_loaders, val_loader=None):
        
        global_model.to(self.device)
        for local_model in local_models:
            local_model.to(self.device)

        if self.algorithm == 'FedAvg':
            self.optimizer = optim.SGD(global_model.parameters(), lr=0.1)
            for i, local_model in enumerate(local_models):
                self.local_optimizers[i] = optim.SGD(local_model.parameters(), lr=0.1)
        elif self.algorithm == 'FedAdam':
            self.optimizer = optim.Adam(global_model.parameters(), lr=0.1)
            for i, local_model in enumerate(local_models):
                self.local_optimizers[i] = optim.Adam(local_model.parameters(), lr=0.1)
        elif self.algorithm == 'FedProx':
            # self.optimizer = optim.SGD(global_model.parameters(), lr=0.1)
            for i, local_model in enumerate(local_models):
                self.local_optimizers[i] = optim.SGD(local_model.parameters(), lr=0.1)
        else:
            raise ValueError("Invalid algorithm")
        
        for epoch in range(self.num_epochs):
            avg_grad = []
            for i, train_loader in enumerate(train_loaders):
                local_model = local_models[i]                
                if self.algorithm == 'FedAvg':
                    self.train_local_model_avg(local_model, train_loader, self.local_optimizers[i], self.local_criteria[i])    
                elif self.algorithm == 'FedAdam':
                    avg_grad.append(self.train_local_model_adam(local_model, train_loader, self.local_optimizers[i], self.local_criteria[i]))
                else:
                    self.train_local_model_prox(local_model, train_loader, self.local_optimizers[i], self.local_criteria[i], global_model)

            if self.algorithm == 'FedAvg':
                global_model = self.fed_avg(global_model, local_models)
            elif self.algorithm == 'FedAdam':
                global_model = self.fed_adam(global_model, local_models, avg_grad)
            else:
                global_model = self.fed_prox(global_model, local_models)
                                
            if val_loader is not None:
                val_accuracy = self.evaluate(global_model, val_loader)
                print(f"Epoch {epoch+1}, Validation Loss: {val_accuracy}")

In [131]:
# Load CIFAR-10 dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create heterogenous data loaders
num_users = 25
heterogenous_loaders = create_heterogenous_data_loaders(train_dataset, num_users)
val_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [134]:
algorithms = ['FedAvg', 'FedAdam', 'FedProx']
for algorithm in algorithms:
    print(f"Algorithm: {algorithm}")
    global_model = CNN()
    local_models = [CNN() for _ in range(num_users)]
    fed_learning = FedLearning(num_users, 10, algorithm)
    fed_learning.train(local_models, global_model, heterogenous_loaders, val_loader)
    print("")

Algorithm: FedAvg
Epoch 1, Validation Loss: 2.3029593204498293
Epoch 2, Validation Loss: 2.302195877456665
Epoch 3, Validation Loss: 2.300949782562256
Epoch 4, Validation Loss: 2.2901844814300536
Epoch 5, Validation Loss: 2.108607790374756
Epoch 6, Validation Loss: 1.9677748048782349
Epoch 7, Validation Loss: 1.8831800437927246
Epoch 8, Validation Loss: 1.7995524208068847
Epoch 9, Validation Loss: 1.728608934211731
Epoch 10, Validation Loss: 1.6675622169494628

Algorithm: FedAdam
Epoch 1, Validation Loss: 1185.5081931640625
Epoch 2, Validation Loss: 9.489222327423096
Epoch 3, Validation Loss: 2.582181950378418
Epoch 4, Validation Loss: 2.307996007156372
Epoch 5, Validation Loss: 2.3119687629699706
Epoch 6, Validation Loss: 2.315098455429077
Epoch 7, Validation Loss: 2.3183091468811035
Epoch 8, Validation Loss: 2.312909606552124
Epoch 9, Validation Loss: 2.3110937377929686
Epoch 10, Validation Loss: 2.311985194778442

Algorithm: FedProx
Epoch 1, Validation Loss: 2.3026687561035155
Epoch