# Perturbation based approach using decoder to map ENDOs to different regions in brain. 



In [None]:
#imports 
#PyTorch
import torch
#mention the gpu number
device = torch.device("cuda:0")

#custom modules
from model128 import engine_AE

#Python
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.colors import colorConverter
from matplotlib import cm
import matplotlib
import numpy as np
from tqdm import tqdm
import pickle
from scipy.stats import ttest_ind
from scipy.ndimage import gaussian_filter
#imaging
import nibabel as nb

#Hyperparameters 
autumn = cm.get_cmap("autumn", 16)
color1 = colorConverter.to_rgba('yellow',alpha=0.0)
color_len = 12
colors = list(autumn(np.linspace(0, 1, color_len)))
cmap = matplotlib.colors.ListedColormap([color1]*(16-color_len)+colors)

import nibabel as nib
from scipy.stats import ttest_rel

In [None]:
# T1 ENDOs
T1_ckpt = "T1.ckpt"
device = torch.device("cuda:0")
model_T1 = engine_AE()
model_T1 = model_T1.to(device)
checkpoint = torch.load(T1_ckpt, map_location=device)
model_T1.load_state_dict(checkpoint["state_dict"])


In [None]:
#T1_128_ENDOs.csv refers to ENDOs derived from T1. 
features_T1 = pd.read_csv("T1_128_ENDOs.csv")
features = features_T1.values[:500, 3:].astype('f')

In [None]:
modality = "T1"
dimensions_of_interest = [i for i in range(0,128,1)]

In [None]:
#generating mask using MNI152 linearly registered T1 1mm brain  (masked MSE was used for training)
b = nb.load('MNI152lin_T1_1mm_brain.nii.gz').get_fdata()  #linearly registered T1 MNI 152 
mask = (b != 0).astype('f')
aff = nb.load('MNI152lin_T1_1mm_brain.nii.gz').affine

In [None]:
#just using the decoder aspect of the model.
def decode(model, lin1):
    model.eval()
    with torch.no_grad():
        model = model.to(device)
        lin1 = lin1.to(device)
        dec = model.decoding_mlp(lin1)
        dec = dec.view([dec.shape[0], 256, 12, 14, 12])
        dec = model.first_decoder(dec)
        dec = model.first_transconv(dec)
        dec = model.second_decoder(dec)
        dec = model.second_transconv(dec)
        dec = model.third_decoder(dec)
        dec = model.third_transconv(dec)
        dec = model.fourth_decoder(dec)
        dec = model.fourth_transconv(dec)
        recon = model.last_cnn(dec)
    return recon
  
def crop(im, w, h):
    wi, hi = im.shape
    dw = wi - w
    dh = hi - h
    wend = -(dw - dw//2) if dw != 0 else wi
    hend = -(dh - dh//2) if dh != 0 else hi
    return im[dw//2: wend, dh//2:hend]

In [None]:
def save_for_viz(attribution, dimension,affine_matrix = aff):
    nft_img = nb.Nifti1Image(attribution , affine_matrix)
    nb.save(nft_img, f"{dimension}.nii.gz")

In [None]:
def xai(dim):
    original_sd = []
    perturb_sd = []
    #Adding 1 std noise to ENDOs of interest
    for i in tqdm(range(0,500,10)):
        lin1 = torch.from_numpy(features[i:i+10,:].astype('f'))
        lin1 = lin1.to(device)
        with torch.no_grad():
            recon1 = decode(model_T1, lin1)
            sd = np.std(pd.DataFrame(features[:500,:]).iloc[:,dim])
            lin1[:, dim] += sd
            recon2 = decode(model_T1, lin1)
            original_sd.extend((recon1.detach().cpu().numpy()))
            perturb_sd.extend((recon2.detach().cpu().numpy()))
    original_sd = [np.squeeze(i) for i in original_sd]
    perturb_sd = [np.squeeze(i) for i in perturb_sd]
    #paired ttest
    t_sd = ttest_rel(original_sd, perturb_sd, axis=0, nan_policy="omit")
    #absolute value of tmap
    t_sd  = abs(t_sd[0])
    #mask used during training (masked MSE)
    t_masked_sd = mask * t_sd
    #gaussian smoothing
    t_masked_sd = gaussian_filter(t_masked_sd, sigma=3)
    #saving tmap as nifti image
    save_for_viz(t_masked_sd, f"paired_ttest_T1_{dim}")

In [None]:
for i in tqdm(range(0,128,1)):
    xai(i)