### Env building and Importing libiraries

In [17]:
# ! pip install -q tensorflow numpy matplotlib seaborn scikit-learn
! pip install -q tensorflow-model-optimization

  pid, fd = os.forkpty()


In [37]:
# !rm -rf /kaggle/working/*
# !zip -r file.zip /kaggle/working
# !rm /kaggle/working/file.zip

  pid, fd = os.forkpty()


In [19]:
# import tensorflow as tf
# from tensorflow.keras.models import Sequential
# from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# from tensorflow.keras.datasets import mnist
# from sklearn.metrics import classification_report, confusion_matrix
# from sklearn.model_selection import train_test_split
# from tensorflow.keras.optimizers import Adam
# from tensorflow import keras
# from tensorflow.keras import layers
# import tensorflow_model_optimization as tfmot
# from tensorflow_model_optimization.python.core.keras.compat import keras

import time
import numpy as np# Import the 'models' module from Keras
import pandas as pd
import tensorflow as tf
# from tensorflow.keras import models
# from tensorflow.keras import layers
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow_model_optimization.python.core.keras.compat import keras

VERBOSE = 2
EPOCHS = 5

### Helper functions

In [4]:
def compile_model(model, optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']):
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    return model

def create_model(precision: str = 'float32'):
    if precision == 'float32':
        precision = tf.float32
        model = keras.Sequential([
            keras.layers.InputLayer(input_shape=(28, 28, 1)),
            keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
            keras.layers.MaxPooling2D(pool_size=(2, 2)),
            keras.layers.Flatten(),
            keras.layers.Dense(64, activation='relu'),
            keras.layers.Dense(10, activation='softmax')
        ])
    elif precision == 'float16':
        precision = tf.float16
        model = keras.Sequential([
            keras.layers.InputLayer(input_shape=(28, 28)),
            keras.layers.Reshape(target_shape=(28, 28, 1)),
            keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu', dtype=precision),
            keras.layers.MaxPooling2D(pool_size=(2, 2)),
            keras.layers.Flatten(),
            keras.layers.Dense(64, dtype=precision),
            keras.layers.Dense(10, dtype=precision)
        ])
    else:
        raise ValueError("Unsupported precision type (use 'float32', 'float16').")

    return compile_model(model)

In [5]:
def save_tflite_model(tflite_model, model_path = None, VERBOSE:int = 0):
    """
    model_path : Given tflite model(string.tflite), this function will save the tflite model on the model path
    """
    if model_path is not None:
        with open(model_path, 'wb') as f:
            f.write(tflite_model)
            if VERBOSE : print(f"Model saved as {model_path}")
    else:
        raise ValueError("Model path must be provided.")

In [6]:
def custom_evaluate(interpreter, x_test, y_test, quantization_type:str = 'DEFAULT'):
    if quantization_type == 'int8':
        x_test = x_test.astype(np.int8)
    elif quantization_type == 'uint8':
        x_test = x_test.astype(np.uint8)
    elif quantization_type == 'int16':
        x_test = x_test.astype(np.int16)
    elif quantization_type == 'float16':
        x_test = x_test.astype(np.float16)
    elif quantization_type == 'DEFAULT' or quantization_type == 'float32':
        x_test = x_test.astype(np.float32)
    else:
        raise ValueError("Unsupported quantization type.")

    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on every image in the "test" dataset.
    y_pred = []
    for i, test_image in enumerate(x_test):
        # Pre-processing: add batch dimension and convert to float32
        # to match with the model's input data format.
        test_image = np.expand_dims(test_image, axis=0)
        # test_image = test_image.astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        y_pred.append(digit)

    # Compare prediction results with ground truth labels to calculate accuracy.
    y_pred = np.array(y_pred)
    accuracy = (y_pred == y_test).mean()
    return accuracy , y_pred

In [7]:
def test_tflite(tflite_model = None, model_path = None, x_test = None, y_test = None, VERBOSE:int = 0):
    if tflite_model is None and model_path is None:
        raise ValueError("Either tf_model or model_path must be provided.")
    if tflite_model is not None and model_path is not None:
        raise ValueError("Only one of tf_model or model_path must be provided.")
    if x_test is None or y_test is None:
        raise ValueError("x_test and y_test must be provided for testing.")

    if tflite_model is not None:
        tflite_model = tflite_model
    if model_path is not None:
        tflite_model = tf.keras.models.load_model(model_path)

    tflite_interpreter = tf.lite.Interpreter(model_content = tflite_model)
    tflite_interpreter.allocate_tensors()
    test_accuracy, y_pred = custom_evaluate(tflite_interpreter, x_test, y_test)

    if VERBOSE : print(f'TFLite test accuracy: {100*test_accuracy:.7f}%\n')

    return test_accuracy, y_pred, y_test

In [8]:
def tf_to_tflite(model, model_path = None, quantization_type:str = 'DEFAULT', test = False, x_test = None, y_test = None, VERBOSE:int = 0):
    """
    model_path : Given tflite model(string.tflite), this function will save the tflite model on the model path
    test       : if test is true; give x_test and y_test and the function returns tflite_model, test_acc, y_pred, y_test
    """

    # Convert the model to TF Lite format.
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    if quantization_type == 'float16':
        converter.target_spec.supported_types = [tf.float16]
    tflite_model = converter.convert()

    if model_path != None: save_tflite_model(tflite_model, model_path = model_path, VERBOSE = VERBOSE)

    # return tflite_model, (test_acc, y_pred, y_test)
    if test:
        return tflite_model, test_tflite(tflite_model = tflite_model, x_test = x_test, y_test = y_test, VERBOSE = VERBOSE)
    # returns only tflite model
    else: return tflite_model

In [9]:
def get_tflite_model_precision(tflite_model):
    # Load the TFLite model from path or model object
    if isinstance(tflite_model, str):
        interpreter = tf.lite.Interpreter(model_path=tflite_model)
    elif isinstance(tflite_model, bytes):
        interpreter = tf.lite.Interpreter(model_content=tflite_model)
    else:
        interpreter = tflite_model

    interpreter.allocate_tensors()

    # Get model details
    tensor_details = interpreter.get_tensor_details()
    operator_details = interpreter._get_ops_details()

    # Dictionary to map tensor indices to their names and data types
    tensor_info = {tensor['index']: {'name': tensor['name'], 'dtype': tensor['dtype']} for tensor in tensor_details}

    # Collect precision information
    precision_info = []

    for op in operator_details:
        op_name = op['op_name']
        input_tensors = [tensor_info[tensor_idx] for tensor_idx in op['inputs'] if tensor_idx != -1]
        output_tensors = [tensor_info[tensor_idx] for tensor_idx in op['outputs'] if tensor_idx != -1]
        precision_info.append({
            'op_name': op_name,
            'inputs': input_tensors,
            'outputs': output_tensors
        })

    return precision_info

def print_precision_info(precision_info):
    for op_info in precision_info:
        print(f"Operator: {op_info['op_name']}")
        print("  Inputs:")
        for tensor in op_info['inputs']:
            print(f"    - Name: {tensor['name']}, DataType: {tensor['dtype']}")
        print("  Outputs:")
        for tensor in op_info['outputs']:
            print(f"    - Name: {tensor['name']}, DataType: {tensor['dtype']}\n")

In [10]:
def check_quantization_type(model = None, model_path = None):
    # Load the model
    if model_path is not None:
        model = tf.keras.models.load_model(model_path)
    if model is None and model_path is None:
        raise ValueError("Either model or model_path must be provided.")

    # Get the model's input and output tensors
    input_tensor = model.input
    output_tensor = model.output

    # Iterate through the layers and check their weights
    for layer in model.layers:
        weights = layer.get_weights()
        for weight in weights:
            weight_dtype = weight.dtype
            if weight_dtype == tf.float32:
                print(f"Layer {layer.name} has float32 weights.")
            elif weight_dtype == tf.float16:
                print(f"Layer {layer.name} has float16 weights.")
            elif weight_dtype == tf.int8:
                print(f"Layer {layer.name} has int8 weights.")
            else:
                print(f"Layer {layer.name} has weights of type {weight_dtype}.")

    return True

# trdML
Using traditional ML i.e. with whole dataset to a single to be trained on MNIST dataset

#### Loading MNIST

In [13]:
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# # Normalize the input image so that each pixel value is between 0 to 1.
# x_train, x_test = x_train.astype(np.float64), x_test.astype(np.float64)
x_train = x_train / 255.0
x_test = x_test / 255.0

# # Reshape the data to include channel dimension
# # This is important for testing on tflite so channel is required
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

print(f'Training data shape: {x_train.shape},  {type(x_train[0][0][0])}')
print(f'Test     data shape: {x_test.shape},  {type(x_test[0][0][0])}')

print(f'Label type(trn,tst): {type(y_train[0])}, {type(y_test[0])}')

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Training data shape: (60000, 28, 28, 1),  <class 'numpy.ndarray'>
Test     data shape: (10000, 28, 28, 1),  <class 'numpy.ndarray'>
Label type(trn,tst): <class 'numpy.uint8'>, <class 'numpy.uint8'>


#### trd Without quantization

In [28]:
EPOCHS, VERBOSE = 3, 1

raw_model = create_model()
raw_model = compile_model(raw_model)

if(VERBOSE > 1): raw_model.summary()
if(VERBOSE > 1): check_quantization_type(raw_model)

# Train the no-quant model
stime = time.time()
raw_history = raw_model.fit(x_train[:], y_train[:], epochs=EPOCHS, batch_size=256, validation_split=0.2, verbose = VERBOSE)
_, raw_test_acc = raw_model.evaluate(x_test, y_test, verbose=2)
etime = time.time()

trdTrainTime = etime - stime
print(f"For training model for {EPOCHS} epochs had taken is {(etime-stime):.4f}s in  Traditional setup on MNIST")
print(f"Testing Inference on trained model; accuracy {100*raw_test_acc:.4f}% in  Traditional setup on MNIST")

Epoch 1/3
Epoch 2/3
Epoch 3/3
313/313 - 1s - loss: 0.0674 - accuracy: 0.9789 - 539ms/epoch - 2ms/step
For training model for 3 epochs had taken is 5.4185s in  Traditional setup on MNIST
Testing Inference on trained model; accuracy 97.8900% in  Traditional setup on MNIST


##### Convert to tflite and test tflit model

In [29]:
# # raw_model convert this tf model to tflite
raw_tflite, report = tf_to_tflite(raw_model, model_path = "raw_model.tflite",
                                quantization_type = 'DEFAULT',
                                test = True, x_test = x_test, y_test = y_test,
                                VERBOSE = 2)

if VERBOSE > 1: print_precision_info(get_tflite_model_precision(raw_tflite))

Summary on the non-converted ops:
---------------------------------
 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 6, Total Ops 16, % non-converted = 37.50 %
 * 6 ARITH ops

- arith.constant:    6 occurrences  (f32: 5, i32: 1)



  (f32: 1)
  (f32: 2)
  (f32: 1)
  (uq_8: 1)
  (f32: 1)
  (f32: 1)


Model saved as raw_model.tflite
TFLite test accuracy: 97.8900000%



### Train then quantization

With and without finetuning for one epoch and 10000 datapoints

In [30]:
import tensorflow_model_optimization as tfmot
q_aware_model = tfmot.quantization.keras.quantize_model(raw_model)
q_aware_model = compile_model(q_aware_model)


if VERBOSE > 1: q_aware_model.summary()
if VERBOSE > 1: check_quantization_type(q_aware_model)
_, test_acc_qaware = q_aware_model.evaluate(x_test, y_test, verbose=2)
print(f"Testing Inference on quant aware trained model; accuracy {100*test_acc_qaware:.4f}% in  Traditional setup on MNIST. \n",
        "Its sanity is argued, so it with grain of salt.\n")
    

    
    
FINETUNE = False
train_quant_tflite, report = tf_to_tflite(q_aware_model,
                                      model_path = f'train_quant_ftune{FINETUNE}.tflite',
                                      quantization_type = 'float16',
                                      test = True, x_test = x_test, y_test = y_test,
                                      VERBOSE = 2)
    
    
    
FINETUNE = True
if FINETUNE:
    dataset_size_for_fine_tuning = 10
    train_images_subset, train_labels_subset = x_train[0:dataset_size_for_fine_tuning], y_train[0:dataset_size_for_fine_tuning]
    q_aware_model.fit(train_images_subset, train_labels_subset, batch_size=500, epochs=1, validation_split=0.1)
print()
train_quant_finetuned_tflite, report = tf_to_tflite(q_aware_model,
                                                    model_path = f'train_quant_ftune{FINETUNE}x{dataset_size_for_fine_tuning/1000}k.tflite',
                                                    quantization_type = 'float16',
                                                    test = True, x_test = x_test, y_test = y_test,
                                                    VERBOSE = 2)


    
FINETUNE = True
if FINETUNE:
    dataset_size_for_fine_tuning = 10000
    train_images_subset, train_labels_subset = x_train[0:dataset_size_for_fine_tuning], y_train[0:dataset_size_for_fine_tuning]
    q_aware_model.fit(train_images_subset, train_labels_subset, batch_size=500, epochs=1, validation_split=0.1)
print()
train_quant_finetuned_tflite, report = tf_to_tflite(q_aware_model,
                                                    model_path = f'train_quant_ftune{FINETUNE}x{dataset_size_for_fine_tuning/1000}k.tflite',
                                                    quantization_type = 'float16',
                                                    test = True, x_test = x_test, y_test = y_test,
                                                    VERBOSE = 2)

# print_precision_info(get_tflite_model_precision(train_quant_finetuned_tflite)) 

313/313 - 2s - loss: 3.4472 - accuracy: 0.1135 - 2s/epoch - 6ms/step
Testing Inference on quant aware trained model; accuracy 11.3500% in  Traditional setup on MNIST. 
 Its sanity is argued, so it with grain of salt.



Summary on the non-converted ops:
---------------------------------
 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 1, Total Ops 18, % non-converted = 5.56 %
 * 1 ARITH ops

- arith.constant:    1 occurrences  (i32: 1)



  (uq_8: 1)
  (f32: 1)
  (uq_8: 2)
  (uq_8: 1)
  (uq_8: 3, uq_32: 3)
  (uq_8: 1)
  (uq_8: 1)
  (uq_8: 1)


Model saved as train_quant_ftuneFalse.tflite
TFLite test accuracy: 11.3500000%




Summary on the non-converted ops:
---------------------------------
 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 1, Total Ops 18, % non-converted = 5.56 %
 * 1 ARITH ops

- arith.constant:    1 occurrences  (i32: 1)



  (uq_8: 1)
  (f32: 1)
  (uq_8: 2)
  (uq_8: 1)
  (uq_8: 3, uq_32: 3)
  (uq_8: 1)
  (uq_8: 1)
  (uq_8: 1)


Model saved as train_quant_ftuneTruex0.01k.tflite
TFLite test accuracy: 95.5800000%




Summary on the non-converted ops:
---------------------------------
 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 1, Total Ops 18, % non-converted = 5.56 %
 * 1 ARITH ops

- arith.constant:    1 occurrences  (i32: 1)



  (uq_8: 1)
  (f32: 1)
  (uq_8: 2)
  (uq_8: 1)
  (uq_8: 3, uq_32: 3)
  (uq_8: 1)
  (uq_8: 1)
  (uq_8: 1)


Model saved as train_quant_ftuneTruex10.0k.tflite
TFLite test accuracy: 97.6600000%



### Quantization then train

In [31]:
EPOCHS, VERBOSE = 3, 1
PRECISION = 'float16'

# On TPUs and CPUs, use 'mixed_bfloat16' instead
from tensorflow.keras import mixed_precision
if PRECISION == 'float16': mixed_precision.set_global_policy('mixed_float16')

 # Define the model
quant_train_model = create_model(precision = PRECISION)
quant_train_model = create_model()
quant_train_model = compile_model(quant_train_model)

q_aware_model = quant_train_model
import tensorflow_model_optimization as tfmot
q_aware_model = tfmot.quantization.keras.quantize_model(quant_train_model)
q_aware_model = compile_model(q_aware_model)

if(VERBOSE > 1): q_aware_model.summary()
if(VERBOSE > 1): check_quantization_type(q_aware_model)

# Train the quant model
stime = time.time()
raw_history = q_aware_model.fit(x_train[:], y_train[:], epochs=EPOCHS, batch_size=128, validation_split=0.2, verbose = VERBOSE)
_, quant_train_test_acc = q_aware_model.evaluate(x_test, y_test, verbose=2)
etime = time.time()

trdTrainTime = etime - stime
print(f"For training model for {EPOCHS} epochs had taken is {(etime-stime):.4f}s in  Traditional setup on MNIST")
print(f"Testing Inference on trained model; accuracy {100*quant_train_test_acc:.4f}% in  Traditional setup on MNIST")

# tf.keras.mixed_precision.set_global_policy(None)

Epoch 1/3
Epoch 2/3
Epoch 3/3
313/313 - 1s - loss: 0.0570 - accuracy: 0.9811 - 694ms/epoch - 2ms/step
For training model for 3 epochs had taken is 9.7164s in  Traditional setup on MNIST
Testing Inference on trained model; accuracy 98.1100% in  Traditional setup on MNIST


In [32]:
# if(VERBOSE > 1): q_aware_model.summary()
# if(VERBOSE > 1): check_quantization_type(q_aware_model)

FINETUNE = False
quant_train_tflite, report = tf_to_tflite(model = q_aware_model,
                                      model_path = f'quant_train_ftune{FINETUNE}.tflite',
                                      quantization_type = 'float16',
                                      test = True, x_test = x_test, y_test = y_test,
                                      VERBOSE = 2)

FINETUNE = True
if FINETUNE:
    dataset_size_for_fine_tuning = 10000
    train_images_subset, train_labels_subset = x_train[0:dataset_size_for_fine_tuning], y_train[0:dataset_size_for_fine_tuning]
    q_aware_model.fit(train_images_subset, train_labels_subset, batch_size=512, epochs=1, validation_split=0.1)

quant_train_finetuned_tflite, report = tf_to_tflite(model = q_aware_model,
                                      model_path = f'quant_train_ftune{FINETUNE}x{dataset_size_for_fine_tuning/1000}k.tflite',
                                      quantization_type = 'float16',
                                      test = True, x_test = x_test, y_test = y_test,
                                      VERBOSE = 2)

Summary on the non-converted ops:
---------------------------------
 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 1, Total Ops 18, % non-converted = 5.56 %
 * 1 ARITH ops

- arith.constant:    1 occurrences  (i32: 1)



  (uq_8: 1)
  (f32: 1)
  (uq_8: 2)
  (uq_8: 1)
  (uq_8: 3, uq_32: 3)
  (uq_8: 1)
  (uq_8: 1)
  (uq_8: 1)


Model saved as quant_train_ftuneFalse.tflite
TFLite test accuracy: 98.1100000%



Summary on the non-converted ops:
---------------------------------
 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 1, Total Ops 18, % non-converted = 5.56 %
 * 1 ARITH ops

- arith.constant:    1 occurrences  (i32: 1)



  (uq_8: 1)
  (f32: 1)
  (uq_8: 2)
  (uq_8: 1)
  (uq_8: 3, uq_32: 3)
  (uq_8: 1)
  (uq_8: 1)
  (uq_8: 1)


Model saved as quant_train_ftuneTruex10.0k.tflite
TFLite test accuracy: 98.3300000%



##### Checking Quantization of tflite model

In [35]:
tflite_model_paths = ['raw_model.tflite','train_quant_ftuneFalse.tflite', f'train_quant_ftuneTruex{dataset_size_for_fine_tuning/1000}k.tflite',
                      'quant_train_ftuneFalse.tflite', f'quant_train_ftuneTruex{dataset_size_for_fine_tuning/1000}k.tflite']

import os
for tflite_model_path in tflite_model_paths[:]:
    print(f'\n\n{tflite_model_path}')
    print_precision_info(get_tflite_model_precision(tflite_model_path))
    
    # Get the size of the model file
    model_size_in_bytes = os.path.getsize(tflite_model_path)
    model_size_in_mb = model_size_in_bytes / float(2**20)

    print(f"Model = {tflite_model_path} size in MB: {model_size_in_mb:.2f}")
    
# import tempfile
# import os

# # Create float TFLite model.
# float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
# float_tflite_model = float_converter.convert()

# # Measure sizes of models.
# _, float_file = tempfile.mkstemp('.tflite')
# _, quant_file = tempfile.mkstemp('.tflite')

# with open(quant_file, 'wb') as f:
#   f.write(quantized_tflite_model)

# with open(float_file, 'wb') as f:
#   f.write(float_tflite_model)

# print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
# print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))



raw_model.tflite
Model = raw_model.tflite size in MB: 0.34


train_quant_ftuneFalse.tflite
Model = train_quant_ftuneFalse.tflite size in MB: 0.34


train_quant_ftuneTruex10.0k.tflite
Model = train_quant_ftuneTruex10.0k.tflite size in MB: 0.34


quant_train_ftuneFalse.tflite
Model = quant_train_ftuneFalse.tflite size in MB: 0.34


quant_train_ftuneTruex10.0k.tflite
Model = quant_train_ftuneTruex10.0k.tflite size in MB: 0.34


### Quant Ablation

Baseline test accuracy: 98.170% <br>
Quant without finetune test accuracy: 98.390% <br>
Quant with    finetune test accuracy: 98.310% 		[on 1000 datapoints] <br>
Quant with    finetune TFLite test accuracy: 98.310% <br>

In [None]:
# import warnings
# warnings.filterwarnings("ignore")

# _, baseline_model_accuracy = model.evaluate(
#     x_test, y_test, verbose=0)

# fine_tune_datapoints = 10000
# quantization_types = ['DEFAULT', 'float32', 'uint8', 'int8',] # 'int16'
# for quantization_type in quantization_types[1:2]:
#     print(quantization_type)
#     # Quantize aware model; Quantize the model;;
#     q_aware_model = apply_quantization_to_model(model)
#     _, q_aware_model_without_finetune_accuracy = q_aware_model.evaluate(
#        x_test, y_test, verbose=0)

#     x_train_quant, y_train_quant = x_train[:fine_tune_datapoints].astype(np.float32), y_train[0:fine_tune_datapoints]
#     q_aware_history = q_aware_model.fit(x_train_quant, y_train_quant, epochs=1,
#                                         batch_size=64,
#                                         validation_split=0.2,
#                                         verbose = 0)

#     _, q_aware_model_accuracy = q_aware_model.evaluate(
#         x_test, y_test, verbose=0)

#     # Example: Convert to different quantization types
#     quantization_type = quantization_type
#     quantized_tflite_model = convert_to_tflite(model = q_aware_model, dataset=x_train, quantization_type = quantization_type)
#     quantized_tflite_interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
#     quantized_tflite_interpreter.allocate_tensors()
#     quantized_tflite_test_accuracy = custom_evaluate(quantized_tflite_interpreter, x_test, y_test,quantization_type)

#     print(f'Baseline test accuracy: {100*baseline_model_accuracy:.7f}%')
#     print(f'{quantization_type} Quant without finetune test accuracy: {100*q_aware_model_without_finetune_accuracy:.7f}%')
#     print(f'{quantization_type} Quant with    finetune test accuracy: {100*q_aware_model_accuracy:.7f}% \t\t[on {fine_tune_datapoints} datapoints]')
#     print(f'{quantization_type} Quant with    finetune TFLite test accuracy: {100*quantized_tflite_test_accuracy:.7f}%')

## Pruning

### Aware

In [14]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot
import time

# Define constants
EPOCHS = 1
VERBOSE = 2

# Define the model architecture
model_noQuant  = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28, 1)),
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
# Apply pruning
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
    initial_sparsity=0.0,
    final_sparsity=0.5,
    begin_step=0,
    end_step=np.ceil(x_train.shape[0] / 64).astype(np.int32) * EPOCHS
)

pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model_noQuant, pruning_schedule=pruning_schedule)

# Compile the model
pruned_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

pruned_model.summary()

# Define callbacks for pruning
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir='./pruning_logs')
]

# Train the pruned model
stime = time.time()
history_pruned = pruned_model.fit(x_train[:], y_train[:], epochs=EPOCHS, batch_size=64, validation_split=0.2, verbose=VERBOSE, callbacks=callbacks)
_, test_acc_pruned = pruned_model.evaluate(x_test, y_test, verbose=2)
etime = time.time()

prunedTrainTime = etime - stime
print(f"For training pruned model for {EPOCHS} epochs had taken is {(etime-stime):.4f}s on MNIST")
print(f"Testing Inference on pruned model; accuracy {100*test_acc_pruned:.4f}% on MNIST")


Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 26, 26, 32)        610       
 _2 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 32)        1         
 oling2d_2 (PruneLowMagnitu                                      
 de)                                                             
                                                                 
 prune_low_magnitude_flatte  (None, 5408)              1         
 n_2 (PruneLowMagnitude)                                         
                                                                 
 prune_low_magnitude_dense_  (None, 64)                692290    
 4 (PruneLowMagnitude)                                           
                                                      

KeyboardInterrupt: 

In [22]:
import time
import numpy as np# Import the 'models' module from Keras
import pandas as pd
import tensorflow as tf
from tensorflow import keras
try:
    import tensorflow_model_optimization as tfmot
