In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7b6f5f5ab290>

In [3]:
#simple cnn for mnist
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 1600)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [17]:
#simple training function
def train_epoch(model, dataloader, optimizer, loss_function, device):
    model.train()
    total_loss = 0
    train_correct = 0
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        train_correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = train_correct / len(dataloader.dataset)
    return model, total_loss, accuracy

#simple testing function
def test_model(model, dataloader, loss_function, device):
    model.eval()
    total_loss = 0
    test_correct = 0
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += loss_function(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            test_correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = test_correct / len(dataloader.dataset)
    return total_loss, accuracy

def train(model, train_loader, test_loader, optimizer, loss_function, device, num_epochs):
    for epoch in range(num_epochs):
        model, train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, loss_function, device)
        if test_loader:
            test_loss, test_accuracy = test_model(model, test_loader, loss_function, device)
            # print(f'Epoch {epoch + 1}/{num_epochs} - Train Loss: {train_loss:.4f} - Train Accuracy: {train_accuracy:.4f} - Test Loss: {test_loss:.4f} - Test Accuracy: {test_accuracy:.4f}')
        else:
            # print(f'Epoch {epoch + 1}/{num_epochs} - Train Loss: {train_loss:.4f} - Train Accuracy: {train_accuracy:.4f}')
            test_accuracy = 'nan'
            test_loss = 'nan'
    return model, train_loss, train_accuracy, test_loss, test_accuracy

In [None]:
#load mnist data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5), (0.5))])
train_data = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('../data', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)

In [None]:
num_epochs = 10
model = SimpleCNN().to("cuda")
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_function = nn.CrossEntropyLoss()

model, train_loss, train_acc, test_loss, test_acc = train(model, trainloader, testloader, optimizer, loss_function, "cuda", num_epochs)

Epoch 1/10 - Train Loss: 356.7633 - Train Accuracy: 0.9430 - Test Loss: 46.9566 - Test Accuracy: 0.9571
Epoch 2/10 - Train Loss: 207.3113 - Train Accuracy: 0.9676 - Test Loss: 31.3497 - Test Accuracy: 0.9709
Epoch 3/10 - Train Loss: 186.7636 - Train Accuracy: 0.9703 - Test Loss: 27.7003 - Test Accuracy: 0.9765
Epoch 4/10 - Train Loss: 186.6661 - Train Accuracy: 0.9716 - Test Loss: 29.4968 - Test Accuracy: 0.9745
Epoch 5/10 - Train Loss: 174.2264 - Train Accuracy: 0.9732 - Test Loss: 38.4680 - Test Accuracy: 0.9676
Epoch 6/10 - Train Loss: 173.8800 - Train Accuracy: 0.9736 - Test Loss: 34.8001 - Test Accuracy: 0.9705
Epoch 7/10 - Train Loss: 171.1548 - Train Accuracy: 0.9736 - Test Loss: 30.7949 - Test Accuracy: 0.9729
Epoch 8/10 - Train Loss: 162.4146 - Train Accuracy: 0.9757 - Test Loss: 40.5097 - Test Accuracy: 0.9662
Epoch 9/10 - Train Loss: 166.5521 - Train Accuracy: 0.9747 - Test Loss: 27.5115 - Test Accuracy: 0.9758
Epoch 10/10 - Train Loss: 157.2287 - Train Accuracy: 0.9764 - Te

In [18]:
def get_model(config):
    if config["model"] == "SimpleCNN":
        return SimpleCNN()
    else:
        raise ValueError("Model not supported")
    
def load_data(config):
    if config["dataset"] == "MNIST":
        data = datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]))
    else:
        raise ValueError("Dataset not supported")
    
    if config["iid"]:
        client_data = random_split(data, [len(data) // config['num_clients'] for _ in range(config['num_clients'])])
    else:
        raise ValueError("NOT IMPLEMENTED")
    return client_data

def get_testloader(config):
    if config["dataset"] == "MNIST":
        data = datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]))
    else:
        raise ValueError("Dataset not supported")
    return torch.utils.data.DataLoader(data, batch_size=config["batch_size"], shuffle=False)

class FedAVG():
    def __init__(self, config) -> None:
        self.config = config
        self.global_model = get_model(config).to(config["device"])
        self.client_data = load_data(config)
        self.testloader = get_testloader(config)

    def client_train(self, client_id):
        data = self.client_data[client_id]
        model = get_model(self.config).to(self.config["device"])
        model.load_state_dict(self.global_model.state_dict())
        optimizer = getattr(optim, self.config["optimizer"])(model.parameters(), lr=self.config["lr"])
        loss_function = getattr(nn, self.config["loss_function"])()

        data_loader = torch.utils.data.DataLoader(data, batch_size=self.config["batch_size"], shuffle=True)

        model, local_train_loss, local_train_accuracy, _, _ = train(model, data_loader, None, optimizer, loss_function, self.config["device"], self.config["num_epochs"])

        return model.state_dict()
    
    def aggregate(self, weights, client_sample_sizes):
        total_samples = np.sum(client_sample_sizes)
        avg_weights = {}
        for key in weights[0].keys():
            avg_weights[key] = sum(
                weights[i][key] * (client_sample_sizes[i] / total_samples)
                for i in range(len(weights)))
        return avg_weights
    
    def run(self):
        for round in range(self.config["num_rounds"]):
            client_weights = []
            client_sample_sizes = []
            selected_clients = np.random.choice(len(self.client_data), self.config["num_client_selection"], replace=False)
            for client_id in selected_clients:
                client_weights.append(self.client_train(client_id))
                client_sample_sizes.append(len(self.client_data[client_id]))
            aggregated_weights = self.aggregate(client_weights, client_sample_sizes)
            self.global_model.load_state_dict(aggregated_weights)
            print(f"Round {round + 1} completed")

            test_loss, test_accuracy = test_model(self.global_model,self.testloader, getattr(nn, self.config["loss_function"])(), self.config["device"])
            print(f"Test Loss: {test_loss:.4f} - Test Accuracy: {test_accuracy:.4f}")
            print("-"*50)
        return self.global_model    

In [22]:
config = {
    "model": "SimpleCNN",
    "dataset": "MNIST",
    "iid": True,
    "num_clients": 10,
    "num_epochs": 2,
    "lr": 0.01,
    "num_rounds": 10,
    "device": 'cuda',
    "batch_size": 32,
    "optimizer": "Adam",
    "loss_function": "CrossEntropyLoss",
    "num_client_selection": 4
}

In [23]:
fl_avg = FedAVG(config)

In [24]:
global_model = fl_avg.run()

Round 1 completed
Test Loss: 726.8008 - Test Accuracy: 0.1135
--------------------------------------------------
Round 2 completed
Test Loss: 40.4302 - Test Accuracy: 0.9622
--------------------------------------------------
Round 3 completed
Test Loss: 25.0442 - Test Accuracy: 0.9753
--------------------------------------------------
Round 4 completed
Test Loss: 20.0644 - Test Accuracy: 0.9806
--------------------------------------------------
Round 5 completed
Test Loss: 17.9355 - Test Accuracy: 0.9822
--------------------------------------------------
Round 6 completed
Test Loss: 18.8264 - Test Accuracy: 0.9831
--------------------------------------------------
Round 7 completed
Test Loss: 17.9822 - Test Accuracy: 0.9837
--------------------------------------------------
Round 8 completed
Test Loss: 18.2469 - Test Accuracy: 0.9816
--------------------------------------------------
Round 9 completed
Test Loss: 22.1747 - Test Accuracy: 0.9811
------------------------------------------