In [None]:
import util
import numpy as np
import nibabel as nib
from qsmrad_feats import MLP, train_model
import torch
from torch import nn
import torch.nn.functional as F
from IPython.display import HTML
import matplotlib.pyplot as plt
import sklearn.preprocessing as skp
import sklearn.metrics as sme
import sklearn.feature_selection as skf
import gc

In [None]:
# Ensure model.eval() drops regularizations, that model.eval() = f(X,θ)
# Add channels

In [None]:
HTML('''
<style>
.jupyter-matplotlib {
    background-color: #000;
}

.widget-label, .jupyter-matplotlib-header{
    color: #fff;
}

.jupyter-button {
    background-color: #333;
    color: #fff;
}
</style>
''')


In [None]:
reload = 0
qsms = util.full_path('/home/ali/RadDBS-QSM/data/nii/qsm')
qsms_subs = qsms[-9]
segs = util.full_path('/home/ali/RadDBS-QSM/data/nii/seg')
chi = []
im_subs = []
if reload == 1:
    for j in np.arange(len(qsms)):
        data = nib.load(qsms[j])
        qsm_subs = qsms[j][-9:-7]
        try:
            mask = nib.load('/home/ali/RadDBS-QSM/data/nii/seg/labels_2iMag'+qsm_subs+'.nii.gz').get_fdata()
            img = util.mask_crop(data.get_fdata(),mask)
            img = util.pad_to(img,152,152,105)
            chi.append(img)
            im_subs.append(qsms[j][-9:-7])
        except:
            print('Skipping',qsms[j])

    np.save('chi.npy',np.asarray(chi))
    np.save('im_subs.npy',np.asarray(im_subs))
else:
    chi = np.load('chi.npy')
    im_subs = np.load('im_subs.npy')

In [None]:
# Get case IDs
case_list = open('/home/ali/RadDBS-QSM/data/docs/cases_90','r')
lines = case_list.read()
lists = np.loadtxt(case_list.name,comments="#", delimiter=",",unpack=False,dtype=str)
case_id = []
for lines in lists:     
    case_id.append(lines[-9:-7])

# Load scores
file_dir = '/home/ali/RadDBS-QSM/data/docs/QSM anonymus- 6.22.2023-1528.csv'
motor_df = util.filter_scores(file_dir,'pre-dbs updrs','stim','CORNELL ID')
# Find cases with all required scores
subs,pre_imp,post_imp,pre_updrs_off = util.get_full_cases(motor_df,
                                                          'CORNELL ID',
                                                          'OFF (pre-dbs updrs)',
                                                          'ON (pre-dbs updrs)',
                                                          'OFF meds ON stim 6mo')
ID_all = im_subs
ids = np.asarray(ID_all).astype(int)
# Find overlap between scored subjects and feature extraction cases
c_cases = np.intersect1d(np.asarray(case_id).astype(int),np.asarray(subs).astype(int))
# Complete case indices with respect to feature matrix
c_cases_idx = np.in1d(ids,c_cases)
# Re-index the scored subjects with respect to complete cases
s_cases_idx = np.in1d(subs,ids[c_cases_idx])
subsc = subs[s_cases_idx]
# Re-index the scored subjects with respect to complete cases
s_cases_idx = np.in1d(subs,ids[c_cases_idx])
subsc = subs[s_cases_idx]
pre_imp = pre_imp[s_cases_idx]
post_imp = post_imp[s_cases_idx]
pre_updrs_off = pre_updrs_off[s_cases_idx]
per_change = post_imp
X_all_c = np.asarray(chi)[c_cases_idx,:,:,:]
print(X_all_c.shape)


In [None]:
npy_dir = '/home/ali/RadDBS-QSM/data/npy/'
phi_dir = '/home/ali/RadDBS-QSM/data/phi/phi/'
roi_path = '/data/Ali/atlas/mcgill_pd_atlas/PD25-subcortical-labels.csv'
n_rois = 6
all_rois = False
Phi_all, X_all, R_all, K_all, ID_all = util.load_featstruct(phi_dir,npy_dir+'X/',npy_dir+'R/',npy_dir+'K/',n_rois,1595,all_rois)

