<a href="https://colab.research.google.com/github/avocadopelvis/BTP/blob/main/lung-segmentation/fl-ls.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
!pip install keras_unet_collection

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting keras_unet_collection
  Downloading keras_unet_collection-0.1.13-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 1.8 MB/s 
[?25hInstalling collected packages: keras-unet-collection
Successfully installed keras-unet-collection-0.1.13


In [65]:
# load libraries
import os
import random
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf
import cv2
from cv2 import imread, createCLAHE
from tqdm import tqdm
from glob import glob

from sklearn.model_selection import train_test_split

In [5]:
# dataset path
image_path = os.path.join("/content/drive/MyDrive/BTP/LUNG SEGMENTATION/DATASET/CXR_png")
mask_path = os.path.join("/content/drive/MyDrive/BTP/LUNG SEGMENTATION/DATASET/masks")

#### Since we have 800 images but only 704 masks, we will make a 1-1 correspondence from masks to images

In [6]:
image = os.listdir(image_path)
mask = os.listdir(mask_path)
# get the file name of each mask and store in a list
masks = [fName.split(".png")[0] for fName in mask]
# get the corresponding image file name for each mask and store in a list
images = [fName.split("_mask")[0] for fName in masks]

In [7]:
check = [i for i in masks if "mask" in i]
print("Total mask that has modified name:", len(check))

Total mask that has modified name: 566


In [8]:
testing_files = set(os.listdir(image_path)) & set(os.listdir(mask_path))
training_files = check

# function to get data
def getData(X_shape, flag = "test"):
    im_array = []
    mask_array = []
    
    if flag == "test":
        for i in tqdm(testing_files): 
            im = cv2.resize(cv2.imread(os.path.join(image_path,i)),(X_shape,X_shape))[:,:,0]
            mask = cv2.resize(cv2.imread(os.path.join(mask_path,i)),(X_shape,X_shape))[:,:,0]
            
            im_array.append(im)
            mask_array.append(mask)
        
        return im_array,mask_array
    
    if flag == "train":
        for i in tqdm(training_files): 
            im = cv2.resize(cv2.imread(os.path.join(image_path,i.split("_mask")[0]+".png")),(X_shape,X_shape))[:,:,0]
            mask = cv2.resize(cv2.imread(os.path.join(mask_path,i+".png")),(X_shape,X_shape))[:,:,0]

            im_array.append(im)
            mask_array.append(mask)

        return im_array,mask_array
    
    
# function to perform sanity check
def plotMask(X,y):
    sample = []
    
    for i in range(6):
        left = X[i]
        right = y[i]
        combined = np.hstack((left,right))
        sample.append(combined)
        
        
    for i in range(0,6,3):

        plt.figure(figsize=(25,10))
        
        plt.subplot(2,3,1+i)
        plt.imshow(sample[i])
        
        plt.subplot(2,3,2+i)
        plt.imshow(sample[i+1])
        
        
        plt.subplot(2,3,3+i)
        plt.imshow(sample[i+2])
        
        plt.show()

In [9]:
# load data
dim = 256*2
X_train, y_train = getData(dim, flag = 'train')
X_test, y_test = getData(dim)

100%|██████████| 566/566 [13:19<00:00,  1.41s/it]
100%|██████████| 138/138 [03:31<00:00,  1.53s/it]


In [10]:
# combine datasets and use it as a unified dataset
X_train = np.array(X_train).reshape(len(X_train), dim, dim, 1)
y_train = np.array(y_train).reshape(len(y_train), dim, dim, 1)
X_test = np.array(X_test).reshape(len(X_test), dim, dim, 1)
y_test = np.array(y_test).reshape(len(y_test), dim, dim, 1)

assert X_train.shape == y_train.shape
assert X_test.shape == y_test.shape

images = np.concatenate((X_train, X_test), axis = 0)
masks = np.concatenate((y_train, y_test), axis = 0)

In [42]:
# split the data into train, validation and test sets
train_vol, validation_vol, train_seg, validation_seg = train_test_split((images-127.0)/127.0, 
                                                                        (masks>127).astype(np.float32), 
                                                                        test_size = 0.2,random_state = 42)

train_vol, test_vol, train_seg, test_seg = train_test_split(train_vol,train_seg, 
                                                            test_size = 0.2, 
                                                            random_state = 42)

In [12]:
# train = 450
# client = 450/3 = 150

In [75]:
from keras.models import *
from keras.layers import *
# from keras.optimizers import *
from tensorflow.keras.optimizers import *
from keras import backend as keras
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint, LearningRateScheduler

def dice_coef(y_true, y_pred):
    y_true_f = keras.flatten(y_true)
    y_pred_f = keras.flatten(y_pred)
    intersection = keras.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (keras.sum(y_true_f) + keras.sum(y_pred_f) + 1)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

def unet(input_size=(256,256,1)):
    inputs = Input(input_size)
    
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    return Model(inputs=[inputs], outputs=[conv10])

In [66]:
# function to take in data and return a dictionary with client names as keys and values as data shards
def create_client(x, y, num_clients, initial = 'client'):
  # create a list of client names
  client_names = ['{}_{}'.format(initial, i+1) for i in range(num_clients)]

  # size of data shard
  size = len(x)//num_clients
  # create data shard for each client
  x_shards = [x[i:i+size] for i in range(0, len(x), size)]
  y_shards = [y[i:i+size] for i in range(0, len(x), size)]

  # number of clients must equal number of shards
  assert(len(x_shards) == len(client_names))
  assert(len(y_shards) == len(client_names))

  return {client_names[i] : [x_shards[i], y_shards[i]] for i in range(len(client_names))} 


def weight_scaling_factor(data, train):
    return len(data)/len(train)


def scale_model_weights(weight, scalar):
    '''function for scaling a models weights'''
    weight_final = []
    steps = len(weight)
    for i in range(steps):
        weight_final.append(scalar * weight[i])
    return weight_final



def sum_scaled_weights(scaled_weight_list):
    '''Return the sum of the listed scaled weights. The is equivalent to scaled avg of the weights'''
    avg_grad = list()
    #get the average grad accross all client gradients
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0)
        avg_grad.append(layer_mean)
        
    return avg_grad


