<a href="https://colab.research.google.com/github/MahzabinC/Federated-Learning/blob/main/Federated_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import random
from copy import deepcopy

# Federated Learning Parameters
NUM_CLIENTS = 100
FRACTION_CLIENTS = 0.1  # Fraction of clients to sample each round
LOCAL_EPOCHS = 5
BATCH_SIZE = 50  # Local batch size
LR = 0.01
ROUNDS = 100  # Total communication rounds

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# Dataset Preparation
def get_dataloader(dataset_name, iid=True):
    if dataset_name == "mnist":
        dataset = datasets.MNIST(
            root="./data",
            train=True,
            download=True,
            transform=transforms.ToTensor()
        )
    elif dataset_name == "cifar10":
        dataset = datasets.CIFAR10(
            root="./data",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        )
    else:
        raise ValueError("Dataset not supported")

    # IID partitioning
    if iid:
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        client_data = np.array_split(indices, NUM_CLIENTS)
    else:
        # Non-IID partitioning (two-class partitioning for MNIST)
        if dataset_name == "mnist":
            labels = dataset.targets.numpy()
        elif dataset_name == "cifar10":
            labels = np.array(dataset.targets)
        indices = np.argsort(labels)
        client_data = np.array_split(indices, NUM_CLIENTS)

    return dataset, client_data

In [5]:
# Model Architectures
class MNIST2NN(nn.Module):
    def __init__(self):
        super(MNIST2NN, self).__init__()
        self.fc1 = nn.Linear(28*28, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class CNN(nn.Module):
    def __init__(self, input_channels=1):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.fc1 = nn.Linear(64*8*8, 512) # Changed from 64*7*7 to 64*8*8
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 64*8*8)  # Changed from 64*7*7 to 64*8*8
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Federated Learning Process
def train_federated(model, clients_data, dataset, iid=True):
    global_model = deepcopy(model)
    criterion = nn.CrossEntropyLoss()

    for round_num in range(ROUNDS):
        sampled_clients = random.sample(range(NUM_CLIENTS), int(FRACTION_CLIENTS * NUM_CLIENTS))
        local_models = []

        for client in sampled_clients:
            local_model = deepcopy(global_model)
            optimizer = optim.SGD(local_model.parameters(), lr=LR)

            # Prepare client data
            client_indices = clients_data[client]
            client_data = torch.utils.data.Subset(dataset, client_indices)
            dataloader = torch.utils.data.DataLoader(client_data, batch_size=BATCH_SIZE, shuffle=True)

            # Local training
            local_model.train()
            for _ in range(LOCAL_EPOCHS):
                for data, target in dataloader:
                    data, target = data.to(device), target.to(device)
                    optimizer.zero_grad()
                    output = local_model(data)
                    loss = criterion(output, target)
                    loss.backward()
                    optimizer.step()

            local_models.append(deepcopy(local_model.state_dict()))

        # Aggregate global model
        global_dict = global_model.state_dict()
        for key in global_dict.keys():
            global_dict[key] = torch.stack([local_model[key] for local_model in local_models], dim=0).mean(dim=0)
        global_model.load_state_dict(global_dict)

        # Evaluate global model
        evaluate(global_model, dataset)

    return global_model

def evaluate(model, dataset):
    model.eval()
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=1000, shuffle=False)
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    print(f"Accuracy: {correct / total:.4f}")

In [7]:
# Experiments
# MNIST 2NN
mnist_dataset, mnist_clients = get_dataloader("mnist", iid=True)
mnist_model = MNIST2NN().to(device)
train_federated(mnist_model, mnist_clients, mnist_dataset)

Accuracy: 0.2177
Accuracy: 0.3122
Accuracy: 0.4393
Accuracy: 0.5088
Accuracy: 0.5280
Accuracy: 0.5191
Accuracy: 0.5619
Accuracy: 0.6322
Accuracy: 0.6747
Accuracy: 0.7033
Accuracy: 0.7412
Accuracy: 0.7697
Accuracy: 0.7862
Accuracy: 0.8047
Accuracy: 0.8146
Accuracy: 0.8234
Accuracy: 0.8297
Accuracy: 0.8376
Accuracy: 0.8422
Accuracy: 0.8499
Accuracy: 0.8532
Accuracy: 0.8577
Accuracy: 0.8623
Accuracy: 0.8645
Accuracy: 0.8695
Accuracy: 0.8712
Accuracy: 0.8731
Accuracy: 0.8762
Accuracy: 0.8786
Accuracy: 0.8810
Accuracy: 0.8838
Accuracy: 0.8856
Accuracy: 0.8868
Accuracy: 0.8886
Accuracy: 0.8893
Accuracy: 0.8906
Accuracy: 0.8923
Accuracy: 0.8938
Accuracy: 0.8948
Accuracy: 0.8951
Accuracy: 0.8961
Accuracy: 0.8967
Accuracy: 0.8974
Accuracy: 0.8981
Accuracy: 0.8995
Accuracy: 0.9001
Accuracy: 0.9002
Accuracy: 0.9014
Accuracy: 0.9014
Accuracy: 0.9021
Accuracy: 0.9027
Accuracy: 0.9036
Accuracy: 0.9046
Accuracy: 0.9040
Accuracy: 0.9058
Accuracy: 0.9050
Accuracy: 0.9066
Accuracy: 0.9061
Accuracy: 0.90

MNIST2NN(
  (fc1): Linear(in_features=784, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=10, bias=True)
)

In [8]:
# CIFAR-10 CNN
cifar10_dataset, cifar10_clients = get_dataloader("cifar10", iid=True)
cifar10_model = CNN(input_channels=3).to(device)
train_federated(cifar10_model, cifar10_clients, cifar10_dataset)

Files already downloaded and verified
Accuracy: 0.1706
Accuracy: 0.1830
Accuracy: 0.2260
Accuracy: 0.2466
Accuracy: 0.2649
Accuracy: 0.2709
Accuracy: 0.2901
Accuracy: 0.3085
Accuracy: 0.3193
Accuracy: 0.3401
Accuracy: 0.3513
Accuracy: 0.3601
Accuracy: 0.3733
Accuracy: 0.3791
Accuracy: 0.3896
Accuracy: 0.3957
Accuracy: 0.4060
Accuracy: 0.4174
Accuracy: 0.4258
Accuracy: 0.4343
Accuracy: 0.4421
Accuracy: 0.4480
Accuracy: 0.4566
Accuracy: 0.4628
Accuracy: 0.4639
Accuracy: 0.4726
Accuracy: 0.4775
Accuracy: 0.4817
Accuracy: 0.4859
Accuracy: 0.4906
Accuracy: 0.4946
Accuracy: 0.4968
Accuracy: 0.5006
Accuracy: 0.5061
Accuracy: 0.5102
Accuracy: 0.5162
Accuracy: 0.5163
Accuracy: 0.5231
Accuracy: 0.5226
Accuracy: 0.5252
Accuracy: 0.5302
Accuracy: 0.5349
Accuracy: 0.5390
Accuracy: 0.5403
Accuracy: 0.5436
Accuracy: 0.5469
Accuracy: 0.5474
Accuracy: 0.5528
Accuracy: 0.5544
Accuracy: 0.5574
Accuracy: 0.5590
Accuracy: 0.5609
Accuracy: 0.5635
Accuracy: 0.5674
Accuracy: 0.5679
Accuracy: 0.5724
Accuracy: 

CNN(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (fc1): Linear(in_features=4096, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
)