<a href="https://colab.research.google.com/github/AtulyaMS/CYML/blob/main/SMOTE_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
# Set numpy to print only 2 decimal digits for neatness
np.set_printoptions(precision=2, suppress=True)

In [None]:
import shutil
import tensorflow as tf
import os
import gzip
import tarfile

import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter
from tensorflow.keras.callbacks import EarlyStopping
from imblearn.over_sampling import SMOTE
import seaborn as sns

In [None]:
!pip install scikit-plot

In [None]:
from sklearn.model_selection import cross_val_predict, train_test_split
import scikitplot as skplt
from sklearn.metrics import classification_report

In [None]:
IMG_SHAPE = (78, 110, 86)
IMG_2D_SHAPE = (IMG_SHAPE[1] * 4, IMG_SHAPE[2] * 4)
#SHUFFLE_BUFFER = 5 #Subject to change
N_CLASSES = 3

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

In [None]:
root_dir = "/content/gdrive/MyDrive/"
base_dir = root_dir + 'Final_model_training'
os.chdir(base_dir)

In [None]:
def resample_img(itk_image, out_spacing=[2.0, 2.0, 2.0]):
    ''' This function resamples images to 2-mm isotropic voxels.
      
        Parameters:
            itk_image -- Image in simpleitk format, not a numpy array
            out_spacing -- Space representation of each voxel
            
        Returns: 
            Resulting image in simpleitk format, not a numpy array
    '''
    
    # Resample images to 2mm spacing with SimpleITK
    original_spacing = itk_image.GetSpacing()
    original_size = itk_image.GetSize()

    out_size = [
        int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))),
        int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))),
        int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size)
    resample.SetOutputDirection(itk_image.GetDirection())
    resample.SetOutputOrigin(itk_image.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(itk_image.GetPixelIDValue())

    resample.SetInterpolator(sitk.sitkBSpline)

    return resample.Execute(itk_image)

In [None]:
def registrate(sitk_fixed, sitk_moving, bspline=False):
    ''' Perform image registration using SimpleElastix.
        By default, uses affine transformation.
        
        Parameters:
            sitk_fixed -- Reference atlas (sitk .nii)
            sitk_moving -- Image to be registrated
                           (sitk .nii)
            bspline -- Whether or not to perform non-rigid
                       registration. Note: it usually deforms
                       the images and increases execution times
    '''
    
    elastixImageFilter = sitk.ElastixImageFilter()#sitk.ElastixImageFilter()   SimpleElastix()
    elastixImageFilter.SetFixedImage(sitk_fixed)
    elastixImageFilter.SetMovingImage(sitk_moving)

    parameterMapVector = sitk.VectorOfParameterMap()
    parameterMapVector.append(sitk.GetDefaultParameterMap("affine"))
    if bspline:
        parameterMapVector.append(sitk.GetDefaultParameterMap("bspline"))
    elastixImageFilter.SetParameterMap(parameterMapVector)

    elastixImageFilter.Execute()
    return elastixImageFilter.GetResultImage()

In [None]:
def skull_strip_nii(original_img, destination_img, frac=0.2): #
    ''' Practice skull stripping on the given image, and save
        the result to a new .nii image.
        Uses FSL-BET 
        (https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/BET/UserGuide#Main_bet2_options:)
        
        Parameters:
            original_img -- Original nii image
            destination_img -- The new skull-stripped image
            frac -- Fractional intensity threshold for BET
    '''
    
    btr = fsl.BET()
    btr.inputs.in_file = original_img
    btr.inputs.frac = frac
    btr.inputs.out_file = destination_img
    btr.cmdline
    res = btr.run()
    return res

In [None]:
def slices_matrix_2D(img):
  ''' Transform a 3D MRI image into a 2D image, by obtaining 9 slices 
      and placing them in a 4x4 two-dimensional grid.
      
      All 16 cuts are from a horizontal/axial view. They are selected
      from the 30th to the 60th level of the original 3D image.
      
      Parameters:
        img -- np.ndarray with the 3D image
        
      Returns:
        np.ndarray -- The resulting 2D image
  '''
  
  # create the final 2D image 
  image_2D = np.empty(IMG_2D_SHAPE)
  
  # set the limits and the step
  TOP = 60
  BOTTOM = 30
  STEP = 2
  N_CUTS = 16
  
  # iterator for the cuts
  cut_it = TOP
  # iterator for the rows of the 2D final image
  row_it = 0
  # iterator for the columns of the 2D final image
  col_it = 0
  
  for cutting_time in range(N_CUTS):
    
    # cut
    cut = img[cut_it, :, :]
    cut_it -= STEP
    
    # reset the row iterator and move the
    # col iterator when needed
    if cutting_time in [4, 8, 12]:
      row_it = 0
      col_it += cut.shape[1]
    
    # copy the cut to the 2D image
    for i in range(cut.shape[0]):
      for j in range(cut.shape[1]):
        image_2D[i + row_it, j + col_it] = cut[i, j]
    row_it += cut.shape[0]
  
  # return the final 2D image, with 3 channels
  # this is necessary for working with most pre-trained nets
  return np.repeat(image_2D[None, ...], 3, axis=0).T

In [None]:
def load_image_2D(abs_path): #, labels
  ''' Load an image (.nii) and its label, from its absolute path.
      Transform it into a 2D image, by obtaining 16 slices and placing them
      in a 4x4 two-dimensional grid.
      
      Parameters:
        abs_path -- Absolute path, filename included
        labels -- Label mapper
        
      Returns:
        img -- The .nii image, converted into a numpy array
        label -- The label of the image (from argument 'labels')
        
  '''
  
  # obtain the label from the path (it is the last directory name)
  #label = labels[abs_path.split('/')[-2]]
  
  # load the image with SimpleITK
  sitk_image = sitk.ReadImage(abs_path)
  
  # transform into a numpy array
  img = sitk.GetArrayFromImage(sitk_image)
  
  # apply whitening
  img = preprocessing.whitening(img)
  
  # make the 2D image
  img = slices_matrix_2D(img)
  
  return img

In [None]:
def gz_extract(zipfile):
    file_name = (os.path.basename(zipfile)).rsplit('.',1)[0] #get file name for file within
    with gzip.open(zipfile,"rb") as f_in, open(f"{zipfile.split('/')[0]}/{file_name}","wb") as f_out:
        shutil.copyfileobj(f_in, f_out)
    os.remove(zipfile) # delete zipped file
    # return f"{zipfile.split('/')[0]}/{file_name}"

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))): # if value ist tensor
        value = value.numpy() # get value of tensor
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a floast_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_array(array):
  array = tf.io.serialize_tensor(array)
  return array

