In [1]:
# Using newly preprocessed subjects

import pickle

metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [2]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np
from natsort import natsorted

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return natsorted(list(allsubs))

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['wrat'], ['rest', 'nback', 'emoid'])
print(len(subs))

X = get_X(allts, ['rest', 'nback', 'emoid'], subs)
print(X[0].shape)

830
(830, 264, 124)


In [3]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 3

def filter_design_ts(X):
    Xs = []
    for i in range(X.shape[0]):
        nX = butter_bandpass_filter(X[i], [0.01, 0.15], 1/tr)
        Xs.append(nX)
    return np.stack(Xs)

def ts_to_flat_fc(X):
    p = np.corrcoef(X)
    a,b = np.triu_indices(p[0].shape[0], 1)
    p = p[a,b]
    return p

p = [np.stack([ts_to_flat_fc(ts) for ts in filter_design_ts(Xp)]) for Xp in X]
# Xfiltnorm = [tsmod/np.linalg.norm(tsmod, axis=(-1), keepdims=True) for tsmod in ts]
print(p[0].shape)

(830, 34716)


In [4]:
# Get all partial correlations

from nilearn.connectome import ConnectivityMeasure

cm = ConnectivityMeasure(kind='partial correlation')

a,b = np.triu_indices(264, 1)

allp = []
for taskidx in range(3):
    Xf = filter_design_ts(X[taskidx])
    partials = cm.fit_transform(Xf.transpose(0,2,1))
    partials = partials[:,a,b]
    allp.append(partials)
    print('Done partials')

allp = np.stack(allp)
print(allp.shape)

Done partials
Done partials
Done partials
(3, 830, 34716)


In [67]:
from sklearn.base import BaseEstimator

import sys

if '../../LatentSimilarity' not in sys.path:
    sys.path.append('../../LatentSimilarity/')

from latsim import LatSim, train_sim_mse, train_sim_ce

import torch
import torch.nn.functional as F

def to_torch(x):
    if not isinstance(x, torch.Tensor):
        return torch.from_numpy(x).float().cuda()
    else:
        return x

'''
One class for regression, one (sub)class for classification
'''
class LatSimReg(BaseEstimator):
    def __init__(self, **params):
        self.set_params(**params)

    @staticmethod
    def get_default_params():
        return dict(ld=2, stop=1, lr=1e-4, nepochs=100)

    @staticmethod
    def get_default_distributions():
        return dict(
            ld=[1,2,10],
            stop=[0,1,10*10,100*100],
            lr=[1e-5,1e-4,1e-3],
            nepochs=[100,1000,2000],
        )

    def get_params(self, deep=False):
        return dict(ld=self.ld, stop=self.stop, lr=self.lr, nepochs=self.nepochs)

    def set_params(self, **params):
        dft = LatSimReg.get_default_params()
        for key in dft:
            if key in params:
                setattr(self, key, params[key])
            else:
                setattr(self, key, dft[key])
        return self

    def fit(self, x, y, **kwargs):
        x = to_torch(x)
        y = to_torch(y)
        self.x = 1*x
        self.y = 1*y
        params = LatSimReg.get_default_params()
        for arg in kwargs:
            if arg in params:
                params[arg] = kwargs[arg]
        self.sim = LatSim(x.shape[1], params['ld'])
        del params['ld']
        train_sim_mse(self.sim, self.x, self.y, **params)
        return self

    def predict(self, x):
        x = torch.from_numpy(x).float().cuda()
        with torch.no_grad():
            yhat = self.sim(self.x, self.y, x)
        return yhat.detach().cpu().numpy()
    
class LatSimClf(LatSimReg):
    @staticmethod
    def get_default_params():
        return dict(ld=2, stop=1, lr=1e-4, nepochs=100)

    @staticmethod
    def get_default_distributions():
        return dict(
            ld=[1,2,10],
            stop=[0,0.1,0.2,0.3],
            lr=[1e-5,1e-4,1e-3],
            nepochs=[100,1000,2000],
        )
    
    def set_params(self, **params):
        dft = LatSimClf.get_default_params()
        for key in dft:
            if key in params:
                setattr(self, key, params[key])
            else:
                setattr(self, key, dft[key])
        return self
    
    def fit(self, x, y, **kwargs):
        y = to_torch(y).long()
        y = F.one_hot(y).float()
        return super().fit(x, y, **kwargs)

    def predict(self, x):
        yhat = super().predict(x)
        return np.argmax(yhat, axis=1)

print('Done')

Done


In [80]:
ar = []

for i in range(20):
    idcs = np.random.permutation(830)
    ntrain = 500
    task ='sex'

    x = p[2]
    x = x[idcs]
    xtr = x[:ntrain]
    xt = x[ntrain:]

    mux = np.mean(xtr, axis=0, keepdims=True)
    sigx = np.std(xtr, axis=0, keepdims=True)
    xtr = xtr - mux
    xt = xt - mux

    y = get_y(metadict, [task], subs)[0]
    y = y[idcs]
    y = np.argmax(y, axis=1)
    ytr = y[:ntrain]
    yt = y[ntrain:]

#     mu = np.mean(ytr)
#     ytr = ytr - mu
#     yt = yt - mu

    reg = LatSimClf().fit(xtr, ytr, ld=2, nepochs=1000, lr=1e-4, stop=0)
    yhat = reg.predict(xt)
    acc = np.sum(yhat == yt)/len(yt)
    acc = float(acc)
    print(acc)
    ar.append(acc)
    # print(np.sum(yhat == yt)/len(yt))
#     rmse = np.mean((yhat-yt)**2)**0.5
#     rmse = float(rmse)
#     print(rmse)
#     ar.append(rmse)

np.mean(np.array(ar))

0.7696969696969697
0.7787878787878788
0.7727272727272727
0.7303030303030303
0.7454545454545455
0.7636363636363637
0.7424242424242424
0.7424242424242424
0.7393939393939394
0.7484848484848485
0.7636363636363637
0.706060606060606
0.7393939393939394
0.7727272727272727
0.7515151515151515
0.7727272727272727
0.7090909090909091
0.7666666666666667
0.7666666666666667
0.7606060606060606


0.7521212121212122

In [73]:
# Grid search for WRAT

from sklearn.model_selection import RandomizedSearchCV, GridSearchCV

x = p[2]
y = get_y(metadict, ['sex'], subs)[0]

x = x - np.mean(x, axis=0, keepdims=True)
y = np.argmax(y, axis=1)
# y = y - np.mean(y)

reg = LatSimClf()
params = LatSimClf.get_default_distributions()
n_iter = 100
reg = RandomizedSearchCV(reg, params, scoring='accuracy')
search = reg.fit(x, y)
search.best_params_

{'stop': 0.1, 'nepochs': 1000, 'lr': 0.0001, 'ld': 10}