In [None]:
!nvidia-smi -L

: 

In [None]:

# Do this if using Colab....
!pip install matplotlib==3.5.2
!pip install numpy==1.22.0 --no-dependencies

In [44]:
import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
import copy
# import tensorflow.keras.backend as K
import random

from IPython import display


In [45]:

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Defined by the number of parameters
THRESHOLD = [1664,204928,823553]
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).batch(BATCH_SIZE * 10)


In [69]:
def split_dataset_noniid(train_dataset, distribution, num_discriminators):
  micro_batch = []
  discriminator_dataset = []
  current = 0 
  for k in range (num_discriminators):
    sample_size = int(distribution[k] * 2560)
    print(sample_size)
    batch = train_dataset[current:current+sample_size,:,:,:]
    discriminator_dataset.append(batch)
    current = current+ sample_size
  return discriminator_dataset


def split_dataset_iid(train_dataset, num_discriminators):
  discriminator_dataset = []
  current = 0
  for i in range (num_discriminators):
    sample_size = int(2560/num_discriminators)
    batch = train_dataset[current:current+sample_size,:,:,:]
    discriminator_dataset.append(batch)
    current = current+ sample_size
  return discriminator_dataset


In [70]:

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

generator = make_generator_model()
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model
  
# Num_of_params = 1664
def make_discriminator_part_A():
  model = tf.keras.Sequential()
  model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
  model.add(layers.LeakyReLU())
  model.add(layers.Dropout(0.3))
  return model

# Num_of_params = 204,928
def make_discriminator_part_B():
  model = tf.keras.Sequential()
  model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
  model.add(layers.LeakyReLU())
  model.add(layers.Dropout(0.3))

  
  return model
