# Description

In this notebook, I will train the Federated Learning system with **ResNet architecture** on the FEMNIST dataset.

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
tf.random.set_seed(42)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:  tf.config.experimental.set_memory_growth(gpus[0], True)
  except RuntimeError as e: print(e)

import gc
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from skimage.transform import resize

from sklearn.model_selection import train_test_split

from config_celeb import *
from utils.read_data_utils import *
from utils.model_utils import *
from utils.pruning_utils import *

In [None]:
# =========================================================
# Dataset Hyper-parameter
DATASET_NAME = 'mnist'  # mnist

IMAGE_DIMENSION = 28
INPUT_SHAPE = (IMAGE_DIMENSION, IMAGE_DIMENSION, 1)

OUPUT_SHAPE = 62 # 


# =========================================================
# Model Hyper-parameter
OPTIMIZER = 'adam'
LOSS = 'categorical_crossentropy'
METRICS = ['accuracy']

LIST_NUMBER_FILTERS = [16, 32, 64]
FILTER_SIZE = 5

MODEL_TYPE = "resnet" # ['vanilla_conv', 'resnet', 'xception']
PATH_GLOBAL_MODEL = os.path.join("models", "global_model_resnet_femnist.h5")


# =========================================================
# Training Hyper-parameter
NUM_ROUNDS = 500
NUM_SELECTED_CLIENT = 10

LOCAL_EPOCHS = 5
LOCAL_BATCH_SIZE = 32

# 1. Dataset

## 1.1. Load dataset

In [None]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits=False)

In [None]:
num_clients = len(emnist_train.client_ids)
print(f"Number of clients: {num_clients}")

list_num_samples = []
for idx_client in range(num_clients):
    num_samples = len(list(emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[idx_client])))
    list_num_samples.append(num_samples)
list_num_samples = np.array(list_num_samples)

print(f"Total number of samples in training set: {list_num_samples.sum()}")
print(f"Average number of samples per client: {list_num_samples.mean()}")

## 1.2. Prepare training dataset

In [None]:
list_clients_data = Create_Clients_Data(emnist_train, DATASET_NAME)
print(f"Number of user: {len(list_clients_data)}")

In [None]:
idx_client = idx_sample = np.random.randint(0, 50)

client_data = list_clients_data[idx_client]

client_name = client_data['client_name']
list_X = client_data['list_X']
list_y = client_data['list_y']

X = list_X[idx_sample]
print(f"Shape of image: {X.shape}")
y = list_y[idx_sample]

print(f"Client name= {client_name}")
print(f"Label = {y}")
plt.imshow(X, cmap='gray')
plt.show();

## 1.3. Prepare val - test dataset

In [None]:
list_data_test = Create_Clients_Data(emnist_test, dataset_name=DATASET_NAME)

X_test = []
y_test = []
for data_test in list_data_test:
    X_test.append(data_test['list_X'])
    y_test.append(data_test['list_y'])
X_test = np.concatenate(X_test)
y_test = np.concatenate(y_test)

X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42)

print(f"Shape of X val: {X_val.shape}")
print(f"Shape of y val: {y_val.shape}")

print(f"Shape of X test: {X_test.shape}")
print(f"Shape of y test: {y_test.shape}")

In [None]:
print(f"Shape of X val: {X_val.shape}")
print(f"Shape of y val: {y_val.shape}")
print(f"Max value of X_val: {X_val[3].max()}")
print(f"Min value of X_val: {X_val[3].min()}")
print()
print(f"Shape of X test: {X_test.shape}")
print(f"Shape of y test: {y_test.shape}")

# 2. FL Training

## 2.1. Define components

