In [None]:
import sys
sys.executable

In [None]:
import seaborn as sns

from scipy.signal import convolve2d
from scipy.ndimage import median_filter
from scipy.stats import spearmanr, pearsonr
from statsmodels.regression.linear_model import OLS

from scipy.ndimage import convolve
from tqdm.notebook import tqdm
import os
import sys
import pydicom
import SimpleITK as sitk
import PySimpleGUI as sg
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import nilearn
from nilearn.image import resample_img
import time
from tqdm import tqdm
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
%matplotlib inline

from IPython.display import clear_output
import time
from scipy import stats

image_viewer = sitk.ImageViewer()
image_viewer.SetApplication('C:/Program Files/ITK-SNAP 3.8/bin/ITK-SNAP')
reader = sitk.ImageSeriesReader()
reader.LoadPrivateTagsOn()

import pickle
with open('asist.pickle', 'rb') as file:
    asist = pickle.load(file)
    
os.getcwd()

In [None]:
def start_plot():
    global metric_values, multires_iterations
    metric_values = []
    multires_iterations = []
def end_plot():
    global metric_values, multires_iterations  
    del metric_values
    del multires_iterations
    plt.close()
def plot_values(registration_method):
    global metric_values, multires_iterations 
    metric_values.append(registration_method.GetMetricValue())                                       
    clear_output(wait=True)
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    plt.show()
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values)) 

## [1]. sort DICOM to folder, calculate perfusion maps with PMA

## 2. convert DICOM to nifti

In [None]:
!dcm2niix --terse -f /../d_swi0 -b n -z n _swi0/
!dcm2niix --terse -f /../d_swi1 -b n -z n _swi1/
!dcm2niix --terse -f /../d_t2 -b n -z n _t2/
!dcm2niix --terse -f /../d_pwi0 -b n -z n _pwi0/
!dcm2niix --terse -f /../d_pwi/%d -b n -z n _pwi1-pma/

## 3. reg pwi0 to swi0, save transformation

In [None]:
fixed_img = './d_swi0.nii'
moving_img_pwi = './d_pwi0.nii'
fixed = sitk.ReadImage(fixed_img, sitk.sitkFloat32)
moving = sitk.ReadImage(moving_img_pwi, sitk.sitkFloat32)

In [None]:
initial_transform = sitk.CenteredTransformInitializer(fixed, 
                                                      moving, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.MOMENTS)

In [None]:
%%time
registration_method = sitk.ImageRegistrationMethod()

#registration_method.SetMetricAsANTSNeighborhoodCorrelation(radius=5)
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.5)
registration_method.SetInterpolator(sitk.sitkLinear)

registration_method.SetOptimizerAsGradientDescent(learningRate=0.5, numberOfIterations=70, convergenceMinimumValue=1e-6, convergenceWindowSize=20)
registration_method.SetOptimizerScalesFromPhysicalShift()
#registration_method.SetOptimizerAsLBFGSB()

# Setup for the multi-resolution framework.            
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,3,2])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[6,3,2])
#registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Don't optimize in-place, we would possibly like to run this cell multiple times.
registration_method.SetInitialTransform(initial_transform, inPlace=False)

# Connect all of the observers so that we can perform plotting during registration.
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method))

final_transform = registration_method.Execute(sitk.Cast(fixed, sitk.sitkFloat32), 
                                              sitk.Cast(moving, sitk.sitkFloat32))

In [None]:
%%time
sitk.WriteTransform(final_transform, 'transform_pwi0-swi0.tfm')
fixed = sitk.ReadImage(fixed_img, sitk.sitkFloat32)
moving = sitk.ReadImage(moving_img_pwi, sitk.sitkFloat32)
moving_resampled = sitk.Resample(moving, fixed, final_transform, sitk.sitkBSpline, 0.0, moving.GetPixelID())