except:
    !pip -q install tensorflow-model-optimization
    import tensorflow_model_optimization as tfmot

import time

# Define constants
EPOCHS = 3
VERBOSE = 2

# Apply pruning
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
    initial_sparsity=0.5,
    final_sparsity=0.8,
    begin_step=0,
    end_step=np.ceil(x_train.shape[0] / 64).astype(np.int32) * EPOCHS
)
# Define callbacks for pruning
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir='./pruning_logs')
]


raw_model = create_model()
raw_model = compile_model(raw_model)
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(raw_model, pruning_schedule=pruning_schedule)
pruned_model = compile_model(pruned_model)

# if VERBOSE >= 2: raw_model.summary()
# if VERBOSE >= 2: pruned_model.summary()

# Train baseline
stime = time.time()
raw_history = raw_model.fit(x_train[:], y_train[:], epochs=EPOCHS, batch_size=64, validation_split=0.2, verbose=VERBOSE)
etime = time.time()
print(f"For training raw baseline model for {EPOCHS} epochs had taken is {(etime-stime):.4f}s on MNIST")


# Train the pruned model
stime = time.time()
pruned_history = pruned_model.fit(x_train[:], y_train[:], epochs=EPOCHS, batch_size=64, validation_split=0.2, verbose=VERBOSE, callbacks=callbacks)
etime = time.time()
print(f"For training prune baseline model for {EPOCHS} epochs had taken is {(etime-stime):.4f}s on MNIST")



_, raw_test_acc = raw_model.evaluate(x_test, y_test, verbose=2)
_, pruned_test_acc = pruned_model.evaluate(x_test, y_test, verbose=2)
print(f"Testing Inference on raw model; accuracy {100*raw_test_acc:.4f}% on MNIST")
print(f"Testing Inference on pruned model; accuracy {100*pruned_test_acc:.4f}% on MNIST")

Epoch 1/3
750/750 - 4s - loss: 0.2342 - accuracy: 0.9338 - val_loss: 0.0953 - val_accuracy: 0.9725 - 4s/epoch - 6ms/step
Epoch 2/3
750/750 - 2s - loss: 0.0772 - accuracy: 0.9772 - val_loss: 0.0766 - val_accuracy: 0.9768 - 2s/epoch - 3ms/step
Epoch 3/3
750/750 - 2s - loss: 0.0528 - accuracy: 0.9837 - val_loss: 0.0645 - val_accuracy: 0.9811 - 2s/epoch - 3ms/step
For training raw baseline model for 3 epochs had taken is 8.7334s on MNIST
Epoch 1/3
750/750 - 6s - loss: 0.0471 - accuracy: 0.9861 - val_loss: 0.0641 - val_accuracy: 0.9813 - 6s/epoch - 8ms/step
Epoch 2/3
750/750 - 3s - loss: 0.0412 - accuracy: 0.9887 - val_loss: 0.0964 - val_accuracy: 0.9711 - 3s/epoch - 4ms/step
Epoch 3/3
750/750 - 3s - loss: 0.0349 - accuracy: 0.9904 - val_loss: 0.0564 - val_accuracy: 0.9849 - 3s/epoch - 4ms/step
For training prune baseline model for 3 epochs had taken is 12.4445s on MNIST
313/313 - 1s - loss: 0.0512 - accuracy: 0.9844 - 571ms/epoch - 2ms/step
313/313 - 1s - loss: 0.0512 - accuracy: 0.9844 - 

