#Master Thesis
#Sinth1 _ Trial 1


#Dataset Selection, loading and preprocessing 

In [None]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

In [None]:
# %cd 'drive/My Drive'
# %cd 'Thesis/Code'

In [None]:
!pip install torchsummary

In [None]:
!pip3 install https://nic.udg.edu/niclib/wheels/niclib-1.0b0-py3-none-any.whl

In [None]:
import os
import numpy as np
import nibabel as nib
import torch
from torch.nn import functional as F
import niclib as nl
import time
import torch.nn as nn
from torchsummary import summary
from matplotlib import pyplot as plt



In [None]:
day = nl.get_timestamp(time_format = '%d_%m_%Y')
whm_small = False
if whm_small:
  challenge ='Wmh_small'
  data_path = '/home/pierpaolo/Datasets/WHM_small'

else:
  challenge ='Wmh_big_trial2'
  data_path = '/home/pierpaolo/Datasets/WMH20173D'

mainfolder = nl.make_dir(day +'_'+challenge)
checkpoints_path = nl.make_dir(os.path.join(mainfolder,'checkpoints/'))
results_path = nl.make_dir(os.path.join(mainfolder,'results/'))
metrics_path = nl.make_dir(os.path.join(mainfolder,'metrics/'))
log_path = nl.make_dir(os.path.join(mainfolder,'log/'))

case_paths = [f.path for f in os.scandir(data_path) if f.is_dir()]
if whm_small:
  case_paths, test_case_paths = case_paths[:-3], case_paths[-3:] # Set aside 3 images for testing
else:
  case_paths, test_case_paths = case_paths[:-5], case_paths[-5:]

print("Loading training dataset with {} images...".format(len(case_paths)))
if challenge =='small':
  def load_case(case_path):
      #Dictionary containing at first the img and the gt
      t1_nifti = nib.load(os.path.join(case_path, 'T1_brain.nii.gz'))
      t1_img = (t1_nifti.get_data())
      t1_img = nl.data.clip_percentile(t1_img, [0.0, 99.99])  # Clip to ignore bright extrema
      t1_img = nl.data.adjust_range(t1_img, [0.0, 255.0]) # Adjust range to 0-1 for Sigmoid activation
      wmh_nifti =nib.load(os.path.join(case_path,'wmh.nii.gz'))
      wmh_img = (wmh_nifti.get_data())
      wmh_img = nl.data.clip_percentile(wmh_img, [0.0, 99.99])  # Clip to ignore bright extrema
      wmh_img = nl.data.adjust_range(wmh_img, [0.0, 255.0]) # Adjust range to 0-1 for Sigmoid activation
      ob = np.stack([t1_img, wmh_img],axis=0)
      gt_nifti = nib.load(os.path.join(case_path, 'FLAIR_brain.nii.gz'))
      gt_img = np.expand_dims(gt_nifti.get_data(), axis=0)
      gt_img = nl.data.clip_percentile(gt_img, [0.0, 99.99])  # Clip to ignore bright extrema
      gt_img = nl.data.adjust_range(gt_img, [0.0, 255.0]) # Adjust range to 0-1 for Sigmoid activation
      return {'id' : case_path.split('/')[-1], 'nifiti': t1_nifti, 't1': ob, 'flair': gt_img}
else:
  def load_case(case_path):
    #Dictionary containing at first the img and the gt
    t1_nifti = nib.load(os.path.join(case_path, '3DT1_brain.nii.gz'))
    t1_img = (t1_nifti.get_data())
    t1_img = nl.data.clip_percentile(t1_img, [0.0, 99.99])  # Clip to ignore bright extrema
    t1_img = nl.data.adjust_range(t1_img, [0.0, 255.0]) # Adjust range to 0-1 for Sigmoid activation
    t1_img = np.expand_dims(t1_nifti.get_data(), axis=0) # Add single channel modality
