In [1]:
# Code cell intended for imports and global settings.

# Imports
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from SupportClasses import ModelSupport as ms, EnvironmentSetup as env
from torch.utils.data import ConcatDataset

# Global values
# These are the values you can change
NUMBER_OF_NODES = 3     # We only support values between 2 and 4 at present.
NUMBER_OF_EPOCHS = 5
DATA_DISTRIBUTION = [None, None, (25, 100), (25, 25, 100), (25, 25, 25, 100)]    # Keep these values under 3000.
NUMBER_OF_CLASSES = NUMBER_OF_NODES     # Generally we will have the same number of classes as nodes.
DATASET = env.Dataset.CIFAR_10     # You can either pick: MNIST, fashion_mnist, or cifar10.
MU, SIGMA = 0, 0        # These values are meant to be used for adding noise to the federated average model.
# End of values you can change.

BEST_FEDERATED_MODEL_PATH = "best-model.pt"
BASELINE_MODEL_PATH = 'baseline-model.pt'

# Model specific values
criterion = nn.CrossEntropyLoss()

# Remove baseline and best model when we restart
try:
    os.remove(BEST_FEDERATED_MODEL_PATH)
    os.remove(BASELINE_MODEL_PATH)
except:
    pass

In [2]:
# This code cell will be used for setting up the unbalanced datasets.

# Note that we are implicitly assuming the data is well balanced in the original dataset.
# Data distributions based on the number of nodes.
data_distribution_list = env.data_distribution(NUMBER_OF_NODES, DATA_DISTRIBUTION)
train_set, validation_set, test_set, classes = env.download_cifar10(NUMBER_OF_CLASSES)

# Now we distribute the dataset, for each node.
unbalanced_training_sets = []
for data_dist in data_distribution_list:
    unbalanced_training_sets.append( env.unbalance_training_set(train_set=train_set, classes=classes,
                                                                data_distribution=data_dist) )

print("Done importing data.")

Files already downloaded and verified
Files already downloaded and verified
The classes in the training and testing set are ['airplane', 'automobile']
Done importing data.


In [3]:
# This code cell is likely where you will want to do the GAN work on the given datasets.

'''
This is roughly how the data will look. The number will be replaced with actual datasets of the size specified in
DATA_DISTRIBUTION based on the number of nodes you are currently working with.
'''
# unbalanced_training_sets = [[class1=15, class2=15, class3=600], [class1=600, class2=15, class3=15], [class1=15, class2=600, class3=15]]

'''
When leaving this code cell make sure to Concatenate the datasets into the unbalanced_training_sets variable.
You can likely leave the following line as is. Though you may have to change it based on the changes you made.
'''
unbalanced_training_sets = [ConcatDataset(dataset) for dataset in unbalanced_training_sets]

In [4]:
# The global model will be trained on the mix of the unbalanced training sets.
global_train_loader = env.create_single_loader(ConcatDataset(unbalanced_training_sets))
# This code cell is to be used for importing data and setting up the model.
training_loaders, validation_loader, test_loader = env.create_data_loaders(training_sets=unbalanced_training_sets,
                                                                       validation_set=validation_set, test_set=test_set)
# Create and load the models. We initiate the model with None as we will update it with the global model in each round.
fed_models = {f"Federated_Model_{i+1}": ms.FederatedModel(train_loader, validation_loader,
                  ms.ConvNetCifar(NUMBER_OF_CLASSES) if DATASET == env.Dataset.CIFAR_10 else ms.ConvNetMnist(NUMBER_OF_CLASSES))
                for i, train_loader in enumerate(training_loaders)}

# Create the baseline, non-federated model.
baseline_model = ms.ConvNetCifar(NUMBER_OF_CLASSES) if DATASET == env.Dataset.CIFAR_10 else ms.ConvNetMnist(NUMBER_OF_CLASSES)
# Create the federated model
federated_model = ms.ConvNetCifar(NUMBER_OF_CLASSES) if DATASET == env.Dataset.CIFAR_10 else ms.ConvNetMnist(NUMBER_OF_CLASSES)

# Send the models to the CUDA device if it exists.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
baseline_model.to(device=device)
federated_model.to(device=device)

ConvNetCifar(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=2, bias=True)
)

In [5]:
# Here we train a baseline model on all data. No federation as our baseline.
optimizer = optim.Adam(baseline_model.parameters())

# We train a new model, if the model does not already exist in memory.
if not os.path.exists(BASELINE_MODEL_PATH):
    for epoch in range(NUMBER_OF_EPOCHS):
        start_time = time.time()
        train_loss, train_acc = ms.train(baseline_model, global_train_loader, optimizer, criterion, device=device)
        valid_loss, valid_acc = ms.test(baseline_model, validation_loader, criterion, device=device)
        end_time = time.time()
        # Get the time to perform non-federated learning
        epoch_mins, epoch_secs = ms.epoch_time(start_time, end_time)

        print(f'Epoch: {epoch+1:02} | Model name: Baseline Model | Epoch time (Baseline Training): {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.5f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')
    torch.save(baseline_model.state_dict(), BASELINE_MODEL_PATH)
print("Baseline model training complete.\n\n")

