In [1]:
# Code cell intended for imports and global settings.

# Imports
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from SupportClasses import ModelSupport as ms, EnvironmentSetup as env

# Global values
NUMBER_OF_NODES = 3     # We only support values between 2 and 4 at present.
NUMBER_OF_EPOCHS = 5
NUMBER_OF_CLASSES = NUMBER_OF_NODES     # Generally we will have the same number of classes as nodes.
PATH = 'baseline-model.pt'

# Model specific values
criterion = nn.CrossEntropyLoss()

In [2]:
# This code cell will be used for setting up the unbalanced datasets.

# Note that we are implicitly assuming the data is well balanced in the original dataset.
# Data distributions based on the number of nodes.
data_distribution_list = env.data_distribution(NUMBER_OF_NODES)
train_set, validation_set, test_set, classes = env.download_mnist(NUMBER_OF_CLASSES)

# Now we distribute the dataset, for each node.
unbalanced_training_sets = []
for data_dist in data_distribution_list:
    unbalanced_training_sets.append( env.unbalance_training_set(train_set=train_set, classes=classes, data_distribution=data_dist) )

The classes in the training and testing set are [0, 1, 2]


In [None]:
# This code cell is likely where you will want to do the GAN work on the given datasets.

In [3]:
global_train_loader = env.create_single_loader(train_set.dataset)
# This code cell is to be used for importing data and setting up the model.
training_loaders, validation_loader, test_loader = env.create_data_loaders(training_sets=unbalanced_training_sets,
                                                                       validation_set=validation_set, test_set=test_set)
# Create and load the models. We initiate the model with None as we will update it with the global model in each round.
fed_models = {f"Federated_Model_{i+1}": ms.FederatedModel(train_loader, validation_loader,
                                                          ms.ConvNet(NUMBER_OF_CLASSES))
                for i, train_loader in enumerate(training_loaders)}

# Create the baseline, non-federated model.
baseline_model = ms.ConvNet(NUMBER_OF_CLASSES)
# Create the federated model
federated_model = ms.ConvNet(NUMBER_OF_CLASSES)

# Send the models to the CUDA device if it exists.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
baseline_model.to(device=device)
federated_model.to(device=device)

ConvNet(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (fc1): Linear(in_features=1568, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=3, bias=True)
)

In [6]:
# Here we train a baseline model on all data. No federation as our baseline.
optimizer = optim.Adam(baseline_model.parameters())

# We train a new model, if the model does not already exist in memory.
if not os.path.exists(PATH):
    for epoch in range(NUMBER_OF_EPOCHS):
        start_time = time.time()
        train_loss, train_acc = ms.train(baseline_model, global_train_loader, optimizer, criterion, device=device)
        valid_loss, valid_acc = ms.test(baseline_model, validation_loader, criterion, device=device)
        end_time = time.time()
        # Get the time to perform non-federated learning
        epoch_mins, epoch_secs = ms.epoch_time(start_time, end_time)

        print(f'Epoch: {epoch+1:02} | Model name: Baseline Model | Epoch time (Baseline Training): {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.5f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')
    torch.save(baseline_model.state_dict(), PATH)
print("Baseline model training complete.\n\n")

18625 5923 6742 5958
Epoch: 01 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 9s
	Train Loss: 0.00158 | Train Acc: 98.81%
	 Val. Loss: 0.00011 |  Val. Acc: 99.57%
18625 5923 6742 5958
Epoch: 02 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 9s
	Train Loss: 0.00042 | Train Acc: 99.61%
	 Val. Loss: 0.00002 |  Val. Acc: 99.89%
18625 5923 6742 5958
Epoch: 03 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 9s
	Train Loss: 0.00026 | Train Acc: 99.83%
	 Val. Loss: 0.00002 |  Val. Acc: 99.89%
18625 5923 6742 5958
Epoch: 04 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 9s
	Train Loss: 0.00026 | Train Acc: 99.80%
	 Val. Loss: 0.00001 |  Val. Acc: 100.00%
18625 5923 6742 5958
Epoch: 05 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 11s
	Train Loss: 0.00012 | Train Acc: 99.90%
	 Val. Loss: 0.00000 |  Val. Acc: 100.00%
