<a href="https://colab.research.google.com/github/Diishasing/Optical-coherence-tomography-Image-Classification-via-Federated-learning/blob/main/Federated_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os 
import sys
import numpy as np
import random
import cv2
import tensorflow as tf
from imutils import paths
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Activation, Flatten, Dense
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras import backend as K


In [None]:
!pip install opendatasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting opendatasets
  Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)
Installing collected packages: opendatasets
Successfully installed opendatasets-0.1.22


In [None]:
import opendatasets as od
import pandas
 
od.download(
    "https://www.kaggle.com/datasets/mohamedberrimi/oct-images-balanced-version")

Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username: diishasiing
Your Kaggle Key: ··········
Downloading oct-images-balanced-version.zip to ./oct-images-balanced-version


100%|██████████| 4.15G/4.15G [03:04<00:00, 24.2MB/s]





In [None]:
def load(path, verbose = 1):
    data = list()
    labels = list()

    for (i, imgpath) in enumerate(path):
        im_gray = cv2.imread(imgpath, cv2.IMREAD_GRAYSCALE)
        im_gray = cv2.resize(im_gray, (128, 128))
        img = np.array(im_gray).flatten()
        label = imgpath.split(os.path.sep)[-2]

        data.append(img/255.0)
        labels.append(label)

        if verbose > 0 and i > 0 and (i + 1) % verbose == 0:
            print('processed the image {}/{}'.format(i+1, len(path)))
    
    return data, labels        

In [None]:
img_path = '/content/oct-images-balanced-version/test/test'

image_paths = list(paths.list_images(img_path))

image_list, label_list = load(image_paths, verbose = 1000)

lb = LabelBinarizer()

label_list = lb.fit_transform(label_list)

X_train, X_test, Y_train, Y_test = train_test_split(image_list, label_list,
                                                    test_size = 0.1, random_state = 23)


In [None]:
def create_client(image_list, label_list, num_clients = 10, initial = 'clients'):
    client_names = ['{}_{}'.format(initial, i+1) for i in range(num_clients)]

    data = list(zip(image_list, label_list))
    random.shuffle(data)

    size = len(data)//num_clients
    shards = [data[i : i+size] for i in range(0, size*num_clients, size)]

    assert(len(shards) == len(client_names))

    return {client_names[i] : shards[i] for i in range(len(client_names))}


In [None]:
clients = create_client(X_train, Y_train, num_clients = 10, initial = 'client')

In [None]:
def batch_fn(data_shard, bs = 32):
    data, label = zip(*data_shard)
    dataset = tf.data.Dataset.from_tensor_slices((list(data), list(label)))
    return dataset.shuffle(len(label)).batch(bs)

In [None]:
clients_batched = dict()
for (client_name, data) in clients.items():
    clients_batched[client_name] = batch_fn(data)

test_batched = tf.data.Dataset.from_tensor_slices((X_test, Y_test)).batch(len(Y_test))    

In [None]:
class Model_():
    @staticmethod
    def build(shape, classes):
        model = Sequential()
        model.add(Dense(200, input_shape = (shape, )))
        model.add(Activation('relu'))    

        model.add(Dense(200)) 
        model.add(Activation('relu'))    

        model.add(Dense(classes))   
        model.add(Activation('softmax'))  

        return model  

In [None]:
class Model_2():
    @staticmethod
    def build(shape, classes):
        model = keras.models.load_model('mobile-vit-xxs')
        return model 

In [None]:
!git lfs install
!git clone https://huggingface.co/keras-io/mobile-vit-xxs

Error: Failed to call git rev-parse --git-dir: exit status 128 
Git LFS initialized.
fatal: destination path 'mobile-vit-xxs' already exists and is not an empty directory.


In [None]:
from tensorflow import keras
model = keras.models.load_model('mobile-vit-xxs')

KeyboardInterrupt: ignored

In [None]:
lr = 0.01
comms_round = 100
loss = 'categorical_crossentropy'
metrics = ['accuracy']
optimizer_1 = tf.keras.optimizers.legacy.SGD(learning_rate = lr,
                                decay = lr / comms_round,
                                momentum = 0.9)
optimizer_2 = Adam(lr = lr,
                   amsgrad = False)



In [None]:
def weight_scaling_factor(clients_trn_data, client_name):
    client_names = list(clients_trn_data.keys())

    bs = list(clients_trn_data[client_name])[0][0].shape[0]
    global_count = sum([tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() for client_name in client_names])*bs

    local_count = tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy()*bs
    return local_count / global_count

In [None]:
def scale_model_weights(weight, scalar):
    weight_final = []
    steps = len(weight)
    for i in range(steps):
        weight_final.append(scalar * weight[i])
    return weight_final    

In [None]:
def sum_scaled_weights(scaled_weight_list):
    avg_grad = list()
    #get the average grad of all the clients gradients
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = tf.math.reduce_sum(grad_list_tuple, axis = 0)
        avg_grad.append(layer_mean)

    return avg_grad        

In [None]:
def test_model(X_test, Y_test, model, comm_round):
    cce = tf.keras.losses.CategoricalCrossentropy(from_logits = True)
    logits = model.predict(X_test)
    loss = cce(Y_test, logits)
    acc = accuracy_score(tf.argmax(logits, axis = 1), tf.argmax(Y_test, axis = 1))
    print('global_accuracy: {:.3%} | comm_round: {}'.format(acc, comm_round))
    return acc, loss

In [None]:
smlp_global = Model_2()
global_model = smlp_global.build(16384, 4)

for comm_round in range(comms_round):
    global_weights = global_model.get_weights()

    scaled_local_weight_list = list()

    client_names = list(clients_batched.keys())
    random.shuffle(client_names)

    for client in client_names:
        smlp_local = Model_2()
        local_model = smlp_local.build(16384, 4)
        local_model.compile(loss = loss,
                            optimizer = optimizer_1,
                            metrics = metrics)
        
        local_model.set_weights(global_weights)

        local_model.fit(clients_batched[client], epochs = 1, verbose = 0)

        scaling_factor = weight_scaling_factor(clients_batched, client)
        scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
        scaled_local_weight_list.append(scaled_weights)

        K.clear_session()

    average_weights = sum_scaled_weights(scaled_local_weight_list)

    #update the global model
    global_model.set_weights(average_weights)

    for (X_test, Y_test) in test_batched:
        global_acc, global_loss = test_model(X_test, Y_test, global_model, comm_round)
        



ValueError: ignored