In [1]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import pickle

class NeuralNetwork(nn.Module):
    def __init__(self, model_params):
        super(NeuralNetwork, self).__init__()
        
        input_dim = model_params["input_dim"]
        hidden_dim = model_params["hidden_dim"]
        output_dim = model_params["output_dim"]
        
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        return out

class Net_MNIST(nn.Module):
    def __init__(self, model_params):
        super(Net_MNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
            x = self.pool1(torch.relu(self.conv1(x)))
            x = self.pool2(torch.relu(self.conv2(x)))
            x = x.view(-1, 32 * 4 * 4)
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            x = self.fc3(x)
            return x
        
class Net_CIFAR10(nn.Module):
    def __init__(self, model_params):
        super(Net_CIFAR10, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
            x = self.pool1(torch.relu(self.conv1(x)))
            x = self.pool2(torch.relu(self.conv2(x)))
            x = x.view(-1, 32 * 5 * 5)
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
class LogisticRegression(nn.Module):
    def __init__(self, model_params):
        super(LogisticRegression, self).__init__()
        
        input_dim = model_params["input_dim"]
        output_dim = model_params["output_dim"]
        
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        out = self.linear(x)
        return out

    
# Currently only allows for scrambling of labels as the corruption method

def FedAvg(model, data, model_params, training_params, corrupt=False, cp=0.5):
    acc_dict = dict()
    
    model_dict = {1: NeuralNetwork, 2: LogisticRegression, 3: "Net"}
    model = model_dict[model]
    
    num_clients = training_params["num_clients"]
    epochs = training_params["epochs"]
    batch_size = training_params["batch_size"]
    lr = training_params["lr"]
    
    
    if data == "MNIST":
        # Load MNIST dataset - Normalized (MEAN STD)
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
        if model == "Net":
            model = Net_MNIST
    
    if data == "CIFAR10":
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        train_dataset = datasets.CIFAR10('../data', train=True, download=True, transform=transform)
        if model == "Net":
            model = Net_CIFAR10

    # Split data among clients
    client_datasets = torch.utils.data.random_split(train_dataset, [len(train_dataset) // num_clients] * num_clients)
    
    # Initialize global model
    central_model = model(model_params)

    # Train global model using federated averaging
    central_optimizer = optim.SGD(central_model.parameters(), lr=lr)
    central_criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        central_model.train()

        # Train local models on each client
        local_models = []
        for client_dataset in client_datasets:
            local_model = model(model_params)
            local_model.load_state_dict(central_model.state_dict())
            local_optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=0.9)
            local_criterion = nn.CrossEntropyLoss()

            for local_epoch in range(epochs):
                local_model.train()

                for local_data, local_target in torch.utils.data.DataLoader(client_dataset, batch_size=batch_size, shuffle=True):
                    if corrupt:
                        local_target = torch.as_tensor(data_corruption(3, local_data, local_target.tolist(), cp)).type(torch.LongTensor)
                    local_optimizer.zero_grad()
                    if model == Net_MNIST or model == Net_CIFAR10:
                        local_output = local_model(local_data)
                    else:
                        local_output = local_model(local_data.view(local_data.shape[0], -1))
                        
                    local_loss = local_criterion(local_output, local_target)
                    local_loss.backward()
                    local_optimizer.step()

            local_models.append(local_model)

        # Update central model using federated averaging
        for name, param in central_model.named_parameters():
            if name.endswith('.bias'):
                continue

            local_params = torch.stack([local_model.state_dict()[name] for local_model in local_models])
            central_mean = local_params.mean(0)
            param.data = central_mean.data

        central_loss = 0
        central_accuracy = 0
        central_optimizer.zero_grad()

        for central_data, central_target in torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True):
            if model == Net_MNIST or model == Net_CIFAR10:
                central_output = central_model(central_data)
            else:
                central_output = central_model(central_data.view(central_data.shape[0], -1))
            central_loss += central_criterion(central_output, central_target)
            central_accuracy += (central_output.argmax(1) == central_target).float().sum()

        central_loss /= len(train_dataset)
        central_accuracy /= len(train_dataset)

        central_loss.backward()
        central_optimizer.step()
        
        acc_dict[epoch+1] = central_accuracy
        print(f'Epoch {epoch+1} - Global Training Loss: {central_loss:.4f}, Global Training Accuracy: {central_accuracy:.4f}')
    return acc_dict



training_params = {"epochs": 10, "lr": 0.01, "batch_size": 32, "num_clients": 10}
model_params = {"input_dim": 784, "output_dim": 10}
results = FedAvg(2, "MNIST", model_params, training_params, corrupt=False, cp=0.9)


with open('Logistic_regression_MNIST_C4', 'wb') as handle:
    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x784 and 728x10)