In [None]:
%pip install tensorflow
%pip install keras
%pip install pandas
%pip install -U scikit-learn
%pip install torch
%pip install einops
%pip install torchvision
%pip install torch
%pip install icecream
%pip install tqdm
%pip install nibabel
%pip install numpy
%pip install NiLearn
%pip install matplotlib
%pip install scikit-image
%pip install jupyter_capture_output

In [1]:
import glob
import os
import sys
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, LeakyReLU, Dense, Flatten, Reshape, UpSampling2D,BatchNormalization, Dropout, Conv3D, Conv3DTranspose, Conv2D, Conv2DTranspose, ConvLSTM2D
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras import Sequential, callbacks
import keras
from keras import layers
from keras.utils import pad_sequences
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from torch.utils.data import DataLoader
from torchvision import transforms
import icecream as ic
import tqdm
import nibabel as nib
from nibabel import processing
from nilearn import plotting
import matplotlib.pyplot as plt
import skimage.transform as skTrans
import pickle

In [2]:
# All the model definitions

training_output = open("training_output.txt", "a")

step_size = 3
batch_size=1
dim_size = 128
d_losses=[]
g_losses=[]
d_accuracies=[]
g_accuracies=[]
iteration_checkpoints=[]

# Fake Image Label: All 1
fake = np.ones((batch_size, 1))

# Real Image Labels: All 0
real = np.zeros((batch_size, 1))

# 2D slices from 3D image
slices = [96,128]

def get_slice(image, slice_index):
	index = int((image.shape[0]/256) * slices[slice_index])
	return image[index]

def get_slice_list(imagelist, slice_index):
	list = []
	for image in imagelist:
		index = int((image.shape[0]/256) * slices[slice_index])
		list.append(image[index])
	return np.array(list)

def create_generator():
	# 2D LSTM
	generator = Sequential(name='generator')
	generator.add(Input(shape=(step_size, dim_size, dim_size)))
	generator.add(Reshape((step_size, dim_size, dim_size, 1)))
	generator.add(ConvLSTM2D(dim_size, (7, 7), strides=(1, 1), padding='same', return_sequences=True, activation='relu'))
	generator.add(BatchNormalization())
	generator.add(ConvLSTM2D(dim_size, (5, 5), strides=(1, 1), padding='same', return_sequences=True, activation='relu'))
	generator.add(BatchNormalization())
	generator.add(ConvLSTM2D(dim_size, (3, 3), strides=(1, 1), padding='same', return_sequences=True, activation='relu'))
	generator.add(BatchNormalization())
	generator.add(ConvLSTM2D(dim_size, (1, 1), strides=(1, 1), padding='same', return_sequences=True, activation='relu'))
	generator.add(BatchNormalization())
	generator.add(ConvLSTM2D(1, (1, 1), strides=(1, 1), padding='same', activation='sigmoid'))

	generator.add(Reshape((dim_size, dim_size)))

	generator_optimizer = tf.optimizers.Adam(1e-4)
	generator.compile(optimizer=generator_optimizer, loss='mse', metrics=['accuracy'])
	generator.summary()
	return generator

def create_discriminator():
	discriminator_input = Input(shape=(dim_size, dim_size))
	y = Dense(dim_size)(discriminator_input)
	y = LeakyReLU(alpha=0.05)(y)
	y = Dense(64)(y)
	y = LeakyReLU(alpha=0.05)(y)
	y = Flatten()(y)
	y = Dense(32)(y)
	y = LeakyReLU(alpha=0.05)(y)
	y = Dense(1, activation='sigmoid')(y)
	discriminator = Model(discriminator_input, y)
	#discriminator_optimizer = RMSprop(learning_rate=8e-4, clipvalue=1.0, decay=1e-8)
	discriminator_optimizer = tf.optimizers.Adam(1e-4)
	discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy', metrics=['accuracy'])
	discriminator.trainable = False
	discriminator.summary()
	return discriminator

def create_gan(generator, discriminator):
	gan_input = Input(shape=(step_size, dim_size, dim_size))
	gan_output = discriminator(generator(gan_input))
	gan = Model(gan_input, gan_output)
	#gan_optimizer = RMSprop(learning_rate=4e-4, clipvalue=1.0, decay=1e-8)
	gan_optimizer = tf.optimizers.Adam(1e-4)
	gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy', metrics=['accuracy'])
	gan.summary()
	return gan