# function to evaluate the model on test data and print the current round and metrics
def evaluate_model(data, model, round): 
  test_generator = DataGenerator(data)
  results = model.evaluate(test_generator, batch_size = batch_size, verbose = 1)
  loss, accuracy = results[0], results[1]*100
  print(f'round: {round} | loss: {loss} | accuracy: {accuracy:.2f}%')

In [46]:
# create clients
clients = create_client(train_vol, train_seg, 3)

In [67]:
# one = clients['client_1']
# x = one[0]
# y = one[1]
# len(y)

In [68]:
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau

weight_path="{}_weights.best.hdf5".format('cxr_reg')

checkpoint = ModelCheckpoint(weight_path, monitor='val_loss', verbose=1, 
                             save_best_only=True, mode='min', save_weights_only = True)

reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss', factor=0.5, 
                                   patience=3, 
                                   verbose=1, mode='min', min_delta = 0.0001, cooldown=2, min_lr=1e-6)

early = EarlyStopping(monitor="val_loss", 
                      mode="min", 
                      patience=15) # probably needs to be more patient, but kaggle time is limited

callbacks = [checkpoint, early, reduceLROnPlat]

In [70]:
# Hyperparameters
batch_size = 16
epochs = 1
loss = [dice_coef_loss]
metrics = [dice_coef, 'binary_accuracy']
optimizer = Adam(learning_rate = 2e-4)
rounds = 5

In [76]:
from keras_unet_collection import models

# model = models.unet_2d((512, 512, 1), [32, 64, 128, 256, 512], 2)
model = unet(input_size = (512, 512, 1))

In [1]:
# initialize global model
global_model = model
global_model.compile(
        loss = loss,
        optimizer = optimizer,
        metrics = metrics
        )

print("Begin Training")
# commence global training loop
for round in range(1, rounds+1):
  print(f'\nRound: {round}')

  # get global model's weights
  global_weights = global_model.get_weights()

  # initial list to collect local model weights after scaling
  scaled_local_weight_list = list()

  # get client names
  client_names = list(clients.keys())
  random.shuffle(client_names)

  count = 1
  # loop through each client and create new local model
  for client in client_names:
    print(f'Client {count}')
    local_model = model
    local_model.compile(
        loss = loss,
        optimizer = optimizer,
        metrics = metrics
        )
    
    #set local model weight to the weight of the global model
    local_model.set_weights(global_weights)

    # get client data and pass it through a data generator
    # data = DataGenerator(clients[client])

    # get client data
    x = clients[client][0]
    y = clients[client][1]

    # fit local model with client's data
    local_model.fit(x, y, epochs = epochs, verbose = 1, callbacks = callbacks) #steps_per_epoch = len(x), verbose = 1, callbacks = callbacks, validation_data = valid_generator)

    # scale the model weights and add to list
    scaling_factor = weight_scaling_factor(x, train_vol)
    scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
    
    # not adding scaling
    scaled_local_weight_list.append(local_model.get_weights())

    # clear session to free memory after each communication round
    # K.clear_session()

    count += 1

  #to get the average over all the local model, we simply take the sum of the scaled weights
  average_weights = sum_scaled_weights(scaled_local_weight_list)
      
  #update global model 
  global_model.set_weights(average_weights)

  # evaluate_model(test_ids, global_model, round)

print('\nTraining Done!')

In [18]:
# history = model.fit(x = train_vol,
#                     y = train_seg,
#                     batch_size = batch_size,
#                     epochs = epochs,
#                     validation_data =(test_vol, test_seg) ,
#                     callbacks = callbacks)