# Aprendizaje Federado: Modelos Globales

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from TheModel2 import build

In [None]:
# Cargar el dataset para utilizar test set
train, test = tf.keras.datasets.mnist.load_data()

x_train, x_test = np.expand_dims(train[0] / 255.0, -1), np.expand_dims(test[0] / 255.0, -1)
y_train, y_test = train[1], test[1]

In [None]:
## Llamar los modelos entrandos localmente:

import os
loaded_local_models = [tf.keras.models.load_model(os.path.join(root, file)) for root, dirs, files in os.walk("./") for file in files if file.endswith('.keras')]

for i in range(len(loaded_local_models)-1):
    assert loaded_local_models[i].summary() == loaded_local_models[i+1].summary(), "Models have different architectures"

In [None]:
# Pesos de cada modelo local 
local_weights = [x.get_weights() for x in loaded_local_models]

### FedAvg

In [None]:
# Promedio ponderado de los pesos
averaged_weights = [np.mean(np.array(weights), axis=0) for weights in zip(*local_weights)]

# Construcción del modelo local utilizando los pesos promediados
global_model_Favg = build.build_it()
global_model_Favg.set_weights(averaged_weights)

# Predicción para el test set
y_pred = global_model_Favg.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)

# Reporte de classificación
print(classification_report(y_test, y_pred_classes))

global_model_Favg.save('global_model_Favg.keras') # Se guarda el modelo

### FedBN

In [None]:
# Modelo base para aplicar FedBN
global_model_Fbn = build.build_it()

# Insertar promedio de los pesos para las capas necesarias (excluyendo BatchNormalization)
for layer in global_model_Fbn.layers:
    if not isinstance(layer, tf.keras.layers.BatchNormalization):  
        # Pesos de cada modelo para la capa actual
        layer_weights = [model.get_layer(layer.name).get_weights() for model in loaded_local_models]
        
        # Promedio de los pesos
        averaged_weights = [sum(weights) / len(weights) for weights in zip(*layer_weights)]
        
        # Asigna los pesos promediados a la capa en el modelo global
        layer.set_weights(averaged_weights)

# Predicción para el test set
y_pred = global_model_Fbn.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)

# Reporte de classificación
print(classification_report(y_test, y_pred_classes))

global_model_Fbn.save('global_model_Fbn.keras') # Se guarda el modelo