## Load Dataset MNIST

In [1]:
import tensorflow as tf
import numpy as np
import os

train, test = tf.keras.datasets.mnist.load_data()

x_train, x_test = np.expand_dims(train[0] / 255.0, -1), np.expand_dims(test[0] / 255.0, -1)
y_train, y_test = train[1], test[1]

In [2]:
print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)

(60000, 28, 28, 1)
(10000, 28, 28, 1)
(60000,)
(10000,)


## Split MNIST Dataset in 5 parts

In [3]:


from sklearn.model_selection import StratifiedKFold

n_team_members=5

skf = StratifiedKFold(n_splits=n_team_members, shuffle=True, random_state=42)

#folds = skf.split(x_train, y_train)

folds = list(skf.split(x_train, y_train)) 

In [5]:

names = ['eder', 'sofia', 'pame', 'lesli', 'fer']

local_data = [(x_train[val_idx], y_train[val_idx]) for _, val_idx in folds]

# Guardar en archivos .npz
for (x_local, y_local), name in zip(local_data, names):
    np.savez(f'mnist_{name}.npz',
             x_train_local=x_local,
             y_train_local=y_local,
             x_test=x_test,
             y_test=y_test)
    
    # Verificar shapes y balance de clases
    print(f"--- {name} ---")
    print("x_train_local:", x_local.shape)
    print("y_train_local:", y_local.shape)
    print("Distribución de clases:", np.bincount(y_local))
    print("x_test:", x_test.shape)
    print("y_test:", y_test.shape)
    print()


--- eder ---
x_train_local: (12000, 28, 28, 1)
y_train_local: (12000,)
Distribución de clases: [1184 1348 1191 1227 1169 1085 1183 1253 1170 1190]
x_test: (10000, 28, 28, 1)
y_test: (10000,)

--- sofia ---
x_train_local: (12000, 28, 28, 1)
y_train_local: (12000,)
Distribución de clases: [1185 1349 1191 1226 1168 1084 1184 1253 1170 1190]
x_test: (10000, 28, 28, 1)
y_test: (10000,)

--- pame ---
x_train_local: (12000, 28, 28, 1)
y_train_local: (12000,)
Distribución de clases: [1185 1349 1192 1226 1168 1084 1184 1253 1170 1189]
x_test: (10000, 28, 28, 1)
y_test: (10000,)

--- lesli ---
x_train_local: (12000, 28, 28, 1)
y_train_local: (12000,)
Distribución de clases: [1185 1348 1192 1226 1168 1084 1184 1253 1170 1190]
x_test: (10000, 28, 28, 1)
y_test: (10000,)

--- fer ---
x_train_local: (12000, 28, 28, 1)
y_train_local: (12000,)
Distribución de clases: [1184 1348 1192 1226 1169 1084 1183 1253 1171 1190]
x_test: (10000, 28, 28, 1)
y_test: (10000,)

