In [None]:
import numpy as np 
import pandas as pd 
import os
import warnings
import torch
from torchvision import datasets, transforms, models
from torchvision.models import resnet50
from torch.autograd import Variable
from matplotlib import pyplot as plt
from scipy.spatial import cKDTree
from sklearn.exceptions import ConvergenceWarning
import sklearn.linear_model as slm
import sklearn.feature_selection as skf
import sklearn.preprocessing as skp
import util
from adapt.feature_based import CORAL

In [None]:
# Get case IDs
case_list = open('/home/ali/RadDBS-QSM/data/docs/cases_90','r')
lines = case_list.read()
lists = np.loadtxt(case_list.name,comments="#", delimiter=",",unpack=False,dtype=str)
case_id = []
for lines in lists:     
    case_id.append(lines[-9:-7])

# Load scores
file_dir = '/home/ali/RadDBS-QSM/data/docs/QSM anonymus- 6.22.2023-1528.csv'
motor_df = util.filter_scores(file_dir,'pre-dbs updrs','stim','CORNELL ID')
# Find cases with all required scores
subs_init,pre_imp_init,post_imp_init,pre_updrs_off_init = util.get_full_cases(motor_df,
                                                          'CORNELL ID',
                                                          'OFF (pre-dbs updrs)',
                                                          'ON (pre-dbs updrs)',
                                                          'OFF meds ON stim 6mo')
# Load extracted features
npy_dir = '/home/ali/RadDBS-QSM/data/npy/'
phi_dir = '/home/ali/RadDBS-QSM/data/phi/phi/'
roi_path = '/data/Ali/atlas/mcgill_pd_atlas/PD25-subcortical-labels.csv'
n_rois = 6
Phi_all, X_all, R_all, K_all, ID_all = util.load_featstruct(phi_dir,npy_dir+'X/',npy_dir+'R/',npy_dir+'K/',n_rois,1595,False)
del Phi_all, X_all, R_all, K_all
ids = np.asarray(ID_all).astype(int)

# Find overlap between scored subjects and feature extraction cases
c_cases = np.intersect1d(np.asarray(case_id).astype(int),np.asarray(subs_init).astype(int))
# Complete case indices with respect to feature matrix
c_cases_idx = np.in1d(ids,c_cases)
# Re-index the scored subjects with respect to complete cases
s_cases_idx = np.in1d(subs_init,ids[c_cases_idx])
subs_init = subs_init[s_cases_idx]
pre_imp_init = pre_imp_init[s_cases_idx]
post_imp_init = post_imp_init[s_cases_idx]
pre_updrs_off_init = pre_updrs_off_init[s_cases_idx]
per_change_init = post_imp_init
subs = np.asarray(ID_all,dtype=float)[np.in1d(np.asarray(ID_all,dtype=float),subs_init)]

pre_imp = np.zeros((1,len(subs))).T
post_imp = np.zeros((1,len(subs))).T
pre_updrs_off = np.zeros((1,len(subs))).T
per_change = np.zeros((1,len(subs))).T
for j in np.arange(len(subs)):
    pre_imp[j] = pre_imp_init[subs_init == subs[j]]
    post_imp[j] = post_imp_init[subs_init == subs[j]]
    pre_updrs_off[j] = pre_updrs_off_init[subs_init == subs[j]]
    per_change[j] = per_change_init[subs_init == subs[j]]

subsc = subs_init
X_img = []
results_bls = np.zeros_like(per_change)

In [None]:
pt_model = resnet50(weights='DEFAULT')
pt_model.fc
for param in pt_model.parameters():
    param.required_grad = False

In [None]:
X_img, subsc, per_change, pre_metric, pre_imp = util.slice_pick(subsc=subs_init,per_change=per_change,pre_metric=pre_updrs_off,pre_comp=pre_imp,pshape=(100,100,64),roi_l=0,roi_u=4,
                mask_crop_output=True,mask_output=True,o_index=False,
                file_path='/home/ali/RadDBS-QSM/data/pt/X_img.pt',
                qsm_path='/home/ali/RadDBS-QSM/data/nii/qsm',
                seg_prefix='/home/ali/RadDBS-QSM/data/nii/seg/labels_2iMag0',
                save_image=False,
                img_directory='/home/ali/RadDBS-QSM/data/tif/',
                visualize=True,
                reload=True)

In [None]:
Z = []
for j in np.arange(len(X_img)):
    Z.append(util.get_latent_rep(X_img[j],pt_model).detach().numpy())
    plt.imshow(Z[j].reshape((10,100)))

