In [1]:
# It might be faster to construct codebook page by page, and update the weights periodically,
# Rather than have the codebook and weights update in one big network

from google.colab import drive
drive.mount('/content/drive')

ModuleNotFoundError: No module named 'google.colab'

In [1]:
# Using newly preprocessed subjects

import pickle

# metadictname = '/content/drive/MyDrive/Tulane/Research/PNC/PNC_agesexwrat.pkl'
# alltsname = '/content/drive/MyDrive/Tulane/Research/PNC/PNC_PowerTS_float2.pkl'
metadictname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_agesexwrat.pkl'
alltsname = '/home/anton/Documents/Tulane/Research/PNC_Good/PNC_PowerTS_float2.pkl'

with open(metadictname, 'rb') as f:
    metadict = pickle.load(f)

with open(alltsname, 'rb') as f:
    allts = pickle.load(f)
    
print(list(metadict.keys()))
print(list(allts.keys()))
print('Complete')

['age', 'sex', 'wrat', 'missingage', 'missingsex', 'missingwrat', 'failedqc']
['emoid', 'nback', 'rest']
Complete


In [2]:
'''
Get subjects that have all tasks and paras specified
Functions for creating independent and response variables
'''

import numpy as np

def get_subs(allts, metadict, tasks, paras):
    # Get subs for all paras
    for i,para in enumerate(paras):
        tmpset = set([int(sub[4:]) for sub in allts[para].keys()])
        if i == 0:
            paraset = tmpset
        else:
            paraset = paraset.intersection(tmpset)
    # Get subs for all tasks
    for i,task in enumerate(tasks):
        tmpset = set([sub for sub in metadict[task].keys()])
        if i == 0:
            taskset = tmpset
        else:
            taskset = paraset.intersection(tmpset)
    # Remove QC failures
    allsubs = taskset.intersection(paraset)
    for badsub in metadict['failedqc']:
        try:
            allsubs.remove(int(badsub[4:]))
        except:
            pass
    return allsubs

def get_X(allts, paras, subs):
    X = []
    for para in paras:
        pX = [allts[para][f'sub-{sub}'] for sub in subs]
        pX = np.stack(pX)
        X.append(pX)
    return X

def get_y(metadict, tasks, subs):
    y = []
    for task in tasks:
        if task == 'age' or task == 'wrat':
            var = [metadict[task][sub] for sub in subs]
            var = np.array(var)
            y.append(var)
        if task == 'sex':
            maleness = [metadict[task][sub] == 'M' for sub in subs]
            maleness = np.array(maleness)
            sex = np.stack([maleness, 1-maleness], axis=1)
            y.append(sex)
    return y

subs = get_subs(allts, metadict, ['age'], ['rest', 'nback', 'emoid'])
print(len(subs))

X = get_X(allts, ['rest', 'nback', 'emoid'], subs)
print(X[0].shape)

847
(847, 264, 124)


In [3]:
# TS to condensed FC

from scipy import signal

def butter_bandpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = [cutoff[0] / nyq, cutoff[1] / nyq]
    b, a = signal.butter(order, normal_cutoff, btype='band', analog=False)
    return b, a

def butter_bandpass_filter(data, cutoff, fs, order=5):
    b, a = butter_bandpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

tr = 1.83
N = X[0].shape[0]

def filter_design_ts(X):
    Xs = []
    for i in range(X.shape[0]):
        nX = butter_bandpass_filter(X[i], [tr/20*N, 0.8*N], 2*N)
        Xs.append(nX)
    return np.stack(Xs)

def ts_to_flat_fc(X):
    p = np.corrcoef(X)
    a,b = np.triu_indices(p[0].shape[0], 1)
    p = p[a,b]
    return p

p = [[np.corrcoef(ts) for ts in filter_design_ts(Xp)] for Xp in X]
print(len(p))
print(p[0][0].shape)

3
(264, 264)


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def mask(e):
    return e - torch.diag(torch.diag(e.detach()))

x = [[mask(torch.from_numpy(pp).float().cuda()) for pp in para] for para in p]
x = torch.stack([torch.stack(para) for para in x], dim=1)
xtr = x[:600]
xt = x[600:]

print(x.shape)
print(xtr.shape)
print(xt.shape)

torch.Size([847, 3, 264, 264])
torch.Size([600, 3, 264, 264])
torch.Size([247, 3, 264, 264])


In [None]:
import itertools

class LowRankCodes(nn.Module):
    def __init__(self, ranks):
        super(LowRankCodes, self).__init__()
        self.book = []
        for rank in ranks:
            self.book.append(nn.Parameter(1e-5*torch.randn(rank,264).float().cuda()))
        self.book = nn.ParameterList(self.book)
        self.page = 0

    def turn_page(self):
        if self.page < len(self.book)-1:
            self.page += 1
        
    def is_finished(self):
        return self.page == len(self.book)-1

    def get_book(self):
        book = []
        for page in self.book:
            book.append(mask(page.T@page))
        return torch.stack(book)

    def forward(self):
        return mask(self.book[self.page].T@self.book[self.page])
    
class LowRankWeights(nn.Module):
    def __init__(self, nsubs, nranks):
        super(LowRankWeights, self).__init__()
        self.w = nn.Parameter(torch.ones(nsubs, nranks).float().cuda())
        
    def forward(self, book, pagenum):
        w = F.relu(self.w[:,:pagenum+1])
        return torch.einsum('nr,rab->nab',w,book[:pagenum+1])
    
nEpochs = 30000
pPeriod = 100
nRanks = 100
lr = 1e-1
lr_thresh = 1e-2

mseLoss = nn.MSELoss()
    
lrc = LowRankCodes(nRanks*[1])
lrw = LowRankWeights(xtr.shape[0], nRanks)

optim = torch.optim.Adam(
    [
        {'params': lrc.parameters(), 'lr': lr},
        {'params': lrw.parameters(), 'lr': lr} 
    ],
    weight_decay=0
)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.9, eps=1e-7)

def reset_lr(optim, lr):
    for i, param_group in enumerate(optim.param_groups):
        param_group['lr'] = lr

book = torch.zeros(nRanks,264,264).float().cuda()
    
for epoch in range(nEpochs):
    optim.zero_grad()
    book = book.detach()
    book[lrc.page] = lrc()
    loss = []
    for ip in range(1):
        scratch = lrw(book, lrc.page)
        xloss = mseLoss(scratch, xtr[:,ip])
        loss.append(xloss)
    sum(loss).backward()
    optim.step()
    sched.step(xloss)
    if epoch % pPeriod == 0 or epoch == nEpochs-1:
        print(f'{epoch} {[float(ploss) for ploss in loss]} page: {lrc.page} lr: {sched._last_lr}')
    if sched._last_lr[0] < lr_thresh:
        # print('Turned')
        lrc.turn_page()
        reset_lr(optim, lr)
    if lrc.is_finished():
        print('Early finish')
        break
        
print('Complete')

0 [0.1620311588048935] page: 0 lr: [0.1, 0.1]
100 [0.04182826727628708] page: 0 lr: [0.08100000000000002, 0.08100000000000002]
200 [0.04182606562972069] page: 0 lr: [0.03138105960900001, 0.03138105960900001]
300 [0.04182606562972069] page: 0 lr: [0.01215766545905694, 0.01215766545905694]


In [None]:
torch.sum(lrw.w[:,:100] > 0)

tensor(40181, device='cuda:0')