In [None]:
## %pylab inline
#start a timer
import time
start_time = time.time()

#import operating system 
from os import listdir
from os.path import isfile, join
from pathlib import Path
import shutil

#To Show Images
import numpy as np
import nibabel as nib
import nilearn
from nilearn import plotting
import pickle
import matplotlib.pyplot as plt


# Import Pytorch
import torch
import torchvision
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch import autograd
from torch.autograd import Variable
from torch_similarity.modules import NormalizedCrossCorrelation
import pytorch_ssim
from torch.utils.data import Dataset, DataLoader


#Import Generator, Discriminator, Encoder and Code Discriminator 
from mmGANBased3 import *

#Monai (Import data and transform data)
from monai.transforms import \
    Compose,Flip, AddChannel,ResizeWithPadOrCrop, ScaleIntensity, ToTensor, Resize, RandRotate, RandFlip, RandScaleIntensity, RandZoom, RandGaussianNoise, RandAffine
from monai.data import CacheDataset, ImageDataset

## Choose architecture
α-WGANSigmaRat1 => arch1

α-WGANSigmaRat2 => arch2

In [None]:
#architecture = "arch1"
architecture = "arch2"

In [None]:
#Import Generator, Discriminator, Encoder and Code Discriminator 
if architecture == "arch1":
    from WGAN_SigmaRat1 import *
if architecture == "arch2":
    from WGAN_SigmaRat2 import *


# Configuration

In [None]:
#_______________________________________________Constants_______________________________________________
DEBUG=True
PATH_DATASET = 'path/to/the/dataset/MRI/'
MODEL_NAME="Model_Name"

#Neural net
BATCH_SIZE = 2 #it must be 2 because of some metrics 
WORKERS = 2

#setting latent variable sizes
LATENT_DIM = 500

#_________________________Visualization_variables_________________________
SEE_SEVERAL_MODELS=True
SEE_SLICE_SERIES=True
USE_MONAI=True

#__________________________f_it_is_to_calculate___________________________
CALC_MS_SSIM=True
CALC_MMD_SCORE=True
CALC_PSNR=True
CALC_RMSE=True
CALC_MAE=True
CALC_NCC=True

# Data Set Creator

In [None]:
def create_train_loader():
    train_transforms = Compose([AddChannel(),
                                
                                ScaleIntensity(minv=0, maxv=1.0),
                                ResizeWithPadOrCrop(spatial_size =(64, 64, 64), mode='constant'),
                                ToTensor()])

    train_ds = ImageDataset(image_files=train_dataset, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS, pin_memory=torch.cuda.is_available())
    return(train_loader)

def create_train_dataset_list(path_dataset):
    dataset = [f for f in listdir(path_dataset) if isfile(join(path_dataset, f))]
    train_dataset=list()
    for i in dataset:
        train_dataset.append(path_dataset+i)
    return train_dataset


In [None]:
train_dataset = create_train_dataset_list(PATH_DATASET)
train_loader=create_train_loader()

----
# Visualization

In [None]:
def nets(only_G, number):
    if only_G:
        G = Generator(noise = LATENT_DIM)
        #G.cuda()
        device = torch.device('cpu')
        G.load_state_dict(torch.load('./checkpoint/'+MODEL_NAME+'/G_iter'+str(number)+'.pth'))
        #G.cuda()
        return (G)
    else:
        G = Generator(noise = LATENT_DIM)
        CD = Code_Discriminator(code_size = LATENT_DIM ,num_units = 4096)
        D = Discriminator(is_dis=True)
        if architecture == "arch1":
            E = Discriminator(out_class = LATENT_DIM,is_dis=False)
        if architecture == "arch2":
            E = Encoder(out_class = LATENT_DIM,is_dis=False)
        '''
        # __TO_GPU___
        E.cuda()   #|
        G.cuda()   #|
        CD.cuda()  #|
        D.cuda()   #|
        #__________#|
        '''
        
        G.load_state_dict(torch.load('./checkpoint/'+MODEL_NAME+'/G_iter'+str(number)+'.pth'))
        CD.load_state_dict(torch.load('./checkpoint/'+MODEL_NAME+'/CD_iter'+str(number)+'.pth'))
        E.load_state_dict(torch.load('./checkpoint/'+MODEL_NAME+'/E_iter'+str(number)+'.pth'))
        D.load_state_dict(torch.load('./checkpoint/'+MODEL_NAME+'/D_iter'+str(number)+'.pth'))

        return (G,CD,D,E)