In [None]:
def write_tfrecords(x, y, filename):
    writer = tf.io.TFRecordWriter(filename)

    for image, label in zip(x, y):
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'image': _bytes_feature(serialize_array(image)), #tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()])),
                'label': _int64_feature(label) #tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            }))
        writer.write(example.SerializeToString())


In [None]:
def _parse_image_function(example_proto):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }
    features = tf.io.parse_single_example(example_proto, image_feature_description)
    image = tf.io.parse_tensor(features['image'], out_type=tf.double) #tf.io.decode_raw(features['image'], tf.float32) 
   # image.set_shape([3 * 344 * 440])
    image = tf.reshape(image, [344, 440, 3])

    label = tf.cast(features['label'], tf.int32)
    label = tf.one_hot(features['label'], 3)

    return image, label

In [None]:
def read_dataset(epochs, batch_size, filename):

    # filenames = [os.path.join(channel, channel_name + '.tfrecords')]
    dataset = tf.data.TFRecordDataset(filename)

    dataset = dataset.prefetch(batch_size)                      ##4
    dataset = dataset.repeat(epochs)                            ##2
    dataset = dataset.shuffle(buffer_size=10 * batch_size)      ##1
    dataset = dataset.batch(batch_size, drop_remainder=True)    ##3


    # dataset = dataset.map(_parse_image_function, num_parallel_calls=10)
    # dataset = dataset.shuffle(buffer_size=10 * batch_size)               ##1
    # dataset = dataset.repeat(epochs)                                     ##2
    # dataset = dataset.batch(batch_size, drop_remainder=True)             ##3
    # dataset = dataset.prefetch(batch_size)                               ##4

    return dataset