Epoch: 01 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 4s
	Train Loss: 0.02711 | Train Acc: 51.60%
	 Val. Loss: 0.00638 |  Val. Acc: 73.80%
Epoch: 02 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 4s
	Train Loss: 0.02233 | Train Acc: 76.00%
	 Val. Loss: 0.00606 |  Val. Acc: 72.33%
Epoch: 03 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 4s
	Train Loss: 0.01998 | Train Acc: 76.80%
	 Val. Loss: 0.00530 |  Val. Acc: 74.87%
Epoch: 04 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 4s
	Train Loss: 0.01833 | Train Acc: 79.60%
	 Val. Loss: 0.00495 |  Val. Acc: 77.27%
Epoch: 05 | Model name: Baseline Model | Epoch time (Baseline Training): 0m 4s
	Train Loss: 0.01716 | Train Acc: 80.00%
	 Val. Loss: 0.00459 |  Val. Acc: 79.07%
Baseline model training complete.




In [6]:
# Here we train our federated model.
best_valid_loss = float('inf')

for epoch in range(NUMBER_OF_EPOCHS):
    # Perform the computation steps on the individual models
    start_time = time.time()
    for key, fed_model in fed_models.items():
        # Update each model with the global model, before training again.
        fed_model.model.load_state_dict(federated_model.state_dict())
        fed_model.model.to(device=device)

        # Begin training
        optimizer = optim.Adam(fed_model.model.parameters())
        train_loss, train_acc = ms.train(fed_model.model, fed_model.train_loader, optimizer, criterion, device=device)
        valid_loss, valid_acc = ms.test(fed_model.model, fed_model.validation_loader, criterion, device=device)
        print(f'Epoch: {epoch+1:02} | Model name: {key}')
        print(f'\tTrain Loss: {train_loss:.5f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')
    end_time = time.time()
    # Get the time to perform federated learning
    epoch_mins, epoch_secs = ms.epoch_time(start_time, end_time)

    # Average the federated models and combine their weights into the main model.
    federated_model.load_state_dict(ms.federated_averaging(fed_models, MU, SIGMA))
    # Validate this model on a, small balanced validation set
    valid_loss, valid_acc = ms.test(federated_model, validation_loader, criterion, device=device)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        # This will save our best model in case we encounter a drop off during training.
        torch.save(federated_model.state_dict(), BEST_FEDERATED_MODEL_PATH)

    print(f'Epoch: {epoch+1:02} | Model name: Federated Average | Epoch time (Federated Training): {epoch_mins}m {epoch_secs}s')
    print(f'\t Val. Loss: {valid_loss:.5f} |  Val. Acc: {valid_acc*100:.2f}%')

print("Federated Model training complete.\n\n")

Epoch: 01 | Model name: Federated_Model_1
	Train Loss: 0.02625 | Train Acc: 73.60%
	 Val. Loss: 0.00753 |  Val. Acc: 49.20%
Epoch: 01 | Model name: Federated_Model_2
	Train Loss: 0.02432 | Train Acc: 76.00%
	 Val. Loss: 0.00733 |  Val. Acc: 50.80%
Epoch: 01 | Model name: Federated Average | Epoch time (Federated Training): 0m 8s
	 Val. Loss: 0.00688 |  Val. Acc: 70.67%
Epoch: 02 | Model name: Federated_Model_1
	Train Loss: 0.02525 | Train Acc: 77.60%
	 Val. Loss: 0.00854 |  Val. Acc: 49.20%
Epoch: 02 | Model name: Federated_Model_2
	Train Loss: 0.02205 | Train Acc: 82.40%
	 Val. Loss: 0.00886 |  Val. Acc: 50.80%
Epoch: 02 | Model name: Federated Average | Epoch time (Federated Training): 0m 8s
	 Val. Loss: 0.00674 |  Val. Acc: 66.40%
Epoch: 03 | Model name: Federated_Model_1
	Train Loss: 0.02498 | Train Acc: 70.40%
	 Val. Loss: 0.00889 |  Val. Acc: 49.20%
Epoch: 03 | Model name: Federated_Model_2
	Train Loss: 0.01896 | Train Acc: 80.80%
	 Val. Loss: 0.01102 |  Val. Acc: 50.80%
Epoch: 0

In [7]:
# The main testing loop
# Load the model
baseline_model.load_state_dict(torch.load(BASELINE_MODEL_PATH))
federated_model.load_state_dict(torch.load(BEST_FEDERATED_MODEL_PATH))

baseline_test_loss, baseline_test_acc = ms.test(baseline_model, test_loader, criterion, device=device)
fed_avg_test_loss, fed_avg_test_acc = ms.test(federated_model, test_loader, criterion, device=device)

print(f'Model name: Baseline | Test Loss: {baseline_test_loss:.3f} | Test Acc: {baseline_test_acc*100:.2f}%')
print(f'Model name: Federated Average | Test Loss: {fed_avg_test_loss:.3f} | Test Acc: {fed_avg_test_acc*100:.2f}%')

Model name: Baseline | Test Loss: 0.004 | Test Acc: 80.95%
Model name: Federated Average | Test Loss: 0.006 | Test Acc: 74.20%