In [None]:
def get_affine(origin):
    scan = nib.load(origin)
    affine=scan.affine
    header=scan.header

    return affine, header

def saver(image, origin, name, model_name):
    affine, header=get_affine(origin)
    
    if not os.path.exists(model_name):
        Path(model_name).mkdir(parents=True, exist_ok=True)
    feat = np.squeeze((image).data.cpu().numpy())
    feat = feat[:,:,12:52]
    feat = nib.Nifti1Image(feat, affine = affine, header = header)
    nib.save(feat, model_name + "/" + name + ".nii")
    
    return ("File "+ name + " saved")

In [None]:
def visualization(image, reality):
    #Show the real image
    #image=rescale_array(image,0,1)
    feat = np.squeeze((image[0]).data.cpu().numpy())
    #feat=np.fliplr(feat)
    feat = nib.Nifti1Image(feat,affine = np.eye(4))
    plotting.plot_img(feat,title=reality, annotate=False, draw_cross=False, black_bg=True)
    plotting.show()

In [None]:
real_images = next(iter(train_loader)) #next image
visualization(image=real_images,  reality="Real")

In [None]:
def rescale_array(arr: np.ndarray, minv: float = 0.0, maxv: float = 1.0): #monmai function adapted
    """
    Rescale the values of numpy array `arr` to be from `minv` to `maxv`.
    """

    mina = torch.min(arr)
    maxa = torch.max(arr)

    if mina == maxa:
        return arr * minv

    norm = (arr - mina) / (maxa - mina)  # normalize the array first
    return (norm * (maxv - minv)) + minv  # rescale by minv and maxv, which is the normalized array by default

In [None]:
G=nets(only_G=True, number=200000)
z_rand = Variable(torch.randn((BATCH_SIZE,LATENT_DIM)),volatile=True) #random vector
x_rand = G(z_rand)# Image Generation
x_rand=rescale_array(x_rand)
visualization(x_rand, "Fake")


# Slice series visualization

In [None]:
def slice_series(display_mode='x', arr1=list(), arr2=list(), show_color=True,show_real=False,model_number=163000):
    G=nets(only_G=True, number=model_number)
    z_rand = Variable(torch.randn((BATCH_SIZE,LATENT_DIM)),volatile=True) #random vector
    x_rand = G(z_rand)
    x_rand=rescale_array(x_rand,0,1)

    if show_real:
        real_images = next(iter(train_loader)) #next image
        feat = np.squeeze((real_images[0]).data.cpu().numpy())
        feat=np.fliplr(feat)
        feat = nib.Nifti1Image(feat,affine = np.eye(4))
        print("###########################################################")
        print("#####################---REAL_Slices---#####################")
        print("###########################################################")
        if show_color:
            #first row
            disp = plotting.plot_img(feat,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode=display_mode)
            plotting.show()
            #second row
            disp=plotting.plot_img(feat,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode=display_mode)
            plotting.show()
        else:
            #first row
            disp = plotting.plot_anat(feat,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode=display_mode)
            plotting.show()
            #second row
            disp=plotting.plot_anat(feat,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode=display_mode)
            plotting.show()
    print("###########################################################")
    print("#####################---FAKE_Slices---#####################")
    print("###########################################################")
    feat = np.squeeze((x_rand[0]).data.cpu().numpy())
    feat=np.fliplr(feat)
    feat = nib.Nifti1Image(feat,affine = np.eye(4))
    if show_color:
        #first row
        disp = plotting.plot_img(feat,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode=display_mode)
        plotting.show()
        #second row
        disp=plotting.plot_img(feat,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode=display_mode)
        plotting.show()
        
        disp = plotting.plot_img(feat,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode='x')
        plotting.show()
        #second row
        disp=plotting.plot_img(feat,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode='x')
        plotting.show()
        
        disp = plotting.plot_img(feat,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode='y')
        plotting.show()
        #second row
        disp=plotting.plot_img(feat,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode='y')
        plotting.show()
    else:
        #first row
        disp = plotting.plot_anat(feat,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode=display_mode)
        plotting.show()
        #second row
        disp=plotting.plot_anat(feat,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode=display_mode)
        plotting.show()
        
        disp = plotting.plot_anat(feat,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode='x')
        plotting.show()
        #second row
        disp=plotting.plot_anat(feat,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode='x')
        plotting.show()
        
        disp = plotting.plot_anat(feat,cut_coords=arr1,draw_cross=False,annotate=False,black_bg=True,display_mode='y')
        plotting.show()
        #second row
        disp=plotting.plot_anat(feat,cut_coords=arr2,draw_cross=False,annotate=False,black_bg=True,display_mode='y')
        plotting.show()