In [None]:
def Define_ResNet_Model(input_shape, output_shape, list_number_filters, max_pooling_step=2, model_name=None):
    """
    This function create the simple Residual Network model. 
    """
    inputs = layers.Input(shape=input_shape)
    
    # Initial Convolutional Layer
    x = layers.Conv2D(list_number_filters[0], kernel_size=3, strides=2, padding='same', name=f'prunable_conv_0')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D(pool_size=max_pooling_step, strides=max_pooling_step, padding='same')(x)
    
    # Residual Blocks
    for (idx_residual_block, number_filters) in enumerate(list_number_filters[1:]):
        x = residual_block(x, num_filters_1=number_filters, num_filters_2=number_filters, strides=(2, 2), idx_residual_block=idx_residual_block)
    
    # Final Layers
    # x = layers.GlobalAveragePooling2D()(x)
    x = Flatten()(x)
    x = layers.Dense(output_shape, activation='softmax')(x)
    
    model = tf.keras.Model(inputs=inputs, outputs=x)
    return model

In [None]:
keras.backend.clear_session()
global_model = Define_ResNet_Model(INPUT_SHAPE, OUPUT_SHAPE, LIST_NUMBER_FILTERS, max_pooling_step=2, model_name=None)
global_model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics = METRICS)
print(f"Number of params: {global_model.count_params()}")

num_conv_layers = Count_Conv2d_Layers(global_model)
print(f"Number of Conv2D layer: {num_conv_layers}")
plot_model(global_model, to_file=os.path.join('images', f'ResNet_{num_conv_layers}_{DATASET_NAME}.png'), show_shapes=True, show_layer_names=True);
global_model.save(PATH_GLOBAL_MODEL)

global_model.summary()

## 2.2. FL training

In [None]:
NUM_CLIENTS = num_clients
list_val_acc = []
list_val_loss = []
list_model_params = []


for idx_round in range(NUM_ROUNDS):
    print("\n [INFO] Round {}".format(idx_round))
    global_model = tf.keras.models.load_model(PATH_GLOBAL_MODEL)
    global_model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics = METRICS)
    
    client_model = keras.models.clone_model(global_model)    
    client_model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=METRICS)
    
    selected_clients_data = random.sample(list_clients_data, NUM_SELECTED_CLIENT)

    # Loop through selected client
    list_client_model_weight = []
    list_client_scales = []
    for selectd_client_data in selected_clients_data:      

        # Clone client's weight from global model
        client_model.set_weights(global_model.get_weights())

        client_name = selectd_client_data['client_name']
        list_X = selectd_client_data['list_X']
        list_y = selectd_client_data['list_y']

        client_model.fit(list_X, list_y, epochs=LOCAL_EPOCHS, batch_size=LOCAL_BATCH_SIZE, verbose=0)

        list_client_model_weight.append(client_model.get_weights())    # store local weight for update global model later.
        list_client_scales.append(len(list_X))
    
    # Calculate scale of each client
    list_client_scales = np.array(list_client_scales)
    list_client_scales = list_client_scales / list_client_scales.sum()

    # Update the global model weights
    avg_weights = FedAvg(global_model, list_client_model_weight, list_client_scales)
    global_model.set_weights(avg_weights)

    # Evaluate model on validation data
    if idx_round % 20 == 0:
        val_loss, val_acc = global_model.evaluate(X_val, y_val, verbose=0)
        print(f'Val loss: {val_loss}, Val accuracy: {val_acc}')
        list_val_acc.append(val_acc)
        list_val_loss.append(val_loss)

    global_model.save(PATH_GLOBAL_MODEL)
    selected_clients_data = None
    list_client_model_weight = list_client_scales = None
    tf.keras.backend.clear_session()
    gc.collect()

# 3. Evaluation

In [None]:
# X_test = np.array([resize(image, (IMAGE_DIMENSION, IMAGE_DIMENSION)) for image in X_test])
print(f"Shape of X test: {X_test.shape}")
print(f"Shape of y test: {y_test.shape}")

In [None]:
# Evaluate model on testing data
val_loss, val_acc = global_model.evaluate(X_test, y_test, verbose=0)
print(f'Val loss: {val_loss}, Val accuracy: {val_acc}')