def train_lstm(generator, iteration, X, Y, index, filename):
	temp_X = get_slice_list(copy.deepcopy(X), index).reshape(batch_size, step_size, dim_size, dim_size)
	temp_Y = get_slice(copy.deepcopy(Y), index).reshape(batch_size, dim_size, dim_size)

	for i in range(1):
		g_loss, g_accuracy = generator.train_on_batch(temp_X, temp_Y)

    # Save loss and accuracy to plot graphs after training
	g_losses.append(g_loss)
	g_accuracies.append(100.0 * g_accuracy)
	iteration_checkpoints.append(iteration)

	if iteration % 20 == 0:
		plot_image(generator.predict(temp_X)[0], filename)
		training_output.write("Epoch: %d [Generator loss: %f, accuracy: %.2f%%]\n" % (iteration, g_loss, 100 * g_accuracy))
		training_output.flush()


def train_gan(generator, discriminator, gan, iteration, X, Y, index, filename):

	temp_X = get_slice_list(copy.deepcopy(X), index).reshape(batch_size, step_size, dim_size, dim_size)
	temp_Y = get_slice(copy.deepcopy(Y), index).reshape(batch_size, dim_size, dim_size)

	for i in range(10000):
		g_loss, g_accuracy = generator.train_on_batch(temp_X, temp_Y)
		if i % 10 == 0:
			plot_image(generator.predict(temp_X)[0], filename)
			training_output.write("Epoch: %d [Generator loss: %f, accuracy: %.2f%%]\n" % (iteration, g_loss, 100 * g_accuracy))
			training_output.flush()

	predictions = generator.predict(temp_X)

	for i in range(10):
		d_loss, d_accuracy = discriminator.train_on_batch(np.concatenate([predictions, temp_Y], 0), np.concatenate([fake, real], 0))

	for i in range(10):
		gan_loss, gan_accuracy = gan.train_on_batch(temp_X, real)

    # Save loss and accuracy to plot graphs after training
	d_losses.append(d_loss)
	g_losses.append(g_loss)
	d_accuracies.append(100.0 * d_accuracy)
	g_accuracies.append(100.0 * g_accuracy)

	iteration_checkpoints.append(iteration)

	plt.imshow(generator.predict(temp_X)[0], cmap='bone')
	plt.axis('off')
	plt.show()
	plt.savefig(filename)

	training_output.write("Epoch: %d [Disciminator loss: %f, accuracy: %.2f%%], [Generator loss: %f, accuracy: %.2f%%] [Gan loss: %f, accuracy: %.2f%%]\n" % (iteration, d_loss, 100.0 * d_accuracy, g_loss, 100 * g_accuracy,  gan_loss, 100 * gan_accuracy))
	training_output.flush()

def plot_learning_process(label, values, filename):
    plt.figure(figsize=(24, 10))
    plt.title('Visualization of the learning process', fontsize=16)
    plt.plot(np.arange(1, len(values) + 1), values)
    plt.xlabel('Epochs', fontsize=16)
    plt.ylabel(label, fontsize=16)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.xlim(1, len(values))
    plt.grid()
    plt.legend(fontsize=16)
    plt.savefig(filename)
    plt.show()

def plot_image(image, filename):

	f, axs = plt.subplots()
	f.set_figheight(2.15)
	f.set_figwidth(2.15)

	#Plot original images
	plt.imshow(image, cmap='bone')
	plt.axis('off')
	plt.savefig(filename)
	plt.show()

def plot_images(image1, image2, image3, image4, index, filename):

	f, axs = plt.subplots(1, 4)

	#Plot original images
	axs[0].imshow(get_slice(image1, index), cmap='bone')
	axs[1].imshow(get_slice(image2, index), cmap='bone')
	axs[2].imshow(get_slice(image3, index), cmap='bone')
	axs[3].imshow(get_slice(image4, index), cmap='bone')

	axs[0].axis('off')
	axs[1].axis('off')
	axs[2].axis('off')
	axs[3].axis('off')

	f.set_figheight(2.5)
	f.set_figwidth(10)

	plt.savefig(filename)
	plt.show()

def save_object(obj, filename):
    with open(filename, 'wb') as outp:  # Overwrites any existing file.
        pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)

def load_object(filename):
	with open(filename, 'rb') as inp:
		obj = pickle.load(inp)
	return obj

In [None]:
# Reading all the images from AD directory and training the AD model

data_dir = './dataset'
ad_generator_checkpoint = './checkpoints/ad_generator_checkpoint.weights.h5'
ad_generator_checkpoint_backup = './checkpoints/ad_generator_checkpoint_%d.weights.h5'
state_checkpoint = './checkpoints/state_checkpoint'
cell_output_dir = './cell_output'

train_dir = os.path.join(data_dir, 'train')

AD_generator = create_generator()
AD_discriminator = create_discriminator()
AD_gan = create_gan(AD_generator, AD_discriminator)
AD_image_count = 0
loop_index = 0

