In [1]:
# Using newly preprocessed subjects

import pickle

metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [2]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return allsubs

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['age', 'sex', 'wrat'], ['rest', 'nback', 'emoid'])
print(len(subs))

X = get_X(allts, ['rest', 'nback', 'emoid'], subs)
print(X[0].shape)

830
(830, 264, 124)


In [3]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 1.83

def filter_design_ts(X):
    Xs = []
    for i in range(X.shape[0]):
        nX = butter_bandpass_filter(X[i], [0.01, 0.2], 1/tr)
        Xs.append(nX)
    return np.stack(Xs)

def ts_to_flat_fc(X):
    p = np.corrcoef(X)
    a,b = np.triu_indices(p[0].shape[0], 1)
    p = p[a,b]
    return p

p = [np.stack([ts_to_flat_fc(ts) for ts in filter_design_ts(Xp)]) for Xp in X]
print(p[0].shape)

(830, 34716)


In [4]:
import sys

sys.path.append('../../LatentSimilarity')

from latsim import LatSim

print('Complete')

Complete


In [7]:
import torch
import torch.nn as nn

mseLoss = nn.MSELoss()

nreps = 20
trainsizes = [600]
res = np.zeros((nreps,len(trainsizes)))

nepochs = 500
pperiod = 100
verbose = True

for rep in range(nreps):

    idcs = np.arange(p[2].shape[0])
    np.random.shuffle(idcs)
    
    losses = []

    for ntrain in trainsizes:
        xps = torch.from_numpy(p[2][idcs]).float().cuda().unsqueeze(1)
        
        mu = torch.mean(xps[:ntrain], dim=0, keepdims=True)
        std = torch.std(xps[:ntrain], dim=0, keepdims=True)
        xps = (xps-mu)/std
    
        xtr = xps[:ntrain]
        xt = xps[ntrain:]

        y = get_y(metadict, ['age'], subs)[0]
        y_t = torch.from_numpy(y[idcs]).float().cuda()
        ytr = y_t[:ntrain]
        yt = y_t[ntrain:]

        # dp=0.5 for wrat
        sim = LatSim(1, xps, dp=0.5, edp=0.1, wInit=1e-4, dim=2, temp=1)
        optim = torch.optim.Adam(sim.parameters(), lr=1e-4, weight_decay=1e-4)
        sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.75, eps=1e-7)
        
        for epoch in range(nepochs):
            optim.zero_grad()
            yhat = sim(xtr, [ytr])[0][0]
            loss = mseLoss(yhat, ytr)**0.5
            loss.backward()
            optim.step()
            sched.step(loss)
            if verbose:
                if epoch % pperiod == 0 or epoch == nepochs-1:
                    print(f'{epoch} {float(loss)} {sched._last_lr}')
                    
        sim.eval()
        yhat = sim(xps, [y_t], np.arange(ntrain,idcs.shape[0]))[0][0][ntrain:]
        loss = mseLoss(yhat, yt)**0.5
        losses.append(float(loss))
        
#         xps = torch.from_numpy(p[2][idcs]).float().cuda()
#         xps = torch.cat([xps, torch.ones(xps.shape[0], 1).float().cuda()], dim=1)
#         xtr = xps[:ntrain]
#         xt = xps[ntrain:]

#         y = get_y(metadict, ['age'], subs)[0]
#         y_t = torch.from_numpy(y[idcs]).float().cuda()
#         ytr = y_t[:ntrain]
#         yt = y_t[ntrain:]

#         # REDUCE THIS TO GET GOOD RESULTS WITH SPARSITY 0.01->0.001 or 0.0001
#         w, _, _, _ = torch.linalg.lstsq(xtr, ytr)

#         print(torch.mean((yt-xt@w)**2)**0.5)
#         losses.append(float(torch.mean((yt-xt@w)**2)**0.5))
    
        print(losses[-1])
    
    res[rep,:] = losses
    print(f'Finished {rep}')
    
print(np.mean(res, axis=0))
print(np.std(res, axis=0))

0 39.38517379760742 [0.0001]
100 7.366809844970703 [7.500000000000001e-05]
200 5.412649631500244 [2.373046875e-05]
300 5.187360763549805 [5.631351470947265e-06]
400 5.135972499847412 [7.516946818213909e-07]
499 5.233081817626953 [3.171211938933993e-07]
24.85190773010254
Finished 0
0 39.76314163208008 [0.0001]
100 7.131792068481445 [5.6250000000000005e-05]
200 6.810090065002441 [1.0011291503906249e-05]
300 6.074042797088623 [7.516946818213909e-07]
400 6.756427764892578 [3.171211938933993e-07]
499 6.421941757202148 [3.171211938933993e-07]
25.419036865234375
Finished 1
0 39.467933654785156 [0.0001]
100 6.383162021636963 [7.500000000000001e-05]
200 5.777366638183594 [2.373046875e-05]
300 5.199838161468506 [2.3757264018058777e-06]
400 5.283851623535156 [3.171211938933993e-07]
499 5.483890533447266 [3.171211938933993e-07]
25.96707534790039
Finished 2
0 40.01612854003906 [0.0001]
100 7.066390514373779 [5.6250000000000005e-05]
200 6.758663177490234 [1.3348388671874999e-05]
300 6.65279912948608