In [1]:
# Get FC all tasks in ImageNomer directory

from pathlib import Path
import numpy as np
import re

snpsdir = '../../ImageNomer/data/anton/cohorts/test/fc/'

fcs = dict(rest=dict(), nback=dict(), emoid=dict())

for f in Path(snpsdir).iterdir():
    mobj = re.match('([0-9]+)_task-([a-z]+)_fc.npy', f.name)
    if not mobj:
        continue
    fc = np.load(f'{snpsdir}/{f.name}')
    sub = mobj.group(1)
    mod = mobj.group(2) 
    fcs[mod][sub] = fc

for mod in fcs:
    print(len(list(fcs[mod].keys())))

830
830
830


In [12]:
from sklearn.decomposition import PCA
from natsort import natsorted

subs = natsorted(fcs['rest'].keys())

task = 'emoid'

x = []
for sub in subs:
    fc = fcs[task][sub]
    x.append(fc)
x = np.stack(x)
pca = PCA()
xt = pca.fit_transform(x)
print(xt.shape)
print(pca.components_.shape)

dcomp = f'../../ImageNomer/data/anton/cohorts/test/decomp/{task}pca-comps'
dws = f'../../ImageNomer/data/anton/cohorts/test/decomp/{task}pca-weights'

for i,sub in enumerate(subs):
    c = pca.components_[i]
    w = xt[i]
    np.save(f'{dcomp}/{task}pca_comp-{i}.npy', c)
    np.save(f'{dws}/{sub}_comp-{task}pca_weights.npy', w)
    
print('Done')

(830, 830)
(830, 34716)
Done


In [4]:
print(pca.components_.shape)

(830, 34716)


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import itertools

mseLoss = nn.MSELoss()

# Use make_dict() followed by fit_weights()

'''
LowRankCodes - dictionary components (rank-N matrices)
For now PSD (symmetric) components only
'''
class LowRankCodes(nn.Module):
    '''
    ranks: array of rank for each codebook matrix
    '''
    def __init__(self, ranks):
        super(LowRankCodes, self).__init__()
        self.As = []
        for rank in ranks:
            A = nn.Parameter(1e-2*torch.randn(rank,264).float().cuda())
            self.As.append(A)
        self.As = nn.ParameterList(self.As)

    '''
    Generate codebook
    '''
    def forward(self):
        book = []
        for A in self.As:
            AA = A.T@A
            book.append(AA)
        return torch.stack(book)
    
'''
LowRankWeights - weights for the LowRankCodes codebook entries
'''
class LowRankWeights(nn.Module):
    '''
    ncodes: number of pages in the codebook
    Xs: list of inputs to LowRankWeights of size [(nsubs, nrois, nt)...]
    subids: id of each subject in modlist (optional)
    '''
    def __init__(self, ncodes, Xs, subids=None):
        super(LowRankWeights, self).__init__()
        self.ncodes = ncodes
        self.params = []
        for i in range(len(Xs)):
            nsubs = Xs[i].shape[0]
            nt = Xs[i].shape[-1]
            params = nn.Parameter(1e-2*torch.rand(nsubs, self.ncodes, nt).float().cuda())
            self.params.append(params)
        self.params = nn.ParameterList(self.params)
        self.subids = subids

    '''
    Get estimated instantaneous FC from book
    '''
    def forward(self, sub, book, mod):
        w = self.params[mod][sub]
        return torch.einsum('pt,pab->abt', w, book) # PUT BACK leaky relu

def get_recon_loss(x, xhat):
    return mseLoss(xhat, x)

def get_smooth_loss_fc(xhat):
    before = xhat[:,:,:-1]
    after = xhat[:,:,1:]
    return torch.mean((before-after)**2)

def get_mag_loss(lrc):
    loss = [torch.mean((A-0.01)**2) for A in lrc.As]
    return sum(loss)/len(loss)

def get_sub_fc(subts):
    return torch.einsum('at,bt->abt',subts,subts)

def default_or_custom(kwargs, field, default):
    if field not in kwargs:
        kwargs[field] = default