# Reload the model and intermediate state if checkpoint files exist
# This is done to recover the progress from process crashes
if os.path.exists(ad_generator_checkpoint):
    AD_generator.load_weights(ad_generator_checkpoint)

if os.path.exists(state_checkpoint):
    objs = load_object(state_checkpoint)
    g_losses = objs[0]
    g_accuracies = objs[1]
    iteration_checkpoints = objs[2]
    AD_image_count = objs[3]
    loop_index = objs[4]

while loop_index < 2000:
    loop_index = loop_index + 1
    for stage_dir in os.scandir(train_dir):
        if not stage_dir.is_dir():
            continue
        for patient_dir in os.scandir(stage_dir.path):
            if not patient_dir.is_dir():
                continue

            date_dirs = []
            for scantype in os.listdir(patient_dir.path):
                scantype_dir = patient_dir.path + "/" + scantype
                if not (os.path.isdir(scantype_dir)):
                    continue
                date_dirs += os.listdir(scantype_dir)

            date_dirs.sort()

            #find all nii_files in sorted order by date
            nii_files = []
            for date in date_dirs:
                for scantype_dir in os.listdir(patient_dir.path):
                    date_dir = os.path.join(patient_dir.path, scantype_dir, date)
                    nii_file = glob.glob(os.path.join(date_dir, '**', '*.nii'))
                    if nii_file:
                        nii_files.append(nii_file[0])
                        continue

            if len(nii_files) != 4:
                print ("nii file missing! patient: " + patient_dir.path)
                continue

            nii_imglist = []
            resized_imglist = []
            normalized_resized_imglist = []

            #load the image data
            for i in range(len(nii_files)):
                nii_img = nib.load(nii_files[i]).get_fdata()
                if (nii_img.shape[0] != 256):
                    print ("skipping the image")
                    break

                resized_img = skTrans.resize(nii_img, (dim_size,dim_size,dim_size), order=1, preserve_range=True)
                normalized_resized_img = MinMaxScaler().fit_transform(np.reshape(resized_img, (-1,1))).reshape(dim_size,dim_size,dim_size)

                nii_imglist.append(nii_img)
                resized_imglist.append(resized_img)
                normalized_resized_imglist.append(normalized_resized_img)


            if len(normalized_resized_imglist) != 4:
                continue

            if("/CN/" in nii_files[0]):
                #CNfirst_transition.append(nii_files_tuple)
                print ("CN file loaded")
            elif("/MCI/" in nii_files[0]):
                #MCIfirst_transition.append(nii_files_tuple)
                print ("MCI file loaded")
            elif("/AD/" in nii_files[0]):

                if AD_image_count % 20 == 0:
                    # Plotting the original sequence of images
                    filename=cell_output_dir + "/image_squence_" + str(loop_index) + "_" + str(AD_image_count) + ".png"
                    plot_images(normalized_resized_imglist[:3][0], normalized_resized_imglist[:3][1], normalized_resized_imglist[:3][2], normalized_resized_imglist[3:][0], 0, filename)

                # Training the model
                filename=cell_output_dir + "/predicted_image_" + str(loop_index) + "_" + str(AD_image_count) + ".png"
                train_lstm(AD_generator, AD_image_count, normalized_resized_imglist[:3], normalized_resized_imglist[3:][0], 0, filename)

                AD_image_count += 1

                #train_gan(AD_generator, AD_discriminator, AD_gan, AD_image_count, normalized_resized_imglist[:3], normalized_resized_imglist[3:][0], slice_index, filename)

    # Saving the state
    AD_generator.save_weights(ad_generator_checkpoint_backup % loop_index)
    AD_generator.save_weights(ad_generator_checkpoint)
    save_object([g_losses, g_accuracies, iteration_checkpoints, AD_image_count, loop_index], state_checkpoint)

#plot_learning_process("Discriminator Loss", d_losses, cell_output_dir + "/plot_learning_process_discriminator_loss.png")
#plot_learning_process("Discriminator Accuracy", d_accuracies, cell_output_dir + "/plot_learning_process_discriminator_accuracy.png")
#plot_learning_process("Gan Loss", g_losses, cell_output_dir + "/plot_learning_process_gan_loss.png")
#plot_learning_process("Gan Accuracy", g_accuracies, cell_output_dir + "/plot_learning_process_gan_accuracy.png")
plot_learning_process("Generator Loss", g_losses, cell_output_dir + "/plot_learning_process_generator_loss.png")
plot_learning_process("Generator Accuracy", g_accuracies, cell_output_dir + "/plot_learning_process_generator_accuracy.png")
print("Total number of image sequences processed: %d" % AD_image_count)
print("done processing train scans")