In [None]:
categories = ['CN', 'MCI', 'AD']
category_dict = {'CN':0, 'MCI':1, 'AD':2}
nifti_files = []
labels = []

newpath = f"./Nifti_files/"
for category in categories:
    path = f"./{category}/"   
# r=root, d=directories, f = files
    for r, d, f in os.walk(path):
        for file in f:
            if '.nii' in file:
                nifti_files.append(os.path.join(r, file))
                label = category_dict[category]#0 if category=='CN' else 1 if category=='MCI' else 2
                labels.append(label)
                
print(nifti_files[0:5])
print(labels[0:5])

In [None]:
atlas = sitk.ReadImage('average305_t1_tal_lin_mask.nii')
atlas = resample_img(atlas)

In [None]:
for image in nifti_files:
    sitk_image = sitk.ReadImage(image)
# transform into a numpy array
    sitk_array = sitk.GetArrayFromImage(sitk_image)
    
    res_image = resample_img(sitk_image)
    res_array = sitk.GetArrayFromImage(res_image)
    res_array = preprocessing.resize_image_with_crop_or_pad(res_array, img_size=(128, 192, 192), mode='symmetric')
    res_array = preprocessing.whitening(res_array)
    
    registrated_image = registrate(atlas, res_image, bspline=False)
    sitk.WriteImage(registrated_image, f"Registrated/{image.split('/')[-1]}_registrated.nii")
    
    registrated_image = sitk.ReadImage(f"Registrated/{image.split('/')[-1]}_registrated.nii")
    registrated_array = sitk.GetArrayFromImage(registrated_image)
    
    skull_strip_nii(f"Registrated/{image.split('/')[-1]}_registrated.nii", f"Skull_Stripped/{image.split('/')[-1]}_stripped.nii", frac=0.2)
    gz_extract(f"Skull_Stripped/{image.split('/')[-1]}_stripped.nii.gz")

In [None]:
ss_images = os.listdir('Skull_Stripped')

for image in ss_images:
    image_2d = load_image_2D(f"Skull_Stripped/{image}")
 #   print(image_2d.shape)
    np.save(f"Image_2d/{image.split('/')[-1]}_2d", image_2d)

In [None]:
image_array = []
label_array = []
train_array = ["CN_TRAIN_Image2D", "MCI_TRAIN_Image2D", "AD_TRAIN_Image2D"]

for folder in train_array:
  for filename in os.listdir(folder):
    if filename.endswith('.npy'):
      image_array.append(np.load(f"{folder}/{filename}")) 
      label_array.append(0 if 'CN' in folder else 1 if 'MCI' in folder else 2)
        
image_array = np.array(image_array)

In [None]:
print(image_array.shape)
print(Counter(label_array).keys()) # equals to list(set(words))
print(Counter(label_array).values()) # counts the elements' frequency

In [None]:

#Over-sampling: SMOTE
#SMOTE (Synthetic Minority Oversampling TEchnique) consists of synthesizing elements for the minority class, 
#based on those that already exist. It works randomly picking a point from the minority class and computing 
#the k-nearest neighbors for this point.The synthetic points are added between the chosen point and its neighbors.
#We'll use ratio='minority' to resample the minority class.
smote = SMOTE('minority')

image_array_sm, label_array_sm = smote.fit_resample(image_array.reshape((image_array.shape[0], image_array.shape[1]*image_array.shape[2]*image_array.shape[3])), label_array)

