In [3]:
# Code cell intended for imports and global settings.

# Imports
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from SupportClasses import ModelSupport as ms, EnvironmentSetup as env

# Global values
NUMBER_OF_NODES = 2     # We only support values between 2 and 4 at present.
NUMBER_OF_EPOCHS = 5

# Model specific values
criterion = nn.CrossEntropyLoss()

# Data distributions based on the number of nodes.
data_distribution = env.data_distribution(NUMBER_OF_NODES)

(0.25, 0.75)
(0.75, 0.25)


In [6]:
# This code cell is to be used for importing data and setting up the model.
train_loader, validation_loader, test_loader = env.download_mnist()
# Create and load the models
models = {f"Federated_Model_{i}": ms.ConvNet(len(train_loader.dataset.dataset.classes)) for i in range(NUMBER_OF_NODES)}

# Create the baseline, non-federated model.
baseline_model = ms.ConvNet(len(train_loader.dataset.dataset.classes))
# Create the federated model
federated_model = ms.ConvNet(len(train_loader.dataset.dataset.classes))

# Send the models to the CUDA device if it exists.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
for model in list(models.values())+[federated_model, baseline_model]:
    model.to(device=device)

ConvNet(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (fc1): Linear(in_features=1568, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=10, bias=True)
)
ConvNet(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (fc1): Linear(in_features=1568, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=10, bias=True)
)
ConvNet(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (fc1)

In [None]:
# Here we train a general model on all data. No federation as our baseline.
path = 'baseline-model.pt'
optimizer = optim.Adam(baseline_model.parameters())

# We train a new model, if the model does not already exist in memory.
if not os.path.exists(path):
    for epoch in range(NUMBER_OF_EPOCHS):
        start_time = time.time()
        train_loss, train_acc = ms.train(baseline_model, train_loader, optimizer, criterion, device=device)
        valid_loss, valid_acc = ms.test(baseline_model, validation_loader, criterion, device=device)
        end_time = time.time()
        # Get the time to perform non-federated learning
        epoch_mins, epoch_secs = ms.epoch_time(start_time, end_time)

        print(f'Epoch: {epoch+1:02} | Model name: Baseline Model | Epoch time (Baseline Training): {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.5f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')
    torch.save(model.state_dict(), path)

In [3]:
# Here we train our federated model.
best_valid_loss = float('inf')

for epoch in range(NUMBER_OF_EPOCHS):
    # Perform the computation steps on the individual models
    start_time = time.time()
    for key, model in models.items():
        optimizer = optim.Adam(model.parameters())
        train_loss, train_acc = ms.train(model, train_loader, optimizer, criterion, device=device)
        valid_loss, valid_acc = ms.test(model, validation_loader, criterion, device=device)
        print(f'Epoch: {epoch+1:02} | Model name: {key}')
        print(f'\tTrain Loss: {train_loss:.5f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')
    end_time = time.time()
    # Get the time to perform federated learning
    epoch_mins, epoch_secs = ms.epoch_time(start_time, end_time)

    # Average the federated models and combine their weights into the main model.
    federated_model = ms.federated_averaging(models)
    # Validate this model on a, small balanced validation set
    valid_loss, valid_acc = ms.test(federated_model, validation_loader, criterion, device=device)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        # This will save our best model in case we encounter a drop off during training.
        torch.save(model.state_dict(), 'best-model.pt')

    print(f'Epoch: {epoch+1:02} | Model name: Federated Average | Epoch time (Federated Training): {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.5f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01
	Train Loss: 0.00504 | Train Acc: 96.06%
	 Val. Loss: 0.00060 |  Val. Acc: 98.16%


KeyboardInterrupt: 

In [4]:
# The main testing loop
# Load the model
baseline_model.load_state_dict(torch.load(path))
federated_model.load_state_dict(torch.load('best-model.pt'))

baseline_test_loss, baseline_test_acc = ms.test(baseline_model, test_loader, criterion, device=device)
fed_avg_test_loss, fed_avg_test_acc = ms.test(federated_model, test_loader, criterion, device=device)

print(f'Model name: Baseline | Test Loss: {baseline_test_loss:.3f} | Test Acc: {baseline_test_acc*100:.2f}%')
print(f'Model name: Federated Average | Test Loss: {fed_avg_test_loss:.3f} | Test Acc: {fed_avg_test_acc*100:.2f}%')

Test Loss: 0.000 | Test Acc: 99.01%
