# CNN with Transfer Learning Training

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import copy
import torch
import pickle
import numpy as np
import torch.nn as nn
import torch.optim as optim
from joblib import Parallel, delayed
from keras.datasets import fashion_mnist

## Hyperparameter Definition

In [None]:
# Hyperparameters Configuration H
epochs = 10
lrs = [0.1, 0.05, 0.01]
n_workers_list = [4, 8, 16, 32, 64, 128]
buffer_len = 10
mini_batch_size = 10

# Number of runs
n_runs = 100

# Number of images per worker
size = 100

## Datasets

In [None]:
# Load the dataset
(_, train_labels), (_, test_labels) = fashion_mnist.load_data()

# Load the precomputed features on the ConvNet
train_features = np.load("res/features/train_features_fp.npy")
test_features = np.load("res/features/test_features_fp.npy")

## Training Algorithms

In [None]:
def tl_train(model, data, target, optimizer):
    model.train()
    optimizer.zero_grad()
    output = model(data)
    loss_function = nn.CrossEntropyLoss()
    loss = loss_function(output, target)
    loss.backward()
    optimizer.step()

    # Save weights of the layers
    weights = []
    for l in model:
        if hasattr(l, "weight"):
            weights.append(np.copy(l.weight.detach().numpy()))
            weights.append(np.copy(l.bias.detach().numpy()))

    return weights

def tl_test(model, test_features, test_labels):
    model.eval()
    loss = 0
    correct = 0
    loss_function = nn.CrossEntropyLoss()
    with torch.no_grad():
        for idx in range(len(test_labels)):
            data = torch.tensor(np.expand_dims(test_features[idx], axis=0))
            target = torch.tensor(np.expand_dims(test_labels[idx], axis=0))
            output = model(data)
            loss += loss_function(output, target).item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    loss /= len(test_labels)
    test_acc = 100. * correct / len(test_labels)
    
    return test_acc

In [None]:
def parallel_tl_train(run, lr, features_batches, n_workers):
    # Net structure
    nets, train_layers = [], []
    net = nn.Sequential(
            nn.Linear(200, 50),
            nn.ReLU(),
            nn.Linear(50, 10)
        )
    for n in range(n_workers):
        nets.append(copy.deepcopy(net))
    
    # Trainable layers
    for l in range(len(net)):
        if hasattr(net[l], "weight"):
            train_layers.append(l)
    
    # Loop over the buffers until all the dataset is used
    test_accs = []
    for mini_batch in range(len(features_batches[0][0])//mini_batch_size):
        # Repeat for the number of epochs 
        for _ in range(epochs):
            # Loop over the mini batches in the buffer
            idx_start = mini_batch * mini_batch_size
            idx_end = idx_start + mini_batch_size

            # Train the network and save the weights for each layer to be averaged
            weights, average_weights = [[] for _ in range(n_workers)], [[] for _ in range(len(train_layers))]
            
            for n in range(n_workers):
                # Extract the correct mini batch data and target from the training dataset
                data, target = torch.tensor(features_batches[n][0][idx_start:idx_end]), torch.tensor(features_batches[n][1][idx_start:idx_end])
                
                # Train params
                optimizer = optim.SGD(nets[n].parameters(), lr=lr, momentum=0.5)
                res = tl_train(nets[n], data, target, optimizer)
                weights[n] = res

        # Average the weights of the layers
        for l in range(0, 2*len(train_layers)-1, 2):
            w_mean = np.mean([weights[n][l] for n in range(n_workers)], axis=0)
            b_mean = np.mean([weights[n][l+1] for n in range(n_workers)], axis=0)
            average_weights[l//2] = [w_mean, b_mean]
        
        # Set the computed average weights to the layers of the networks
        for n in range(n_workers):
            for idx, l in enumerate(train_layers):
                nets[n][l].weight = torch.nn.Parameter(torch.from_numpy(average_weights[idx][0]))
                nets[n][l].bias = torch.nn.Parameter(torch.from_numpy(average_weights[idx][1]))
        
        # Compute the test accuracy
        test_accs.append(tl_test(nets[0], test_features, test_labels))
    return {f"{lr}-{run}": [test_accs, average_weights]}

## Experiments

In [None]:
for n_workers in n_workers_list:

    # Split the train features into different batches
    features_batches = [[] for _ in range(n_workers)]
    for n in range(n_workers):
        features_batches[n] = [train_features[n*size:(n+1)*size], train_labels[n*size:(n+1)*size]]

    # Parallelization of the training procedure using joblib 
    fedavg_res = Parallel(n_jobs=-1)(delayed(parallel_tl_train)(run, lr, features_batches, n_workers) for lr in lrs for run in range(n_runs))
    
    # Reorder and save the results
    fedavg_nets = {}
    for lr_idx in range(len(lrs)):
        accs, W = [], []
        for run in range(n_runs):
            # Extract the results
            index, dict_key = n_runs*lr_idx+run, f"{lrs[lr_idx]}-{run}"
            accs.append(np.divide(fedavg_res[index][dict_key][0], 100))
            W.append(fedavg_res[index][dict_key][1])
        fedavg_nets[lrs[lr_idx]] = [accs, W]

    with open(f"out/fedavg/fmnist_{n_workers}.pkl", "wb") as f:
        pickle.dump(fedavg_nets, f)