<a href="https://colab.research.google.com/github/ChrisW2420/FedDistill/blob/main/Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Prototype

## Import packages

In [1]:
!pip install -q tensorflow-model-optimization

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━[0m [32m143.4/242.5 kB[0m [31m4.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.sparsity import keras as sparsity
import tf_keras as keras
import tempfile
from keras.callbacks import EarlyStopping, Callback

TODO: add wandb to all experiments

In [None]:
!pip install wandb
import wandb
wandb.login()
from wandb.keras import WandbMetricsLogger

## Prepare Dataset

In [4]:
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
validation_split = 0.1

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


## Models

In [5]:
def smallCNN():
  model = keras.Sequential(
      [
          keras.Input(shape=(28, 28, 1)),
          keras.layers.Conv2D(8, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.ReLU(),
          keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          keras.layers.Conv2D(8, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.Flatten(),
          keras.layers.Dense(10),
      ],
      name="smallcnn",
  )
  return model

def mediumCNN():
  model = keras.Sequential(
      [
          keras.Input(shape=(28, 28, 1)),
          keras.layers.Conv2D(8, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.ReLU(),
          keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.ReLU(),
          keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.Flatten(),
          keras.layers.Dense(10),
      ],
      name="mediumcnn",
  )
  return model

def bigCNN():
  model = keras.Sequential(
      [
          keras.Input(shape=(28, 28, 1)),
          keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.ReLU(),
          keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.ReLU(),
          keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.Flatten(),
          keras.layers.Dense(10),
      ],
      name="bigcnn",
  )
  return model

In [11]:
# early stopping when training converges on validation loss
early_stopping = EarlyStopping(
    monitor='val_loss',
    min_delta=0.001,  # only consider as improvement significant changes
    patience=2,      # number of epochs with no improvement after which training will be stopped
    verbose=1,
    mode='min'        # 'min' because we want to minimize the loss
)

def trainCNN(model, _epoch, x_train = x_train, y_train = y_train):
  model.compile(
      optimizer='adam',
      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[keras.metrics.SparseCategoricalAccuracy()]
  )

  model.fit(x_train, y_train, batch_size=batch_size, epochs=_epoch,validation_split=validation_split, callbacks=[early_stopping])
  model.evaluate(x_test, y_test)

  return model

  # _, pretrained_weights = tempfile.mkstemp('.tf')

  # model.save_weights(pretrained_weights)

  # return pretrained_weights

#Pruning

### Basic implementation

In [15]:
# functions
def prune_finetrain(base_model, _epochs, target_sparsity = 0.5, x_train = x_train, y_train = y_train):
  callbacks = [
      sparsity.UpdatePruningStep(),
      early_stopping
  ]
  steps_per_epoch = len(x_train)*(1-validation_split) // batch_size
  pruning_schedule = sparsity.PolynomialDecay(initial_sparsity=0, final_sparsity=target_sparsity,
                                              begin_step=0, end_step=int(steps_per_epoch*_epochs)) # increase sparsity

  model_for_pruning = sparsity.prune_low_magnitude(base_model) #default constant sparsity of 50%
  model_for_pruning.summary()

  model_for_pruning.compile(
        optimizer='adam',
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()]
  )

  model_for_pruning.fit(
      x_train,
      y_train,
      batch_size=batch_size,
      validation_split=validation_split,
      callbacks=callbacks,
      epochs=_epochs,
  )

  return model_for_pruning

def get_model_sparsity(model):
    total_weights = 0
    zero_weights = 0
    for weight in model.get_weights():
        total_weights += weight.size
        zero_weights += np.count_nonzero(weight == 0)
    return zero_weights / total_weights

def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, keras_file = tempfile.mkstemp('.h5')
  model.save(keras_file, include_optimizer=False)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)

In [None]:
# training base model

model = mediumCNN()
model = trainCNN(model, 2)
_, pretrained_weights = tempfile.mkstemp('.tf')
model.save_weights(pretrained_weights)

# pruning
base_model = mediumCNN()
base_model.load_weights(pretrained_weights) # optional but recommended.
pruned_model = prune_finetrain(base_model, 5)

# continue training base model for performance comparison
base_model_copy = mediumCNN()
base_model_copy.load_weights(pretrained_weights)
base_model_copy = trainCNN(base_model_copy, 5)

Epoch 1/2
Epoch 2/2
Model: "mediumcnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 14, 14, 8)         154       
 _3 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_re_lu_  (None, 14, 14, 8)         1         
 2 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_max_po  (None, 14, 14, 8)         1         
 oling2d_2 (PruneLowMagnitu                                      
 de)                                                             
                                                                 
 prune_low_magnitude_conv2d  (None, 7, 7, 16)          2322      
 _4 (PruneLowMagnitude)                                          
                                     

In [None]:
pruned_model_stripped = sparsity.strip_pruning(pruned_model)
print("final model")
pruned_model.summary()

final model
Model: "mediumcnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 14, 14, 8)         154       
 _3 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_re_lu_  (None, 14, 14, 8)         1         
 2 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_max_po  (None, 14, 14, 8)         1         
 oling2d_2 (PruneLowMagnitu                                      
 de)                                                             
                                                                 
 prune_low_magnitude_conv2d  (None, 7, 7, 16)          2322      
 _4 (PruneLowMagnitude)                                          
                                             

In [None]:
# compare accuracy, sparsity and file size
print(f"Base Model test accuracy: {base_model_copy.evaluate(x_test, y_test)[1]:.2f}%")
print(f"Pruned Model test accuracy: {pruned_model.evaluate(x_test, y_test)[1]:.2f}%")

sparsity_percentage = get_model_sparsity(base_model_copy) * 100
print(f"Base Model sparsity: {sparsity_percentage:.2f}%")

sparsity_percentage = get_model_sparsity(pruned_model) * 100
print(f"Pruned Model sparsity: {sparsity_percentage:.2f}%")

print('\n')
print("Size of gzipped base model: %.2f bytes" % (get_gzipped_model_size(base_model_copy)))
print("Size of gzipped pruned model: %.2f bytes" % (get_gzipped_model_size(pruned_model_stripped)))

Base Model test accuracy: 0.98%
Pruned Model test accuracy: 0.98%
Base Model sparsity: 0.00%
Pruned Model sparsity: 49.74%




  saving_api.save_model(


Size of gzipped base model: 25706.00 bytes
Size of gzipped pruned model: 16939.00 bytes


## Pruning to as small as possible with > 99.5% accuracy as best

TODO: make this feature work:

either: target sparsity or target accuracy with max sparsity

early ending on convergence

In [None]:
class CustomPolynomialDecay(sparsity.PolynomialDecay):
    def __init__(self, initial_sparsity, final_sparsity, begin_step, end_step, power=3, frequency=100):
        super().__init__(initial_sparsity, final_sparsity, begin_step, end_step, power=power, frequency=frequency)
        self.freeze_sparsity = False
        self.current_sparsity = 0

    # def _should_prune_in_step(self, step, begin_step, end_step, frequency):
    #     print('testing')
    #     if self.freeze_sparsity:
    #         print('freezed')
    #         return False
    #     else:
    #         return sparsity.PruningSchedule._should_prune_in_step(self, step, begin_step, end_step, frequency)

    def __call__(self, step):
        print('testing')
        if self.freeze_sparsity:
          return (False, self.current_sparsity)
        else:
          _, self.current_sparsity = super().__call__(step)
          return (_, self.current_sparsity)

    def freeze(self):
        self.freeze_sparsity = True

class FreezePruningOnAccuracyDrop(Callback):
    def __init__(self, pruning_schedule, threshold=0.96):
        super().__init__()
        self.pruning_schedule = pruning_schedule
        self.threshold = threshold

    def on_batch_end(self, batch, logs=None):
        current_accuracy = logs.get('sparse_categorical_accuracy')
        if current_accuracy < self.threshold:
            print(f"\nAccuracy has dropped below {self.threshold*100:.2f}%, freezing further pruning.")
            self.pruning_schedule.freeze()
            print(self.pruning_schedule.freeze_sparsity)

In [None]:
# function
# save good performance ones
# class CustomEarlyStopping(Callback):
#     def on_batch_end(self, batch, logs=None):
#         current_accuracy = logs.get('sparse_categorical_accuracy')
#         if current_accuracy <= 0.975:
#             print(f"\nStopping training as accuracy has dropped to {current_accuracy*100:.2f}%")
#             self.model.stop_training = True

def prune_finetrain(base_model, _epochs):
  steps_per_epoch = len(x_train)*(1-validation_split) // batch_size
  print('steps_per_epoch: ', steps_per_epoch)
  # pruning_schedule = sparsity.PolynomialDecay(initial_sparsity=0.20, final_sparsity=0.70,
  #                                             begin_step=0, end_step=int(steps_per_epoch*_epochs)) # increase sparsity

  total_steps = int(_epochs * steps_per_epoch)
  pruning_schedule = CustomPolynomialDecay(
    initial_sparsity=0.0,
    final_sparsity=0.9,
    begin_step=batch_size,
    end_step=total_steps,
    power=3
  )

  model_for_pruning = sparsity.prune_low_magnitude(base_model, pruning_schedule=pruning_schedule)
  model_for_pruning.summary()


  callbacks = [
    sparsity.UpdatePruningStep(),
    sparsity.PruningSummaries(log_dir='/path/to/logs'),
    FreezePruningOnAccuracyDrop(pruning_schedule)
  ]


  model_for_pruning.compile(
        optimizer='adam',
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()]
  )

  model_for_pruning.fit(
      x_train,
      y_train,
      batch_size=batch_size,
      validation_split=validation_split,
      callbacks=callbacks,
      epochs=_epochs,
  )

  return model_for_pruning

In [None]:
# adaptive pruning
base_model = mediumCNN()
base_model.load_weights(pretrained_weights) # optional but recommended.
adaptive_pruned_model = prune_finetrain(base_model, 3)

steps_per_epoch:  843.0
Model: "mediumcnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 14, 14, 8)         154       
 _48 (PruneLowMagnitude)                                         
                                                                 
 prune_low_magnitude_re_lu_  (None, 14, 14, 8)         1         
 32 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_max_po  (None, 14, 14, 8)         1         
 oling2d_32 (PruneLowMagnit                                      
 ude)                                                            
                                                                 
 prune_low_magnitude_conv2d  (None, 7, 7, 16)          2322      
 _49 (PruneLowMagnitude)                                         
                                 



Epoch 2/3
  1/844 [..............................] - ETA: 9s - loss: 0.0804 - sparse_categorical_accuracy: 0.9688
Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has dropped below 96.00%, freezing further pruning.
True
Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has dropped below 96.00%, freezing further pruning.
True
Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has dropped below 96.00%, freezing further pruning.
True

Accuracy has drop

In [None]:
# compare performance between adaptive pruning and basic pruning
print(f"adaptive Model test accuracy: {adaptive_pruned_model.evaluate(x_test, y_test)[1]:.2f}%")
print(f"pruned Model test accuracy: {pruned_model.evaluate(x_test, y_test)[1]:.2f}%")

sparsity_percentage = get_model_sparsity(adaptive_pruned_model) * 100
print(f"adaptive Model sparsity: {sparsity_percentage:.2f}%")

sparsity_percentage = get_model_sparsity(pruned_model) * 100
print(f"Pruned Model sparsity: {sparsity_percentage:.2f}%")

adaptive_pruned_model = sparsity.strip_pruning(adaptive_pruned_model)
print('\n')
print("Size of gzipped adaptive model: %.2f bytes" % (get_gzipped_model_size(adaptive_pruned_model)))
print("Size of gzipped pruned model: %.2f bytes" % (get_gzipped_model_size(pruned_model_stripped)))

adaptive Model test accuracy: 0.92%
pruned Model test accuracy: 0.98%
adaptive Model sparsity: 89.55%




Pruned Model sparsity: 49.74%






Size of gzipped adaptive model: 7417.00 bytes
Size of gzipped pruned model: 16924.00 bytes


# Prune + KD

## Knowledge Distillation functions

In [18]:
class Distiller(keras.Model):
    def __init__(self, teacher, student, alpha=0.1, temperature=3, **kwargs):
        super(Distiller, self).__init__(**kwargs)
        self.teacher = teacher
        self.student = student

    def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha, temperature, **kwargs):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics, **kwargs)
        self.student_loss_fn = student_loss_fn
        self.student.compile(optimizer=optimizer, metrics=metrics, loss=self.student_loss_fn)
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack the data
        x, y = data

        # Forward pass of teacher with no gradient tracking
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of the student
            student_predictions = self.student(x, training=True)

            # Calculate the task-specific loss
            task_loss = self.student_loss_fn(y, student_predictions)

            # Calculate the soft targets and the distillation loss
            soft_targets = tf.nn.softmax(teacher_predictions / self.temperature)
            student_soft = tf.nn.softmax(student_predictions / self.temperature)
            distillation_loss = self.distillation_loss_fn(soft_targets, student_soft)

            # Calculate the total loss
            total_loss = (1 - self.alpha) * task_loss + self.alpha * distillation_loss * (self.temperature ** 2)

        # Compute gradients and update weights
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(total_loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics
        self.compiled_metrics.update_state(y, student_predictions)
        results = {m.name: m.result() for m in self.metrics}
        results.update({"task_loss": task_loss, "distillation_loss": distillation_loss, "total_loss": total_loss})
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Forward pass of the student
        y_pred = self.student(x, training=False)

        # Calculate the task-specific loss
        task_loss = self.student_loss_fn(y, y_pred)

        # Update the metrics
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def call_model(self):
      return self.student

def train_distill(_student, _teacher, _epoch, x_train = x_train, y_train=y_train, _alpha=0.1, _temp=3):
  distiller = Distiller(student=_student, teacher=_teacher)
  distiller.compile(
      optimizer=keras.optimizers.Adam(),
      metrics=[keras.metrics.SparseCategoricalAccuracy()],
      student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      distillation_loss_fn=keras.losses.KLDivergence(),
      alpha=_alpha,
      temperature=_temp,
  )

  # Distill teacher to student

  distiller.fit(x_train, y_train, epochs=_epoch, validation_split=validation_split)
  distiller.evaluate(x_test, y_test)

  return distiller

In [27]:
# simkd
class SimKDDistill(Distiller):
  def __init__(self, teacher, student,  alpha=0.1, temperature=3, **kwargs):
      super(SimKDDistill, self).__init__(teacher, student,**kwargs)

      # Assign weights and biases to the last layer, biases are always the same dimension as number of classes
      self.student.layers[-1].set_weights(self.teacher.layers[-1].get_weights())

      # Freeze the last layer (prevent it from updating)
      self.student.layers[-1].trainable = False

def train_simKD(_student, _teacher, _epoch, x_train = x_train, y_train=y_train, _alpha=0.1, _temp=3):
  distiller = SimKDDistill(student=_student, teacher=_teacher)
  distiller.compile(
      optimizer=keras.optimizers.Adam(),
      metrics=[keras.metrics.SparseCategoricalAccuracy()],
      student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      distillation_loss_fn=keras.losses.KLDivergence(),
      alpha=_alpha,
      temperature=_temp
  )

  distiller.fit(x_train, y_train, epochs=_epoch, validation_split=validation_split)
  distiller.evaluate(x_test, y_test)
  return distiller

## Implementation

first of all, we have 2 homogenous dataset and 2 models of the same architecture, they share the same testing dataset

In [13]:
# Data preparation
mid_point = len(x_train) // 2
# D1
x_train_1 = x_train[:mid_point]
y_train_1 = y_train[:mid_point]
x_train_2 = x_train[mid_point:]
y_train_2 = y_train[mid_point:]

# TODO: visualise distribution
def print_dist(y_train, name):
  unique, counts = np.unique(y_train, return_counts=True)
  label_distribution = dict(zip(unique, counts))

  # Print the label distribution
  print("Label Distribution in Training Set ", name, ":")
  for label, count in label_distribution.items():
      print(f"Label {label}: {count} instances")

print_dist(y_train_1, '1')
print_dist(y_train_2, '2')

Label Distribution in Training Set  1 :
Label 0: 2961 instances
Label 1: 3423 instances
Label 2: 2948 instances
Label 3: 3073 instances
Label 4: 2926 instances
Label 5: 2709 instances
Label 6: 2975 instances
Label 7: 3107 instances
Label 8: 2875 instances
Label 9: 3003 instances
Label Distribution in Training Set  2 :
Label 0: 2962 instances
Label 1: 3319 instances
Label 2: 3010 instances
Label 3: 3058 instances
Label 4: 2916 instances
Label 5: 2712 instances
Label 6: 2943 instances
Label 7: 3158 instances
Label 8: 2976 instances
Label 9: 2946 instances


We train model 1 on dataset 1

In [21]:
model1 = mediumCNN()
model1 = trainCNN(model1, 5, x_train = x_train_1, y_train = y_train_1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


Prune Model1 to Model2 training on dataset 2. ie. model2 weights are initialised with model1 weights

In [22]:
_, weights1 = tempfile.mkstemp('.tf')
model1.save_weights(weights1)

# pruning
model2 = mediumCNN()
model2.load_weights(weights1)
model2 = prune_finetrain(model2, 5, 0.6, x_train = x_train_2, y_train = y_train_2)
model2 = sparsity.strip_pruning(model2)

Model: "mediumcnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 14, 14, 8)         154       
 _12 (PruneLowMagnitude)                                         
                                                                 
 prune_low_magnitude_re_lu_  (None, 14, 14, 8)         1         
 8 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_max_po  (None, 14, 14, 8)         1         
 oling2d_8 (PruneLowMagnitu                                      
 de)                                                             
                                                                 
 prune_low_magnitude_conv2d  (None, 7, 7, 16)          2322      
 _13 (PruneLowMagnitude)                                         
                                                         

relay the knowledge learnt from dataset2 back to model1 with knowledge distillation: model2 is teacher, model1 is student

In [25]:
model1_kd = mediumCNN()
model1_kd.load_weights(weights1)
model1_kd = train_distill(model1_kd, model2, 5, x_train = x_train_1, y_train=y_train_1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [28]:
model1_simkd = mediumCNN()
model1_simkd.load_weights(weights1)
model1_simkd = train_simKD(model1_simkd, model2, 5, x_train = x_train_1, y_train=y_train_1)


Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [30]:
model1_plain = mediumCNN()
model1_plain.load_weights(weights1)
model1_plain = trainCNN(model1_plain, 5, x_train = x_train_1, y_train = y_train_1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## Experiments
compare the performance of model1 after kd, simkd, raw training
- use heterogeneous datasets
- use different sized models

# Prune + KD + GAN

This prototype replaces the "public dataset" D1 with a GAN generated dataset

TODO:
1. build a MNIST Generating GAN model
2. use model2 $($classifier$)$ and D2 as input through GAN to generate a public dataset PD
3. connect model1 to PD instead of D1, repeat the KD step from model2 to model1

# Prune + KD + GAN + FL
This prototype implements the algorithm in a distributed setting
TODO:
1. implement a FedAvg aggregator/server
2. build a centralised FL system with n clients connected to the server
3. design experiments to assess accuracy, efficiency, generalisation on homogenoeous data
4. repeat experiments on heterogeneous data, identical model sparsity
5. repeat experiments on heterogeneous data, different model sparsity, mimicing different computational capability of clients