#     wmh_nifti =nib.load(os.path.join(case_path,'3DWMH.nii.gz'))
#     wmh_img = (wmh_nifti.get_data())
#     ob = np.stack([t1_img, wmh_img],axis =0)
    ob = t1_img
    gt_nifti = nib.load(os.path.join(case_path, '3DFLAIR_brain.nii.gz'))
    gt_img = np.expand_dims(gt_nifti.get_data(), axis=0)
    gt_img = nl.data.clip_percentile(gt_img, [0.0, 99.99])  # Clip to ignore bright extrema
    gt_img = nl.data.adjust_range(gt_img, [0.0, 255.0]) # Adjust range to 0-1 for Sigmoid activation
    return {'id' : case_path.split('/')[-1], 'nifiti': t1_nifti, 't1': ob, 'flair': gt_img}

dataset = nl.parallel_load(load_func=load_case, arguments=case_paths, num_workers=12)
dataset_train, dataset_val = nl.split_list(dataset, fraction=0.8) # Split images into train and validation
print('Training dataset with {} train and {} val images'.format(len(dataset_train), len(dataset_val)))
# print('Dim of input:', len(dataset_train))


### 2. Create training and validation patch generators
#-----Not Needed---------
# train_sampling = nl.generator.BalancedSampling(
#     labels=[np.argmax(case['flair'], axis=0) for case in dataset_train],
#     num_patches=1000 * len(dataset_train))
#------------------------------
step = 1
samples = 500
train_patch_set = nl.generator.ZipSet([
    nl.generator.PatchSet(
        images=[case['t1'] for case in dataset_train],
        patch_shape=(32, 32, 32),
        normalize='none',
        sampling=nl.generator.UniformSampling((step,step,step), num_patches=samples * len(dataset_train))),
    nl.generator.PatchSet(
        images=[case['flair'] for case in dataset_train],
        patch_shape=(32, 32, 32),
        normalize='none',
        sampling=nl.generator.UniformSampling((step,step,step), num_patches=samples * len(dataset_train)))])

# val_sampling = nl.generator.BalancedSampling(
#     labels=[np.argmax(case['flair'], axis=0) for case in dataset_val],
#     num_patches=1000 * len(dataset_val))

val_patch_set = nl.generator.ZipSet([
    nl.generator.PatchSet(
        images=[case['t1'] for case in dataset_val],
        patch_shape=(32, 32, 32),
        normalize='none',
        sampling=nl.generator.UniformSampling((step,step,step), num_patches=samples * len(dataset_val))),
    nl.generator.PatchSet(
        images=[case['flair'] for case in dataset_val],
        patch_shape=(32, 32, 32),
        normalize='none',
        sampling=nl.generator.UniformSampling((step,step,step), num_patches=samples * len(dataset_val)))])
        
        
class Unet(nn.Module):
    """
    Basic U-net model
    """

    def __init__(self, input_size, output_size):

        super(Unet, self).__init__()

        # conv1 down
        self.conv1 = nn.Conv3d(in_channels=input_size,
                               out_channels=32,
                               kernel_size=3,
                               padding=1)
        # max-pool 1
        self.pool1 = nn.Conv3d(in_channels=32,
                               out_channels=32,
                               kernel_size=2,
                               stride=2)
        # conv2 down
        self.conv2 = nn.Conv3d(in_channels=32,
                               out_channels=64,
                               kernel_size=3,
                               padding=1)
        # max-pool 2
        self.pool2 = nn.Conv3d(in_channels=64,
                               out_channels=64,
                               kernel_size=2,
                               stride=2)
        # conv3 down
        self.conv3 = nn.Conv3d(in_channels=64,
                               out_channels=128,
                               kernel_size=3,
                               padding=1)
        # max-pool 3
        self.pool3 = nn.Conv3d(in_channels=128,
                               out_channels=128,
                               kernel_size=2,
                               stride=2)
        # conv4 down (latent space)
        self.conv4 = nn.Conv3d(in_channels=128,
                               out_channels=256,
                               kernel_size=3,
                               padding=1)
        # up-sample conv4
        self.up1 = nn.ConvTranspose3d(in_channels=256,
                                      out_channels=128,
                                      kernel_size=2,
                                      stride=2)        
        # conv 5 (add up1 + conv3)
        self.conv5 = nn.Conv3d(in_channels=128,
                               out_channels=128,
                               kernel_size=3,
                               padding=1)
        # up-sample conv5
        self.up2 = nn.ConvTranspose3d(in_channels=128,
                                      out_channels=64,
                                      kernel_size=2,
                                      stride=2)
        # conv6 (add up2 + conv2) 
        self.conv6 = nn.Conv3d(in_channels=64,
                               out_channels=64,
                               kernel_size=3,
                               padding=1)
        # up 3
        self.up3 = nn.ConvTranspose3d(in_channels=64,
                                      out_channels=32,
                                      kernel_size=2,
                                      stride=2)
        # conv7 (add up3 + conv1)
        self.conv7 = nn.Conv3d(in_channels=32,
                               out_channels=32,
                               kernel_size=3,
                               padding=1)
        # conv8 (classification)
        self.conv8 = nn.Conv3d(in_channels=32,
                               out_channels=output_size,
                               kernel_size=1)
    def forward(self, x):

            # encoder
            x1 = F.relu(self.conv1(x))
            x1p = self.pool1(x1)
            x2 = F.relu(self.conv2(x1p))
            x2p = self.pool2(x2)
            x3 = F.relu(self.conv3(x2p))
            x3p = self.pool3(x3)
            
            # latent space
            x4 = F.relu(self.conv4(x3p))

            # decoder
            up1 = self.up1(x4)
            x5 = F.relu(self.conv5(up1 + x3)) # look how layers are added :o
            up2 = self.up2(x5)
            x6 = F.relu(self.conv6(up2 + x2))
            up3 = self.up3(x6)
            x7 = F.relu(self.conv7(up3 + x1))
            
            # output layer (1 classes)
            out = self.conv8(x7)
            return out
            
