In [None]:
import flwr as fl
import tensorflow as tf
from tensorflow import keras
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def preprocess_data(train_samples_per_class=100, test_samples_per_class=20):
    # Load the MNIST dataset
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # Initialize empty lists to store the selected samples for training and testing
    x_train_subset = []
    y_train_subset = []
    x_test_subset = []
    y_test_subset = []

    # Create a list of unique class labels
    unique_classes = np.unique(y_train)

    # Iterate through each class and select samples for training and testing
    for class_label in unique_classes:
        # Get the indices of samples for the current class
        class_indices = np.where(y_train == class_label)[0]

        # Randomly shuffle the indices to ensure randomness
        np.random.shuffle(class_indices)

        # Select the desired number of samples for training and testing from the current class
        selected_indices_train = class_indices[:train_samples_per_class]
        selected_indices_test = class_indices[train_samples_per_class:train_samples_per_class+test_samples_per_class]

        # Add the selected training samples to the training subset
        x_train_subset.extend(x_train[selected_indices_train])
        y_train_subset.extend(y_train[selected_indices_train])

        # Add the selected testing samples to the testing subset
        x_test_subset.extend(x_train[selected_indices_test])
        y_test_subset.extend(y_train[selected_indices_test])

    # Convert the lists to NumPy arrays
    x_train_subset = np.array(x_train_subset)
    y_train_subset = np.array(y_train_subset)
    x_test_subset = np.array(x_test_subset)
    y_test_subset = np.array(y_test_subset)

    # Preprocess the data by normalizing it
    x_train_subset, x_test_subset = x_train_subset[..., np.newaxis] / 255.0, x_test_subset[..., np.newaxis] / 255.0

    return x_train_subset, y_train_subset, x_test_subset, y_test_subset

x_train_subset, y_train_subset, x_test_subset, y_test_subset = preprocess_data(train_samples_per_class=100, test_samples_per_class=20)


In [None]:
x_train_subset.shape,x_test_subset.shape

In [None]:
# Define the client model
def create_client_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return model

In [None]:
from tensorflow.keras.utils import plot_model
import visualkeras
model=create_client_model()
plot_model(model, to_file='model_architecture.png', show_shapes=True)
visualkeras.layered_view(model)

In [None]:
model.get_weights()

In [None]:
PORT=5010

In [None]:
# Define the client logic
class MnistClient(fl.client.NumPyClient):
    def __init__(self, learning_rate=0.001, batch_size=32):
        self.model = create_client_model()
        self.x_train, self.y_train, self.x_test, self.y_test = preprocess_data()
        self.learning_rate = learning_rate
        self.batch_size = batch_size

    def get_parameters(self,config=None):
        return self.model.get_weights()
    
    
    def fit(self, parameters, config=None):
        optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
        batch_size = self.batch_size
        self.model.set_weights(parameters)
        self.model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
        self.model.fit(self.x_train, self.y_train, batch_size=batch_size, epochs=1, verbose=1)
        return self.model.get_weights(), len(self.x_train), {}


    def evaluate(self, parameters, config=None):
        self.model.set_weights(parameters)
        loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
        return loss, len(self.x_test), {"accuracy": accuracy}

# Start the Flower client
fl.client.start_numpy_client(server_address='localhost:'+str(PORT), client=MnistClient(),grpc_max_message_length = 1024*1024*1024)


In [None]:
cl=MnistClient()
cl.get_parameters()