# Description

In this notebook, I will train the CNN model in the FL system. During the training, I will prune the filters of the CNN.

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
tf.random.set_seed(42)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:  tf.config.experimental.set_memory_growth(gpus[0], True)
  except RuntimeError as e: print(e)

import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from skimage.transform import resize

from sklearn.model_selection import train_test_split

from utils.read_data_utils import *
from utils.model_utils import *
from utils.pruning_utils import *
from config_celeb import *

# 1. Dataset

## 1.1. Load dataset

In [None]:
celeb_train, celeb_test = tff.simulation.datasets.celeba.load_data()

num_clients = len(celeb_train.client_ids)
print(f"Number of clients: {num_clients}")

## 1.2. Prepare training dataset

In [None]:
list_clients_data = Create_Clients_Data(celeb_train, DATASET_NAME)
print(f"Number of client: {len(list_clients_data)}")

In [None]:
idx_client = np.random.randint(0, 10)
idx_sample = np.random.randint(0, 10)

client_data = list_clients_data[idx_client]

client_name = client_data['client_name']
list_X = client_data['list_X']
list_y = client_data['list_y']

X = list_X[idx_sample]
print(f"Shape of image: {X.shape}")

y = list_y[idx_sample]

print(f"Client = {client_name}")
print(f"Label = {y}")
plt.imshow(X, cmap='gray')
plt.show()

## 1.3. Prepare val - test dataset

In [None]:
list_data_test = Create_Clients_Data(celeb_test, dataset_name=DATASET_NAME)

X_test = []
y_test = []
for data_test in list_data_test:
    X_test.append(data_test['list_X'])
    y_test.append(data_test['list_y'])
X_test = np.concatenate(X_test)
y_test = np.concatenate(y_test)

X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42)
X_val = np.array([resize(image, (IMAGE_DIMENSION, IMAGE_DIMENSION)) for image in X_val])
X_test = np.array([resize(image, (IMAGE_DIMENSION, IMAGE_DIMENSION)) for image in X_test])

print(f"Shape of X val: {X_val.shape}")
print(f"Shape of y val: {y_val.shape}")

print(f"Shape of X test: {X_test.shape}")
print(f"Shape of y test: {y_test.shape}")

# 2. Training FL

## 2.1. Define components

In [None]:
STD_THRESHOLD_PRUNE = 2.5

In [None]:
keras.backend.clear_session()

# global_model = Define_ResNet_Model(input_shape=INPUT_SHAPE, output_shape=OUPUT_SHAPE, list_number_filters=LIST_NUMBER_FILTERS, model_name="global_model")

global_model = Get_Model(MODEL_TYPE, INPUT_SHAPE, OUPUT_SHAPE, LIST_NUMBER_FILTERS, model_name="global_model")

global_model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics = METRICS)

print(f"Number of params: {global_model.count_params()}")
plot_model(global_model, to_file=os.path.join('images', f'model_architecture_{DATASET_NAME}.png'), show_shapes=True, show_layer_names=True);

In [None]:
client_model = keras.models.clone_model(global_model)    
client_model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=METRICS)

## 2.2. FL training

In [None]:
def prune_model(model, optimizer, loss_func, metrics, std_threshold=2.0):
    global IS_STILL_PRUNE
    global PRUNE_PATIENCE
    before_prune_params = model.count_params()

    list_number_filters = []
    for layer in model.layers:
        if isinstance(layer, Conv2D) and layer.name != 'classifier' and 'prunable_conv' in layer.name:
            weights = layer.get_weights()[0]
            pruned_filter = Apply_Pruning_Filter(weights, std_threshold)
            pruned_number_filter = pruned_filter.shape[-1]

            if pruned_number_filter <= 0:
                pruned_number_filter = 1
            list_number_filters.append(pruned_number_filter)

    # new_model = Define_ResNet_Model(input_shape=model.input_shape[1:], output_shape=model.output_shape[1], list_number_filters=list_number_filters)
    new_model = Get_Model(MODEL_TYPE, input_shape=model.input_shape[1:], output_shape=model.output_shape[1], list_number_filters=list_number_filters)
    new_model_params = new_model.count_params()

    if before_prune_params > new_model_params:
        PRUNE_PATIENCE = 0
        print(f"--- [INFO] This round PRUNE filter ---")
        new_model.compile(optimizer=optimizer, loss=loss_func, metrics=metrics)
        return new_model
    else:
        PRUNE_PATIENCE += 1
        print(f"--- [INFO] This round NOT prune filter ---")
        if PRUNE_PATIENCE >= MAX_PRUNE_PATIENCE:
            IS_STILL_PRUNE = False
            print(f"===== [INFO] Stop prune here! =====")
            print(f"Final params: {before_prune_params}")
        return model

In [None]:
NUM_CLIENTS = num_clients
list_val_acc = []
list_val_loss = []
list_model_params = []


for idx_round in range(NUM_ROUNDS):
    print("\n [INFO] Round {}".format(idx_round))

    if (idx_round > MAX_PRUNED_ROUND) and (IS_STILL_PRUNE == True):
        IS_STILL_PRUNE = False
        print(f"===== [INFO] Stop prune here! =====")
        print(f"Final params: {global_model.count_params()}")

    if (0 < idx_round) and (IS_STILL_PRUNE == True):  # Perform pruning
        global_model = prune_model(global_model, optimizer=OPTIMIZER, loss_func=LOSS, metrics=METRICS, std_threshold=STD_THRESHOLD_PRUNE)
        client_model = keras.models.clone_model(global_model)    
        client_model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=METRICS)

    # Select random subset of clients
    num_selected_clients = int(NUM_CLIENTS * SELECTED_PERCENT_CLIENT)
    selected_clients_data = random.sample(list_clients_data, num_selected_clients)

    # Loop through selected client
    list_client_model_weight = []
    list_client_scales = []
    for selectd_client_data in selected_clients_data:      

        # Clone client's weight from global model
        client_model.set_weights(global_model.get_weights())

        client_name = selectd_client_data['client_name']
        list_X = selectd_client_data['list_X']
        list_y = selectd_client_data['list_y']
        list_X = np.array([resize(image, (IMAGE_DIMENSION, IMAGE_DIMENSION)) for image in list_X])  # Resize input image shape

        client_model.fit(list_X, list_y, epochs=LOCAL_EPOCHS, batch_size=LOCAL_BATCH_SIZE, verbose=0)

        list_client_model_weight.append(client_model.get_weights())    # store local weight for update global model later.
        list_client_scales.append(len(list_X))
    
    # Calculate scale of each client
    list_client_scales = np.array(list_client_scales)
    list_client_scales = list_client_scales / list_client_scales.sum()

    # Update the global model weights
    avg_weights = FedAvg(global_model, list_client_model_weight, list_client_scales)
    global_model.set_weights(avg_weights)

    # Evaluate model on validation data
    val_loss, val_acc = global_model.evaluate(X_val, y_val, verbose=0)
    print(f'Val loss: {val_loss}, Val accuracy: {val_acc}')
    list_val_acc.append(val_acc)
    list_val_loss.append(val_loss)

# 3. Evaluation

In [None]:
# X_test = np.array([resize(image, (IMAGE_DIMENSION, IMAGE_DIMENSION)) for image in X_test])
print(f"Shape of X test: {X_test.shape}")
print(f"Shape of y test: {y_test.shape}")

In [None]:
# Evaluate model on testing data
val_loss, val_acc = global_model.evaluate(X_test, y_test, verbose=0)
print(f'Val loss: {val_loss}, Val accuracy: {val_acc}')