In [6]:
# Code cell intended for imports and global settings.

# Imports
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from SupportClasses import ModelSupport as ms, EnvironmentSetup as env
from torch.utils.data import ConcatDataset

# Global values
NUMBER_OF_NODES = 3     # We only support values between 2 and 4 at present.
NUMBER_OF_EPOCHS = 5
NUMBER_OF_CLASSES = NUMBER_OF_NODES     # Generally we will have the same number of classes as nodes.
PATH = 'baseline-model.pt'
DATASET = 'cifar10'     # You can either pick: mnist, fashion_mnist, or cifar10.
SUBSET_MAX_SIZE = 100   # This value controls the maximum size of a training subset. This can be modified.

# Model specific values
criterion = nn.CrossEntropyLoss()

# Remove baseline and best model when we restart
os.remove("baseline-model.pt")
os.remove("best-model.pt")

In [2]:
# This code cell will be used for setting up the unbalanced datasets.

# Note that we are implicitly assuming the data is well balanced in the original dataset.
# Data distributions based on the number of nodes.
data_distribution_list = env.data_distribution(NUMBER_OF_NODES)
train_set, validation_set, test_set, classes = env.download_cifar10(NUMBER_OF_CLASSES)

# Now we distribute the dataset, for each node.
unbalanced_training_sets = []
for data_dist in data_distribution_list:
    unbalanced_training_sets.append( env.unbalance_training_set(train_set=train_set, classes=classes, data_distribution=data_dist,
                                                                subset_max_size=SUBSET_MAX_SIZE) )

print("Done importing data.")

Files already downloaded and verified
Files already downloaded and verified
The classes in the training and testing set are ['airplane', 'automobile', 'bird']
Done importing data.


In [None]:
# This code cell is likely where you will want to do the GAN work on the given datasets.

In [3]:
# The global model will be trained on the mix of the unbalanced training sets.
global_train_loader = env.create_single_loader(ConcatDataset(unbalanced_training_sets))
# This code cell is to be used for importing data and setting up the model.
training_loaders, validation_loader, test_loader = env.create_data_loaders(training_sets=unbalanced_training_sets,
                                                                       validation_set=validation_set, test_set=test_set)
# Create and load the models. We initiate the model with None as we will update it with the global model in each round.
fed_models = {f"Federated_Model_{i+1}": ms.FederatedModel(train_loader, validation_loader,
                  ms.ConvNetCifar(NUMBER_OF_CLASSES) if DATASET == 'cifar10' else ms.ConvNetMnist(NUMBER_OF_CLASSES))
                for i, train_loader in enumerate(training_loaders)}

# Create the baseline, non-federated model.
baseline_model = ms.ConvNetCifar(NUMBER_OF_CLASSES) if DATASET == 'cifar10' else ms.ConvNetMnist(NUMBER_OF_CLASSES)
# Create the federated model
federated_model = ms.ConvNetCifar(NUMBER_OF_CLASSES) if DATASET == 'cifar10' else ms.ConvNetMnist(NUMBER_OF_CLASSES)

# Send the models to the CUDA device if it exists.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
baseline_model.to(device=device)
federated_model.to(device=device)

ConvNetCifar(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=3, bias=True)
)

In [7]:
# Here we train a baseline model on all data. No federation as our baseline.
optimizer = optim.Adam(baseline_model.parameters())

# We train a new model, if the model does not already exist in memory.
if not os.path.exists(PATH):
    for epoch in range(NUMBER_OF_EPOCHS):
        start_time = time.time()
        train_loss, train_acc = ms.train(baseline_model, global_train_loader, optimizer, criterion, device=device)
        valid_loss, valid_acc = ms.test(baseline_model, validation_loader, criterion, device=device)
        end_time = time.time()
        # Get the time to perform non-federated learning
        epoch_mins, epoch_secs = ms.epoch_time(start_time, end_time)

        print(f'Epoch: {epoch+1:02} | Model name: Baseline Model | Epoch time (Baseline Training): {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.5f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')
    torch.save(baseline_model.state_dict(), PATH)
print("Baseline model training complete.\n\n")

Epoch: 01 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 4s
	Train Loss: 0.03163 | Train Acc: 66.00%
	 Val. Loss: 0.00866 |  Val. Acc: 63.60%
Epoch: 02 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 4s
	Train Loss: 0.02880 | Train Acc: 71.67%
	 Val. Loss: 0.00853 |  Val. Acc: 64.98%
