In [None]:
import sys
sys.path.append('../')
import util
import numpy as np
import nibabel as nib
from qsm_feats import MLP, train_model
import torch
from torch import nn
import torch.nn.functional as F
from IPython.display import HTML
import matplotlib.pyplot as plt
import gc
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import RawScoresOutputTarget, BinaryClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from util import pyvis
import scipy
%matplotlib widget

In [None]:
HTML('''
<style>
.jupyter-matplotlib {
    background-color: #000;
}

.widget-label, .jupyter-matplotlib-header{
    color: #fff;
}

.jupyter-button {
    background-color: #333;
    color: #fff;
}
</style>
''')


In [None]:
# Check model.eval()
# Pass masked volumes through to reduce dataset memory burden on GPU?

In [None]:
reload = 0
qsms = util.full_path('/home/ali/RadDBS-QSM/data/nii/qsm')
qsms_subs = qsms[-9]
segs = util.full_path('/home/ali/RadDBS-QSM/data/nii/seg')
chi = []
im_subs = []
if reload == 1:
    for j in np.arange(len(qsms)):
        data = nib.load(qsms[j])
        qsm_subs = qsms[j][-9:-7]
        try:
            mask = nib.load('/home/ali/RadDBS-QSM/data/nii/seg/labels_2iMag'+qsm_subs+'.nii.gz').get_fdata()
            cube_mask = util.mask_crop(mask.get_fdata(),mask)
            img = util.mask_crop(data.get_fdata(),mask)
            img = util.pad_to(img,152,152,105)
            chi.append(img)
            masks.append(mask)
            im_subs.append(qsms[j][-9:-7])
        except:
            print('Skipping',qsms[j])
    np.save('../txt/cube_masks',np.asarray(masks))
    np.save('../txt/chi.npy',np.asarray(chi))
    np.save('../txt/im_subs.npy',np.asarray(im_subs))
else:
    print('Using whole susceptibility')
    chi = np.load('../txt/chi.npy')
    im_subs = np.load('../txt/im_subs.npy')

In [None]:
util.pyvis(np.hstack(chi[0].T,cube_mask[0].T),5,5)

In [None]:
# Get case IDs
case_list = open('/home/ali/RadDBS-QSM/data/docs/cases_90','r')
lines = case_list.read()
lists = np.loadtxt(case_list.name,comments="#", delimiter=",",unpack=False,dtype=str)
case_id = []
for lines in lists:     
    case_id.append(lines[-9:-7])

# Load scores
file_dir = '/home/ali/RadDBS-QSM/data/docs/QSM anonymus- 6.22.2023-1528.csv'
motor_df = util.filter_scores(file_dir,'pre-dbs updrs','stim','CORNELL ID')
# Find cases with all required scores
subs,pre_imp,post_imp,pre_updrs_off = util.get_full_cases(motor_df,
                                                          'CORNELL ID',
                                                          'OFF (pre-dbs updrs)',
                                                          'ON (pre-dbs updrs)',
                                                          'OFF meds ON stim 6mo')
ID_all = im_subs
ids = np.asarray(ID_all).astype(int)
# Find overlap between scored subjects and feature extraction cases
c_cases = np.intersect1d(np.asarray(case_id).astype(int),np.asarray(subs).astype(int))
# Complete case indices with respect to feature matrix
c_cases_idx = np.in1d(ids,c_cases)
# Re-index the scored subjects with respect to complete cases
s_cases_idx = np.in1d(subs,ids[c_cases_idx])
subsc = subs[s_cases_idx]
# Re-index the scored subjects with respect to complete cases
s_cases_idx = np.in1d(subs,ids[c_cases_idx])
subsc = subs[s_cases_idx]
pre_imp = pre_imp[s_cases_idx]
post_imp = post_imp[s_cases_idx]
pre_updrs_off = pre_updrs_off[s_cases_idx]
per_change = post_imp
X_all_c = np.asarray(chi)[c_cases_idx,:,:,:]

