<a href="https://colab.research.google.com/github/ChrisW2420/FedPKDG/blob/main/FedPKDG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FedPKDG -- Prune + KD + GAN + FL
This prototype implements the algorithm in a distributed setting
TODO:
1. implement a FedAvg aggregator/server
2. build a centralised FL system with n clients connected to the server
3. design experiments to assess accuracy, efficiency, generalisation on homogenoeous data
4. repeat experiments on heterogeneous data, identical model sparsity
5. repeat experiments on heterogeneous data, different model sparsity, mimicing different computational capability of clients

# Setup and Imports

In [1]:
# NB: package versions are very important
!pip install -q tensorflow-model-optimization # for pruning
!pip install -q git+https://github.com/tensorflow/docs # newest tf
!pip install --upgrade keras #newest keras

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.4/242.5 kB[0m [31m1.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for tensorflow-docs (setup.py) ... [?25l[?25hdone
Collecting keras
  Downloading keras-3.3.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
Collecting namex (from keras)
  Downloading namex-0.0.8-py3-none-any.whl (5.8 kB)
Collecting optree (from keras)
  Downloading optree-0.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m21.9 MB/s[

In [2]:
# 3 versions of keras are used for different functionalities, imported as different names
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.sparsity import keras as sparsity
import tf_keras as keras_model #only for pruning
from tf_keras import layers as model_layers
import keras
import tempfile
from tf_keras.callbacks import EarlyStopping, Callback
from keras import ops, layers
from tensorflow_docs.vis import embed # for GAN
import matplotlib.pyplot as plt

In [None]:
# Logging metrics with WandB
!pip install wandb
import wandb
wandb.login()
from wandb.keras import WandbMetricsLogger

## Loading Data

In [4]:
# MNIST
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

# Components Implementation

## Model zoo

### CNN

In [34]:
def smallCNN():
  model = keras_model.Sequential(
      [
          keras_model.Input(shape=(28, 28, 1)),
          model_layers.Conv2D(8, (3, 3), strides=(2, 2), padding="same"),
          model_layers.LeakyReLU(alpha=0.2),
          model_layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          model_layers.Conv2D(8, (3, 3), strides=(2, 2), padding="same"),
          model_layers.Flatten(),
          model_layers.Dense(10),
      ],
      name="smallcnn",
  )
  return model

def mediumCNN():
  model = keras_model.Sequential(
      [
          keras_model.Input(shape=(28, 28, 1)),
          model_layers.Conv2D(8, (3, 3), strides=(2, 2), padding="same"),
          model_layers.LeakyReLU(alpha=0.2),
          model_layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          model_layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
          model_layers.LeakyReLU(alpha=0.2),
          model_layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          model_layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
          model_layers.Flatten(),
          model_layers.Dense(10),
      ],
      name="mediumcnn",
  )
  return model

def bigCNN():
  model = keras_model.Sequential(
      [
          keras_model.Input(shape=(28, 28, 1)),
          model_layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
          model_layers.LeakyReLU(alpha=0.2),
          model_layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          model_layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
          model_layers.LeakyReLU(alpha=0.2),
          model_layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          model_layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
          model_layers.Flatten(),
          model_layers.Dense(10),
      ],
      name="bigcnn",
  )
  return model

### GAN

In [35]:
num_channels = 1
num_classes = 10
image_size = 28
latent_dim = 128

generator_in_channels = latent_dim + num_classes
discriminator_in_channels = num_channels + num_classes

# Create the discriminator.
discriminator = keras.Sequential(
    [
        keras.layers.InputLayer((28, 28, discriminator_in_channels)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

# Create the generator.
generator = keras.Sequential(
    [
        keras.layers.InputLayer((generator_in_channels,)),
        # We want to generate 128 + num_classes coefficients to reshape into a
        # 7x7x(128 + num_classes) map.
        layers.Dense(7 * 7 * generator_in_channels),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((7, 7, generator_in_channels)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

## CNN

In [36]:
class CNN(): #keras_model.Model
  def __init__(self, config_name, **kwargs):
    #super(CNN, self).__init__(**kwargs)
    if config_name == 'small':
      self.model = smallCNN()
    elif config_name == 'medium':
      self.model = mediumCNN()
    elif config_name == 'big':
      self.model = bigCNN()
    else:
      print('default model of medium CNN')
      self.model = mediumCNN()
    self.optimizer = 'adam'
    self.task_loss = keras_model.losses.SparseCategoricalCrossentropy(from_logits=True)
    self.metricslist = [keras_model.metrics.SparseCategoricalAccuracy()]
    self.validation_split = 0.1
    self.early_stopping = EarlyStopping(
      monitor='val_loss',
      min_delta=0.001,  # only consider as improvement significant changes
      patience=2,      # number of epochs with no improvement after which training will be stopped
      verbose=1,
      mode='min'        # 'min' because we want to minimize the loss
    )
    self.callbacks = []


  def train(self, training_data, testing_data = None, epochs = 10, is_earlystop = True, **kwargs):
    #super(CNN, self).compile(optimizer=self.optimizer, metrics=self.metricslist, **kwargs)
    self.model.compile(
      optimizer=self.optimizer,
      loss=self.task_loss,
      metrics=self.metricslist
    )
    if is_earlystop and self.early_stopping not in self.callbacks:
      self.callbacks.append(self.early_stopping)
    x_train, y_train = training_data
    self.model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,validation_split=self.validation_split, callbacks=self.callbacks)
    if testing_data:
      x_test, y_test = testing_data
      self.model.evaluate(x_test, y_test)
    return self.model


## GAN

In [37]:
class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.seed_generator = keras.random.SeedGenerator(1337)
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        # Unpack the data.
        real_images, one_hot_labels = data

        # Add dummy dimensions to the labels so that they can be concatenated with
        # the images. This is for the discriminator.
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = ops.repeat(
            image_one_hot_labels, repeats=[image_size * image_size]
        )
        image_one_hot_labels = ops.reshape(
            image_one_hot_labels, (-1, image_size, image_size, num_classes)
        )

        # Sample random points in the latent space and concatenate the labels.
        # This is for the generator.
        batch_size = ops.shape(real_images)[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )
        random_vector_labels = ops.concatenate(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Decode the noise (guided by labels) to fake images.
        generated_images = self.generator(random_vector_labels)

        # Combine them with real images. Note that we are concatenating the labels
        # with these images here.
        fake_image_and_labels = ops.concatenate(
            [generated_images, image_one_hot_labels], -1
        )
        real_image_and_labels = ops.concatenate([real_images, image_one_hot_labels], -1)
        combined_images = ops.concatenate(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )

        # Assemble labels discriminating real from fake images.
        labels = ops.concatenate(
            [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
        )

        # Train the discriminator.
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space.
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )
        random_vector_labels = ops.concatenate(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Assemble labels that say "all real images".
        misleading_labels = ops.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = ops.concatenate(
                [fake_images, image_one_hot_labels], -1
            )
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }

In [38]:
# image generation functions
def generate_image(generator, target_class, latent_dim):
    noise_matrix = keras.random.normal(shape=(1, latent_dim))
    # Convert the target label to one-hot encoded vectors.
    target_label = keras.utils.to_categorical([target_class], num_classes)
    target_label = ops.cast(target_label, "float32")
    noise_and_labels = ops.concatenate([noise_matrix, target_label], 1)
    fake = generator.predict(noise_and_labels,verbose = 0)
    return fake

def pseudoDataset(generator, total_num, latent_dim): # producing equal numbers of samples for each class
    pseudo_images = []
    for num in range(10):
      target_class = num
      print('Generating', int(total_num/10), 'fake images of digit', num, '......')
      for _ in range(int(total_num/10)):
        generated_images = generate_image(target_class, latent_dim)
        generated_images *= 255.0
        converted_images = generated_images.astype(np.uint8)
        converted_images = ops.image.resize(converted_images, (28, 28)).numpy().astype(np.uint8)
        pseudo_images.append(converted_images)
    pseudo_images = np.concatenate(pseudo_images, axis=0)
    pseudo_labels = np.repeat(np.arange(10), int(total_num/10))
    return pseudo_images, pseudo_labels

## Pruning

## Knowledge Distillation

# Helper Function Implementation

TODO:
Dataset:
- Dataloader
- heterogeneous dataset partition
- data augmentation

visualisation:
- dataset example visualisation
- data distribution visualisation
- confusion matrix
-

In [39]:
def set_model_weights_to_zero(model):
    for layer in model.layers:
        zero_weights = [np.zeros_like(w) for w in layer.get_weights()]
        layer.set_weights(zero_weights)
    return model

## Callback zoo

In [42]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    min_delta=0.001,  # only consider as improvement significant changes
    patience=2,      # number of epochs with no improvement after which training will be stopped
    verbose=1,
    mode='min'        # 'min' because we want to minimize the loss
    )

# Client Implementation

In [40]:
class Client():
  def __init__(self, model_config_name, train_dataset, generator = generator, discriminator = discriminator):
    self.cnn = smallCNN() #TODO: change this to taking in different models
    self.generator = generator
    self.discriminator = discriminator
    self.latent_dim = 128 # hyperparam, can tune
    self.private_dataset = train_dataset
    self.batch_size = 64 # hyperparam, can tune
    self.validation_split=0.1

  def local_train(self, epochs = 5, is_prune = False, sparsity = 0.5, fine_tune_epochs = 2):
    if is_prune:
      #!!!TODO: prune
      return self.cnn
    #TODO: test train_CNN
    self.train_cnn(epochs)
    return self.cnn

  def train_cnn(self, epochs):
    self.cnn.compile(
        optimizer='adam',
        loss=keras_model.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras_model.metrics.SparseCategoricalAccuracy()]
    )
    self.cnn.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,validation_split=self.validation_split, callbacks=[early_stopping])

  # ! TODO: cnn evaluation def eval_cnn

  def train_gen(self, epochs = 20, d_learning_rate = 0.0003, g_learning_rate = 0.0003):
    #TODO: test ConditionalGAN
    cond_gan = ConditionalGAN(self.discriminator, self.generator, self.latent_dim)
    cond_gan.compile(
        d_optimizer=keras.optimizers.Adam(d_learning_rate),
        g_optimizer=keras.optimizers.Adam(g_learning_rate),
        loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
    )
    # produce GAN training dataset
    train_label = keras.utils.to_categorical(self.y_train, 10) # 1 hot encoding label
    dataset = tf.data.Dataset.from_tensor_slices((self.x_train, train_label))
    dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
    cond_gan.fit(dataset, epochs)


  # interface with the server
  def get_generator(self):
    return self.generator

  def produce_logits(self, dataset): #for KD
    #!!!TODO
    logits = []
    return logits

  def get_cnn(self): #only for FedAvg, disabled
    return self.cnn

# Server Implementation

In [41]:
class Server():
  def __init__(self, model_fn, client_list, generator = generator, comm_freq = 10):
    self.cnn = model_fn
    self.client_list = client_list # calling this param when "uploading" or "downloading"
    self.client_datasize = []
    self.generator = generator
    self.latent_dim = 128 # hyperparam, can tune
    self.public_dataset = None # to generate
    self.batch_size = 64 # hyperparam, can tune
    self.comm_freq = comm_freq # no. of client local training epochs before upload
    self.dataset = None

  def get_client_datasize(self):
    if len(self.client_datasize) != len(self.client_list):
      for i in range(len(self.client_datasize), len(self.client_list)):
        self.client_datasize.append(self.client_list[i].x_train.shape[0])
    return self.client_datasize

  def assign_weights_cnn(self, client):
    client.cnn.set_weights(self.cnn.get_weights())

  def assign_weights_gen(self, client):
    client.generator.set_weights(self.generator.get_weights())

  def broadcast(self):
    # TODO: improve: can use tff.federated_map and tff.federated_broadcast, can combine the two assign fns
    map(self.assign_weights_cnn, self.client_list)
    map(self.assign_weights_cnn, self.client_list)

  def local_training(self, epochs = 5):
    for client in self.client_list:
      client.local_train()

  # def agg_classifier(self):
  #   # !!TODO: can use tff.federated_mean
  #   for i in range(len(self.client_list)):
  #     # !!TODO: aggregate the classifier layer and freeze it

  def agg_cnn(self): #for FedAvg only
    self.cnn = set_model_weights_to_zero(self.cnn)
    p = self.get_client_datasize()
    total_size = sum(p)
    for i in range(len(self.client_list)):
      p_k = p[i]/total_size
      for global_weights, client_weights in zip(self.model.get_weights(), self.client_list[i].cnn.get_weights()):
        global_weights = global_weights + p_k* client_weights

  def agg_gen(self):
    # same can be implemented with tff
    self.generator = set_model_weights_to_zero(self.generator)
    p = self.get_client_datasize()
    total_size = sum(p)
    for i in range(len(self.client_list)):
      p_k = p[i]/total_size
      for global_weights, client_weights in zip(self.generator.get_weights(), self.client_list[i].generator.get_weights()):
        global_weights = global_weights + p_k* client_weights

  def produce_pseudo_dataset(self, total_num):
    # generate with gen, homogenous data: equal number of datapoints for each class
    #TODO: test pseudoDataset
    self.dataset = pseudoDataset(self.generator, total_num, self.latent_dim)

  def agg_logits(self, datapoint):
    # mimics clients sending their logits to the server given the same input
    p = self.get_client_datasize()
    total_size = sum(p)
    for i in range(len(self.client_list)):
      client = self.client_list[i]
      p_k = p[i]/total_size
      if i == 0:
        logits = p_k * client.produce_logits(datapoint)
      else:
        logits += p_k * client.produce_logits(datapoint)
    return logits


  def distill(): ##HARD and IMPORTANT!
    #!!!TODO: distillation based on self.dataset and agg_logits
    return


# Testing - FedAvg

In [None]:
Server = Server()