In [27]:
# import tensorflow as tf
# from tensorflow import keras
# import tensorflow_model_optimization as tfmot
# import time

# Define constants
EPOCHS = 3
VERBOSE = 2
SAVE_MODEL = 1

# Load or define the dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0

# Define the original model architecture
model_noQuant = create_model()
model_noQuant = compile_model(model_noQuant)

# Train the original model
stime = time.time()
history_False = model_noQuant.fit(x_train, y_train, epochs=EPOCHS, batch_size=64, validation_split=0.2, verbose=VERBOSE)
etime = time.time()

_, test_acc_False = model_noQuant.evaluate(x_test, y_test, verbose=VERBOSE)
print(f"Training model for {EPOCHS} epochs took {(etime-stime):.4f}s in the traditional setup on MNIST")
print(f"Testing inference on trained model; accuracy {100*test_acc_False:.4f}% in the traditional setup on MNIST")



# Post-training pruning
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.5, begin_step=0, frequency=1)
}
pruned_model = prune_low_magnitude(model_noQuant, **pruning_params)
pruned_model.compile(optimizer='adam',
                     loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                     metrics=['accuracy'])




# Reporting without finetuning
_, test_acc_without_fine = pruned_model.evaluate(x_test, y_test, verbose=VERBOSE)
print(f"Testing inference on trained then pruned model; accuracy {100*test_acc_without_fine:.4f}% in the traditional setup on MNIST")






# Fine-tune the pruned model
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir='./pruning_logs')
]

stime = time.time()
history_pruned = pruned_model.fit(x_train, y_train, epochs=1, batch_size=64, validation_split=0.2, verbose=VERBOSE, callbacks=callbacks)
etime = time.time()
_, test_acc_pruned = pruned_model.evaluate(x_test, y_test, verbose=VERBOSE)
print(f"Fine-tuning pruned model for {EPOCHS} epochs took {(etime-stime):.4f}s on MNIST")
print(f"Testing inference on pruned model with finetuning ; accuracy {100*test_acc_pruned:.4f}% on MNIST")

# Strip pruning wrappers from the model
final_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

if SAVE_MODEL:
    final_model.save('pruned_model.h5')
    model_noQuant.save('model_noQuant.h5')
    print("Pruned model saved to 'pruned_model.h5'")


Epoch 1/3
750/750 - 5s - loss: 0.2452 - accuracy: 0.9296 - val_loss: 0.0981 - val_accuracy: 0.9722 - 5s/epoch - 7ms/step
Epoch 2/3
750/750 - 2s - loss: 0.0784 - accuracy: 0.9767 - val_loss: 0.0729 - val_accuracy: 0.9790 - 2s/epoch - 3ms/step
Epoch 3/3
750/750 - 2s - loss: 0.0528 - accuracy: 0.9847 - val_loss: 0.0582 - val_accuracy: 0.9834 - 2s/epoch - 3ms/step
313/313 - 1s - loss: 0.0487 - accuracy: 0.9845 - 705ms/epoch - 2ms/step
Training model for 3 epochs took 10.0762s in the traditional setup on MNIST
Testing inference on trained model; accuracy 98.4500% in the traditional setup on MNIST
313/313 - 2s - loss: 0.0487 - accuracy: 0.9845 - 2s/epoch - 5ms/step
Testing inference on trained then pruned model; accuracy 98.4500% in the traditional setup on MNIST
750/750 - 6s - loss: 0.0421 - accuracy: 0.9878 - val_loss: 0.0629 - val_accuracy: 0.9806 - 6s/epoch - 8ms/step
313/313 - 1s - loss: 0.0560 - accuracy: 0.9823 - 571ms/epoch - 2ms/step
Fine-tuning pruned model for 3 epochs took 6.9222

  saving_api.save_model(


In [None]:
history = history_pruned
model = pruned_model

## Visualizations

In [None]:
# Plotting in a single figure with subplots
fig, ax = plt.subplots(1, 2, figsize=(14, 6))

# Plot training & validation loss values
ax[0].plot(history.history['loss'], label='Train Loss', marker="+")
ax[0].plot(history.history['val_loss'], label='Validation Loss')
ax[0].set_title('Trd Model Loss')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].legend(loc='upper right')

# Plot training & validation accuracy values
ax[1].plot(history.history['accuracy'], label='Train Accuracy', marker="+")
ax[1].plot(history.history['val_accuracy'], label='Validation Accuracy')
ax[1].set_title('Trd Model Accuracy')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Accuracy')
ax[1].legend(loc='lower right')

plt.tight_layout()
plt.show()

In [None]:
from sklearn.metrics import classification_report, confusion_matrix

# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)

# Predict the labels
y_pred = model.predict(x_test)
y_pred_classes = y_pred.argmax(axis=1)

# Print the classification report
print(f'Test accuracy: {test_acc}')
print(classification_report(y_test, y_pred_classes))

print('Confusion Matrix:')
print(confusion_matrix(y_test, y_pred_classes))

# fedML

In [None]:
# Number of workers
N_WORKERS = 4

# Federated learning parameters
EPOCHS = 3
SAVE_MODEL = 0
EPOCHS_WITHIN = 1

VERBOSE = 2 # more the VERBOSE more the things about model is exposed
VISUALIZE_WEIGHT_AFTER = 2

In [None]:
from sklearn.model_selection import train_test_split

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

train_subset_length = 60000             # For demonstration purposes use less
x_train = x_train[:train_subset_length]
y_train = y_train[:train_subset_length]

# Normalize the data
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# # Reshape the data to include channel dimension
# x_train = x_train[..., tf.newaxis]
# x_test = x_test[..., tf.newaxis]

# Split the data into training and validation sets
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=1/6, random_state=42)

# Split the training data among workers
x_train_splits = np.array_split(x_train, N_WORKERS)
y_train_splits = np.array_split(y_train, N_WORKERS)

# Print the shapes to verify
for i in range(N_WORKERS):
    print(f'Worker {i+1} - Training data shape: {x_train_splits[i].shape}')
print('--------------------------------------------------')
print(f'Total Training data shape: {x_train.shape}')
print(f'Validation data shape: {x_val.shape}')
print(f'Test data shape: {x_test.shape}')

### Quantized training with pruning flows

In [None]:
# # Initialize models for each worker
# models = [create_model() for _ in range(N_WORKERS)]
# pruned_models = [prune_model(model) for model in models]

# # Initialize metrics
# worker_train_losses = [[] for _ in range(N_WORKERS)]
# worker_val_losses = [[] for _ in range(N_WORKERS)]
# worker_train_accuracies = [[] for _ in range(N_WORKERS)]
# worker_val_accuracies = [[] for _ in range(N_WORKERS)]
# worker_weights = [[] for _ in range(N_WORKERS)]

# # 0.0 Training the models (fedML (none))
# fedTrainTime = 0
# for epoch in range(EPOCHS):
#     fedTrainTime1 = 0
#     worker_histories = []
#     print(f'Epoch {epoch+1}/{EPOCHS}')

#     # Train each worker's model for {EPOCHS_WITHIN} epochs
#     for i in range(N_WORKERS):
#         stime = time.time()
#         history = models[i].fit(x_train_splits[i], y_train_splits[i], epochs=EPOCHS_WITHIN, batch_size=32, validation_split=0.1, verbose=VERBOSE)
#         etime = time.time()
#         fedTrainTime1 += etime - stime

#         if VERBOSE: print(f"Worker {i+1} trained in {etime - stime:.8f}s")
#         if VERBOSE > 2: print(f'{i}th model check: {check_quantization_type(models[i])}')
#         if VERBOSE > 1: print(f'Precision of {i}th worker model: {models[i].get_weights()[0].dtype}')
#         if VERBOSE > 3: print(f'Weights of {i}th worker model are: \n {models[i].get_weights()}')

#         worker_histories.append(history.history)
#         worker_train_losses[i].append(history.history['loss'][0])
#         worker_val_losses[i].append(history.history['val_loss'][0])
#         worker_train_accuracies[i].append(history.history['accuracy'][0])
#         worker_val_accuracies[i].append(history.history['val_accuracy'][0])

#     # Collect and average the weights
#     stime = time.time()
#     new_weights = [model.get_weights() for model in models]
#     avg_weights = [np.mean([new_weights[j][k] for j in range(N_WORKERS)], axis=0) for k in range(len(new_weights[0]))]
#     for model in models:
#         model.set_weights(avg_weights)
#     etime = time.time()

#     fedTrainTime += (fedTrainTime1 / N_WORKERS) + etime - stime

#     if VERBOSE > 1: print(f"Worker weights updated in {etime - stime:.8f}s")
#     if VERBOSE > 2: print(f'Precision of avg model: {avg_weights[0].dtype}')
#     if VERBOSE > 3: print(f'Weights of avg model are: \n {avg_weights}')

#     # Visualize weights after every VISUALIZE_WEIGHT_AFTER epochs
#     if (epoch + 1) % VISUALIZE_WEIGHT_AFTER == 0:
#         for i, model in enumerate(models):
#             worker_weights[i].append(model.get_weights())

#     # Print losses and accuracies for the epoch
#     epoch_train_loss = np.mean([history['loss'][0] for history in worker_histories])
#     epoch_val_loss = np.mean([history['val_loss'][0] for history in worker_histories])
#     epoch_train_acc = np.mean([history['accuracy'][0] for history in worker_histories])
#     epoch_val_acc = np.mean([history['val_accuracy'][0] for history in worker_histories])
#     print(f'Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}')
#     print(f'Train Accuracy: {epoch_train_acc:.4f}, Val Accuracy: {epoch_val_acc:.4f}')