def make_dict(Xs, ranks=400*[1], **kwargs):
    default_or_custom(kwargs, 'nbatch', 20)
    default_or_custom(kwargs, 'smooth_mult', 0.1)
    default_or_custom(kwargs, 'nepochs', 50)
    default_or_custom(kwargs, 'pperiod', 5)
    default_or_custom(kwargs, 'subids', None)
    default_or_custom(kwargs, 'lr', 1e-2)
    default_or_custom(kwargs, 'l2', 0)
    default_or_custom(kwargs, 'patience', 20)
    default_or_custom(kwargs, 'factor', 0.75)
    default_or_custom(kwargs, 'eps', 1e-7)
    default_or_custom(kwargs, 'verbose', False)

    pperiod = kwargs['pperiod']
    nepochs = kwargs['nepochs']
    nbatch = kwargs['nbatch']
    smooth_mult = kwargs['smooth_mult']
    ncodes = len(ranks)
    modlist = [dict(nsubs=X.shape[0], nt=X.shape[-1]) for X in Xs]

    lrc = LowRankCodes(ranks)
    lrw = LowRankWeights(ncodes, Xs, kwargs['subids'])

    optim = torch.optim.Adam(
        itertools.chain(lrc.parameters(), lrw.parameters()), 
        lr=kwargs['lr'], 
        weight_decay=kwargs['l2'])
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optim, 
        patience=kwargs['patience'], 
        factor=kwargs['factor'], 
        eps=kwargs['eps'])

    for epoch in range(nepochs):
        for modidx in range(len(Xs)):
            ntrain = Xs[modidx].shape[0]
            for bstart in range(0,ntrain,nbatch):
                bend = bstart+nbatch
                if bend > ntrain:
                    bend = ntrain
                optim.zero_grad()
                book = lrc()
                recon_loss = 0
                smooth_loss_fc = 0
                for subidx in range(bstart, bend):
                    xsub = get_sub_fc(Xs[modidx][subidx])
                    xhat = lrw(subidx, book, modidx)   
                    recon_loss += get_recon_loss(xsub, xhat)
                    smooth_loss_fc += smooth_mult*get_smooth_loss_fc(xhat)
                recon_loss /= (bend-bstart)
                smooth_loss_fc /= (bend-bstart)
                loss = recon_loss+smooth_loss_fc
                loss.backward()
                optim.step()
                sched.step(loss)

        if not kwargs['verbose']:
            continue
        if epoch % pperiod == 0 or epoch == nepochs-1:
            print(f'{epoch} {bstart} recon: {[float(ls)**0.5 for ls in [recon_loss, smooth_loss_fc]]} '
                f'lr: {sched._last_lr}')

    optim.zero_grad()
    if not kwargs['verbose']:
        print('Complete')

    return lrc, lrw

def fit_weights(low_rank_codes, Xs, **kwargs):
    default_or_custom(kwargs, 'nepochs', 500)
    default_or_custom(kwargs, 'pperiod', 50)
    default_or_custom(kwargs, 'lr', 1e-1)
    default_or_custom(kwargs, 'l1', 0)
    default_or_custom(kwargs, 'l2', 1e-5)
    default_or_custom(kwargs, 'patience', 10)
    default_or_custom(kwargs, 'factor', 0.75)
    default_or_custom(kwargs, 'eps', 1e-7)
    default_or_custom(kwargs, 'verbose', False)
    
    nepochs = kwargs['nepochs']
    pperiod = kwargs['pperiod']

    book = low_rank_codes()
    A = book.reshape(book.shape[0], -1).permute(1,0).detach()
    AA = A.T@A
    ws = []

    for X in Xs:
        AB = []
        for sub in range(X.shape[0]):
            B = get_sub_fc(X[sub]).reshape(-1, X.shape[-1])
            AB.append(A.T@B)
        AB = torch.stack(AB)

        w = nn.Parameter(torch.rand(AB.shape[0],AA.shape[1],AB.shape[-1]).float().cuda())
        ws.append(w)

        optim = torch.optim.Adam(
            [w], 
            lr=kwargs['lr'], 
            weight_decay=kwargs['l2'])
        sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optim, 
            patience=kwargs['patience'], 
            factor=kwargs['factor'], 
            eps=kwargs['eps'])

        for epoch in range(nepochs):
            optim.zero_grad()
            ABhat = torch.matmul(AA.detach(),w) # PUT BACK LEAKY RELU
            pred_loss = mseLoss(ABhat, AB.detach())**0.5
            l1_loss = kwargs['l1']*torch.mean(torch.abs(w))
            loss = pred_loss+l1_loss
            loss.backward()
            optim.step()
            sched.step(loss)
            if not kwargs['verbose']:
                continue
            if epoch % pperiod == 0 or epoch == nepochs-1:
                print(f'{epoch} {[float(ls) for ls in [pred_loss, l1_loss]]} {sched._last_lr}')

        optim.zero_grad()
        if not kwargs['verbose']:
            print('Complete')

    return ws

print('Done')

Done


In [2]:
# Using newly preprocessed subjects

import pickle

metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [3]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np
from natsort import natsorted

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return natsorted(list(allsubs))

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['wrat'], ['rest', 'nback', 'emoid'])
print(len(subs))

X = get_X(allts, ['rest', 'nback', 'emoid'], subs)
print(X[0].shape)

830
(830, 264, 124)


In [4]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 3

def filter_design_ts(X):
    Xs = []
    for i in range(X.shape[0]):
        nX = butter_bandpass_filter(X[i], [0.01, 0.15], 1/tr)
        Xs.append(nX)
    return np.stack(Xs)

def ts_to_flat_fc(X):
    p = np.corrcoef(X)
    a,b = np.triu_indices(p[0].shape[0], 1)
    p = p[a,b]
    return p

# p = [np.stack([ts_to_flat_fc(ts) for ts in filter_design_ts(Xp)]) for Xp in X]
Xf = [filter_design_ts(Xp) for Xp in X]
Xs = [tsmod/np.linalg.norm(tsmod, axis=(-1), keepdims=True) for tsmod in Xf]
print(Xs[0].shape)
print(Xs[1].shape)

