In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import cv2
import numpy as np
import math
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import img_to_array
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Input
from keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam, RMSprop, SGD

from PIL import Image

#Load data for U-net training
import os

#Color Channel Experimentation
import skimage
import skimage.color as scc

In [None]:
#SIZE = 1024

In [None]:
#Image_Dataset: MRI/Anatomical Images - Greyscale
MRI_dataset = []
image_directory = os.listdir("/content/drive/MyDrive/VCU_Lab/TBI_Sorted/Anatomical/")
image_directory = sorted(image_directory)
print(image_directory)

for i in range(len(image_directory)):
  img = cv2.imread("/content/drive/MyDrive/VCU_Lab/TBI_Sorted/Anatomical/"+image_directory[i])
  print(image_directory[i])
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 
  MRI_dataset.append(img_to_array(img))

MRI_dataset = np.asarray(MRI_dataset)
print(MRI_dataset.shape)

In [None]:
#Label Data Pre-processing
#Image_Dataset: E-field Images - RGB
mask_directory = os.listdir('/content/drive/MyDrive/VCU_Lab/TBI_Sorted/E_Field/')
mask_directory = sorted(mask_directory)
Efield_dataset_raw = []

for i in range(len(mask_directory)):


  img = cv2.imread("/content/drive/MyDrive/VCU_Lab/TBI_Sorted/E_Field/"+mask_directory[i])
  print(mask_directory[i])
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  Efield_dataset_raw.append(img_to_array(img))


Efield_dataset_raw = np.asarray(Efield_dataset_raw)
print(Efield_dataset_raw.shape)

In [None]:
#Data Normalization 
# MRI Normalization remains the same for all three sets of anatomical images
max_image = np.max(MRI_dataset)
min_image = np.min(MRI_dataset)
MRI_dataset = (MRI_dataset - min_image)/(max_image - min_image)

# E Field Normalization for RGB E-Field Images
max_mask = np.max(Efield_dataset_raw)
min_mask = np.min(Efield_dataset_raw)
Efield_dataset = (Efield_dataset_raw - min_mask)/(max_mask - min_mask)


In [None]:
# Dataset Split
train_data, validation_data, train_mask, validation_mask  = train_test_split(MRI_dataset, Efield_dataset, test_size=0.1, random_state=1, shuffle=True)
print(train_data.shape)
print(validation_data.shape)

In [None]:
#Sanity check of image-mask set
image_number = np.random.randint(0, len(train_data))
plt.figure(figsize=(20, 10))
plt.subplot(121)
plt.imshow(np.squeeze(train_data[image_number]), cmap='gray')
plt.subplot(122)
plt.imshow(np.squeeze(train_mask[image_number]), cmap='gray')
plt.show()

In [None]:
# Training Parameters
Epoch = 200 
batch_size = 1 

input_shape = (1024, 1024, 1)
inputs = Input(shape = input_shape)

In [None]:
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda
from keras.layers import Activation, MaxPool2D, Concatenate

#Convolutional block to be used in U-Net
def conv_block(input, num_filters):
    x = Conv2D(num_filters, (3,3), padding="same")(input)
    x = BatchNormalization()(x)    
    x = Activation("relu")(x)

    x = Conv2D(num_filters, (3,3), padding="same")(x)
    x = BatchNormalization()(x)  
    x = Activation("relu")(x)

    return x

#Encoder block for U-net: Conv block followed by maxpooling
def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p   