# Num_of_params = 823,553
def make_discriminator_part_C():
  model = tf.keras.Sequential()
  model.add(layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'))
  model.add(layers.LeakyReLU())
  model.add(layers.Dropout(0.3))
  model.add(layers.Flatten())
  model.add(layers.Dense(1))   
  return model



In [71]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss


In [72]:

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()


def aggregation(model_list):
    base = model_list[0]
    base_models = []
    for client in base:
      for client_part in client:
        base_models.append(copy.deepcopy(client_part.get_weights()))
    base_A = [base_models[0]]
    base_B = [base_models[1]]
    base_C = [base_models[2]]
    current_pointer = 1
    for model in model_list[1:]:
        for client in model:
          for client_part in client:
            if current_pointer == 1:
              base_A.append(copy.deepcopy(client_part.get_weights()))
              current_pointer += 1
            elif current_pointer == 2:
              base_B.append(copy.deepcopy(client_part.get_weights()))
              current_pointer += 1
            elif current_pointer == 3:
              base_C.append(copy.deepcopy(client_part.get_weights()))
              current_pointer = 1
    
    # FedAVG
    averaged_A = np.mean(base_A, axis = 0)
    averaged_B = np.mean(base_B, axis = 0)
    averaged_C = np.mean(base_C, axis = 0)
    current_pointer = 1
    for model in model_list:
      for client in model:
        for client_part in client:
          if current_pointer == 1:
            client_part.set_weights(averaged_A)
            current_pointer += 1
          elif current_pointer == 2:
            client_part.set_weights(averaged_B)
            current_pointer += 1
          elif current_pointer == 3:
            client_part.set_weights(averaged_C)
            current_pointer = 1

    return model_list

In [73]:

def split_method_preprocessing_time(client_capacities, client_time_factor):
    clients = zip(client_time_factor, client_capacities)
    clients = list(clients)

    # Implement the selection algorithm here
    # Default example: Sort by client time factor: i.e. faster is better
    clients.sort()
    sorted_capacities = []
    sorted_time_factor = []
    for client_sorted in clients:
        sorted_time_factor.append(client_sorted[0])
        sorted_capacities.append(client_sorted[1])
        
    return sorted_capacities[:3], sorted_time_factor[:3]




In [74]:
def randomly_select_clients_preprocess(client_capacities, client_time_factor):
    clients = zip(client_time_factor, client_capacities)
    clients = list(clients)
    random.shuffle(clients)
    sorted_capacities = []
    sorted_time_factor = []
    for client_sorted in clients:
        sorted_time_factor.append(client_sorted[0])
        sorted_capacities.append(client_sorted[1])
    
    return sorted_capacities[:3], sorted_time_factor[:3]


In [75]:
def split_method_preprocessing_capacity(client_capacities, client_time_factor):
    clients = zip(client_capacities, client_time_factor)
    clients = list(clients)

    # Implement the selection algorithm here
    # Default example: Sort by client time factor: i.e. faster is better
    clients.sort()
    sorted_capacities = []
    sorted_time_factor = []
    for client_sorted in clients:
        sorted_time_factor.append(client_sorted[1])
        sorted_capacities.append(client_sorted[0])
        
    return sorted_capacities[:3], sorted_time_factor[:3]

In [76]:
def split_method_for_D_by_capacity(device_capacities):
    discriminators = []
    discriminator = []
    discriminator_optimizer = []
    discriminators_optimizer = []
    current_pointer = 0
    for device in device_capacities:
        part = []
        part_optimizer = []
        capacity = device
        while capacity >= THRESHOLD[current_pointer]:
            if orders[current_pointer] == "A":
                # Do something
                discri_part = make_discriminator_part_A()
                part.append(discri_part)
                optimizer_part = tf.keras.optimizers.Adam(1e-4)
                part_optimizer.append(optimizer_part)
                current_pointer += 1
            elif orders[current_pointer] == "B":
                discri_part = make_discriminator_part_B()
                part.append(discri_part)
                optimizer_part = tf.keras.optimizers.Adam(1e-4)
                # Do something
                part_optimizer.append(optimizer_part)
                current_pointer += 1
            elif orders[current_pointer] == "C":
                discri_part = make_discriminator_part_C()
                part.append(discri_part)
                current_pointer = 0
                discriminator.append(part)
                discriminators.append(discriminator)
                optimizer_part = tf.keras.optimizers.Adam(1e-4)
                part_optimizer.append(optimizer_part)
                discriminator_optimizer.append(part_optimizer)
                discriminators_optimizer.append(discriminator_optimizer)
                part = []
                discriminator = []
                part_optimizer = []
                discriminator_optimizer = []
                return discriminators, discriminators_optimizer
            capacity -= THRESHOLD[current_pointer]
        if len(part) != 0:
            discriminator.append(part)
            discriminator_optimizer.append(part_optimizer)
        else:
            continue
    # Only reached when NO device has enough capacity
    return discriminators, discriminators_optimizer



In [77]:
def split_method_for_D_by_capacity_baseline(client_capacities):
    discriminators = []
    discriminator = []
    discriminator_optimizer = []
    discriminators_optimizer = []
    current_pointer = 0

    while client_capacities != []:
      for client in client_capacities:
          part = []
          part_optimizer = []
          capacity = client
          if capacity >= THRESHOLD[current_pointer]:
              if orders[current_pointer] == "A":
                  # Do something
                  discri_part = make_discriminator_part_A()
                  part.append(discri_part)
                  optimizer_part = tf.keras.optimizers.Adam(1e-4)
                  part_optimizer.append(optimizer_part)
                  discriminator.append(part)
                  discriminator_optimizer.append(part_optimizer)
                  current_pointer += 1
              elif orders[current_pointer] == "B":
                  discri_part = make_discriminator_part_B()
                  part.append(discri_part)
                  optimizer_part = tf.keras.optimizers.Adam(1e-4)
                  # Do something
                  part_optimizer.append(optimizer_part)
                  discriminator.append(part)
                  discriminator_optimizer.append(part_optimizer)
                  current_pointer += 1
              elif orders[current_pointer] == "C":
                  discri_part = make_discriminator_part_C()
                  part.append(discri_part)
                  discriminator.append(part)
                  optimizer_part = tf.keras.optimizers.Adam(1e-4)
                  part_optimizer.append(optimizer_part)
                  discriminator_optimizer.append(part_optimizer)

                  discriminators.append(discriminator)
                  discriminators_optimizer.append(discriminator_optimizer)
                  current_pointer = 0
                  discriminator = []
                  discriminator_optimizer = []
                  return discriminators, discriminators_optimizer
          else:
              client_capacities.remove(capacity)

            
            
    return discriminators, discriminators_optimizer    



In [78]:
global CLIENT_CAPACITIES
global CLIENT_TIME_FACTOR
global discriminator_list
global discriminator_optimizer_list
global devices_time_sorted

discriminator_list = []
discriminator_optimizer_list = []
# CLIENT_CAPACITIES = [1000000, 80000000, 90000000, 10000000, 5000000, 100000, 3000, 967000, 900000, 6000000, 13500000, 4000000, 60000]
# CLIENT_TIME_FACTOR = [1, 1.5, 2, 0.747, 1, 1.458, 7, 1.343, 4, 2.136, 2.747, 1.343, 3]

# SELECTION_METHOD = "RANDOM/TIME"
SELECTION_METHOD = "RANDOM"

# List of clients, each client contains the information of their devices
# NB: One discriminator refers to ONE client
CLIENTS_CAPACITIES = [ [100000, 20000000, 30000000, 90000000] , [100000, 20000000, 30000000, 90000000] , [600000, 10000000, 30000000, 90000000], [700000, 20000000, 30000000, 90000000], [900000, 20000000, 30000000, 90000000], [10, 10, 10, 10] ]
CLIENT_TIME_FACTOR = [ [3, 2, 2.5, 4], [7, 4, 3.5, 2.3] , [1.7, 1.5, 6, 0.7], [4, 1.5, 2.34, 1], [1.67, 2, 2.5, 4], [999,999,999,999]]
# CLIENT_CAPACITIES_ORI = [60000, 1000000, 5000000,7000000 , 90000000, 100000000, 9000000000]
# CLIENT_TIME_FACTOR_ORI = [1, 1, 3,  1, 1.5 , 1, 1]


# For model evaluation
NUM_OF_CLIENTS = 5
CLIENTS_CAPACITIES = [[823553, 823553, 823553] for i in range(NUM_OF_CLIENTS)]
CLIENT_TIME_FACTOR = [[1, 1, 1] for i in range(NUM_OF_CLIENTS)]


orders = ["A", "B", "C"]

false_clients = 0
devices_time_sorted = []
for client_index in range(len(CLIENTS_CAPACITIES)):
  client_capacity = CLIENTS_CAPACITIES[client_index]
  client_time = CLIENT_TIME_FACTOR[client_index]
  if SELECTION_METHOD == "RANDOM":
    client_cap_sorted, client_time_sorted = randomly_select_clients_preprocess(client_capacity, client_time)
  else:
    client_cap_sorted, client_time_sorted = split_method_preprocessing_time(client_capacity, client_time)
  client_discrimiantor, client_optimizer = split_method_for_D_by_capacity(client_cap_sorted)
  # client_discrimiantor, client_optimizer = split_method_for_D_by_capacity_baseline(client_cap_sorted)

  if len(client_discrimiantor) != 0:
    devices_time_sorted += client_time_sorted
    discriminator_list += client_discrimiantor
    discriminator_optimizer_list += client_optimizer
  else:
    false_clients += 1

print("Number of false clients: " + str(false_clients))


sample = random.sample(range(100), len(discriminator_list))
sample_sum = sum(sample)
sample_distribution = [s/sample_sum for s in sample]





Number of false clients: 0


In [79]:
# Some naming is confusing here, need to be reformed
def train_step_clients(discriminator_list, discriminator_optimizer_list, images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    # noise = tf.random.normal([25, noise_dim])
    gen_loss = None

    client_times = []
    # 50 ms for each round of LAN communication
    LAN_communication_time = 0.05

    # For iid setting
    discriminator_dataset = split_dataset_iid(images, len(discriminator_list))

    # For non iid setting
    discriminator_dataset = split_dataset_noniid(images, sample_distribution, len(discriminator_list))





    with tf.GradientTape() as gen_tape:
        generated_images = generator(noise, training=True)

        for index in range (len(discriminator_list)):
            discriminator = discriminator_list[index]
            gradients = []
            current_d_time = []

            with tf.GradientTape(persistent=True) as disc_tape:
                # current_real_input = random.choice(micro_batch)
                current_real_input = discriminator_dataset[index]
                # current_real_input = images
                current_fake_input = generated_images
                for device_index in range(len(discriminator)):
                    current_client_time_start = time.time()
                    current_device = discriminator[device_index]
                    for device_part in current_device:
                        current_real_input = device_part(current_real_input, training=True)
                        current_fake_input = device_part(current_fake_input, training=True)

                    current_client_time_end = time.time()
                    current_client_time_gap = current_client_time_end - current_client_time_start
                    current_d_time.append(current_client_time_gap)


                if gen_loss == None:
                    gen_loss = generator_loss(current_fake_input)
                else:
                    gen_loss += generator_loss(current_fake_input)
                disc_loss = discriminator_loss(current_real_input, current_fake_input)


            client_time_index = 1
            for device_index_r in range(len(discriminator)-1, -1, -1):
                client_parts = discriminator[device_index_r]
                client_gradients = []
                current_client_time_start = time.time()
                for device_part_index_r in range(len(client_parts)-1, -1, -1):
                    device_part = client_parts[device_part_index_r]
                    gradient_part = disc_tape.gradient(disc_loss, device_part.trainable_variables)
                    client_gradients.append(gradient_part)
                current_client_time_end = time.time()
                current_client_time_gap = current_client_time_end - current_client_time_start
                current_d_time[-client_time_index] += current_client_time_gap
                client_time_index += 1
                gradients.append(client_gradients)


            client_time_index = 1
            for client_index_r in range(len(discriminator)-1, -1, -1):
                client_parts = discriminator[client_index_r]
                current_client_time_start = time.time()
                for client_part_index_r in range(len(client_parts)-1, -1, -1):
                    client_part = client_parts[client_part_index_r]
                    part_optimizer = discriminator_optimizer_list[index][client_index_r][client_part_index_r]
                    gradient = gradients[-client_index_r-1][-client_part_index_r-1]
                    part_optimizer.apply_gradients(zip(gradient, client_part.trainable_variables))
                current_client_time_end = time.time()
                current_client_time_gap = current_client_time_end - current_client_time_start
                current_d_time[-client_time_index] += current_client_time_gap
                client_time_index += 1



            # concatenate the time cost of different clients
            client_times += current_d_time
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

    # Final processing on client times
    for time_index in range(len(client_times)):
        client_times[time_index] *= devices_time_sorted[time_index]
    client_time_index = 0
    discrimiantor_time_costs = [0 for i in range (len(discriminator_list))]


    for discriminator_clients_index in range(len(discriminator_list)):
        client_numbers_in_single_D = len(discriminator_list[discriminator_clients_index])
        for client_id in range(client_numbers_in_single_D):
            client_time = client_times[client_time_index]
            client_time_index += 1
            discrimiantor_time_costs[discriminator_clients_index] += client_time + LAN_communication_time


    return client_times, discrimiantor_time_costs, gen_loss


In [80]:

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)


EPOCHS = 500
noise_dim = 100
num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [81]:

def train(dataset, epochs):
  

  # CLIENT_CAPACITIES, CLIENT_TIME_FACTOR = split_method_preprocessing_time(CLIENT_CAPACITIES_ORI, CLIENT_TIME_FACTOR_ORI)

  # CLIENT_CAPACITIES, CLIENT_TIME_FACTOR = randomly_select_clients_preprocess(CLIENT_CAPACITIES_ORI, CLIENT_TIME_FACTOR_ORI)
  # discriminator_list,discriminator_optimizer_list = split_method_for_D_by_capacity(CLIENT_CAPACITIES)
  
  # discriminator_list,discriminator_optimizer_list = split_method_for_D_by_capacity_baseline(CLIENT_CAPACITIES)
  global discriminator_list
  global discriminator_optimizer_list
  print(len(discriminator_list))
  gloss_list= []
  time_list = []
  
  for epoch in range(epochs):
    gloss = []
    client_time_list = None
    d_time_list = None
    # print(epoch)
    start = time.time()
    for image_batch in dataset:
      
      # gloss_val = train_step(image_batch)
      client_times, d_times, gen_loss = train_step_clients(discriminator_list, discriminator_optimizer_list, image_batch)
      if client_time_list == None:
        client_time_list = client_times
        d_time_list = d_times
      else:
        client_time_list = [sum(x) for x in zip(client_time_list, client_times)]
        d_time_list = [sum(x) for x in zip(d_time_list, d_times)]
      gloss.append(gen_loss)
    gloss_val = sum(gloss)/ len(gloss)
    # If yo want to print the g_loss
    # print(gloss_val.numpy())
    gloss_list.append(gloss_val)
    discriminator_list = aggregation(discriminator_list)


    if (epoch+1) % 100 == 0:
      generate_and_save_images(generator,
                           epoch,
                           seed)
      
    


    # Time for each clients in this EPOCH    
    # print(client_time_list)
    # print(d_time_list)
    # print('-' * 30)
    limit = max(d_time_list)
    # print('The slowest discrminator takes {} sec to finish the current epoch'.format(limit))
    # print("-" * 50)
    time_list.append(limit)


    # Produce images for the GIF as you go
    # display.clear_output(wait=True)
    # generate_and_save_images(generator,
    #                          epoch + 1,
    #                          seed)

    # Save the model every 15 epochs
    # if (epoch + 1) % 15 == 0:
    #   checkpoint.save(file_prefix = checkpoint_prefix)

    # print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  # display.clear_output(wait=True)
  

  generate_and_save_images(generator,
                           epochs,
                           seed)
  return client_time_list,time_list, gloss_list

In [82]:
client_time_list, limit_time_list, gloss_list = train(train_dataset, EPOCHS)


5
620
475
329
174
960
620
475
329
174
960


KeyboardInterrupt: 

In [None]:
averaged_time = sum(limit_time_list)/len(limit_time_list)

In [None]:
print(averaged_time)

In [None]:
print(len(gloss_list))
for loss in gloss_list:
  print(loss.numpy())