### FedAvg algorithm, currently working for both Logistic Regression and 1 hidden layer NN.
##### See bottom for specific calls
https://towardsdatascience.com/logistic-regression-with-pytorch-3c8bbea594be


In [61]:
%run -i DataCorruption.ipynb

In [94]:
# Define neural network model
class NeuralNetwork(nn.Module):
    def __init__(self, id, **kwargs):
        super(NeuralNetwork, self).__init__()
        self.id = id
        for key, value in kwargs.items():
            setattr(self, key, value)
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        return out

In [114]:
# Define logistic regression model
class LogisticRegression(nn.Module):
    def __init__(self, id, **kwargs):
        super(LogisticRegression, self).__init__()
        self.id = id
        for key, value in kwargs.items():
            setattr(self, key, value)
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        out = self.linear(x)
        return out


In [118]:

# Currently only allows for scrambling of labels as the corruption method

def FedAvg(model, data, model_params, training_params, corrupted=False, cp=0.5):
    model_dict = {1: NeuralNetwork, 2: LogisticRegression}
    
    for (k,v) in training_params.items():
          exec(f'{k} = {v}')
    
    
    if data == "MNIST":
        # Load MNIST dataset - Normalized (MEAN STD)
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    
    # Split data among clients
    central_datasets = torch.utils.data.random_split(train_dataset, [len(train_dataset) // num_clients] * num_clients)
    
    # Initialize global model
    central_model = model_dict[model](model_params)

    # Train global model using federated averaging
    central_optimizer = optim.SGD(central_model.parameters(), lr=lr)
    central_criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        central_model.train()

        # Train local models on each client
        local_models = []
        for client_dataset in client_datasets:
            local_model = model_dict[model](model_params)
            local_model.load_state_dict(central_model.state_dict())
            local_optimizer = optim.SGD(local_model.parameters(), lr=lr)
            local_criterion = nn.CrossEntropyLoss()

            for local_epoch in range(epochs):
                local_model.train()

                for local_data, local_target in torch.utils.data.DataLoader(client_dataset, batch_size=batch_size, shuffle=True):
                    if corrupt:
                        local_target = torch.as_tensor(data_corruption(3, local_data, local_target.tolist(), cp)).type(torch.LongTensor)
                    local_optimizer.zero_grad()
                    local_output = local_model(local_data.view(local_data.shape[0], -1))
                    local_loss = local_criterion(local_output, local_target)
                    local_loss.backward()
                    local_optimizer.step()

            local_models.append(local_model)

        # Update global model using federated averaging
        for name, param in central_model.named_parameters():
            if name.endswith('.bias'):
                continue

            local_params = torch.stack([local_model.state_dict()[name] for local_model in local_models])
            central_mean = local_params.mean(0)
            param.data = central_mean.data

        central_loss = 0
        central_accuracy = 0
        central_optimizer.zero_grad()

        for central_data, central_target in torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True):
            central_output = central_model(central_data.view(central_data.shape[0], -1))
            central_loss += central_criterion(central_output, central_target)
            central_accuracy += (central_output.argmax(1) == central_target).float().sum()

        central_loss /= len(train_dataset)
        central_accuracy /= len(train_dataset)

        central_loss.backward()
        central_optimizer.step()

        print(f'Epoch {epoch+1} - Global Loss: {central_loss:.4f}, Global Accuracy: {central_accuracy:.4f}')


### NN

In [119]:
training_params = {"epochs": 2, "lr": 0.01, "batch_size": 32, "num_clients": 4}
model_params = {"input_dim": 728, "hidden_dim": 128, "output_dim": 10}
FedAvg(1, "MNIST", model_params, training_params, corrupted=True, cp=0.9)

Epoch 1 - Global Loss: 0.0697, Global Accuracy: 0.3515
Epoch 2 - Global Loss: 0.0695, Global Accuracy: 0.4629


### Logistic Regression

In [120]:
training_params = {"epochs": 2, "lr": 0.01, "batch_size": 32, "num_clients": 4}
model_params = {"input_dim": 728, "output_dim": 10}
FedAvg(2, "MNIST", model_params, training_params, corrupted=True, cp=0.9)

Epoch 1 - Global Loss: 0.0694, Global Accuracy: 0.2085
Epoch 2 - Global Loss: 0.0683, Global Accuracy: 0.2913
