<a href="https://colab.research.google.com/github/avocadopelvis/BTP/blob/main/brain-tumor-segmentation/fl-bts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [35]:
# !pip install nilearn
# !pip install keras_unet_collection

In [4]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import seaborn as sns
import matplotlib.pyplot as plt

import random
import os
import cv2
import glob 

# PIL adds image processing capabilities to your Python interpreter.
import PIL
from PIL import Image, ImageOps

# Shutil module offers high-level operation on a file like a copy, create, and remote operation on the file.
import shutil

# skimage is a collection of algorithms for image processing and computer vision.
from skimage import data
from skimage.util import montage
import skimage.transform as skTrans
from skimage.transform import rotate
from skimage.transform import resize


# NEURAL IMAGING
import nilearn as nl
import nibabel as nib # access a multitude of neuroimaging data formats
# import nilearn.plotting as nlplt
# import gif_your_nifti.core as gif2nif


# ML Libraries
import keras
import keras.backend as K
from keras.callbacks import CSVLogger
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import plot_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
from tensorflow.keras.layers.experimental import preprocessing

# make numpy printouts easier to read
np.set_printoptions(precision = 3, suppress = True)

import warnings
warnings.filterwarnings('ignore')

In [5]:
# dataset path
train_data = "/content/drive/MyDrive/BTP/BRAIN TUMOR SEGMENTATION/DATASET/BraTS 2020/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"
valid_data = "/content/drive/MyDrive/BTP/BRAIN TUMOR SEGMENTATION/DATASET/BraTS 2020/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/"

In [6]:
# list of directories
train_val_directories = [f.path for f in os.scandir(train_data) if f.is_dir()]

# remove BraTS20_Training_355 since it has ill formatted name for seg.nii file
train_val_directories.remove(train_data + 'BraTS20_Training_355')

# function to convert list of paths into IDs
def pathListIntoIDs(dirList):
  x = []
  for i in range(0, len(dirList)):
    x.append(dirList[i][dirList[i].rfind('/')+1:])
  return x

ids = pathListIntoIDs(train_val_directories)

# split ids into train+test and validation
train_test_ids, val_ids = train_test_split(ids, test_size = 0.2, random_state = 42)
# split train+test into train and test                                           
train_ids, test_ids = train_test_split(train_test_ids, test_size = 0.15, random_state = 42)

# train_ids size is 249
# break train_ids into 3 parts
# size = 249/3 = 82

In [7]:
# define segmentation areas
SEGMENT_CLASSES = {
    0 : 'NOT TUMOR',
    1 : 'NECROTIC/CORE', # or NON-ENHANCING TUMOR CORE
    2 : 'EDEMA',
    3 : 'ENHANCING' # original 4 -> converted into 3 later
}

# ????????
# there are 155 slices per volume
# to start at 5 and use 145 slices means we will skip the first 5 and last 5 
VOLUME_SLICES = 100 
VOLUME_START_AT = 22 # first slice of volume that we will include
IMG_SIZE = 128

# override keras sequence DataGenerator class
class DataGenerator(keras.utils.Sequence):
    # generates data for Keras
    def __init__(self, list_IDs, dim=(IMG_SIZE,IMG_SIZE), batch_size = 1, n_channels = 2, shuffle=True):
        # Initialization
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        # denotes the number of batches per epoch
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        # generate one batch of data

        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        Batch_ids = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(Batch_ids)

        return X, y

    def on_epoch_end(self):
        # updates indexes after each epoch
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, Batch_ids):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # initialization
        X = np.zeros((self.batch_size*VOLUME_SLICES, *self.dim, self.n_channels))
        y = np.zeros((self.batch_size*VOLUME_SLICES, 240, 240))
        Y = np.zeros((self.batch_size*VOLUME_SLICES, *self.dim, 4))

        
        # Generate data
        for c, i in enumerate(Batch_ids):
            case_path = os.path.join(train_data, i)

            data_path = os.path.join(case_path, f'{i}_flair.nii');
            flair = nib.load(data_path).get_fdata()    

            data_path = os.path.join(case_path, f'{i}_t1ce.nii');
            ce = nib.load(data_path).get_fdata()
            
            data_path = os.path.join(case_path, f'{i}_seg.nii');
            seg = nib.load(data_path).get_fdata()
        
            for j in range(VOLUME_SLICES):
                 X[j +VOLUME_SLICES*c,:,:,0] = cv2.resize(flair[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE));
                 X[j +VOLUME_SLICES*c,:,:,1] = cv2.resize(ce[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE));

                 y[j +VOLUME_SLICES*c] = seg[:,:,j+VOLUME_START_AT];
                    
        # Generate masks
        y[y==4] = 3;
        mask = tf.one_hot(y, 4);
        Y = tf.image.resize(mask, (IMG_SIZE, IMG_SIZE));
        return X/np.max(X), Y