Baseline model training complete.




In [4]:
# Here we train our federated model.
best_valid_loss = float('inf')

for epoch in range(NUMBER_OF_EPOCHS):
    # Perform the computation steps on the individual models
    start_time = time.time()
    for key, fed_model in fed_models.items():
        # Update each model with the global model, before training again.
        fed_model.model.load_state_dict(federated_model.state_dict())
        fed_model.model.to(device=device)

        # Begin training
        optimizer = optim.Adam(fed_model.model.parameters())
        train_loss, train_acc = ms.train(fed_model.model, fed_model.train_loader, optimizer, criterion, device=device)
        valid_loss, valid_acc = ms.test(fed_model.model, fed_model.validation_loader, criterion, device=device)
        print(f'Epoch: {epoch+1:02} | Model name: {key}')
        print(f'\tTrain Loss: {train_loss:.5f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')
    end_time = time.time()
    # Get the time to perform federated learning
    epoch_mins, epoch_secs = ms.epoch_time(start_time, end_time)

    # Average the federated models and combine their weights into the main model.
    federated_model.load_state_dict(ms.federated_averaging(fed_models))
    # Validate this model on a, small balanced validation set
    valid_loss, valid_acc = ms.test(federated_model, validation_loader, criterion, device=device)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        # This will save our best model in case we encounter a drop off during training.
        torch.save(federated_model.state_dict(), 'best-model.pt')

    print(f'Epoch: {epoch+1:02} | Model name: Federated Average | Epoch time (Federated Training): {epoch_mins}m {epoch_secs}s')
    print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')

print("Federated Model training complete.\n\n")

5400 1264 2846 1271
Epoch: 01 | Model name: Federated_Model_1
	Train Loss: 0.00334 | Train Acc: 96.88%
	 Val. Loss: 0.00027 |  Val. Acc: 99.10%
5250 1264 1423 2542
Epoch: 01 | Model name: Federated_Model_2
	Train Loss: 0.00331 | Train Acc: 97.19%
	 Val. Loss: 0.00024 |  Val. Acc: 99.36%
5225 2528 1423 1271
Epoch: 01 | Model name: Federated_Model_3
	Train Loss: 0.00362 | Train Acc: 96.71%
	 Val. Loss: 0.00035 |  Val. Acc: 98.68%
Epoch: 01 | Model name: Federated Average | Epoch time (Federated Training): 1m 33s
	 Val. Loss: 0.00042 |  Val. Acc: 98.64%
5400 1264 2846 1271
Epoch: 02 | Model name: Federated_Model_1
	Train Loss: 0.00115 | Train Acc: 99.09%
	 Val. Loss: 0.00021 |  Val. Acc: 99.43%
5250 1264 1423 2542
Epoch: 02 | Model name: Federated_Model_2
	Train Loss: 0.00104 | Train Acc: 99.08%
	 Val. Loss: 0.00016 |  Val. Acc: 99.53%
5225 2528 1423 1271
Epoch: 02 | Model name: Federated_Model_3
	Train Loss: 0.00087 | Train Acc: 99.31%
	 Val. Loss: 0.00021 |  Val. Acc: 99.32%
Epoch: 02 |

In [7]:
# The main testing loop
# Load the model
baseline_model.load_state_dict(torch.load(PATH))
federated_model.load_state_dict(torch.load('best-model.pt'))

baseline_test_loss, baseline_test_acc = ms.test(baseline_model, test_loader, criterion, device=device)
fed_avg_test_loss, fed_avg_test_acc = ms.test(federated_model, test_loader, criterion, device=device)

print(f'Model name: Baseline | Test Loss: {baseline_test_loss:.3f} | Test Acc: {baseline_test_acc*100:.2f}%')
print(f'Model name: Federated Average | Test Loss: {fed_avg_test_loss:.3f} | Test Acc: {fed_avg_test_acc*100:.2f}%')

Model name: Baseline | Test Loss: 0.000 | Test Acc: 99.90%
Model name: Federated Average | Test Loss: 0.000 | Test Acc: 99.94%
