# Calculate low-rank dynamic EC
Use a codebook
Two options
1. Time-delayed dynamic correlation $\mathbf{E}_{12} = \mathbf{t}^T_1\mathbf{t}_2$
2. <span style='color: red;'>Transfer function (what I was doing before) $\mathbf{E}_{12}\mathbf{t}_1 = \mathbf{t}_2$ &lt;- This notebook covers this option</span>

The second option probably requires a very small codebook otherwise it's very underdetermined

In [4]:
# Using newly preprocessed subjects

import pickle

metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [5]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return allsubs

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['age'], ['rest', 'nback', 'emoid'])
print(len(subs))

X = get_X(allts, ['rest', 'nback', 'emoid'], subs)
print(X[0].shape)

847
(847, 264, 124)


In [132]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 1.83

def filter_design_ts(X):
    Xs = []
    for i in range(X.shape[0]):
        nX = butter_bandpass_filter(X[i], [0.01, 0.2], 1/tr)
        Xs.append(nX)
    return np.stack(Xs)

def ts_to_flat_fc(X):
    p = np.corrcoef(X)
    a,b = np.triu_indices(p[0].shape[0], 1)
    p = p[a,b]
    return p

ts = [np.stack([ts for ts in filter_design_ts(Xp)]) for Xp in X]
ts = [tsmod/np.linalg.norm(tsmod, axis=-1, keepdims=True) for tsmod in ts]
print(ts[0].shape)

(847, 264, 124)


In [133]:
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F

class LowRankCodes(nn.Module):
    '''
    ranks: array of rank for each codebook matrix
    '''
    def __init__(self, ranks):
        super(LowRankCodes, self).__init__()
        self.As = []
        self.Bs = []
        for rank in ranks:
            A = nn.Parameter(1e-2*torch.randn(rank,264).float().cuda())
            B = nn.Parameter(1e-2*torch.randn(rank,264).float().cuda())
            self.As.append(A)
            self.Bs.append(B)
        self.As = nn.ParameterList(self.As)
        self.Bs = nn.ParameterList(self.Bs)

    def forward(self):
        book = []
        for A,B in zip(self.As,self.Bs):
            AB = A.T@B
            book.append(AB)
        return torch.stack(book)
    
class LowRankWeights(nn.Module):
    '''
    For a single modality!
    
    nsubs: number of subjects
    ncodes: number of pages in the codebook
    nt: number of timepoints
    '''
    def __init__(self, nsubs, ncodes, nt):
        super(LowRankWeights, self).__init__()
        self.w = nn.Parameter(1e-2*torch.rand(nsubs, ncodes, nt-1).float().cuda())

    # Returns transfer function at each time point 0:-1
    def forward(self, sub, book):
        w = self.w[sub]
        return torch.einsum('pt,pab->abt', w, book)
    
def get_recon_loss(x, xhat):
    return mseLoss(xhat, x)

def get_smooth_loss(tf):
    before = tf[:,:,:-1]
    after = tf[:,:,1:]
    return torch.mean((before-after)**2)

def get_sub_xhat(tf, subts):
    return torch.einsum('abt,bt->at',tf,subts[:,:-1])
    
# Timeseries
x = torch.from_numpy(ts[0]).float().cuda()
    
# Parameters
ntrain = 400
nbatch = 30
smooth_mult = 1
nEpochs = 50
pPeriod = 40

mseLoss = nn.MSELoss()
    
# Codebook and weights
lrc = LowRankCodes(300*[1])
ncodes = len(lrc.As)

lrw = LowRankWeights(ntrain, ncodes, x.shape[-1])

# Optimizers
optim = torch.optim.Adam(itertools.chain(lrc.parameters(), lrw.parameters()), lr=1e-2, weight_decay=0)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=int(ntrain/nbatch)+5, factor=0.75, eps=1e-7)
    
for epoch in range(nEpochs):
    suborder = np.arange(ntrain)