In [22]:
# function to take in data and return a dictionary with client names as keys and values as data shards
def create_client(data, num_clients, initial = 'client'):
  # create a list of client names
  client_names = ['{}_{}'.format(initial, i+1) for i in range(num_clients)]

  # size of data shard
  size = len(data)//num_clients
  # create data shard for each client
  shards = [data[i:i+size] for i in range(0, len(data), size)]

  # number of clients must equal number of shards
  assert(len(shards) == len(client_names))

  return {client_names[i] : shards[i] for i in range(len(client_names))} 

def weight_scaling_factor(data):
    return len(data)/len(train_ids)


def scale_model_weights(weight, scalar):
    '''function for scaling a models weights'''
    weight_final = []
    steps = len(weight)
    for i in range(steps):
        weight_final.append(scalar * weight[i])
    return weight_final



def sum_scaled_weights(scaled_weight_list):
    '''Return the sum of the listed scaled weights. The is equivalent to scaled avg of the weights'''
    avg_grad = list()
    #get the average grad accross all client gradients
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0)
        avg_grad.append(layer_mean)
        
    return avg_grad


# function to evaluate the model on test data and print the current round and metrics
def evaluate_model(data, model, round): 
  test_generator = DataGenerator(data)
  results = model.evaluate(test_generator, batch_size = batch_size, verbose = 1)
  loss, accuracy = results[0], results[1]*100
  print(f'round: {round} | loss: {loss} | accuracy: {accuracy:.2f}%')

In [11]:
# create clients
clients = create_client(train_ids, 3)
valid_generator = DataGenerator(val_ids)

In [25]:
# HYPERPARAMETERS
from keras_unet_collection import losses

loss = "categorical_crossentropy",
learning_rate = 0.001
optimizer = keras.optimizers.Adam(learning_rate = learning_rate)
metrics = ['accuracy', losses.dice]

# add callback for training process
callbacks = [
#     keras.callbacks.EarlyStopping(monitor='loss', min_delta=0,
#                               patience=2, verbose=1, mode='auto'),
      keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, 
                              patience=2, min_lr=0.000001, verbose=1) 
#  keras.callbacks.ModelCheckpoint(filepath = 'model_.{epoch:02d}-{val_loss:.6f}.m5',
#                             verbose=1, save_best_only=True, save_weights_only = True)
    ]

# number of global epochs
rounds = 5
batch_size = 32

In [28]:
from keras_unet_collection import models

# U-Net
# model = models.unet_2d((IMG_SIZE, IMG_SIZE, 2), [32, 64, 128, 256, 512], 4)
# U-Net++
# model = models.unet_plus_2d((IMG_SIZE, IMG_SIZE, 2), [32, 64, 128, 256, 512], 4)
# Attention U-Net
model = models.att_unet_2d((IMG_SIZE, IMG_SIZE, 2), [32, 64, 128, 256, 512], 4)

In [31]:
# initialize global model
# global_model = build_unet(input_layer, 'he_normal', 0.2)
global_model = model
global_model.compile(
        loss = loss,
        optimizer = optimizer,
        metrics = metrics
        )
print("Begin Training")
# commence global training loop
for round in range(1, rounds+1):
  print(f'\nRound: {round}')

  # get global model's weights
  global_weights = global_model.get_weights()

  # initial list to collect local model weights after scaling
  scaled_local_weight_list = list()

  # get client names
  client_names= list(clients.keys())
  random.shuffle(client_names)

  count = 1
  # loop through each client and create new local model
  for client in client_names:
    print(f'Client {count}')
    local_model = model
    local_model.compile(
        loss = loss,
        optimizer = optimizer,
        metrics = metrics
        )
    
    #set local model weight to the weight of the global model
    local_model.set_weights(global_weights)

    # get client data and pass it through a data generator
    data = DataGenerator(clients[client])

    # fit local model with client's data
    local_model.fit(data, epochs = 1, steps_per_epoch = len(data), verbose = 1) #callbacks = callbacks, validation_data = valid_generator)

    # scale the model weights and add to list
    scaling_factor = weight_scaling_factor(data)
    scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
    
    # not adding scaling
    scaled_local_weight_list.append(local_model.get_weights())

    # clear session to free memory after each communication round
    K.clear_session()

    count += 1

  #to get the average over all the local model, we simply take the sum of the scaled weights
  average_weights = sum_scaled_weights(scaled_local_weight_list)
      
  #update global model 
  global_model.set_weights(average_weights)

  evaluate_model(test_ids, global_model, round)

print('\nTraining Done!')

In [15]:
# client_names= list(clients.keys())
# for client in client_names:
#   print(clients[client])

In [33]:
# client_data = clients['client_1']
# data = DataGenerator(client_data)
# valid_generator = DataGenerator(val_ids)

# local_model = model
# local_model.compile(
#       loss = loss,
#       optimizer = optimizer,
#       metrics = metrics
#       )

# local_model.fit(data, epochs=1, steps_per_epoch=len(data), verbose = 1) # callbacks = callbacks, validation_data = valid_generator)