In [None]:
num_epochs = 5
num_neighbors = 3
batch_size = X_all_c.shape[0]-num_neighbors-1
print(batch_size)
results = np.zeros_like(per_change)

In [None]:
gc.collect()

In [None]:
# target_layers = [encoder.layers[-1]]
# input_tensor = X_test_ss # Create an input tensor image for your model..
# # Note: input_tensor can be a batch tensor with several images!

# # Construct the CAM object once, and then re-use it on many images:
# cam = GradCAM(model=encoder, target_layers=target_layers)

# # You can also use it within a with statement, to make sure it is freed,
# # In case you need to re-create it inside an outer loop:
# # with GradCAM(model=model, target_layers=target_layers) as cam:
# #   ...

# # We have to specify the target we want to generate
# # the Class Activation Maps for.
# # If targets is None, the highest scoring category
# # will be used for every image in the batch.
# # Here we use ClassifierOutputTarget, but you can define your own custom targets
# # That are, for example, combinations of categories, or specific outputs in a non standard model.

# targets = [RawScoresOutputTarget()]

# # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
# grayscale_cam = cam(input_tensor=input_tensor, targets=targets)

# # In this example grayscale_cam has only one image in the batch:
# grayscale_cam = grayscale_cam[0, :]
# visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

# # You can also get the model outputs without having to re-inference
# model_outputs = cam.outputs

In [None]:
for j in np.arange(len(subsc)):
    test_id = subsc[j]
    test_index = subsc == test_id
    train_index = subsc != test_id
    X_train = X_all_c[train_index,:,:,:]
    X_test = X_all_c[test_index,:,:,:]
    y_train = per_change[train_index]
    y_test = per_change[test_index]
   # MLP (add early stopping?)
    encoder = MLP(in_size=(batch_size,152,152,105),
                  kernel_size=(30,30,20),
                  cnn_layers=5,
                  n_channels=2,
                  fc_layers=1)
    
    yt, encoder, X_trained, y_trained, X_val, y_val, train_curve, val_curve = train_model(X_all=X_train,
                     y_all=y_train,
                     model=encoder,
                     X_test=X_test,
                     solver='adam',
                     lr=1e-1,
                     lr_decay=None,
                     alpha=1e0,
                     reg_type='latent_dist',
                     num_epochs=num_epochs,
                     batch_size=batch_size,
                     case_id=str(int(subsc[j])),
                     num_neighbors=num_neighbors,
                     random_val=True,
                     early_stopping=False,
                     verbose=True,
                     save_state=False)
    # target_layers = [encoder.layers[:]]
    # input_tensor = torch.unsqueeze(torch.Tensor(X_test),axis=0).cuda()# Create an input tensor image for your model..
    # # Construct the CAM object once, and then re-use it on many images:
    # cam = GradCAM(model=encoder, target_layers=target_layers)
    # targets = [BinaryClassifierOutputTarget(0)]

    # # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
    # grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    # print(grayscale_cam.shape)
    # # In this example grayscale_cam has only one image in the batch:
    # grayscale_cam_in = grayscale_cam[0,:,0,:]
    # visualization = show_cam_on_image(X_test[:,:,:,0], grayscale_cam_in, use_rgb=False)

    # # You can also get the model outputs without having to re-inference
    # model_outputs = cam.outputs
    
    results[j] = yt
    print('Predicted',str(np.round(yt.cpu().detach(),2)),'for',str(np.round(per_change[j],2)))

In [None]:
# pyvis(np.squeeze(X_test.T),10,10)
# pyvis(np.squeeze(grayscale_cam.T),10,10)  

In [None]:
%matplotlib inline
util.eval_prediction(np.vstack((pre_imp,
                               results,
                               )),
                               per_change,
                               ['LCT',
                                'CNN regressor',
                                ],(30,5))
plt.ylim([0,2])
plt.xlim([0,2])
plt.style.use('default')
plt.show()