In [None]:
image_array_smarr, label_array_smarr = np.array(image_array_sm), np.array(label_array_sm)
image_array_smarr = image_array_smarr.reshape(image_array_smarr.shape[0], 344, 440, 3)
print(image_array_smarr.shape, label_array_smarr.shape)

In [None]:
# write_tfrecords(image_array, label_array, "./train.tfrecords")

write_tfrecords(image_array_smarr, label_array_smarr, "./train_smote.tfrecords")

In [None]:
Train =read_dataset(10, 50, './train_smote.tfrecords')   # read_dataset(10, 50, './train.tfrecords')

In [None]:
Train

In [None]:
# label_array = []
# train_array = ["CN_TRAIN_Image2D", "MCI_TRAIN_Image2D", "AD_TRAIN_Image2D"]

# for folder in train_array:
#   for filename in os.listdir(folder):
#     if filename.endswith('.npy'):
#       label_array.append(0 if 'CN' in folder else 1 if 'MCI' in folder else 2)

In [None]:
image_val_array = []
label_val_array = []
val_array = ["CN_VAL_Image2D", "MCI_VAL_Image2D", "AD_VAL_Image2D"]

for folder in val_array:
  for filename in os.listdir(folder):
    if filename.endswith('.npy'):
      image_val_array.append(np.load(f"{folder}/{filename}")) 
      label_val_array.append(0 if 'CN' in folder else 1 if 'MCI' in folder else 2)
        
image_val_array = np.array(image_val_array)

In [None]:
print(image_val_array.shape)
print(Counter(label_val_array).keys()) # equals to list(set(words))
print(Counter(label_val_array).values()) # counts the elements' frequency

In [None]:
write_tfrecords(image_val_array, label_val_array, "./val.tfrecords")

In [None]:
Validation = read_dataset(10, 50, './val.tfrecords') #image_val_array.shape[0]

In [None]:
Validation

In [None]:
image_test_array = []
label_test_array = []
test_array = ["CN_TEST_Image2D", "MCI_TEST_Image2D", "AD_TEST_Image2D"]

for folder in test_array:
  for filename in os.listdir(folder):
    if filename.endswith('.npy'):
      image_test_array.append(np.load(f"{folder}/{filename}")) 
      label_test_array.append(0 if 'CN' in folder else 1 if 'MCI' in folder else 2)
        
image_test_array = np.array(image_test_array)

In [None]:
print(image_test_array.shape)
print(Counter(label_test_array).keys()) # equals to list(set(words))
print(Counter(label_test_array).values())

In [None]:
write_tfrecords(image_test_array, label_test_array, "./test.tfrecords")

In [None]:
Test = read_dataset(10, 50, './test.tfrecords')

In [None]:
Test

In [None]:
base_model = tf.keras.applications.inception_v3.InceptionV3(
    input_shape=(344, 440, 3), 
    weights='imagenet', 
    include_top=False,  #
    pooling='avg') #max
base_model.trainable = False

base_output = base_model.output
hidden_layer = tf.keras.layers.Dense(512, activation='relu')(base_output) #512  #'relu'
#hl_reg = tf.keras.layers.Dropout(0.5)(hidden_layer) #


output_layer = tf.keras.layers.Dense(N_CLASSES, activation='softmax')(hidden_layer)
inception_model = tf.keras.models.Model(inputs=base_model.input, outputs=output_layer)

# for layer in base_model.layers:
#     layer.trainable = False

# compile the model 
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, decay=1e-6)
METRICS = [
      tf.keras.metrics.AUC(name='auc'),
      tf.keras.metrics.AUC(name='prc', curve='PR'), #precision-recall curve
      tf.keras.metrics.CategoricalAccuracy(name='categorical accuracy'),
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.Recall(name='recall'),

]

inception_model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=METRICS)

In [None]:
inception_model.summary()

In [None]:
# inception_model.fit(Train, epochs=2, validation_data=Validation, verbose=1)