Epoch: 03 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 4s
	Train Loss: 0.02726 | Train Acc: 73.67%
	 Val. Loss: 0.00856 |  Val. Acc: 63.91%
Epoch: 04 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 4s
	Train Loss: 0.02541 | Train Acc: 73.33%
	 Val. Loss: 0.00853 |  Val. Acc: 63.69%
Epoch: 05 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 4s
	Train Loss: 0.02368 | Train Acc: 75.67%
	 Val. Loss: 0.00821 |  Val. Acc: 66.49%
Baseline model training complete.




In [8]:
# Here we train our federated model.
best_valid_loss = float('inf')

for epoch in range(NUMBER_OF_EPOCHS):
    # Perform the computation steps on the individual models
    start_time = time.time()
    for key, fed_model in fed_models.items():
        # Update each model with the global model, before training again.
        fed_model.model.load_state_dict(federated_model.state_dict())
        fed_model.model.to(device=device)

        # Begin training
        optimizer = optim.Adam(fed_model.model.parameters())
        train_loss, train_acc = ms.train(fed_model.model, fed_model.train_loader, optimizer, criterion, device=device)
        valid_loss, valid_acc = ms.test(fed_model.model, fed_model.validation_loader, criterion, device=device)
        print(f'Epoch: {epoch+1:02} | Model name: {key}')
        print(f'\tTrain Loss: {train_loss:.5f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')
    end_time = time.time()
    # Get the time to perform federated learning
    epoch_mins, epoch_secs = ms.epoch_time(start_time, end_time)

    # Average the federated models and combine their weights into the main model.
    federated_model.load_state_dict(ms.federated_averaging(fed_models))
    # Validate this model on a, small balanced validation set
    valid_loss, valid_acc = ms.test(federated_model, validation_loader, criterion, device=device)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        # This will save our best model in case we encounter a drop off during training.
        torch.save(federated_model.state_dict(), 'best-model.pt')

    print(f'Epoch: {epoch+1:02} | Model name: Federated Average | Epoch time (Federated Training): {epoch_mins}m {epoch_secs}s')
    print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')

print("Federated Model training complete.\n\n")

Epoch: 01 | Model name: Federated_Model_1
	Train Loss: 0.03636 | Train Acc: 61.00%
	 Val. Loss: 0.01130 |  Val. Acc: 40.67%
Epoch: 01 | Model name: Federated_Model_2
	Train Loss: 0.03952 | Train Acc: 56.00%
	 Val. Loss: 0.01128 |  Val. Acc: 38.93%
Epoch: 01 | Model name: Federated_Model_3
	Train Loss: 0.03906 | Train Acc: 52.00%
	 Val. Loss: 0.01044 |  Val. Acc: 42.84%
Epoch: 01 | Model name: Federated Average | Epoch time (Federated Training): 0m 14s
	 Val. Loss: 0.00989 |  Val. Acc: 58.44%
Epoch: 02 | Model name: Federated_Model_1
	Train Loss: 0.03464 | Train Acc: 67.00%
	 Val. Loss: 0.01064 |  Val. Acc: 46.84%
Epoch: 02 | Model name: Federated_Model_2
	Train Loss: 0.03820 | Train Acc: 56.00%
	 Val. Loss: 0.01094 |  Val. Acc: 42.67%
Epoch: 02 | Model name: Federated_Model_3
	Train Loss: 0.03746 | Train Acc: 56.00%
	 Val. Loss: 0.01034 |  Val. Acc: 45.07%
Epoch: 02 | Model name: Federated Average | Epoch time (Federated Training): 0m 13s
	 Val. Loss: 0.00949 |  Val. Acc: 60.13%
Epoch:

In [9]:
# The main testing loop
# Load the model
baseline_model.load_state_dict(torch.load(PATH))
federated_model.load_state_dict(torch.load('best-model.pt'))

baseline_test_loss, baseline_test_acc = ms.test(baseline_model, test_loader, criterion, device=device)
fed_avg_test_loss, fed_avg_test_acc = ms.test(federated_model, test_loader, criterion, device=device)

print(f'Model name: Baseline | Test Loss: {baseline_test_loss:.3f} | Test Acc: {baseline_test_acc*100:.2f}%')
print(f'Model name: Federated Average | Test Loss: {fed_avg_test_loss:.3f} | Test Acc: {fed_avg_test_acc*100:.2f}%')

Model name: Baseline | Test Loss: 0.008 | Test Acc: 67.03%
Model name: Federated Average | Test Loss: 0.009 | Test Acc: 59.33%
