In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
import torchvision
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7d91d49fb910>

In [6]:
#simple cnn for mnist
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 1600)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
#simple training function
def train_epoch(model, dataloader, optimizer, loss_function, device):
    model.train()
    total_loss = 0
    train_correct = 0
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        train_correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = train_correct / len(dataloader.dataset)
    return model, total_loss, accuracy

#simple testing function
def test_model(model, dataloader, loss_function, device):
    model.eval()
    total_loss = 0
    test_correct = 0
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += loss_function(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            test_correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = test_correct / len(dataloader.dataset)
    return total_loss, accuracy

def train(model, train_loader, test_loader, optimizer, loss_function, device, num_epochs):
    for epoch in range(num_epochs):
        model, train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, loss_function, device)
        if test_loader:
            test_loss, test_accuracy = test_model(model, test_loader, loss_function, device)
            # print(f'Epoch {epoch + 1}/{num_epochs} - Train Loss: {train_loss:.4f} - Train Accuracy: {train_accuracy:.4f} - Test Loss: {test_loss:.4f} - Test Accuracy: {test_accuracy:.4f}')
        else:
            # print(f'Epoch {epoch + 1}/{num_epochs} - Train Loss: {train_loss:.4f} - Train Accuracy: {train_accuracy:.4f}')
            test_accuracy = 'nan'
            test_loss = 'nan'
    return model, train_loss, train_accuracy, test_loss, test_accuracy

In [13]:
#load CIFAR10 dataset
model = torchvision.models.resnet18(weights=None)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


100.0%


Extracting ../data/cifar-10-python.tar.gz to ../data
Files already downloaded and verified


In [17]:
num_epochs = 1
# model = SimpleCNN().to("cuda")
model = torchvision.models.resnet18(weights=None).to("cuda")
model.fc = nn.Linear(512, 10).to("cuda")
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_function = nn.CrossEntropyLoss()

model, train_loss, train_acc, test_loss, test_acc = train(model, train_loader, test_loader, optimizer, loss_function, "cuda", num_epochs)

In [22]:
def get_model(config):
    if config["model"] == "SimpleCNN":
        return SimpleCNN()
    elif config["model"] == "ResNet18":
        model = torchvision.models.resnet18(weights=None)
        model.fc = nn.Linear(512, 10)
        return model
    else:
        raise ValueError("Model not supported")
    