In [None]:

#______________set_slices_to_see____________________
#arr1 = [4,6,8,10,12,14,16,18,20,22,24,26,28,30,32] #Every values between 1 and 64
#arr2 = [34,36,38,40,42,44,46,48,50,52,54,56,58,60] #|
#___________________________________________________|

     #______________set_slices_to_see____________________
arr1 = [22,24,26,28,30] #Every values must be between 1 and 64
arr2 = [32,34,36,38,40] #|
#____________________________________

#_________________________which_types_of_slices_______________________ 
display_mode='z' #it could be 'x'->sagital ,'y'->coronal or 'z'->axial

show_color=False #Anatomical plot (False) or colored plot (True)
show_real=False #If want to see the real image (set True to see)
model_number= 200000 #Set the model number for the Generator

for i in range(20):

    slice_series(display_mode=display_mode,arr1=arr1,arr2=arr2, show_color=show_color,show_real=show_real,model_number=model_number)


----
<a id="Metrics">  </a>
# Metrics - Objective

## Multiscale Structural Similarity Index Measure (MS-SSIM)
# $ SSIM(x,y)=\frac{(2\mu_{x}\mu_{y}+c_{1})(2\sigma_{xy}+c_{2})}{(\mu_{x}^{2} +\mu_{y}^{2}+c_{1})(\sigma_{x}^{2} +\sigma_{y}^{2}+c_{2})} $

In [None]:
def ms_ssim_real(number=20):
    meanarr = list()
    contar=0
    for k in range(number):
        sum_ssim=0
        for i,dat in enumerate(train_loader):
            if len(dat)!=2:
                break
            img1 = dat[0]
            img2 = dat[1]
            
            img1=rescale_array(img1, 0.0, 1.0)
            img2=rescale_array(img2, 0.0, 1.0)
            
            img1 = img1[:,:,:,12:52]
            img2 = img2[:,:,:,12:52]
            
            msssim = pytorch_ssim.msssim_3d(img1,img2)
            sum_ssim = sum_ssim+msssim
            contar+=1
        meanarr.append(sum_ssim/(i+1))
        
    visualization(image=dat[:,:,:,:,12:52], reality="Real")
    print(contar)
    meanarr2=torch.tensor(meanarr)
    return('Total_mean:'+str(torch.mean(meanarr2).item())+' STD:'+str(torch.std(meanarr2).item()))

    
def ms_ssim_generated(number=1750):
    meanarr = list()
    sum_ssim=0
    contar=0
    for i in range(number):
        noise = Variable(torch.randn((2, 1000)).cuda())
        fake_image = G(noise)
        
        img1 = fake_image[0]
        img2 = fake_image[1]
        
        img1 = rescale_array(img1, 0.0, 1.0)
        img2 = rescale_array(img2, 0.0, 1.0)
        
        img1 = img1[:,:,:,12:52]
        img2 = img2[:,:,:,12:52]

        msssim = pytorch_ssim.msssim_3d(img1,img2)
        contar+=1
        sum_ssim = sum_ssim+msssim
        meanarr.append(sum_ssim/(i+1))
    print(contar)
    visualization(image=fake_image[:,:,:,:,12:52], reality="Fake")

    meanarr2=torch.tensor(meanarr)
    
    return('Total_mean:'+str(torch.mean(meanarr2).item())+' STD:'+str(torch.std(meanarr2).item()))