idsr = np.asarray(ID_all).astype(int)
# Find overlap between scored subjects and feature extraction cases
c_cases = np.intersect1d(np.asarray(idsr).astype(int),np.asarray(subsc).astype(int))
# Complete case indices with respect to feature matrix
c_cases_idx = np.in1d(idsr,c_cases)
X_all_cc = X_all[c_cases_idx,0:4,:]
K_all_c = K_all[c_cases_idx,0:4,:]
R_all_c = R_all[c_cases_idx,0:4,:]

print(R_all_c)
# Re-index the scored subjects with respect to complete cases
s_cases_idx = np.in1d(subsc,idsr[c_cases_idx])
subsc = subsc[s_cases_idx]
pre_imp = pre_imp[s_cases_idx]
post_imp = post_imp[s_cases_idx]
pre_updrs_off = pre_updrs_off[s_cases_idx]
per_change = post_imp
X_all_c = X_all_c[s_cases_idx,:,:,:]
# Reshape keys and ROIs
if all_rois == True:
    K_all_cu = np.empty((K_all_c.shape[0],K_all_c.shape[1],K_all_c.shape[2]+1),dtype=object)
    K_all_cu[:,:,:-1] = K_all_c
    K_all_cu[:,:,-1] = 'pre_updrs'
    K = K_all_cu.reshape((K_all_cu.shape[0],K_all_cu.shape[1]*K_all_cu.shape[2]))[0]
    R = R_all_c.reshape((R_all_c.shape[0],R_all_c.shape[1]*R_all_c.shape[2]))
else:
    K = K_all_c.reshape((K_all_c.shape[0],K_all_c.shape[1]*K_all_c.shape[2]))[0]
    K = np.append(K,['pre_updrs'],0)
    R = R_all_c.reshape((R_all_c.shape[0],R_all_c.shape[1]*R_all_c.shape[2]))


In [None]:
num_epochs = 10
batch_size = 37
results = np.zeros_like(per_change)

In [None]:
gc.collect()

In [None]:
for j in np.arange(len(subsc)):
    test_id = subsc[j]
    test_index = subsc == test_id
    train_index = subsc != test_id
    X_train = X_all_c[train_index,:,:]
    X_test = X_all_c[test_index,:,:]
    X_rad_train = X_all_cc[train_index,:,:]
    X_rad_test = X_all_cc[test_index,:,:]
   
    y_train = per_change[train_index]
    y_test = per_change[test_index]
    X0_ss0,scaler_ss,X_test_ss0 = util.model_scale(skp.StandardScaler(),
                                                X_rad_train,train_index,X_rad_test,test_index,pre_updrs_off,False,False)
    with np.errstate(divide='ignore', invalid='ignore'):
      # Feature selection
      sel = skf.SelectKBest(skf.r_regression,k=10)
      X0_ss = sel.fit_transform(X0_ss0,y_train)
      X_test_ss = sel.transform(X_test_ss0)

    X_train_ct = util.nstack(X_train,X0_ss,False)
    X_test_ct = util.nstack(X_test,X_test_ss,False)   
    print(np.sum(X0_ss.ravel()))
   # MLP (add early stopping?)
    encoder = MLP((batch_size,152,152,105),(30,30,10),10,3)
    yt, encoder, X_trained, y_trained, X_val, y_val, train_curve, val_curve = train_model(X_all=X_train_ct,
                     y_all=y_train,
                     model=encoder,
                     X_test=X_test_ct,
                     lr=1e-4,
                     lr_decay=None,
                     alpha=0,
                     reg_type=None,
                     num_epochs=num_epochs,
                     batch_size=36,
                     case_id=str(int(subsc[j])),
                     num_neighbors=8,
                     random_val=True,
                     early_stopping=False,
                     verbose=True,
                     save_state=False)
    results[j] = yt
    print('Predicted',str(np.round(yt.cpu().detach(),2)),'for',str(np.round(per_change[j],2)))

: 

In [None]:
%matplotlib inline
util.eval_prediction(np.vstack((pre_imp,
                               results,
                               )),
                               per_change,
                               ['LCT',
                                'CNN regressor',
                                ],(30,5))
plt.ylim([0,2])
plt.xlim([0,2])
plt.style.use('default')
plt.show()