#     np.random.shuffle(suborder)
    for bstart in range(0,ntrain,nbatch):
        bend = bstart+nbatch
        if bend > ntrain:
            bend = ntrain
        optim.zero_grad()
        book = lrc()
        recon_loss = 0
        smooth_loss_tf = 0
        for subidx in range(bstart, bend):
            sub = suborder[subidx]
            tf = lrw(subidx, book)
            xhat = get_sub_xhat(tf, x[sub])
            recon_loss += get_recon_loss(x[sub,:,1:], xhat)
            smooth_loss_tf += smooth_mult*get_smooth_loss(tf)
        recon_loss /= (bend-bstart)
        smooth_loss_tf /= (bend-bstart)
        totloss = recon_loss+smooth_loss_tf
        totloss.backward()
        optim.step()
        sched.step(totloss)
        if bstart % nbatch == 0:
            print(f'{epoch} {bstart} recon: {[float(ls)**0.5 for ls in [recon_loss, smooth_loss_tf]]} '
                  f'lr: {sched._last_lr}')

print('Complete')

0 0 recon: [7.994211725496176, 7.113235511657167e-06] lr: [0.01]
0 30 recon: [7.920311286280485, 1.4169723791368899e-05] lr: [0.01]
0 60 recon: [9.092452857981138, 2.156844189848287e-05] lr: [0.01]
0 90 recon: [8.258734962815744, 3.863084008517925e-05] lr: [0.01]
0 120 recon: [9.372863119096342, 6.51267587939882e-05] lr: [0.01]
0 150 recon: [9.029240896227622, 9.873394712863727e-05] lr: [0.01]
0 180 recon: [12.515852984424768, 0.00013878519484619177] lr: [0.01]
0 210 recon: [7.509750894003758, 0.0001809486873011394] lr: [0.01]
0 240 recon: [7.321159777874272, 0.00020823579393075247] lr: [0.01]
0 270 recon: [10.628739719845743, 0.00023167886090707328] lr: [0.01]
0 300 recon: [8.31541724391685, 0.00021473349863935847] lr: [0.01]
0 330 recon: [8.422318427420748, 0.00019753111473094675] lr: [0.01]
0 360 recon: [6.920855651611255, 0.0001795265482047033] lr: [0.01]
0 390 recon: [7.423430912157017, 0.0001639160242625242] lr: [0.01]
1 0 recon: [8.705892071619962, 0.0031776481172716736] lr: [0.

8 300 recon: [5.809420518455684, 0.012582267605122805] lr: [0.0075]
8 330 recon: [6.1125788958553855, 0.012233439133563395] lr: [0.0075]
8 360 recon: [5.087413380328332, 0.012746582445891] lr: [0.0075]
8 390 recon: [5.354789173239573, 0.013027266003685548] lr: [0.0075]
9 0 recon: [5.275251785705389, 0.012984356475422655] lr: [0.0075]
9 30 recon: [5.414686378984814, 0.012716433946430867] lr: [0.0075]
9 60 recon: [5.510936787029033, 0.012227573395831595] lr: [0.0075]
9 90 recon: [5.4475329563637285, 0.012155942103404961] lr: [0.0075]
9 120 recon: [5.8617901923416875, 0.012352092240622836] lr: [0.0075]
9 150 recon: [5.925257539584283, 0.011858183257979492] lr: [0.0075]
9 180 recon: [7.994446257213357, 0.011611992990012417] lr: [0.0075]
9 210 recon: [5.41186707916896, 0.011495486882911429] lr: [0.0075]
9 240 recon: [5.310055158942855, 0.011686626876836446] lr: [0.0075]
9 270 recon: [6.947442156265026, 0.011463318885624645] lr: [0.0075]
9 300 recon: [5.652479400799692, 0.011685849860758565]

17 90 recon: [4.121021779955095, 0.016025977706558135] lr: [0.005625]
17 120 recon: [4.4054227011747, 0.016431629469965334] lr: [0.005625]
17 150 recon: [4.500541230608626, 0.016147528363604138] lr: [0.005625]
17 180 recon: [6.009690406683167, 0.016175629004713366] lr: [0.005625]
17 210 recon: [4.128647087863413, 0.016834165243081595] lr: [0.005625]
17 240 recon: [4.024767964942931, 0.01720362659895899] lr: [0.005625]
17 270 recon: [5.061412741620667, 0.01681830849699078] lr: [0.005625]
17 300 recon: [4.047986676914168, 0.017442552973630215] lr: [0.005625]
17 330 recon: [4.3125725339235235, 0.016416100152040317] lr: [0.005625]
17 360 recon: [3.807822813600043, 0.017492451436746444] lr: [0.005625]
17 390 recon: [3.9350102136811413, 0.018005056701760214] lr: [0.005625]
18 0 recon: [3.9989188640554874, 0.017441562659095203] lr: [0.005625]
18 30 recon: [4.2057289651057745, 0.017607774002364598] lr: [0.005625]
18 60 recon: [4.1756972044862986, 0.01697386316753597] lr: [0.005625]
18 90 recon

25 210 recon: [3.353358074524187, 0.02132578862322146] lr: [0.00421875]
25 240 recon: [3.2243526548323787, 0.021300205160399464] lr: [0.00421875]
25 270 recon: [4.222648797041873, 0.020088429781039707] lr: [0.00421875]
25 300 recon: [3.4520227390481892, 0.020978382314029893] lr: [0.00421875]
25 330 recon: [3.7049482661662445, 0.01929510153191483] lr: [0.00421875]
25 360 recon: [3.25657311884649, 0.02059172523953369] lr: [0.00421875]
25 390 recon: [3.392160797365465, 0.020825528397082127] lr: [0.00421875]
26 0 recon: [3.436160711000539, 0.020067978452747423] lr: [0.00421875]
26 30 recon: [3.624597066650147, 0.02026258388149261] lr: [0.00421875]
26 60 recon: [3.619999429059906, 0.019710051654337955] lr: [0.00421875]
26 90 recon: [3.407634348806159, 0.019937795437995728] lr: [0.00421875]
26 120 recon: [3.662528621835155, 0.020498881211731743] lr: [0.00421875]
26 150 recon: [3.6848988814953527, 0.020279098578935437] lr: [0.00421875]
26 180 recon: [4.543036554671882, 0.02031161771747184] lr

33 240 recon: [2.5103653603108325, 0.024228379331064708] lr: [0.00421875]
33 270 recon: [3.2418748084340905, 0.022807228607333533] lr: [0.00421875]
33 300 recon: [2.5622882290304365, 0.02415059140288132] lr: [0.00421875]
33 330 recon: [2.772235190676126, 0.022385834478254533] lr: [0.00421875]
33 360 recon: [2.4486103671164083, 0.024204631408685267] lr: [0.00421875]
33 390 recon: [2.486370700612493, 0.02430118334196403] lr: [0.00421875]
34 0 recon: [2.650549122759404, 0.02376309377406522] lr: [0.00421875]
34 30 recon: [2.753974556583064, 0.02428476667053346] lr: [0.00421875]
34 60 recon: [2.7625895938757665, 0.02363663260268155] lr: [0.00421875]
34 90 recon: [2.4562067003777326, 0.024013670480750603] lr: [0.00421875]
34 120 recon: [2.5675577631528497, 0.024543085109678325] lr: [0.00421875]
34 150 recon: [2.768206151589685, 0.024235813797581007] lr: [0.00421875]
34 180 recon: [3.444482958656934, 0.023952131455582195] lr: [0.00421875]
34 210 recon: [2.5412971874233654, 0.02506288966597768

41 210 recon: [2.055775865427579, 0.027078390405903874] lr: [0.0031640625]
41 240 recon: [2.016726880186094, 0.027156906727969856] lr: [0.0031640625]
41 270 recon: [2.598749237720989, 0.025666562296850753] lr: [0.0031640625]
41 300 recon: [2.0513135237449625, 0.027268813685948936] lr: [0.0031640625]
41 330 recon: [2.192952689889035, 0.025287776353292286] lr: [0.0031640625]
41 360 recon: [1.9522951728361346, 0.027350106735095865] lr: [0.0031640625]
41 390 recon: [1.9914141780894496, 0.027097370080961424] lr: [0.0031640625]
42 0 recon: [2.159967065842649, 0.02658691292404579] lr: [0.0031640625]
42 30 recon: [2.235404464342339, 0.027157476861948356] lr: [0.0031640625]
42 60 recon: [2.2174275715212497, 0.026319751410746925] lr: [0.0031640625]
42 90 recon: [1.933540975447728, 0.026593123342300653] lr: [0.0031640625]
42 120 recon: [1.9959525041567214, 0.02678920199473884] lr: [0.0031640625]
42 150 recon: [2.1952341954937715, 0.02637584199863655] lr: [0.0031640625]
42 180 recon: [2.6129173857

49 150 recon: [2.125556648499165, 0.028579361603357268] lr: [0.002373046875]
49 180 recon: [2.4518205612638604, 0.027880070870601778] lr: [0.002373046875]
49 210 recon: [1.7935708225624896, 0.029298463420722693] lr: [0.002373046875]
49 240 recon: [1.730180221166941, 0.029118464549183418] lr: [0.002373046875]
49 270 recon: [2.2597051693174466, 0.027233635410942977] lr: [0.002373046875]
49 300 recon: [1.7976905630485, 0.028896254989471612] lr: [0.002373046875]
49 330 recon: [1.9171269389367382, 0.0267246822187377] lr: [0.002373046875]
49 360 recon: [1.704264967989345, 0.029023732006419396] lr: [0.002373046875]
49 390 recon: [1.836352621993125, 0.028541962621590453] lr: [0.002373046875]
Complete


In [147]:
# Fast weight estimation for all subjects

with torch.no_grad():
    book = lrc()

    codes = []

    for sub in range(x.shape[0]):
        subcodes = []
        A = torch.einsum('cab,bt->cat',book,x[sub,:,:-1])
        B = x[sub,:,1:]
        AA = torch.einsum('cat,dat->cdt',A,A).permute(2,0,1)
        AB = torch.einsum('cat,at->ct',A,B).permute(1,0)
        C,_,_,_ = torch.linalg.lstsq(AA+10*torch.eye(ncodes).float().cuda(),AB)
        codes.append(C)
        if sub % 100 == 0:
            tf = torch.einsum('cab,tc->abt',book,C)
            xhat = get_sub_xhat(tf, x[sub])
            loss = get_recon_loss(x[sub,:,1:], xhat)
            print(f'Finished {sub} {loss}')
    
codes = torch.stack(codes)
codes = codes.permute(0,2,1)
print(codes.shape)

Finished 0 0.44291189312934875
Finished 100 0.6687223315238953
Finished 200 0.40013283491134644
Finished 300 0.6420438885688782
Finished 400 0.5725203156471252
Finished 500 0.6807302236557007
Finished 600 0.7467097043991089
Finished 700 0.6160550117492676
Finished 800 0.5327571034431458
torch.Size([847, 300, 123])


In [153]:
ntrain = 500

codescuda = codes.float().cuda()
xps = torch.mean(codescuda, dim=-1) #ps.reshape(ps.shape[0],-1)
xps = torch.cat([xps, torch.ones(xps.shape[0], 1).float().cuda()], dim=1)
xtr = xps[:ntrain]
xt = xps[ntrain:]

y = get_y(metadict, ['age'], subs)[0]
y_t = torch.from_numpy(y).float().cuda()
ytr = y_t[:ntrain]
yt = y_t[ntrain:]

w, _, _, _ = torch.linalg.lstsq(xtr.T@xtr + 1*torch.eye(301).float().cuda(), xtr.T@ytr)

print(torch.mean((ytr-xtr@w)**2)**0.5)
print(torch.mean((yt-xt@w)**2)**0.5)

tensor(30.3571, device='cuda:0')
tensor(32.8531, device='cuda:0')