In [None]:
if CALC_MS_SSIM:
    number=200
    model_number=200000 #Set the model number for the Generator
    G=nets(only_G=True, number=model_number).cuda()
    print("#####################---REAL_Images---#####################")
    print("MS_SSIM_real=",ms_ssim_real(number//10)) #will do 1750 comparations
    print("#####################---FAKE_Images---#####################")
    print("MS_SSIM_fake=",ms_ssim_generated(int(number*8.75))) #the same as previous

----
## Maximum-Mean Discrepancy Score (MMD Score)
# $ MMD(P,Q)=||\mu _{P}-\mu _{Q}||_{H} $

In [None]:
def mmd_score(number=100):     
    meanarr = list()
    for s in range(number):
        distmean = 0.0
        for i,(y) in enumerate(train_loader):
            y = Variable(y).cuda()
            y = rescale_array(y, 0.0, 1.0)
            y = y[:,:,:,:,12:52]

            z_rand = Variable(torch.randn((BATCH_SIZE,LATENT_DIM)),volatile=True).cuda() #random vector
            x = G(z_rand)
            x = rescale_array(x, 0.0, 1.0)
            x = x[:,:,:,:,12:52]

            
            B = y.size(0)
            x = torch.reshape(x, (x.size(0), x.size(2) * x.size(3)* x.size(4)))
            #x = x.view(x.size(0), x.size(2) * x.size(3)* x.size(4))
            #y = y.view(y.size(0), y.size(2) * y.size(3)* y.size(4))
            y = torch.reshape(y, (y.size(0), y.size(2) * y.size(3)* y.size(4)))
            
            xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())

            beta = (1./(B*B))
            gamma = (2./(B*B)) 

            Dist = beta * (torch.sum(xx)+torch.sum(yy)) - gamma * torch.sum(zz)
            Dist2=Dist.item()
            distmean += Dist2


        #print('Mean:'+str(distmean/(i+1)))
        meanarr.append(distmean/(i+1))

    meanarr2=torch.tensor(meanarr)

    return('Total_mean:'+str(torch.mean(meanarr2).item())+' STD:'+str(torch.std(meanarr2).item()))

In [None]:
if CALC_MMD_SCORE:
    number=100
    model_number=200000 #Set the model number for the Generator
    G=nets(only_G=True, number=model_number).cuda()
    print("#####################---MMD_Score---#####################")
    print("MMD=",mmd_score(number=number))

---
## Peak signal-to-noise ratio (PSNR) 

$ PSNR = 20\cdot log_{10}(MAXRange)-10\cdot log_{10}(MSE) $

If normalized values are used $ 20\cdot log_{10}(MAXRange)$ will be zero so:

# $ PSNR = 10\cdot log_{10}(MSE) $

In [None]:
criterion_mse = nn.MSELoss()
def PSNR(value_range=1, model_number=200000,number=100):
    meanarr=list()

    #verify if the values are between 0 and 1
    value_range=torch.cuda.FloatTensor([[value_range]])
    #print(value_range)
    for s in range(number):
        semi_psnr=0.0
        for i,(y) in enumerate(train_loader):
            y = Variable(y).cuda()
            
            y = y[:,:,:,:,12:52]

            #y=((y+1)/2) #normalization [0,1]
            y = rescale_array(y, 0.0, 1.0)
            
            z_rand = Variable(torch.randn((2,LATENT_DIM)),volatile=True).cuda() #random vector
            x_rand = G(z_rand)
            x_rand = x_rand[:,:,:,:,12:52]

            
            #x_rand=((x_rand+1)/2) #normalization [0,1]
            x_rand = rescale_array(x_rand, 0.0, 1.0)

            PSNR = 20.*torch.log10(value_range) - 10.*torch.log10(criterion_mse(x_rand, y))
            PSNR2=PSNR.item()
            semi_psnr+=PSNR2
        #print('Mean:'+str(semi_psnr/(i+1)))
        meanarr.append(semi_psnr/(i+1))
    meanarr2=torch.tensor(meanarr)

    
    return('Total_mean:'+str(torch.mean(meanarr2).item())+' STD:'+str(torch.std(meanarr2).item()))

In [None]:
if CALC_PSNR:  
    number=100
    model_number=200000
    G=nets(only_G=True, number=model_number).cuda()
    print("#####################---PSNR_Value---#####################")
    print("PSNR=",PSNR(number=number))

---
## Root Mean Squared Error
$ MSE(P,Q)=\frac{1}{n}\sum_{i=1}^{n}(y_{i}-x_{i})^{2} $
# $ RMSE=\sqrt{MSE} $
#### NRMSE can be calculated by $ RMSE/(max-min) $ -> since max=1 and min=0, $ NRMSE=RMSE $

In [None]:
def RMSE(number=100):
    mse = nn.MSELoss() #-> L2 Loss

    meanarr=list()
    for s in range(number):
        semi_rmse=0.0
        for i,(y) in enumerate(train_loader):

            y = Variable(y).cuda()
            y = y[:,:,:,:,12:52]
            #y=((y+1)/2) #normalization [0,1]
            y = rescale_array(y, 0.0, 1.0)

            z_rand = Variable(torch.randn((1,LATENT_DIM)),volatile=True).cuda() #random vector
            x_rand = G(z_rand) #prediction image
            
            x_rand = x_rand[:,:,:,:,12:52]
            #x_rand=((x_rand+1)/2) #normalization [0,1]
            x_rand = rescale_array(x_rand, 0.0, 1.0)

            RMSE_value= torch.sqrt(mse(x_rand,y))

            semi_rmse+=RMSE_value.item()
        meanarr.append(semi_rmse/(i+1))    
    meanarr2=torch.tensor(meanarr)
    
    return('Total_mean:'+str(torch.mean(meanarr2).item())+' STD:'+str(torch.std(meanarr2).item()))

In [None]:
if CALC_RMSE:
    number=100
    model_number=200000
    G=nets(only_G=True, number=model_number).cuda()
    print("#####################---RMSE/NRMSE_Value---#####################")
    print("RMSE=",RMSE(number=number))

----
## Mean Absolute Error (MAE)
# $ MAE=\frac{\sum_{i=1}^{n}|y_{i}-x_{i}| }{n} $

In [None]:
def MAE(number=100):
    mae = nn.L1Loss() #-> L1 Loss
    meanarr=list()
    
    for s in range(number):
        semi_mae=0.0
        for i,(y) in enumerate(train_loader):
            y = Variable(y).cuda()
            y = rescale_array(y, 0.0, 1.0)
            
            #y=((y+1)/2) #normalization [0,1]

            z_rand = Variable(torch.randn((1,LATENT_DIM)),volatile=True).cuda() #random vector
            x_rand = G(z_rand) #prediction image
            
            x_rand = rescale_array(x_rand, 0.0, 1.0)
            #x_rand=((x_rand+1)/2) #normalization [0,1]

            MAE_value = mae(x_rand,y)

            semi_mae+=MAE_value.item()
        meanarr.append(semi_mae/(i+1))
    meanarr2=torch.tensor(meanarr)
    
    return('Total_mean:'+str(torch.mean(meanarr2).item())+' STD:'+str(torch.std(meanarr2).item()))

In [None]:
if CALC_MAE:
    number=100
    model_number=200000
    G=nets(only_G=True, number=model_number).cuda()
    
    print("#####################---MAE_Value---#####################")
    print("MAE=",MAE(number=number))

----
## Normalized Cross Correlation (NCC)
# $ NCC=\frac{1}{n}\sum_{x,y}\frac{1}{\sigma_{f} \sigma _{f}}f(x,y)t(x,y) $ 

In [None]:
#https://github.com/yuta-hi/pytorch_similarity
def NCC(number=1):
    model = NormalizedCrossCorrelation(return_map=True) #NCC
    
    meanarr=list()

    for s in range(number):
        semi_ncc=0.0
        for i,(y) in enumerate(train_loader):
            
            y = Variable(y).cuda()
            #y=((y+1)/2) #normalization [0,1]
            y = rescale_array(y, 0.0, 1.0)
            
            z_rand = Variable(torch.randn((BATCH_SIZE,LATENT_DIM)),volatile=True).cuda() #random vector
            x_rand = G(z_rand) #prediction image
            #x_rand=((x_rand+1)/2) #normalization [0,1]
            x_rand = rescale_array(x_rand, 0.0, 1.0)
            
            if y.size()==x_rand.size():
                gc, gc_map = model(x_rand, y)
                semi_ncc+=gc.item()
        
        meanarr.append(semi_ncc/(i+1))
    meanarr2=torch.tensor(meanarr)
    
    return('Total_mean:'+str(torch.mean(meanarr2).item())+' STD:'+str(torch.std(meanarr2).item()))

In [None]:
if CALC_NCC:
    number=100
    model_number=200000
    G=nets(only_G=True, number=model_number).cuda()
    print("#####################---NCC_Value---#####################")
    print("NCC=",NCC(number=number))