In [None]:
simg1 = sitk.Cast(sitk.RescaleIntensity(fixed), sitk.sitkUInt8)
simg2 = sitk.Cast(sitk.RescaleIntensity(moving_resampled), sitk.sitkUInt8)
cimg = sitk.Compose(simg1, simg2, simg1 // 2. + simg2 // 2.)
array_cimg = sitk.GetArrayFromImage(cimg)
plt.figure(figsize=(15,10))
plt.subplot(331)
plt.imshow(array_cimg[array_cimg.shape[0]//4,:,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(332)
plt.imshow(array_cimg[array_cimg.shape[0]//3,:,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(333)
plt.imshow(array_cimg[array_cimg.shape[0]//2,:,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(334)
plt.imshow(array_cimg[:,array_cimg.shape[1]//2-15,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(335)
plt.imshow(array_cimg[:,array_cimg.shape[1]//2,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(336)
plt.imshow(array_cimg[:,array_cimg.shape[1]//2+15,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(337)
plt.imshow(array_cimg[:,:,array_cimg.shape[2]//2-15,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(338)
plt.imshow(array_cimg[:,:,array_cimg.shape[2]//2,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(339)
plt.imshow(array_cimg[:,:,array_cimg.shape[2]//2+15,:], origin='lower', interpolation='bilinear', aspect='auto')
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))

## 4. reg all pwi1 maps to swi0 with transformation

In [None]:
fixed_img = './d_swi0.nii'
fixed = sitk.ReadImage(fixed_img, sitk.sitkFloat32)
final_transform = sitk.ReadTransform('transform_pwi0-swi0.tfm')

In [None]:
list = ['BAT', 'BET', 'CBF-sSVD', 'CBV-sSVD', 'CBV-AUC', 'MTT-sSVD', 'Tmax-sSVD', 'NEI', 'dSoverS', 'Cmax', 'FWHM', 'MS', 'TTP', 'fMTT_with_DC', 'fMTT_without_DC']

In [None]:
os.mkdir('r_pwi')

In [None]:
%%time
for perf_map in list:
    name = './d_pwi/' + perf_map + '.nii'
    name2 = './r_pwi/' + perf_map + '.nii'
    moving = sitk.ReadImage(name, sitk.sitkFloat32)
    moving_resampled = sitk.Resample(moving, fixed, final_transform, sitk.sitkBSpline, 0.0, moving.GetPixelID())
    sitk.WriteImage(moving_resampled, name2)

## 5. reg swi1 to swi0

In [None]:
fixed_img = './d_swi0.nii'
fixed = sitk.ReadImage(fixed_img, sitk.sitkFloat32)
moving_img_swi = './d_swi1.nii'
moving = sitk.ReadImage(moving_img_swi, sitk.sitkFloat32)

In [None]:
initial_transform = sitk.CenteredTransformInitializer(fixed, 
                                                      moving, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.MOMENTS)

In [None]:
%%time
registration_method = sitk.ImageRegistrationMethod()

#registration_method.SetMetricAsANTSNeighborhoodCorrelation(radius=5)
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.9)
registration_method.SetInterpolator(sitk.sitkLinear)

registration_method.SetOptimizerAsGradientDescent(learningRate=0.5, numberOfIterations=70, convergenceMinimumValue=1e-6, convergenceWindowSize=20)
registration_method.SetOptimizerScalesFromPhysicalShift()
#registration_method.SetOptimizerAsLBFGSB()

# Setup for the multi-resolution framework.            
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [2,2,2])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,2,1])
#registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Don't optimize in-place, we would possibly like to run this cell multiple times.
registration_method.SetInitialTransform(initial_transform, inPlace=False)

# Connect all of the observers so that we can perform plotting during registration.
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method))

final_transform = registration_method.Execute(sitk.Cast(fixed, sitk.sitkFloat32), 
                                              sitk.Cast(moving, sitk.sitkFloat32))

In [None]:
%%time
fixed = sitk.ReadImage(fixed_img, sitk.sitkFloat32)
moving = sitk.ReadImage(moving_img_swi, sitk.sitkFloat32)
moving_resampled = sitk.Resample(moving, fixed, final_transform, sitk.sitkBSpline, 0.0, moving.GetPixelID())

In [None]:
simg1 = sitk.Cast(sitk.RescaleIntensity(fixed), sitk.sitkUInt8)
simg2 = sitk.Cast(sitk.RescaleIntensity(moving_resampled), sitk.sitkUInt8)
cimg = sitk.Compose(simg1, simg2, simg1 // 2. + simg2 // 2.)
array_cimg = sitk.GetArrayFromImage(cimg)
plt.figure(figsize=(15,10))
plt.subplot(331)
plt.imshow(array_cimg[array_cimg.shape[0]//4,:,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(332)
plt.imshow(array_cimg[array_cimg.shape[0]//3,:,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(333)
plt.imshow(array_cimg[array_cimg.shape[0]//2,:,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(334)
plt.imshow(array_cimg[:,array_cimg.shape[1]//2-15,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(335)
plt.imshow(array_cimg[:,array_cimg.shape[1]//2,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(336)
plt.imshow(array_cimg[:,array_cimg.shape[1]//2+15,:,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(337)
plt.imshow(array_cimg[:,:,array_cimg.shape[2]//2-15,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(338)
plt.imshow(array_cimg[:,:,array_cimg.shape[2]//2,:], origin='lower', interpolation='bilinear', aspect='auto')
plt.subplot(339)
plt.imshow(array_cimg[:,:,array_cimg.shape[2]//2+15,:], origin='lower', interpolation='bilinear', aspect='auto')
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))

In [None]:
sitk.WriteImage(moving_resampled, './r_swi1.nii')

## 6. calculate cells distribution

In [None]:
swi0_img = nib.load('./d_swi0.nii')
swi1_img = nib.load('./r_swi1.nii')
swi0_array = swi0_img.get_fdata()
swi1_array = swi1_img.get_fdata()

In [None]:
swi2_array = (swi0_array*10)/np.median(swi0_array) - (swi1_array*10)/np.median(swi1_array)
limit = np.mean(swi0_array)
swi2l_array = swi2_array > limit*.5
swi2l_array[swi2l_array>0]=1

In [None]:
slice = 200
img = swi0_array[:,:,slice]
img1 = swi1_array[:,:,slice]
img2 = swi2l_array[:,:,slice]
plt.figure(figsize=(15,10))
plt.subplot(131)
plt.imshow(img)
plt.subplot(132)
plt.imshow(img1)
plt.subplot(133)
plt.imshow(img2)
plt.imshow(img1, alpha=0.5)

In [None]:
r_cells = nib.Nifti1Image(swi2l_array*1, swi0_img.affine)
nib.save(r_cells, 'r_cells.nii')

## [7]. cteate brain mask from t2iso with U-Net

## 8. brain mask resample to swi0

In [None]:
swi0_img = nib.load('./d_swi0.nii')
t2_img = nib.load('./md_t2.nii')
swi0_array = swi0_img.get_fdata()
t2_array = t2_img.get_fdata()

In [None]:
t2_resampled = resample_img(t2_img, target_affine=swi0_img.affine, target_shape=swi0_img.shape)

In [None]:
t2_resampled_array = t2_resampled.get_fdata()
img = nib.Nifti1Image(t2_resampled_array, swi0_img.affine)
nib.save(img, './rmd_t2.nii')

## [9]. clean mask with slicer

## 10. mask to cell

In [None]:
mask_img = nib.load('mask.nii')
cell_img = nib.load('r_cells.nii')
mask_array = mask_img.get_fdata()
cell_array = cell_img.get_fdata()

In [None]:
cell_array[mask_array<1]=0

In [None]:
mr_cells = nib.Nifti1Image(cell_array, cell_img.affine)
nib.save(mr_cells, 'mr_cells')

## [11]. clean cell map with itk-SNAP

## 12. load cell and pwi with mask 

In [None]:
mask_img = nib.load('mask.nii')
cell_fin_img = nib.load('smr_cells.nii')
mask_ar = mask_img.get_fdata()
cell_ar = cell_fin_img.get_fdata()
list = ['BAT', 'BET', 'CBF-sSVD', 'CBV-sSVD', 'CBV-AUC', 'MTT-sSVD', 'Tmax-sSVD', 'NEI', 'dSoverS', 'Cmax', 'FWHM', 'MS', 'TTP', 'fMTT_with_DC', 'fMTT_without_DC']

In [None]:
map_ar = {}
map_ar_m = {}
for perf_map in list:
    name = f'./r_pwi/{perf_map}.nii'
    map_img = nib.load(name)
    map_ar[perf_map] = map_img.get_fdata()
    map_ar_m[perf_map] = map_ar[perf_map].copy()
    map_ar_m[perf_map][mask_ar<1] = 0

In [None]:
nmap = 'CBF-sSVD'
n_slice=200
plt.figure(figsize=(10,10))
plt.subplot(141)
plt.imshow(map_ar[nmap][:,:,n_slice], cmap=asist)
plt.subplot(142)
plt.imshow(mask_ar[:,:,n_slice])
plt.subplot(143)
plt.imshow(cell_ar[:,:,n_slice])
plt.subplot(144)
plt.imshow(map_ar_m[nmap][:,:,n_slice], cmap=asist)
plt.imshow(cell_ar[:,:,n_slice], alpha=0.5)

## 13. Density map of cell and median filter for pwi

In [None]:
list = ['BAT', 'BET', 'CBF-sSVD', 'CBV-sSVD', 'CBV-AUC', 'MTT-sSVD', 'Tmax-sSVD', 'NEI', 'dSoverS', 'Cmax', 'FWHM', 'MS', 'TTP', 'fMTT_with_DC', 'fMTT_without_DC']

In [None]:
kernel_size_cell = 11
kernel = np.ones((kernel_size_cell, kernel_size_cell, kernel_size_cell))
cell_ar_map = convolve(cell_ar, kernel, mode='constant', cval=0)
cell_ar_map[mask_ar<1] = 0

In [None]:
kernel_size = 5
perf_map_blurred = {}
for perf_map in tqdm(list):
    perf_map_blurred[perf_map] = median_filter(map_ar_m[perf_map], size=[kernel_size, kernel_size, kernel_size])
    perf_map_blurred[perf_map][mask_ar<1] = 0

In [None]:
nmap = 'TTP'
n_slice=177
plt.figure(figsize=(10,10))
plt.subplot(141)
plt.imshow(map_ar[nmap][:,:,n_slice], cmap=asist)
plt.subplot(142)
plt.imshow(mask_ar[:,:,n_slice])
plt.subplot(143)
plt.imshow(map_ar_m[nmap][:,:,n_slice])
plt.imshow(cell_ar_map[:,:,n_slice], alpha=0.5)
plt.subplot(144)
plt.imshow(cell_ar[:,:,n_slice], cmap=asist)
plt.imshow(cell_ar_map[:,:,n_slice], alpha=0.5)

## 14. Statistical analysis

In [None]:
# pwi1 with mask limit
for perf_map in list:
    print (perf_map)
    cor, pvalue = spearmanr(cell_ar_map[mask_ar>0], perf_map_blurred[perf_map][mask_ar>0])
    print(cor, '	', pvalue)
    ols = OLS(endog=cell_ar_map[mask_ar>0], exog=perf_map_blurred[perf_map][mask_ar>0])
    res = ols.fit()
    print(res.rsquared_adj, '	', res.pvalues[0])

In [None]:
# pwi1 with mask limit - cor sp
for perf_map in list:
    cor, pvalue = spearmanr(cell_ar_map[mask_ar>0], perf_map_blurred[perf_map][mask_ar>0])
    print(cor)

In [None]:
# pwi1 with mask limit - OLS
for perf_map in list:
    ols = OLS(endog=cell_ar_map[mask_ar>0], exog=perf_map_blurred[perf_map][mask_ar>0])
    res = ols.fit()
    print(res.rsquared_adj)

In [None]:
# pwi1 on cell density map
for perf_map in list:
    print (perf_map)
    cor, pvalue = spearmanr(cell_ar_map[cell_ar_map>0], perf_map_blurred[perf_map][cell_ar_map>0])
    print(cor, '	', pvalue)
    ols = OLS(endog=cell_ar_map[cell_ar_map>0], exog=perf_map_blurred[perf_map][cell_ar_map>0])
    res = ols.fit()
    print(res.rsquared_adj, '	', res.pvalues[0])

In [None]:
# pwi1 on cell density map - cor sp
for perf_map in list:
    cor, pvalue = spearmanr(cell_ar_map[cell_ar_map>0], perf_map_blurred[perf_map][cell_ar_map>0])
    print(cor)

In [None]:
# pwi1 on cell density map - OLS
for perf_map in list:
    ols = OLS(endog=cell_ar_map[cell_ar_map>0], exog=perf_map_blurred[perf_map][cell_ar_map>0])
    res = ols.fit()
    print(res.rsquared_adj)

In [None]:
# stat between pwi & pwi_cell
for perf_map in list:
    print (perf_map)
    print('mask:', perf_map_blurred[perf_map][mask_ar>0].reshape(-1,1).mean(), '| mask-cell', perf_map_blurred[perf_map][cell_ar_map>0].reshape(-1,1).mean())
    print(stats.mannwhitneyu(perf_map_blurred[perf_map][mask_ar>0].reshape(-1,1), perf_map_blurred[perf_map][cell_ar_map>0].reshape(-1,1)))
    plt.figure(figsize=(15,3))
    plt.hist(perf_map_blurred[perf_map][mask_ar>0], alpha=1, bins=33)
    plt.hist(perf_map_blurred[perf_map][cell_ar_map>0], alpha=0.5, bins=33)
    plt.show()