# print("# 0.0 Training the models (fedML (none))")
# raw_fedML_model = global_model_and_test(models)  # Calculate the global model

# print(f'0.1 Post-training pruning without fine-tuning')
# post_pruned_models = [prune_model(model, ptype="ConstantSparsity") for model in models]
# postTrain_pruning_fedML_model = global_model_and_test(post_pruned_models)

# EPOCHS_finetune = 1
# fine_tune_callbacks = [
#     tfmot.sparsity.keras.UpdatePruningStep(),
#     tfmot.sparsity.keras.PruningSummaries(log_dir='./pruning_logs')
# ]
# fine_tuned_pruned_models = post_pruned_models

# # Fine-tune the pruned raw models
# fedTrainTime = 0
# for epoch in range(EPOCHS_finetune):
#     fedTrainTime1 = 0
#     worker_histories = []
#     print(f'Epoch {epoch+1}/{EPOCHS_finetune}')

#     for i in range(N_WORKERS):
#         stime = time.time()
#         history = fine_tuned_pruned_models[i].fit(x_train_splits[i], y_train_splits[i], epochs=EPOCHS_WITHIN, batch_size=32, validation_split=0.1, verbose=VERBOSE, callbacks=fine_tune_callbacks)
#         etime = time.time()
#         fedTrainTime1 += etime - stime

#         if VERBOSE: print(f"Worker {i+1} trained in {etime - stime:.8f}s")
#         if VERBOSE > 2: print(f'{i}th model check: {check_quantization_type(fine_tuned_pruned_models[i])}')
#         if VERBOSE > 1: print(f'Precision of {i}th worker model: {fine_tuned_pruned_models[i].get_weights()[0].dtype}')
#         if VERBOSE > 3: print(f'Weights of {i}th worker model are: \n {fine_tuned_pruned_models[i].get_weights()}')

#         worker_histories.append(history.history)
#         worker_train_losses[i].append(history.history['loss'][0])
#         worker_val_losses[i].append(history.history['val_loss'][0])
#         worker_train_accuracies[i].append(history.history['accuracy'][0])
#         worker_val_accuracies[i].append(history.history['val_accuracy'][0])

#     # Collect and average the weights
#     stime = time.time()
#     new_weights = [model.get_weights() for model in fine_tuned_pruned_models]
#     avg_weights = [np.mean([new_weights[j][k] for j in range(N_WORKERS)], axis=0) for k in range(len(new_weights[0]))]
#     for model in fine_tuned_pruned_models:
#         model.set_weights(avg_weights)
#     etime = time.time()

#     fedTrainTime += (fedTrainTime1 / N_WORKERS) + etime - stime

#     if VERBOSE > 1: print(f"Worker weights updated in {etime - stime:.8f}s")
#     if VERBOSE > 2: print(f'Precision of avg model: {avg_weights[0].dtype}')
#     if VERBOSE > 3: print(f'Weights of avg model are: \n {avg_weights}')

#     # Visualize weights after every VISUALIZE_WEIGHT_AFTER epochs
#     if (epoch + 1) % VISUALIZE_WEIGHT_AFTER == 0:
#         for i, model in enumerate(fine_tuned_pruned_models):
#             worker_weights[i].append(model.get_weights())

#     # Print losses and accuracies for the epoch
#     epoch_train_loss = np.mean([history['loss'][0] for history in worker_histories])
#     epoch_val_loss = np.mean([history['val_loss'][0] for history in worker_histories])
#     epoch_train_acc = np.mean([history['accuracy'][0] for history in worker_histories])
#     epoch_val_acc = np.mean([history['val_accuracy'][0] for history in worker_histories])
#     print(f'Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}')
#     print(f'Train Accuracy: {epoch_train_acc:.4f}, Val Accuracy: {epoch_val_acc:.4f}')

# # Evaluate fine-tuned pruned models
# print("# 0.2 Post-training pruning with fine-tuning")
# fineTuned_postTrain_pruning_fedML_model = global_model_and_test(post_pruned_models)

# # Training the models with pruning-aware training
# fedTrainTime = 0
# for epoch in range(EPOCHS):
#     fedTrainTime1 = 0
#     worker_histories = []
#     print(f'Epoch {epoch+1}/{EPOCHS}')

#     for i in range(N_WORKERS):
#         stime = time.time()
#         history = pruned_models[i].fit(x_train_splits[i], y_train_splits[i], epochs=EPOCHS_WITHIN, batch_size=32, validation_split=0.1, verbose=VERBOSE, callbacks=[tfmot.sparsity.keras.UpdatePruningStep()])
#         etime = time.time()
#         fedTrainTime1 += etime - stime

#         if VERBOSE: print(f"Worker {i+1} trained in {etime - stime:.8f}s")
#         if VERBOSE > 2: print(f'{i}th model check: {check_quantization_type(pruned_models[i])}')
#         if VERBOSE > 1: print(f'Precision of {i}th worker model: {pruned_models[i].get_weights()[0].dtype}')
#         if VERBOSE > 3: print(f'Weights of {i}th worker model are: \n {pruned_models[i].get_weights()}')

#         worker_histories.append(history.history)
#         worker_train_losses[i].append(history.history['loss'][0])
#         worker_val_losses[i].append(history.history['val_loss'][0])
#         worker_train_accuracies[i].append(history.history['accuracy'][0])
#         worker_val_accuracies[i].append(history.history['val_accuracy'][0])

#     # Collect and average the weights
#     stime = time.time()
#     new_weights = [model.get_weights() for model in pruned_models]
#     avg_weights = [np.mean([new_weights[j][k] for j in range(N_WORKERS)], axis=0) for k in range(len(new_weights[0]))]
#     for model in pruned_models:
#         model.set_weights(avg_weights)
#     etime = time.time()

#     fedTrainTime += (fedTrainTime1 / N_WORKERS) + etime - stime

#     if VERBOSE > 1: print(f"Worker weights updated in {etime - stime:.8f}s")
#     if VERBOSE > 2: print(f'Precision of avg model: {avg_weights[0].dtype}')
#     if VERBOSE > 3: print(f'Weights of avg model are: \n {avg_weights}')

#     # Visualize weights after every VISUALIZE_WEIGHT_AFTER epochs
#     if (epoch + 1) % VISUALIZE_WEIGHT_AFTER == 0:
#         for i, model in enumerate(pruned_models):
#             worker_weights[i].append(model.get_weights())

#     # Print losses and accuracies for the epoch
#     epoch_train_loss = np.mean([history['loss'][0] for history in worker_histories])
#     epoch_val_loss = np.mean([history['val_loss'][0] for history in worker_histories])
#     epoch_train_acc = np.mean([history['accuracy'][0] for history in worker_histories])
#     epoch_val_acc = np.mean([history['val_accuracy'][0] for history in worker_histories])
#     print(f'Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}')
#     print(f'Train Accuracy: {epoch_train_acc:.4f}, Val Accuracy: {epoch_val_acc:.4f}')

# # Calculate the global pruneTraining_fedML_model
# print("# 1.0 Training the models with pruning aware training")
# pruneTraining_fedML_model = global_model_and_test(pruned_models)

# pruned_fine_tune_callbacks = [
#     tfmot.sparsity.keras.UpdatePruningStep(),
#     tfmot.sparsity.keras.PruningSummaries(log_dir='./pruning_logs')
# ]
# fine_tuned_prune_aware_models = pruned_models

# # Fine-tune the pruned models
# fedTrainTime = 0
# for epoch in range(EPOCHS_finetune):
#     fedTrainTime1 = 0
#     worker_histories = []
#     print(f'Epoch {epoch+1}/{EPOCHS_finetune}')

#     for i in range(N_WORKERS):
#         stime = time.time()
#         history = fine_tuned_prune_aware_models[i].fit(x_train_splits[i], y_train_splits[i], epochs=EPOCHS_WITHIN, batch_size=32, validation_split=0.1, verbose=VERBOSE, callbacks=pruned_fine_tune_callbacks)
#         etime = time.time()
#         fedTrainTime1 += etime - stime

#         if VERBOSE: print(f"Worker {i+1} trained in {etime - stime:.8f}s")
#         if VERBOSE > 2: print(f'{i}th model check: {check_quantization_type(fine_tuned_prune_aware_models[i])}')
#         if VERBOSE > 1: print(f'Precision of {i}th worker model: {fine_tuned_prune_aware_models[i].get_weights()[0].dtype}')
#         if VERBOSE > 3: print(f'Weights of {i}th worker model are: \n {fine_tuned_prune_aware_models[i].get_weights()}')

#         worker_histories.append(history.history)
#         worker_train_losses[i].append(history.history['loss'][0])
#         worker_val_losses[i].append(history.history['val_loss'][0])
#         worker_train_accuracies[i].append(history.history['accuracy'][0])
#         worker_val_accuracies[i].append(history.history['val_accuracy'][0])

#     # Collect and average the weights
#     stime = time.time()
#     new_weights = [model.get_weights() for model in fine_tuned_prune_aware_models]
#     avg_weights = [np.mean([new_weights[j][k] for j in range(N_WORKERS)], axis=0) for k in range(len(new_weights[0]))]
#     for model in fine_tuned_prune_aware_models:
#         model.set_weights(avg_weights)
#     etime = time.time()

#     fedTrainTime += (fedTrainTime1 / N_WORKERS) + etime - stime

