# NOT the latest file, the latest is on Google Colab

In [1]:
# Using newly preprocessed subjects

import pickle

metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [2]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return allsubs

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['age'], ['emoid'])
print(len(subs))

X = get_X(allts, ['emoid'], subs)
print(X[0].shape)

922
(922, 264, 210)


In [3]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 1.83
N = X[0].shape[0]

def filter_design_ts(X):
    Xs = []
    for i in range(X.shape[0]):
        nX = butter_bandpass_filter(X[i], [tr/20*N, 0.8*N], 2*N)
        Xs.append(nX)
    return np.stack(Xs)

def ts_to_flat_fc(X):
    p = np.corrcoef(X)
    a,b = np.triu_indices(p[0].shape[0], 1)
    p = p[a,b]
    return p

p = [np.corrcoef(ts) for ts in filter_design_ts(X[0])]
p = np.stack(p)
print(p.shape)

(922, 264, 264)


In [5]:
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F

import sys

sys.path.append('/home/anton/Documents/Tulane/Research/LatentSimilarity')

from latsim import LatSim

# outliers = np.array([167, 405, 58, 129, 602])
outliers = np.array([141,429])
admit = np.setdiff1d(np.arange(p.shape[0]), outliers)
x = torch.from_numpy(p[admit]).float().cuda()
xtr = x[:600]
xt = x[600:]

def mask(e):
    return e - torch.diag(torch.diag(e.detach()))

xtr = torch.stack([mask(xtr[i]) for i in range(xtr.shape[0])])

class LowRankCodes(nn.Module):
    def __init__(self, ranks):
        super(LowRankCodes, self).__init__()
        self.book = []
        for rank in ranks:
            self.book.append(nn.Parameter(1e-5*torch.randn(rank,264).float().cuda()))
        self.book = nn.ParameterList(self.book)
        
    def forward(self):
        book = []
        for page in self.book:
            book.append(mask(page.T@page))
        return torch.stack(book)
    
class LowRankWeights(nn.Module):
    def __init__(self, nsubs, nranks):
        super(LowRankWeights, self).__init__()
        self.w = nn.Parameter(torch.ones(nsubs, nranks).float().cuda())
        
    def forward(self, book):
        w = F.relu(self.w)
        return torch.einsum('nr,rab->nab',w,book)
    
nEpochs = 3000
pPeriod = 400

mseLoss = nn.MSELoss()
    
lrc = LowRankCodes(400*[1])
lrw = LowRankWeights(xtr.shape[0], 400)

optim = torch.optim.Adam(
    [
        {'params': itertools.chain(lrc.parameters(), lrw.parameters())}, 
    ], 
    lr=1e-2, 
    weight_decay=0
)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.95, eps=1e-7)
    
for epoch in range(nEpochs):
    optim.zero_grad()
    book = lrc()
    xhat = lrw(book)
    xloss = mseLoss(xhat, xtr)
    (xloss).backward()
    optim.step()
    sched.step(xloss)
    if epoch % pPeriod == 0 or epoch == nEpochs-1:
        print(f'{epoch} {float(xloss)} lr: {sched._last_lr}')
        
print('Complete')

0 0.14906658232212067 lr: [0.01]
400 0.019359642639756203 lr: [0.01]
800 0.01606043055653572 lr: [0.01]


KeyboardInterrupt: 

In [None]:
# Estimate test set

xt = torch.stack([mask(xt[i]) for i in range(xt.shape[0])])

lrw2 = LowRankWeights(xt.shape[0], 400)

optim = torch.optim.Adam(
    itertools.chain(lrw2.parameters()), 
    lr=1e-2, 
    weight_decay=0
)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.95, eps=1e-7)

with torch.no_grad():
    book = lrc()
    
nEpochs = 1000

for epoch in range(nEpochs):
    optim.zero_grad()
    xhat2 = lrw2(book)
    loss = mseLoss(xhat2, xt)
    loss.backward()
    optim.step()
    sched.step(loss)
    if epoch % pPeriod == 0 or epoch == nEpochs-1:
        print(f'{epoch} {float(loss)} lr: {sched._last_lr}')
    
print(float(loss))

In [92]:
y = get_y(metadict, ['age'], subs)[0]
y_t = torch.from_numpy(y[admit]).float().cuda()
ytr = y_t[:600]
yt = y_t[600:]

wtr = lrw.w.clone().detach().unsqueeze(1)
wt = lrw2.w.clone().detach().unsqueeze(1)
w = torch.cat([wtr, wt])

nEpochs = 1000
pPeriod = 100

lrsim = LatSim(1, torch.zeros(1,1,400), dp=0.5, edp=0, temp=1)
optim = torch.optim.Adam(
    itertools.chain(lrsim.parameters()), 
    lr=1e-4, 
    weight_decay=1e-4
)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.95, eps=1e-7)

for epoch in range(nEpochs):
    optim.zero_grad()
    yhat = lrsim(wtr, [ytr])[0][0]
    loss = mseLoss(yhat, ytr)
    loss.backward()
    optim.step()
    sched.step(loss)
    if epoch % pPeriod == 0 or epoch == nEpochs-1:
        print(f'{epoch} {float(loss)**0.5} lr: {sched._last_lr}')
    
print('Complete')

lrsim.eval()
yhat = lrsim(w, [y_t], torch.arange(600,y_t.shape[0]))[0][0]
print(mseLoss(yhat[600:], yt)**0.5)

0 41.051036796533076 lr: [0.0001]
100 34.84720239342661 lr: [0.0001]
200 31.170241968502587 lr: [9.5e-05]
300 27.992952822359484 lr: [6.634204312890622e-05]
400 28.078829289998453 lr: [5.133420832795048e-05]
500 28.05702941179024 lr: [3.584859224085419e-05]
600 27.847305293499673 lr: [2.2593554099256555e-05]
700 26.654871557266507 lr: [1.4989025404881544e-05]
800 27.366408023084386 lr: [9.446824413773763e-06]
900 27.220272542806907 lr: [5.953855510552941e-06]
1000 27.523520658758848 lr: [3.7524139211116024e-06]
1100 26.647893528235727 lr: [2.3649566588229932e-06]
1200 27.30782796675193 lr: [1.9262719795904457e-06]
1300 26.644848226471826 lr: [1.9262719795904457e-06]
1400 25.54381432353658 lr: [1.9262719795904457e-06]
1500 26.71119684906266 lr: [1.9262719795904457e-06]
1600 27.357292405388115 lr: [1.9262719795904457e-06]
1700 26.51636273465464 lr: [1.9262719795904457e-06]
1800 25.44363096992239 lr: [1.9262719795904457e-06]
1900 26.773029432256426 lr: [1.9262719795904457e-06]
1999 26.909

In [178]:
w, _, _, _ = torch.linalg.lstsq(wtr.squeeze(), ytr)

print(mseLoss(wtr.squeeze()@w, ytr)**0.5)
print(mseLoss(wt.squeeze()@w, yt)**0.5)

tensor(13.7078, device='cuda:0')
tensor(16.7623, device='cuda:0')


In [172]:
wtr.squeeze().shape

torch.Size([600, 50])

In [170]:
wt.shape

torch.Size([305, 1, 50])

In [171]:
ytr.shape

torch.Size([600])

In [173]:
yt.shape

torch.Size([305])