#Decoder block for U-net
#skip features gets input from encoder for concatenation
def decoder_block(input, skip_features, num_filters):
    x = Conv2DTranspose(2*num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

#Build Unet using the blocks
def build_unet_depthfive(inputs):
    s1, p1 = encoder_block(inputs, 8)
    s2, p2 = encoder_block(p1, 16)
    s3, p3 = encoder_block(p2, 32)
    s4, p4 = encoder_block(p3, 64)
    s5, p5 = encoder_block(p4, 128)

    b1 = conv_block(p5, 256) #Bridge

    d1 = decoder_block(b1, s5, 128)
    d2 = decoder_block(d1, s4, 64)
    d3 = decoder_block(d2, s3, 32)
    d4 = decoder_block(d3, s2, 16)
    d5 = decoder_block(d4, s1, 8)

    outputs = Conv2D(3, (3,3), padding="same", activation="sigmoid")(d5)  #Binary (can be multiclass)

    return outputs

Unet = Model(inputs, build_unet_depthfour(inputs))
Unet.compile(loss='mse', optimizer = Adam())
print(Unet.summary())

In [None]:
 #Train the model

Unet_model_history = Unet.fit(train_data, train_mask, 
                    verbose=1,
                    batch_size = batch_size,
                    validation_data=(validation_data, validation_mask), 
                    epochs=Epoch)

In [None]:
#Plot the training and validation loss at each epoch
print(Unet_model_history.history)
loss = Unet_model_history.history['loss']
val_loss = Unet_model_history.history['val_loss']
epochs = range(Epoch)
plt.figure(figsize=(40, 20))
plt.plot(epochs, loss, 'g', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
#Plot the training and validation loss at each epoch
print(Unet_model_history.history)
loss = Unet_model_history.history['loss']

epochs = range(Epoch)
plt.figure(figsize=(40, 20))
plt.plot(epochs, loss, 'g', label='Training loss')

plt.title('Training loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
#Plot the training and validation loss at each epoch
print(Unet_model_history.history)

val_loss = Unet_model_history.history['val_loss']
epochs = range(Epoch)
plt.figure(figsize=(40, 20))

plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
#Test the model

Unet_model_test = Unet.evaluate(validation_data, validation_mask, 
                                batch_size=1)

In [None]:
#Unet.save('Unet_12_24_22.h5')

In [None]:
predict_train = Unet.predict(train_data)
predict_val = Unet.predict(validation_data)

In [None]:
mse_train = np.mean((train_mask - predict_train) ** 2)
psnr_train = 10 * math.log10( 1 / mse_train)
print('MSE-Training:', mse_train)
print('Training Data PSNR: {psnr_train}dB'.format(psnr_train=np.round(psnr_train,2)))

mse_val = np.mean((validation_mask - predict_val) ** 2)
psnr_val = 10 * math.log10( 1 / mse_val)
print('MSE-Validation:', mse_val)
print('Validation Data PSNR: {psnr_val}dB'.format(psnr_val=np.round(psnr_val,2)))

In [None]:
plt.figure(figsize=(20, 10))
print("Ground Truth:")
for i in range(5):
  plt.subplot(1, 5, i+1)
  plt.imshow(train_mask[i])
plt.show()

plt.figure(figsize=(20, 10))
print("Prediction:")
for i in range(5):
  plt.subplot(1, 5, i+1)
  plt.imshow(predict_train[i])
plt.show()

plt.figure(figsize=(20, 10))
print("Ground Truth:")
for i in range(5):
  plt.subplot(1, 5, i+1)
  plt.imshow(train_mask[i+5])
plt.show()

plt.figure(figsize=(20, 10))
print("Prediction:")
for i in range(5):
  plt.subplot(1, 5, i+1)
  plt.imshow(predict_train[i+5])
plt.show()

plt.figure(figsize=(20, 10))
print("Ground Truth:")
for i in range(5):
  plt.subplot(1, 5, i+1)
  plt.imshow(train_mask[i+10])
plt.show()

plt.figure(figsize=(20, 10))
print("Prediction:")
for i in range(5):
  plt.subplot(1, 5, i+1)
  plt.imshow(predict_train[i+10])
plt.show()

#plt.figure(figsize=(20, 10))
#print("Ground Truth:")
#for i in range(3):
#  plt.subplot(1, 3, i+1)
# plt.imshow(validation_mask[i])
#plt.show()

#plt.figure(figsize=(20, 10))
#print("Prediction:")
#for i in range(3):
#  plt.subplot(1, 3, i+1)
#  plt.imshow(predict_val[i])
#plt.show()

plt.figure(figsize=(20, 10))
print("Ground Truth:")
for i in range(7):
  plt.subplot(1, 7, i+1)
  plt.imshow(validation_mask[i])
plt.show()

plt.figure(figsize=(20, 10))
print("Prediction:")
for i in range(7):
  plt.subplot(1, 7, i+1)
  plt.imshow(predict_val[i])
plt.show()


In [None]:
import plotly.graph_objects as go

Epochs = list(range(1, 201))

fig = go.Figure()
fig.add_trace(go.Scatter(x=Epochs, y=loss,
                    mode='lines+markers',
                    name='Training Loss'))
fig.add_trace(go.Scatter(x=Epochs, y=val_loss,
                    mode='lines+markers',
                    name='Validation Loss'))
fig.update_xaxes(title_text="Epochs")
fig.update_yaxes(title_text="Loss (MSE)")
fig.show()