(830, 264, 124)
(830, 264, 231)


In [91]:
Xt = []
for Xp in Xs[2:3]:
    idcs = torch.randperm(830)
    Xp = torch.from_numpy(Xp).float().cuda()
    Xt.append(Xp[idcs][:400])
    
print('Starting')

lrc, lrw = make_dict(Xt, ranks=400*[1], nepochs=35, verbose=True)

Starting
0 380 recon: [0.00671566538138334, 2.8938422445752883e-06] lr: [0.01]
5 380 recon: [0.00589714166271204, 0.0008019399140955256] lr: [0.0075]
10 380 recon: [0.005512465897948504, 0.0011852861145529663] lr: [0.0075]
15 380 recon: [0.005258004446623459, 0.0011447641199750335] lr: [0.00421875]
20 380 recon: [0.005105440741235429, 0.0012158229252871432] lr: [0.00421875]
25 380 recon: [0.004986756161166922, 0.0012774127056460307] lr: [0.00421875]
30 380 recon: [0.005085004332884545, 0.0016409427823867795] lr: [0.0031640625]
34 380 recon: [0.004879976544251507, 0.001316699278656137] lr: [0.0017797851562500002]


In [6]:
Xt = []
for Xp in Xs[2:3]:
    Xp = torch.from_numpy(Xp).float().cuda()
    Xt.append(Xp)

w = fit_weights(lrc, Xt, verbose=True)

0 [198.72190856933594, 0.0] [0.1]
50 [0.4534512758255005, 0.0] [0.1]
100 [0.38452914357185364, 0.0] [0.1]


KeyboardInterrupt: 

In [92]:
def vstk2mat(vstk):
    a,b = torch.triu_indices(264,264,offset=1)
    return vstk[:,a,b]

def fit_w(lrc, x):
    book = lrc()
    A = vstk2mat(book)
    A = A.T.detach()
    AA = A.T@A
    ws = []
    rs = []
    
    # One subject at a time
    for sub in range(x.shape[0]):
        B = vstk2mat(get_sub_fc(x[sub]).permute(2,0,1))
        B = B.T
        BA = A.T@B
        I = 0.1*torch.eye(A.shape[1]).float().cuda()
        w,_,_,_ = torch.linalg.lstsq((AA+I).detach(),BA.detach())
        rmse = torch.mean((B-A@w)**2)**0.5
        rmse = float(rmse)
        rs.append(rmse)
        ws.append(w.detach())
        
    print(np.mean(rmse))
    return ws

Xt = []
for Xp in Xs[2:3]:
    Xp = torch.from_numpy(Xp).float().cuda()
    Xt.append(Xp)

ws = fit_w(lrc, Xt[0])
w = torch.mean(torch.stack(ws), axis=-1)
print(w.shape)

0.004749314859509468
torch.Size([830, 400])


In [98]:
import os
from pathlib import Path

task = 'emoid'
basedir = '../../ImageNomer/data/anton/cohorts/test/decomp'
wdir = f'{basedir}/{task}dd-weights'
cdir = f'{basedir}/{task}dd-comps'
    
book = vstk2mat(lrc())
b = torch.mean(torch.abs(w), dim=0).detach().cpu().numpy()
c = torch.mean(torch.abs(book), dim=1).detach().cpu().numpy()
idcs = np.argsort(b*c)[::-1]
# print((b*c)[idcs])

cc = book.detach().cpu().numpy()[idcs,:]
ww = w.detach().cpu().numpy()[:,idcs]

if not Path(wdir).exists():
    os.mkdir(wdir)
    
if not Path(cdir).exists():
    os.mkdir(cdir)

for i,sub in enumerate(subs):
    fname = f'{wdir}/{sub}_comp-{task}dd_weights.npy'
    np.save(fname, ww[i])

print('Done 1')
    
for i in range(book.shape[0]):
    fname = f'{cdir}/{task}dd_comp-{i}.npy'
    np.save(fname, cc[i])

print('Done 2')

Done 1
Done 2


In [97]:
ntrain = 700
rs = []

for i in range(100):
    idcs = torch.randperm(830)

    y = get_y(metadict, ['wrat'], subs)[0]
    y = y[idcs]
    y = torch.from_numpy(y).float().cuda()
    ytr = y[:ntrain]
    yt = y[ntrain:]

    mu = torch.mean(ytr)

    ytr = ytr - mu
    yt = yt - mu

    x = w[idcs]
#     x = torch.mean(x, axis=-1)
    xtr = x[:ntrain]
    xt = x[ntrain:]
    
    mux = torch.mean(xtr, axis=0)
    
#     xtr = xtr - mux
#     xt = xt - mux

    w2,_,_,_ = torch.linalg.lstsq(xtr.T@xtr+1e-2*torch.eye(xtr.shape[-1]).float().cuda(), xtr.T@ytr)
    yhat = xt@w2
#     print(torch.mean(yt**2)**0.5)
    rmse = torch.mean((yhat-yt)**2)**0.5
    rmse = float(rmse)
    rs.append(rmse)
    
print('---')
print(np.mean(rs))

---
15.307805767059326