In [None]:
es = EarlyStopping(patience=5, restore_best_weights=True) #, monitor='val_loss'
history = inception_model.fit(Train, epochs=50, validation_data=Validation, verbose=1, callbacks=[es])

In [None]:
inception_model.save(base_dir+'/inception_model2.h5')

In [None]:
inception_model.evaluate(Test)

In [None]:
def plot_metrics(history):
    with plt.style.context('seaborn-deep'):
        fig, ax = plt.subplots(1, 3, figsize=(15, 4))
        ## Plot Losses and Accuracies
        x_axis = np.arange(len(history.history['loss']))
        ax[0].set_title("Loss")
        ax[0].plot(x_axis, history.history['loss'], color="blue", linestyle=":", marker="X", label="Train Loss")
        ax[0].plot(x_axis, history.history['val_loss'], color="orange", linestyle="-", marker="X", label="Val Loss")
        ax[1].set_title("AUC")
        ax[1].plot(x_axis, history.history['auc'], color="blue", linestyle=":", marker="X", label="Train AUC")
        ax[1].plot(x_axis,
                   history.history['val_auc'],
                   color="orange",
                   linestyle="-",
                   marker="X",
                   label="Val AUC")
        ax[2].set_title("PRC")
        ax[2].plot(x_axis, history.history['prc'], color="blue", linestyle=":", marker="X", label="Train PRC")
        ax[2].plot(x_axis,
                   history.history['val_prc'],
                   color="orange",
                   linestyle="-",
                   marker="X",
                   label="Val PRC")
        ## Customization
        ax[0].grid(axis="x", linewidth=0.5)
        ax[0].grid(axis="y", linewidth=0.5)
        ax[0].legend()
        ax[1].grid(axis="x", linewidth=0.5)
        ax[1].grid(axis="y", linewidth=0.5)
        ax[1].legend()
        ax[2].grid(axis="x", linewidth=0.5)
        ax[2].grid(axis="y", linewidth=0.5)
        ax[2].legend()
        plt.show()

In [None]:
plot_metrics(history)

In [None]:
import os

path = f"./"   
Hdf5_files = []
# r=root, d=directories, f = files
for r, d, f in os.walk(path):
    for file in f:
        if '.h5' in file:
            Hdf5_files.append(os.path.join(r, file))

In [None]:
Hdf5_files

In [None]:
from tensorflow.keras.models import load_model
my_model = load_model(Hdf5_files[-1])

In [None]:
predictions = my_model.predict(image_test_array)

In [None]:
predictions.shape

In [None]:
ypred = []
for prediction in predictions:
  ypred.append(prediction.argmax())

In [None]:
from sklearn.metrics import confusion_matrix
conf_matrix = confusion_matrix(
    label_test_array,
    ypred
)

In [None]:
conf_matrix

In [None]:
cm_df = pd.DataFrame(conf_matrix,
                     index = ['CN','MCI','AD'], 
                     columns = ['CN','MCI','AD'])

In [None]:
cm_df

In [None]:
plt.figure(figsize=(5,4))
sns.heatmap(cm_df, annot=True, cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.ylabel('Actal Values')
plt.xlabel('Predicted Values')
plt.show()

In [None]:
skplt.metrics.plot_roc(label_test_array, predictions, title = 'ROC Plot');

In [None]:
skplt.metrics.plot_precision_recall(label_test_array, predictions, title = 'PR Curve');

In [None]:
target_names = ['CN', 'MCI', 'AD']
print(classification_report(label_test_array, ypred, target_names=target_names))

In [None]:
import os
from google.cloud import storage


os.environ["GOOGLE_APPLICATION_CREDENTIALS"]= CREDENTIALS

# Initialise a client
client = storage.Client(PROJECT_NAME)
# Create a bucket object for our bucket
bucket = client.get_bucket(BUCKET_NAME)
# Create a blob object from the filepath
blob = bucket.blob('inception_trial.h5')
# Upload the file to a destination
blob.upload_from_filename('./inception_model1.h5')