#     if VERBOSE > 1: print(f"Worker weights updated in {etime - stime:.8f}s")
#     if VERBOSE > 2: print(f'Precision of avg model: {avg_weights[0].dtype}')
#     if VERBOSE > 3: print(f'Weights of avg model are: \n {avg_weights}')

#     # Visualize weights after every VISUALIZE_WEIGHT_AFTER epochs
#     if (epoch + 1) % VISUALIZE_WEIGHT_AFTER == 0:
#         for i, model in enumerate(fine_tuned_prune_aware_models):
#             worker_weights[i].append(model.get_weights())

#     # Print losses and accuracies for the epoch
#     epoch_train_loss = np.mean([history['loss'][0] for history in worker_histories])
#     epoch_val_loss = np.mean([history['val_loss'][0] for history in worker_histories])
#     epoch_train_acc = np.mean([history['accuracy'][0] for history in worker_histories])
#     epoch_val_acc = np.mean([history['val_accuracy'][0] for history in worker_histories])
#     print(f'Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}')
#     print(f'Train Accuracy: {epoch_train_acc:.4f}, Val Accuracy: {epoch_val_acc:.4f}')

# # Evaluate fine-tuned pruned models
# print("# 1.2 Training the models with pruning aware training with finetuning")
# fineTuned_postTrain_pruning_fedML_model = global_model_and_test(fine_tuned_prune_aware_models)

# # Restore the original policy
# tf.keras.mixed_precision.set_global_policy(None)


### gg

In [None]:
import warnings
warnings.filterwarnings("ignore")

import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam
import tensorflow_model_optimization as tfmot

# Enable mixed precision policy if needed
enable_mixed_precision_policy = True
if enable_mixed_precision_policy:
    policy = tf.keras.mixed_precision.Policy('mixed_float16')
    tf.keras.mixed_precision.set_global_policy(policy)


def compile_model(model):
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model
def create_model(precision: str = 'float32'):
    if precision == 'float32':
        precision = tf.float32
        model = keras.Sequential([
            keras.layers.InputLayer(input_shape=(28, 28, 1)),
            keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
            keras.layers.MaxPooling2D(pool_size=(2, 2)),
            keras.layers.Flatten(),
            keras.layers.Dense(64, activation='relu'),
            keras.layers.Dense(10, activation='softmax')
        ])
    elif precision == 'float16':
        precision = tf.float16
        model = models.Sequential([
            layers.InputLayer(input_shape=(28, 28)),
            layers.Reshape(target_shape=(28, 28, 1)),
            layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu', dtype=precision),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dense(64, dtype=precision),
            layers.Dense(10, dtype=precision)
        ])
    else:
        pass

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

def prune_model(model, ptype:str = "PolynomialDecay",
                initial_sparsity = 0.0, final_sparsity = 0.5,
                begin_step = 0, end_step = -1):
    if ptype == "PolynomialDecay":
        # Define a pruning schedule using PolynomialDecay
        pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.0,  # initial_sparsity: float, the initial level of sparsity (0.0 means no pruning at the start).
            final_sparsity=0.5,    # final_sparsity: float, the target level of sparsity (0.5 means 50% of weights will be pruned).
            begin_step=0,          # begin_step: int, the step at which pruning begins (0 means pruning starts at the beginning).
            end_step=np.ceil(x_train.shape[0] / 64).astype(np.int32) * EPOCHS  # end_step: int, the step at which pruning ends, calculated based on total training steps.
            # Step at which to end pruning. -1 by default. -1 implies continuing to prune till the end of training.
        )

        # Apply pruning to the model using the defined pruning schedule
        model = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)

    elif ptype == "ConstantSparsity":
        # Define a pruning schedule using ConstantSparsity
        pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
            target_sparsity=0.5,    # target_sparsity: float, the desired level of sparsity in the model's weights.
                                    # 0.5 means that 50% of the weights will be pruned (set to zero).
                                    # For example, if set to 0.8, 80% of the weights will be pruned.

            begin_step=0,           # begin_step: int, the step at which pruning begins.
                                    # 0 means pruning starts from the very beginning of training.
                                    # For example, if set to 1000, pruning will start after 1000 training steps.

            frequency=1             # frequency: int, the frequency (in number of steps) with which pruning is applied.
                                    # 100 means the pruning function will be applied every 100 steps during training.
                                    # This allows for gradual pruning, helping the model adapt to the reduced weights.
        )
        }
        model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

    else:
        pass

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

def global_model_and_test(models):
    new_weights = [model.get_weights() for model in models]
    avg_weights = [np.mean([new_weights[j][k] for j in range(N_WORKERS)], axis=0) for k in range(len(new_weights[0]))]
    # for model in models:
    #     model.set_weights(avg_weights)

    model = models[0]
    model.compile(optimizer='adam',
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    model.set_weights(avg_weights)

    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
    print(f'Test accuracy: {test_acc}')

    return model


# # 0.0 Training then pruning    (Post - Train Pruning)
# # 0.1 without fine-tuning     :: post-training with pruning without fine-tuning
# # 0.2 with fine-tuning        :: post-training with pruning with fine-tuning

# # 1. Pruning aware training  (Pruning aware Training)
# # 1.0 without fine-tuning     :: pre-training with pruning without fine-tuning
### 1.2 with fine-tuning        :: pre-training with pruning with    fine-tuning


# Initialize models for each worker
models = [create_model() for _ in range(N_WORKERS)]
pruned_models = [prune_model(model) for model in models]

# 0.0 Training the models (fedML (none))
fedTrainTime = 0
for epoch in range(EPOCHS):
    fedTrainTime1 = 0
    worker_histories = []
    print(f'Epoch {epoch+1}/{EPOCHS}')

    for i in range(N_WORKERS):
        stime = time.time()
        history = models[i].fit(x_train_splits[i], y_train_splits[i], epochs=EPOCHS_WITHIN, batch_size=32, validation_split=0.1, verbose=VERBOSE)
        etime = time.time()
        fedTrainTime1 += etime - stime

        if VERBOSE: print(f"Worker {i+1} trained in {fedTrainTime1:.8f}s")
        worker_histories.append(history.history)

    # Collect and average the weights
    stime = time.time()
    new_weights = [model.get_weights() for model in models]
    avg_weights = [np.mean([new_weights[j][k] for j in range(N_WORKERS)], axis=0) for k in range(len(new_weights[0]))]
    for model in models:
        model.set_weights(avg_weights)
    etime = time.time()

    fedTrainTime += (fedTrainTime1 / N_WORKERS) + etime - stime

    if VERBOSE: print(f"Worker weights updated in {fedTrainTime:.8f}s")

print("# 0.0 Training the models (fedML (none))")
raw_fedML_model = global_model_and_test(models)        # :: Calculate the global model
history = models[i].fit(x_train_splits[i][:8], y_train_splits[i][:8], epochs=EPOCHS_WITHIN, batch_size=8, validation_split=0.0, verbose=VERBOSE)
# print(f"For training model with {N_WORKERS} workers for {EPOCHS} epochs, time taken is {etime-stime:.4f}s")


print(f'0.1 Post-training pruning without fine-tuning')
post_pruned_models = [prune_model(model,ptype= "ConstantSparsity") for model in models]
postTrain_pruning_fedML_model = global_model_and_test(post_pruned_models)


EPOCHS_finetune = 1
fine_tune_callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir='./pruning_logs')
    ]
fine_tuned_pruned_models = post_pruned_models # [prune_model(model, ptype = "ConstantSparsity") for model in models]

# Fine-tune the pruned raw-models
fedTrainTime = 0
for epoch in range(EPOCHS_finetune):
    fedTrainTime1 = 0
    worker_histories = []
    print(f'Epoch {epoch+1}/{EPOCHS}')

    for i in range(N_WORKERS):
        stime = time.time()
        history = fine_tuned_pruned_models[i].fit(x_train_splits[i], y_train_splits[i], epochs=EPOCHS_WITHIN, batch_size=32, validation_split=0.1, verbose=VERBOSE, callbacks=fine_tune_callbacks)
        etime = time.time()
        fedTrainTime1 += etime - stime

        if VERBOSE: print(f"Worker {i+1} trained in {fedTrainTime1:.8f}s")
        worker_histories.append(history.history)

    # Collect and average the weights
    stime = time.time()
    new_weights = [model.get_weights() for model in fine_tuned_pruned_models]
    avg_weights = [np.mean([new_weights[j][k] for j in range(N_WORKERS)], axis=0) for k in range(len(new_weights[0]))]
    for model in fine_tuned_pruned_models:
        model.set_weights(avg_weights)
    etime = time.time()

    fedTrainTime += (fedTrainTime1 / N_WORKERS) + etime - stime

    if VERBOSE: print(f"Worker weights updated in {fedTrainTime:.8f}s")

# Evaluate fine-tuned pruned models
print("# 0.2 Post-training pruning with fine-tuning")
fineTuned_postTrain_pruning_fedML_model = global_model_and_test(post_pruned_models)





# Training the models with pruning aware training
fedTrainTime = 0
for epoch in range(EPOCHS):
    fedTrainTime1 = 0
    worker_histories = []
    print(f'Epoch {epoch+1}/{EPOCHS}')

    for i in range(N_WORKERS):
        stime = time.time()
        history = pruned_models[i].fit(x_train_splits[i], y_train_splits[i], epochs=EPOCHS_WITHIN, batch_size=32, validation_split=0.1, verbose=VERBOSE,  callbacks=[tfmot.sparsity.keras.UpdatePruningStep()])
        etime = time.time()
        fedTrainTime1 += etime - stime

        if VERBOSE: print(f"Worker {i+1} trained in {fedTrainTime1:.8f}s")
        worker_histories.append(history.history)

    # Collect and average the weights
    stime = time.time()
    new_weights = [model.get_weights() for model in pruned_models]
    avg_weights = [np.mean([new_weights[j][k] for j in range(N_WORKERS)], axis=0) for k in range(len(new_weights[0]))]
    for model in pruned_models:
        model.set_weights(avg_weights)
    etime = time.time()

    fedTrainTime += (fedTrainTime1 / N_WORKERS) + etime - stime

    if VERBOSE: print(f"Worker weights updated in {fedTrainTime:.8f}s")

