In [1]:
# Using newly preprocessed subjects

import pickle

metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [2]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return allsubs

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['age'], ['rest', 'nback', 'emoid'])
print(len(subs))

X = get_X(allts, ['rest'], subs)
print(X[0].shape)

847
(847, 264, 124)


In [3]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 1.83
N = X[0].shape[0]

def filter_design_ts(X):
    Xs = []
    for i in range(X.shape[0]):
        nX = butter_bandpass_filter(X[i], [tr/20*N, 0.8*N], 2*N)
        Xs.append(nX)
    return np.stack(Xs)

def ts_to_flat_fc(X):
    p = np.corrcoef(X)
    a,b = np.triu_indices(p[0].shape[0], 1)
    p = p[a,b]
    return p

pflat = [ts_to_flat_fc(ts) for ts in filter_design_ts(X[0])]
pflat = np.stack(pflat)
print(pflat.shape)

(847, 34716)


In [6]:
# Standardize pflat

mu = np.mean(pflat, axis=0, keepdims=True)
sigma = np.std(pflat, axis=0, keepdims=True)
pflat = (pflat-mu)/sigma

print(sigma.shape)
print(pflat.shape)

(1, 34716)
(847, 34716)


In [10]:
from sklearn.decomposition import PCA

pca = PCA(n_components=100)
psmall = pca.fit_transform(pflat)

print(psmall.shape)

(847, 100)


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

mseLoss = nn.MSELoss()

x = torch.from_numpy(psmall).float().cuda()
xtr = x[:600]
xt = x[600:]

y = get_y(metadict, ['age'], subs)[0]
y_t = torch.from_numpy(y).float().cuda()
ytr = y_t[:600]
yt = y_t[600:]

print(x.shape)
print(xtr.shape)
print(xt.shape)
print(y.shape)

torch.Size([847, 100])
torch.Size([600, 100])
torch.Size([247, 100])
(847,)


In [15]:
class MLP(nn.Module):
    def __init__(self, dim):
        super(MLP, self).__init__()
        self.l0 = nn.Linear(dim,40).float().cuda()
        self.l1 = nn.Linear(40,1).float().cuda()
        
    def forward(self, x):
        y = F.relu(self.l0(x))
        return self.l1(y).squeeze()
    
mlp = MLP(x.shape[-1])
optim = torch.optim.Adam(mlp.parameters(), lr=1e-3, weight_decay=1e-1)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.95, eps=1e-7)

nEpochs = 10000
pPeriod = 500
lmbda = 1e-2

for epoch in range(nEpochs):
    optim.zero_grad()
    yhat = mlp(xtr)
    loss = mseLoss(yhat, ytr)
    sloss = lmbda*(torch.sum(torch.abs(mlp.l0.weight))+torch.sum(torch.abs(mlp.l0.bias)))
    (sloss+loss).backward()
    optim.step()
    sched.step(loss)
    if epoch % pPeriod == 0 or epoch == nEpochs-1:
        print(f'{epoch} {float(loss)**0.5} {float(sloss)} lr: {sched._last_lr}')
        
yhat = mlp(xt)
loss = mseLoss(yhat, yt)**0.5
print(loss)

0 182.35013110085774 2.02274227142334 lr: [0.001]
500 23.992272722495986 8.777979850769043 lr: [0.001]
1000 9.789697284012293 10.006196022033691 lr: [0.001]
1500 3.5190125823081906 10.65129280090332 lr: [0.001]
2000 1.188502892082653 10.759209632873535 lr: [0.001]
2500 0.5125196174030429 10.579184532165527 lr: [0.001]
3000 0.36477485349300703 10.277637481689453 lr: [0.001]
3500 0.3275233838782166 9.919389724731445 lr: [0.001]
4000 0.30239545247002164 9.529871940612793 lr: [0.001]
4500 0.2795042455219412 9.114214897155762 lr: [0.001]
5000 0.2545028865787283 8.678559303283691 lr: [0.001]
5500 0.23268585280625106 8.226531028747559 lr: [0.001]
6000 0.2110021015425293 7.752992153167725 lr: [0.001]
6500 0.18979868883889983 7.265472412109375 lr: [0.001]
7000 0.17334447505297956 6.818976402282715 lr: [0.0009025]
7500 0.1560455713485971 6.377130031585693 lr: [0.0009025]
8000 0.1434218057656591 5.986624240875244 lr: [0.0006983372960937497]
8500 0.1318490057225688 5.67603063583374 lr: [0.00054036

In [12]:
w, _, _, _ = torch.linalg.lstsq(xtr, ytr)
yhattr = xtr@w
yhatt = xt@w
print(mseLoss(yhattr, ytr)**0.5)
print(mseLoss(yhatt, yt)**0.5)

tensor(178.1264, device='cuda:0')
tensor(213.0996, device='cuda:0')


In [10]:
import sys 

sys.path.append('../../LatentSimilarity')

from latsim import LatSim

print('Complete')

Complete


In [124]:
sim = LatSim(1, x.unsqueeze(1), dp=0, edp=0, wInit=1e-4, dim=1, temp=1)
optim = torch.optim.Adam(sim.parameters(), lr=1e-4, weight_decay=1e-4)

nEpochs = 100
pPeriod = 100

sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.95, eps=1e-7)

for epoch in range(nEpochs):
    optim.zero_grad()
    yhat = sim(xtr.unsqueeze(1), [ytr])[0][0]
    loss = mseLoss(yhat, ytr)
    loss.backward()
    optim.step()
    sched.step(loss)
    if epoch % pPeriod == 0 or epoch == nEpochs-1:
        print(f'{epoch} {float(loss)**0.5} lr: {sched._last_lr}')
        
print('Complete')

print(mseLoss(sim(x.unsqueeze(1), [y_t], torch.arange(600,y_t.shape[0]))[0][0][600:], yt)**0.5)

0 40.08103821039866 lr: [0.0001]
99 31.769400368982218 lr: [0.0001]
Complete
tensor(36.6730, device='cuda:0', grad_fn=<PowBackward0>)