In [None]:
r = np.zeros(len(per_change))
err_var = np.zeros(len(per_change))
rerror = np.zeros(len(per_change))
kappa = []
results_ls = np.zeros(len(per_change))
Z = np.asarray(Z)
for j in np.arange(len(subsc)):
    test_id = subsc[j]
    test_index = subsc == test_id
    train_index = subsc != test_id
    X_train = Z[train_index]
    X_test = Z[test_index]
    y_train = per_change[train_index].ravel()
    y_test = per_change[test_index]

    idy = y_train[y_train<=0.3]
    
    # Cross validation
    cvn = 6
    X0_ss0,scaler_ss,X_test_ss0 = util.model_scale(skp.StandardScaler(),
                                                X_train,train_index,X_test,test_index,pre_metric,False,False,False)
    lasso = slm.LassoLarsCV(max_iter=1000,cv=cvn,normalize=False,eps=0.1,n_jobs=1)
    # lgr = slm.LogisticRegressionCV(cv=2)
    # est_lgr = lgr.fit(X0_ss0,y_train>0.3)
    # print('Predicting label',est_lgr.predict(X_test_ss0))
    # print((np.expand_dims(X0_ss0[np.where(y_train==idy[j])[0][0],:],axis=0)).shape)
    # for j in np.arange(len(idy)):
    #    X0_ss0 = np.concatenate((X0_ss0,
    #                             np.expand_dims(X0_ss0[np.where(y_train==idy[j])[0][0],:],axis=0)),
    #                             axis=0)
    #    y_train = np.concatenate((y_train,np.expand_dims(y_train[np.where(y_train==idy[j])[0][0]],
    #                              axis=0)))

    with warnings.catch_warnings() and np.errstate(divide='ignore', invalid='ignore'):
      # Feature selection
      warnings.filterwarnings("ignore", category=ConvergenceWarning)
      sel = skf.SelectKBest(skf.r_regression,k=100)
      X0_sst = sel.fit_transform(X0_ss0,y_train)
      X_test_sst = sel.transform(X_test_ss0)
      gel = skf.RFECV(lasso,verbose=0,cv=cvn,step=1,n_jobs=1)
      X0_ss = gel.fit_transform(X0_sst,y_train)
      kappa.append(np.linalg.cond(X0_ss0))
      X_test_ss = gel.transform(X_test_sst)
     #Ks.append(sel.transform(K.reshape(1, -1)))
      dx, y_n = cKDTree(X0_ss).query(X_test_ss, k=1)

    # LASSO
    try:
      lassoc = CORAL(lasso, Xt=X_test_ss, random_state=0)
      est_ls = lassoc.fit(X0_ss,y_train)
      dx, y_n = cKDTree(X0_ss).query(lassoc.transform(X_test_ss), k=1)
    except:
      print('CORAL failed on test case with covariance:',np.cov(X_test_ss0))
      est_ls = lasso.fit(X0_ss,y_train)
    # Reconstruct nearest neighbor
    r[j] = est_ls.predict(X0_ss[y_n,:])
    err_var[j] = np.mean(abs(est_ls.predict(X0_ss)-y_train))
    rerror[j] = np.abs(r[j]-y_train[y_n])
    #s.append(est_ls.score(X0_ss,y_train))
    results_ls[j] = est_ls.predict(X_test_ss)
    # # If reconstruction error is too high, use nearest neighbor
    # if rerror[j] > 0.1:
    #   results_ls[j] = y_train[y_n]
    #   print('Using nearest neighbor')

    # else:
    #   # LASSO
    #   with warnings.catch_warnings() and np.errstate(divide='ignore', invalid='ignore'):
    #     warnings.filterwarnings("ignore", category=ConvergenceWarning)
    #     lasso = slm.LassoLarsCV(max_iter=1000,cv=cvn,n_jobs=1,normalize=False,eps=0.1)
    #     est_ls = lasso.fit(X0_ss,y_train)
    #   # Reconstruct nearest neighbor
    #   r[j] = est_ls.predict(X0_ss[y_n,:])
    #   err_var[j] = np.mean(abs(est_ls.predict(X0_ss)-y_train))
    #   rerror[j] = np.abs(r[j]-y_train[y_n])
    #   #s.append(est_ls.score(X0_ss,y_train))
    #   results_ls[j] = est_ls.predict(X_test_ss)
    #     #results_ls[j] = r[j]
    #   w.append(est_ls.coef_)
    print('Lasso predicts',str(np.round(results_ls[j],2)),
              'for case with',str(np.round(per_change[j],2)))
              #'with regularization',str(est_ls.alpha_))

In [None]:
util.eval_prediction(np.vstack((np.squeeze(pre_imp),
                               results_ls,
                               )),
                               np.squeeze(per_change),
                               ['LCT',
                                'Lasso',
                                ],(20,10))
plt.ylim([0,1.5])
plt.xlim([0,1.5])
plt.style.use('default')