# Calculate the global pruneTraining_fedML_model
print("# 1.0 Training the models with pruning aware training")
pruneTraining_fedML_model = global_model_and_test(pruned_models)
# print(f"For training model with {N_WORKERS} workers for {EPOCHS} epochs, time taken is {etime-stime:.4f}s")


pruned_fine_tune_callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir='./pruning_logs')
    ]
fine_tuned_prune_aware_models = pruned_models # [prune_model(model, ptype = "ConstantSparsity") for model in models]

# Fine-tune the pruned raw-models
fedTrainTime = 0
for epoch in range(EPOCHS_finetune):
    fedTrainTime1 = 0
    worker_histories = []
    print(f'Epoch {epoch+1}/{EPOCHS}')

    for i in range(N_WORKERS):
        stime = time.time()
        history = fine_tuned_prune_aware_models[i].fit(x_train_splits[i], y_train_splits[i], epochs=EPOCHS_WITHIN, batch_size=32, validation_split=0.1, verbose=VERBOSE, callbacks=pruned_fine_tune_callbacks)
        etime = time.time()
        fedTrainTime1 += etime - stime

        if VERBOSE: print(f"Worker {i+1} trained in {fedTrainTime1:.8f}s")
        worker_histories.append(history.history)

    # Collect and average the weights
    stime = time.time()
    new_weights = [model.get_weights() for model in fine_tuned_prune_aware_models]
    avg_weights = [np.mean([new_weights[j][k] for j in range(N_WORKERS)], axis=0) for k in range(len(new_weights[0]))]
    for model in fine_tuned_prune_aware_models:
        model.set_weights(avg_weights)
    etime = time.time()

    fedTrainTime += (fedTrainTime1 / N_WORKERS) + etime - stime

    if VERBOSE: print(f"Worker weights updated in {fedTrainTime:.8f}s")

# Evaluate fine-tuned pruned models
print("# 1.2 Training the models with pruning aware training with finetuning")
fineTuned_postTrain_pruning_fedML_model = global_model_and_test(fine_tuned_prune_aware_models)






# Restore the original policy
tf.keras.mixed_precision.set_global_policy(None)

In [None]:
def compare_model_weights(model1, model2):
    layers1 = model1.layers
    layers2 = model2.layers

    for layer1, layer2 in zip(layers1, layers2):
        weights1 = layer1.get_weights()
        weights2 = layer2.get_weights()
        print(f"Layer {layer1.name} / {layer2.name}")
        for i, (w1, w2) in enumerate(zip(weights1, weights2)):
            are_weights_equal = np.array_equal(w1, w2)
            print(f"  Weight {i + 1}:")
            print(f"    Are weights equal? {are_weights_equal}")

# Compare weights of the two models
compare_model_weights(models[-1], pruned_models[-1])

# [ 0.          0.          0.         -0.30790305  0.19578068
#      0.23647237  0.19294487  0.          0.14917755  0.
#      0.         -0.48902744  0.1838266  -0.46426666  0.
#      0.          0.1348818   0.01700406  0.         -0.19746181
#      0.          0.23833129  0.1361009  -0.18042842  0.2017729
#      0.          0.         -0.42712042  0.3622789   0.
#     -0.23073876  0.        ]

# [ 0.          0.          0.         -0.30790305  0.19578068
#      0.23647237  0.19294487  0.          0.14917755  0.
#      0.         -0.48902744  0.1838266  -0.46426666  0.
#      0.          0.1348818   0.01700406  0.         -0.19746181
#      0.          0.23833129  0.1361009  -0.18042842  0.2017729
#      0.          0.         -0.42712042  0.3622789   0.
#     -0.23073876  0.        ]

In [None]:
import warnings
warnings.filterwarnings("ignore")

import time
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam
from tensorflow_model_optimization.python.core.keras.compat import keras

enable_mixed_precision_policy = False
if enable_mixed_precision_policy:
    policy = tf.keras.mixed_precision.Policy('mixed_float16')
    tf.keras.mixed_precision.set_global_policy(policy)

def create_model(precision:str = 'float32'):
    if precision == 'float16':
        model = models.Sequential([
            layers.InputLayer(input_shape=(28, 28)),
            layers.Reshape(target_shape=(28, 28, 1)),
            layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu', dtype = 'float16'),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dense(64, dtype='float16'),  # Ensure Dense layer uses float32 for variables
            layers.Dense(10, dtype='float16')   # Ensure Dense layer uses float32 for variables
        ])

    else:
        model = models.Sequential([
            layers.InputLayer(input_shape=(28, 28)),
            layers.Reshape(target_shape=(28, 28, 1)),
            layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dense(64),  # Ensure Dense layer uses float32 for variables
            layers.Dense(10)   # Ensure Dense layer uses float32 for variables
        ])

    model.compile(optimizer=Adam(),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    return model

# Initialize lists to store weights, training and validation losses and accuracies for each worker
worker_weights, worker_train_losses, worker_val_losses, worker_train_accuracies, worker_val_accuracies = ([[] for _ in range(N_WORKERS)] for _ in range(5))


# Initialize and quantize models for each worker
models = [create_model() for _ in range(N_WORKERS)]
print(models[0].summary())

fedTrainTime = 0
for epoch in range(EPOCHS):
    fedTrainTime1 = 0
    worker_histories = []
    print(f'Epoch {epoch+1}/{EPOCHS}')

    # Train each worker's model for {EPOCHS_WITHIN} epoch
    for i in range(N_WORKERS):
        stime = time.time()
        history = models[i].fit(x_train_splits[i], y_train_splits[i], epochs=EPOCHS_WITHIN, batch_size=32, validation_split=0.1, verbose=VERBOSE)
        etime = time.time()
        fedTrainTime1 += etime - stime

        if (VERBOSE) : print(f"Worker {i+1} trained in {fedTrainTime1:.8f}s")
        if (VERBOSE > 2) : print(f'{i}th model check : {check_quantization_type(models[i])}')
        if (VERBOSE > 1) : print(f'Presicion of {i} th worker model {models[i].get_weights()[0].dtype}')
        if (VERBOSE > 3) : print(f'Weights of {i} th worker model are :: \n {models[i].get_weights()}')

        worker_histories.append(history.history)
        worker_train_losses[i].append(history.history['loss'][0])
        worker_val_losses[i].append(history.history['val_loss'][0])
        worker_train_accuracies[i].append(history.history['accuracy'][0])
        worker_val_accuracies[i].append(history.history['val_accuracy'][0])

    # Collect and average the weights
    stime = time.time()
    new_weights = [model.get_weights() for model in models]
    avg_weights = [np.mean([new_weights[j][k] for j in range(N_WORKERS)], axis=0) for k in range(len(new_weights[0]))]
    for model in models:
        model.set_weights(avg_weights)
    etime = time.time()

    fedTrainTime += (fedTrainTime1 / N_WORKERS) + etime - stime

    if VERBOSE > 1 : print(f"Worker weights updated in {fedTrainTime:.8f}s")
    if VERBOSE > 2 : print(f'Presicion of avg model {avg_weights[0].dtype}')
    if VERBOSE > 3 : print(f'Weights of avg model are :: \n {avg_weights}')

    # Visualize weights after every VISUALIZE_WEIGHT_AFTER epochs
    if (epoch + 1) % VISUALIZE_WEIGHT_AFTER == 0:
        for i, model in enumerate(models):
            worker_weights[i].append(model.get_weights())

    # Print losses and accuracies for the epoch
    epoch_train_loss = np.mean([history['loss'][0] for history in worker_histories])
    epoch_val_loss = np.mean([history['val_loss'][0] for history in worker_histories])
    epoch_train_acc = np.mean([history['accuracy'][0] for history in worker_histories])
    epoch_val_acc = np.mean([history['val_accuracy'][0] for history in worker_histories])
    print(f'Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}')
    print(f'Train Accuracy: {epoch_train_acc:.4f}, Val Accuracy: {epoch_val_acc:.4f}')

# # Restore the original policy
tf.keras.mixed_precision.set_global_policy(None)

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

# Calculate the global model
final_model = create_model()
final_model.compile(optimizer=Adam(),
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])
final_model.set_weights(avg_weights)

test_loss, test_acc = final_model.evaluate(x_test, y_test, verbose=0)
print(f'Test accuracy: {test_acc}')

print(f"For training model with {N_WORKERS} workers for {EPOCHS} epochs, time taken is {etime-stime:.4f}s")


policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)

In [None]:
check_quantization_type(final_model)

In [None]:
#  Calculate the global model
final_model = create_model()
final_model.compile(optimizer=Adam(),
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])
final_model.set_weights(avg_weights)
check_quantization_type(final_model)

In [None]:
import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

# Function to prepare data for plotting
def prepare_data_for_plotting(worker_metrics, metric_name, ttype="Training"):
    data = []
    for i in range(N_WORKERS):
        for epoch in range(EPOCHS):
            data.append([epoch + 1, worker_metrics[i][epoch], f'Worker {i+1}', f'{ttype} {metric_name}'])
    return pd.DataFrame(data, columns=['Epoch', metric_name, 'Worker', 'Type'])