def load_data(config):
    if config["dataset"] == "MNIST":
        data = datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]))
    elif config["dataset"] == "CIFAR10":
        data = datasets.CIFAR10('../data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]))
    else:
        raise ValueError("Dataset not supported")
    
    if config["iid"]:
        client_data = random_split(data, [len(data) // config['num_clients'] for _ in range(config['num_clients'])])
    else:
        client_data = create_noniid_dataset_mnist(data, config['num_clients'], config['num_labels_per_client'])
    return client_data

def get_testloader(config):
    if config["dataset"] == "MNIST":
        data = datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]))
    elif config["dataset"] == "CIFAR10":
        data = datasets.CIFAR10('../data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]))
    else:
        raise ValueError("Dataset not supported")
    return torch.utils.data.DataLoader(data, batch_size=config["batch_size"], shuffle=False)


def create_noniid_dataset_mnist(data, num_clients, num_labels_per_client):
    """
    Splits the dataset into non-IID subsets, where each client gets data for a fixed number of labels.
    
    Args:
        data: The dataset to be split.
        num_clients: The number of clients.
        num_labels_per_client: Number of unique labels per client.
    
    Returns:
        A list of Subset objects, one for each client.
    """
    # Group indices by labels
    label_indices = {i: [] for i in range(10)} 
    for idx, (_, label) in enumerate(data):
        label_indices[label].append(idx)
    
    # Shuffle indices within each label
    for label in label_indices:
        torch.manual_seed(42)  # For reproducibility
        torch.randperm(len(label_indices[label])).tolist()
    
    # Assign labels to clients
    clients = [[] for _ in range(num_clients)]
    label_list = list(label_indices.keys())
    
    # Divide labels across clients
    labels_per_client = len(label_list) // num_clients
    for client_id in range(num_clients):
        start_label = (client_id * labels_per_client) % len(label_list)
        client_labels = label_list[start_label:start_label + num_labels_per_client]
        
        # Add data for these labels to the client
        for label in client_labels:
            clients[client_id].extend(label_indices[label])
    
    # Create Subset objects for each client
    client_data = [Subset(data, client_indices) for client_indices in clients]
    return client_data


class FedAVG():
    def __init__(self, config) -> None:
        self.config = config
        self.global_model = get_model(config).to(config["device"])
        self.client_data = load_data(config)
        self.testloader = get_testloader(config)

    def client_train(self, client_id):
        data = self.client_data[client_id]
        model = get_model(self.config).to(self.config["device"])
        model.load_state_dict(self.global_model.state_dict())
        optimizer = getattr(optim, self.config["optimizer"])(model.parameters(), lr=self.config["lr"])
        loss_function = getattr(nn, self.config["loss_function"])()

        data_loader = torch.utils.data.DataLoader(data, batch_size=self.config["batch_size"], shuffle=True)

        model, local_train_loss, local_train_accuracy, _, _ = train(model, data_loader, None, optimizer, loss_function, self.config["device"], self.config["num_epochs"])

        return model.state_dict()
    
    def aggregate(self, weights, client_sample_sizes):
        total_samples = np.sum(client_sample_sizes)
        avg_weights = {}
        for key in weights[0].keys():
            avg_weights[key] = sum(
                weights[i][key] * (client_sample_sizes[i] / total_samples)
                for i in range(len(weights)))
        return avg_weights
    
    def run(self):
        for round in range(self.config["num_rounds"]):
            client_weights = []
            client_sample_sizes = []
            selected_clients = np.random.choice(len(self.client_data), self.config["num_client_selection"], replace=False)
            for client_id in selected_clients:
                client_weights.append(self.client_train(client_id))
                client_sample_sizes.append(len(self.client_data[client_id]))
            aggregated_weights = self.aggregate(client_weights, client_sample_sizes)
            self.global_model.load_state_dict(aggregated_weights)
            print(f"Round {round + 1} completed")

            test_loss, test_accuracy = test_model(self.global_model,self.testloader, getattr(nn, self.config["loss_function"])(), self.config["device"])
            print(f"Test Loss: {test_loss:.4f} - Test Accuracy: {test_accuracy:.4f}")
            print("-"*50)
        return self.global_model    

In [23]:
config = {
    "model": "ResNet18",
    "dataset": "CIFAR10",
    "iid": True,
    "num_clients": 10,
    "num_epochs": 2,
    "lr": 0.01,
    "num_rounds": 10,
    "device": 'cuda',
    "batch_size": 32,
    "optimizer": "Adam",
    "loss_function": "CrossEntropyLoss",
    "num_client_selection": 4
}

In [24]:
fl_avg = FedAVG(config)

Files already downloaded and verified
Files already downloaded and verified


In [25]:
global_model = fl_avg.run()

Round 1 completed
Test Loss: 721.2015 - Test Accuracy: 0.1000
--------------------------------------------------
Round 2 completed
Test Loss: 720.0008 - Test Accuracy: 0.1000
--------------------------------------------------
Round 3 completed
Test Loss: 606.1033 - Test Accuracy: 0.2617
--------------------------------------------------
Round 4 completed
Test Loss: 460.6566 - Test Accuracy: 0.4750
--------------------------------------------------
Round 5 completed
Test Loss: 406.6692 - Test Accuracy: 0.5335
--------------------------------------------------
Round 6 completed
Test Loss: 371.5809 - Test Accuracy: 0.5701
--------------------------------------------------
Round 7 completed
Test Loss: 347.8602 - Test Accuracy: 0.5971
--------------------------------------------------
Round 8 completed
Test Loss: 320.6075 - Test Accuracy: 0.6375
--------------------------------------------------
Round 9 completed
Test Loss: 298.4674 - Test Accuracy: 0.6664
----------------------------------

In [45]:
config = {
    "model": "SimpleCNN",
    "dataset": "MNIST",
    "iid": False,
    "num_clients": 10,
    "num_epochs": 2,
    "lr": 0.01,
    "num_rounds": 10,
    "device": 'cuda',
    "batch_size": 32,
    "optimizer": "Adam",
    "loss_function": "CrossEntropyLoss",
    "num_client_selection": 6,
    "num_labels_per_client": 3
}

In [46]:
fl_avg = FedAVG(config)

In [47]:
global_model = fl_avg.run()

Round 1 completed
Test Loss: 718.0779 - Test Accuracy: 0.1698
--------------------------------------------------
Round 2 completed
Test Loss: 631.9236 - Test Accuracy: 0.2781
--------------------------------------------------
Round 3 completed
Test Loss: 672.9356 - Test Accuracy: 0.3972
--------------------------------------------------
Round 4 completed
Test Loss: 555.6201 - Test Accuracy: 0.4403
--------------------------------------------------
Round 5 completed
Test Loss: 565.7080 - Test Accuracy: 0.5105
--------------------------------------------------
Round 6 completed
Test Loss: 546.7292 - Test Accuracy: 0.4844
--------------------------------------------------
Round 7 completed
Test Loss: 437.6212 - Test Accuracy: 0.6330
--------------------------------------------------
Round 8 completed
Test Loss: 401.4468 - Test Accuracy: 0.6732
--------------------------------------------------
Round 9 completed
Test Loss: 311.9458 - Test Accuracy: 0.7171
----------------------------------