model = ResUNet(outputsize=1, inputsize=1, k=16)

name = 'ResUNet50_2'
name2 = '_full_wmh_net.pt'
fullname = name+name2
trainer = nl.net.train.Trainer(
    max_epochs=50,
    loss_func=nn.L1Loss(),
    optimizer=torch.optim.Adadelta,
    optimizer_opts={'lr': 1.0},
    train_metrics={'l1': nn.L1Loss()},
    val_metrics={'l1': nn.L1Loss()},
    plugins=[
        nl.net.train.ProgressBar(),
        nl.net.train.ModelCheckpoint(checkpoints_path + fullname, save='best', metric_name='loss', mode='min'),
        nl.net.train.EarlyStopping(metric_name='loss', mode='min', patience=10),
        nl.net.train.Logger(log_path + 'train_log.csv')],
    device='cuda')

trainer.train(model, train_gen, val_gen)
unet_trained = torch.load(checkpoints_path + fullname)

predictor = nl.net.test.PatchTester(
    patch_shape=(2, 32, 32, 32),
    patch_out_shape=(1, 32, 32, 32),
    extraction_step=(16, 16, 16),
    normalize='none',
    activation=None)
    
dataset_test = nl.parallel_load(load_func=load_case, arguments=test_case_paths, num_workers=12)
synthetic_images  = []
for n, case in enumerate(dataset_test):
  #Idea: show the img in imput for testing
    a = np.squeeze(case['flair'])
    # plt.imshow(a[:,:,33])
    # Predict image with the predictor and store
    print("Synthetising image {}".format(n))
    synth_img = predictor.predict(unet_trained, case['t1'])
    b = np.squeeze(synth_img)
    plt.figure(1, figsize=(7,7))
    plt.subplot(221)
    plt.imshow(a[:,:,48], cmap ='gray')

    plt.subplot(222)
    plt.imshow(b[:,:,48], cmap = 'gray')
    plt.subplot(223)
    plt.imshow(a[:,:,39], cmap ='gray')

    plt.subplot(224)
    plt.imshow(b[:,:,39], cmap = 'gray')
    plt.show()
    #plt.show(tissue_probabilities[0,:,:,20])
    namei = name[0:6]
    nl.save_nifti(
        filepath=results_path + name+'_img_{}_synthetic.nii.gz'.format(case['id']),
        volume=np.squeeze(synth_img,axis =0), # Remove single channel dimension for storage
        reference=case['nifiti']
        #channel_handling='split'
        )

    # Add the input and output images to a list for metrics computation
    synthetic_images.append(synth_img)                        