# Prepare data for plotting
trn_loss_df = prepare_data_for_plotting(worker_train_losses, 'Loss')
val_loss_df = prepare_data_for_plotting(worker_val_losses, 'Loss', ttype="Validation")

trn_accuracy_df = prepare_data_for_plotting(worker_train_accuracies, 'Accuracy')
val_accuracy_df = prepare_data_for_plotting(worker_val_accuracies, 'Accuracy', ttype="Validation")

# Combine the data for losses and accuracies
loss_df = pd.concat([trn_loss_df, val_loss_df])
accuracy_df = pd.concat([trn_accuracy_df, val_accuracy_df])

# Plot losses
plt.figure(figsize=(14, 8))
sns.lineplot(data=loss_df, x='Epoch', y='Loss', hue='Worker', style='Type', markers=True, dashes=False)
plt.title('Training and Validation Losses for Each Worker')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.show()

# Plot accuracies
plt.figure(figsize=(14, 8))
sns.lineplot(data=accuracy_df, x='Epoch', y='Accuracy', hue='Worker', style='Type', markers=True, dashes=False)
plt.title('Training and Validation Accuracies for Each Worker')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Print metrics
print(f'Test accuracy: {test_acc}')
print('Confusion Matrix:')
print(confusion_matrix(y_test, y_pred_classes))
print('Classification Report:')
print(classification_report(y_test, y_pred_classes))


In [None]:
def mean_axis_0(lst):
    num_rows = len(lst)
    num_cols = len(lst[0]) if lst else 0
    if num_cols == 0:
        return []
    min_len = min(len(row) for row in lst)
    col_means = []
    for j in range(min_len):
        col_sum = sum(lst[i][j] for i in range(num_rows))
        col_mean = col_sum / num_rows
        col_means.append(col_mean)
    return col_means , min_len


def plot_weight_distributions_across_epochs(worker_weights, epochs, n_sigma_away=1, bucket_size=0.001):
    num_workers = len(worker_weights)
    cols = (num_workers)+1  # Number of columns for subplots

    for epoch in range(epochs):
        fig, axes = plt.subplots(1, cols, figsize=(cols * 5, 5))  # Adjust figsize for wider graphs
        fig.suptitle(f"PDF Distribution of Weights - Epoch {epoch + 1}", fontsize=32)

        store_hist = []
        for worker_id, weights_history in enumerate(worker_weights):
            # Check if there are any epochs to plot
            if epoch >= len(weights_history):
                print(f"No epochs to plot for Worker {worker_id + 1}")
                continue

            flattened_weights = np.concatenate([w.flatten() for w in weights_history[epoch]])

            # Remove weights that are n_sigma_away from the mean
            mean_weight = np.mean(flattened_weights)
            std_weight = np.std(flattened_weights)
            filtered_weights = flattened_weights[np.abs(flattened_weights - mean_weight) <= n_sigma_away * std_weight]

            min_weight = min(filtered_weights)
            max_weight = max(filtered_weights)

            bins = np.arange(min_weight, max_weight, bucket_size)
            hist, _ = np.histogram(filtered_weights, bins=bins, density=True)

            store_hist.append(np.array(hist))

            ax = axes[worker_id]
            sns.histplot(filtered_weights, bins=bins, kde=True, ax=ax)
            ax.set_title(f"Worker {worker_id + 1}")
            ax.set_xlabel("Weight Value (ignore values)")
            ax.set_ylabel("Density")

        avg_hist, min_len = mean_axis_0(store_hist)
        diff_hist = [hist[:min_len] - avg_hist for hist in store_hist]
#         diff_hist = [hist[:min_len] - store_hist[-1][:min_len] for hist in store_hist]


        plt.plot(avg_hist)
        plt.title(f"Avg Worker")
        plt.xlabel("Weight Value")
        plt.ylabel("Density")
        plt.show()

        for i in range(len(diff_hist)):
            diff_hist[i] = [2000*j for j in diff_hist[i]]

        plt.figure(figsize=(20, 10))  # Adjust the figure size as needed
        for i, sublist in enumerate(diff_hist):
            x = np.arange(len(sublist))
            y = sublist
            plt.scatter(x, y, label=f'Worker {i+1}')
            plt.title(f"Dot plot of Difference of Worker{i}")
            plt.xlabel("Weight Value (ignore values)")
            plt.ylabel("Density")
            plt.legend()


        for i in range(len(diff_hist)):
            plt.plot(diff_hist[i])
            plt.title(f"Difference of Worker to the average")
            plt.xlabel("Weight Value (ignore values)")
            plt.ylabel("Density")
            plt.legend()

        plt.tight_layout(rect=[0, 0, 1, 0.95])
        plt.show()

#         break
#         return avg_hist, store_hist

# Call the function to plot weight distributions across epochs
plot_weight_distributions_across_epochs(worker_weights, EPOCHS)


In [None]:
print(models[0].summary())

# Iterate over each epoch
for epoch in range(len(worker_weights[0])):
    # Iterate over each layer
    for layer_id in range(len(worker_weights[0][0])):
        # Initialize an empty list to store flattened weights of the current layer for all workers
        layer_weights_all_workers = []

        # Collect weights of the current layer for all workers at the current epoch
        for worker_id, weights_history in enumerate(worker_weights):
            # Append the flattened weights of the current layer for the current worker and epoch
            layer_weights_all_workers.append(weights_history[epoch][layer_id].flatten())

        # Plot the histograms of weights of the current layer for all workers in one figure
        plt.figure(figsize=(10, 5))
        plt.hist(layer_weights_all_workers, bins=50, label=[f'Worker {i+1}' for i in range(N_WORKERS)], alpha=0.7)
        plt.title(f'Epoch {epoch + 1}, Layer {layer_id + 1} - Distribution of Weights')
        plt.xlabel('Weight Value')
        plt.ylabel('Frequency')
        plt.legend()
        plt.show()

In [None]:
def plot_weight_distribution(weights_history, worker_id, n_sigma_away=1, bucket_size=0.001):
    num_epochs = len(weights_history)

    # Check if there are any epochs to plot
    if num_epochs == 0:
        print(f"No epochs to plot for Worker {worker_id + 1}")
        return

    cols = N_WORKERS  # Number of columns for subplots
    rows = num_epochs  # Calculate the number of rows needed

    fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5))  # Adjust figsize for wider graphs
    fig.suptitle(f"PDF Distribution of Weights for Worker {worker_id + 1}", fontsize=32)

    bin_values = []

    for epoch, weights in enumerate(weights_history):
        flattened_weights = np.concatenate([w.flatten() for w in weights])

        # Remove weights that are 3 sigma away from the mean
        mean_weight = np.mean(flattened_weights)
        std_weight = np.std(flattened_weights)
        filtered_weights = flattened_weights[np.abs(flattened_weights - mean_weight) <= n_sigma_away * std_weight]

        min_weight = min(filtered_weights)
        max_weight = max(filtered_weights)

        bins = np.arange(min_weight, max_weight, bucket_size)
        hist, _ = np.histogram(filtered_weights, bins=bins, density=True)

        bin_values.append(hist)

        ax = axes[epoch // cols, epoch % cols]
        sns.histplot(filtered_weights, bins=bins, kde=True, ax=ax)
        ax.set_title(f"Epoch {epoch + 1}")
        ax.set_xlabel("Weight Value")
        ax.set_ylabel("Density")

    # Remove any empty subplots
    for i in range(num_epochs, rows * cols):
        fig.delaxes(axes.flatten()[i])

    plt.tight_layout(rect=[0, 0, 1, 1])
    plt.show()

    return bin_values

histograms_array = [[] for i in range(EPOCHS)]
# Visualize weights for each worker
for worker_id, weights_history in enumerate(worker_weights):
    tmp2 = [i for i in plot_weight_distribution(weights_history, worker_id)]
    for i in range(EPOCHS):
        histograms_array[i].append((tmp2[i]))

In [None]:
# Full plot with sigma away ; Low visuality

# Define function to remove outliers and plot PDF distribution of weights
def plot_weight_distribution(weights_history, worker_id, bucket_size=0.001):
    num_epochs = len(weights_history)
    cols = N_WORKERS  # Number of columns for subplots
    rows = num_epochs  # Calculate the number of rows needed

    fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5))  # Adjust figsize for wider graphs
    fig.suptitle(f"PDF Distribution of Weights for Worker {worker_id + 1}", fontsize=16)

    for epoch, weights in enumerate(weights_history):
        flattened_weights = np.concatenate([w.flatten() for w in weights])

        # Remove weights that are 3 sigma away from the mean
        mean_weight = np.mean(flattened_weights)
        std_weight = np.std(flattened_weights)
        filtered_weights = flattened_weights[np.abs(flattened_weights - mean_weight) <= 3 * std_weight]

        min_weight = min(filtered_weights)
        max_weight = max(filtered_weights)

        bins = np.arange(min_weight, max_weight, bucket_size)

        ax = axes[epoch // cols, epoch % cols]
        sns.histplot(filtered_weights, bins=bins, kde=True, ax=ax)
        ax.set_title(f"Epoch {epoch + 1}")
        ax.set_xlabel("Weight Value")
        ax.set_ylabel("Density")

    # Remove any empty subplots
    for i in range(num_epochs, rows * cols):
        fig.delaxes(axes.flatten()[i])

    plt.tight_layout(rect=[0, 0, 1, 1])
    plt.show()

# Visualize weights for each worker
for worker_id, weights_history in enumerate(worker_weights):
    plot_weight_distribution(weights_history, worker_id)
#     break  # Remove this line if you want to visualize for all workers
