In [1]:
import pickle

# Load meta dict

with open('../../PNC/AllSubjectsMeta.bin', 'rb') as f:
    meta = pickle.load(f)
    
# Load rest subject ids and splits

with open('../../Work/Abstract/PaperBin/AllThreeSplit.bin', 'rb') as f:
    splits = pickle.load(f)
    subids = splits['allThreeYesWrat']
    groups = splits['groups']
    
print(len(subids))

593


In [2]:
import numpy as np

subidsNp = np.array(subids)

# Load timeseries

def loadSeries(prefix, para, idx):
    with open('{:}/{:}_fmri_power264/timeseries/{:}.bin'.format(prefix, para, idx), 'rb') as f:
        return pickle.load(f)

rest_ts = [loadSeries('../../PNC', 'rest', meta[subid]['rest']) for subid in subidsNp]
nback_ts = [loadSeries('../../PNC', 'nback', meta[subid]['nback']) for subid in subidsNp]
emoid_ts = [loadSeries('../../PNC', 'emoid', meta[subid]['emoid']) for subid in subidsNp]

print('Loading complete')

Loading complete


In [3]:
import numpy as np

def normalizeSubjects(subjects):
    for i in range(len(subjects)):
        subj = subjects[i]
        subj -= np.mean(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        subj /= np.std(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        if np.sum(np.isnan(subj)) > 0:
            print(i)
        if np.sum(np.isinf(subj)) > 0:
            print(i)

normalizeSubjects(rest_ts)
normalizeSubjects(nback_ts)
normalizeSubjects(emoid_ts)

print('Complete')

Complete


In [4]:
# Calculate pearson matrices

rest_p = np.stack([np.corrcoef(sub) for sub in rest_ts])
nback_p = np.stack([np.corrcoef(sub) for sub in nback_ts])
emoid_p = np.stack([np.corrcoef(sub) for sub in emoid_ts])

print(rest_p.shape)
print('Complete')

(593, 264, 264)
Complete


In [5]:
# Create feature vectors (right now just ages, maleness, and femaless)

males = 0
females = 0

X_all = []
for subid in subidsNp:
    subj = meta[subid]
    maleness = 1 if subj['meta']['Gender'] == 'M' else 0
    femaleness = 1 if maleness == 0 else 0
    feat = np.array([subj['meta']['AgeInMonths'], maleness, femaleness])
    X_all.append(feat)
    if maleness == 1:
        males += 1
    if femaleness == 1:
        females += 1
X_all = np.vstack(X_all)

print(f'{males} {females}')
print(X_all[10:20])
print('Complete')

271 322
[[223   1   0]
 [190   0   1]
 [197   0   1]
 [145   1   0]
 [148   0   1]
 [142   0   1]
 [123   1   0]
 [176   1   0]
 [129   0   1]
 [173   1   0]]
Complete


In [8]:
import torch

def convertTorch(p):
    t = torch.from_numpy(p).float()
    u = []
    for i in range(t.shape[0]):
        u.append(t[i][torch.triu_indices(264,264,offset=1).unbind()])
    return torch.stack(u).cuda()

rest_p_t = convertTorch(rest_p)
nback_p_t = convertTorch(nback_p)
emoid_p_t = convertTorch(emoid_p)

print(rest_p_t.shape)
print(nback_p_t.shape)
print(emoid_p_t.shape)
print('Complete')

torch.Size([593, 34716])
torch.Size([593, 34716])
torch.Size([593, 34716])
Complete


In [9]:
wratDict = dict()

with open('../../PNC/wrat.csv', 'r') as f:
    lines = f.readlines()[1:]
    for line in lines:
        line = line.strip().split(',')
        wratDict[line[0]] = {'raw': line[2], 'std': line[3]}

wrat = []

for key in subids:
    wrat.append(float(wratDict[str(key)]['std']))
    
wrat = np.array(wrat)
wrat_t = torch.from_numpy(wrat).float().cuda()

print('Complete')

Complete


In [21]:
import torch.nn as nn
import torch.nn.functional as F

def makePoly(ps):
    pps = []
    for i in range(ps.shape[0]):
        p = ps[i].flatten()
        pp = nPoly*[None]
        for j in range(nPoly):
            pp[j] = p**(j+1)
        pps.append(torch.stack(pp))
    return torch.stack(pps)

def arith(n):
    return int(n*(n+1)/2)

def mask(e):
    return e - torch.diag(torch.diag(e))

class MiniPgi(nn.Module):
    def __init__(self, w, nPara, nPoly, nTgts, dp=0.5, relu=0.1):
        super(MiniPgi, self).__init__()
        self.masks = []
        if type(w) == int:
            w = nTgts*[w]
        for i in range(nTgts):
            self.masks.append(nn.Parameter(
                0.01*torch.ones(nPara,nPoly,arith(263),w[i]).float().cuda()
                +0.001*torch.randn(nPara,nPoly,arith(263),w[i]).float().cuda()
            ))
        self.dp = nn.Dropout(p=dp)
        self.relu = []
        for i in range(nTgts):
            rel = relu if type(relu) == float or type(relu) == int else relu[i]
            self.relu.append(nn.LeakyReLU(negative_slope=rel))
    
    def getLatentsAndEdges(self, x, idx):
        y = torch.einsum('abcd,bcde->ae', x, self.masks[idx])
        e = y@y.T
        return y, e
        
    def forward(self, x, age=None, gender=None, wrat=None):
        x = self.dp(x)
        lbls = [age, gender, wrat]
        res = []
        ss = []
        for i,lbl in enumerate(lbls):
            _, e = self.getLatentsAndEdges(x, i)
            idcs = torch.logical_not(torch.any(lbl, dim=1))
            e[:,idcs] = 0
            e = self.relu[i](mask(e))
            s = torch.sum(e, dim=1)
            e = e/s.unsqueeze(1)
            res.append(e@lbl)
            ss.append(s)
        return res, ss
        
print('Complete')

Complete


In [23]:
ceLoss = torch.nn.CrossEntropyLoss()
mseLoss = torch.nn.MSELoss()
nEpochs = 10000
pPeriod = 200
thresh = torch.Tensor((40,3.2e-1,20)).float().cuda()

nPoly = 1
para = [makePoly(nback_p_t), makePoly(emoid_p_t)]
    
rmse = []

for i in range(1):
    pgigcn = MiniPgi((10, 20, 20), len(para), nPoly, 3, 0.5, (0, 0, 0))
    optim = torch.optim.Adam(pgigcn.masks, lr=1e-4, weight_decay=1e-4)

    trainIdcs = groups[i][0]
    testIdcs = groups[i][1]
    
    X = torch.stack(para, dim=1)
    X = X[trainIdcs]
    Y = torch.from_numpy(X_all[trainIdcs]).float().cuda()
    
    gen = Y[:,1:]
    wrt = wrat_t[trainIdcs].unsqueeze(1)
    age = Y[:,0].unsqueeze(1)
    
    for epoch in range(nEpochs):
        optim.zero_grad()
        res, ss = pgigcn(X, age=age, gender=gen, wrat=wrt)
        loss0 = mseLoss(res[0], age)
        loss1 = ceLoss(res[1], gen)
        loss2 = mseLoss(res[2], wrt)
        loss = torch.stack([loss0, loss1, loss2])
        lossR = 0
        for s in ss:
            lossR += 100*torch.sum((1/s)**0.5)
        torch.sum(loss + lossR).backward()
        optim.step()
        if (epoch % pPeriod == 0 or epoch == nEpochs-1):
            print(f'epoch {epoch} loss={(float(loss0), float(loss1), float(loss2), float(lossR))}')
        if torch.all(loss < thresh):
            print('Early stopping')
            break
            
    print('Finished training')
    
    pgigcn.eval()
    
    X = torch.stack(para, dim=1)
    Y = torch.from_numpy(X_all).float().cuda()
        
    gen = Y[:,1:]
    wrt = wrat_t.unsqueeze(1)
    age = Y[:,0].unsqueeze(1)

    gen0 = gen.clone().detach()
    gen0[testIdcs] = 0
    wrt0 = wrt.clone().detach()
    wrt0[testIdcs] = 0
    age0 = age.clone().detach()
    age0[testIdcs] = 0
    
    res, ss = pgigcn(X, age=age0, gender=gen0, wrat=wrt0)
    loss0 = mseLoss(res[0][testIdcs].detach(), age[testIdcs]).cpu().numpy()**0.5
    frac1 = torch.sum(torch.argmax(res[1].detach(), dim=1)[testIdcs] 
                     == torch.argmax(gen[testIdcs], dim=1))/testIdcs.shape[0]
    loss2 = mseLoss(res[2][testIdcs].detach(), wrt[testIdcs]).cpu().numpy()**0.5
    
    rmse.append((float(loss0), float(frac1), float(loss2)))
    print(i, end=' ')
    print(rmse[-1])

epoch 0 loss=(1591.887939453125, 0.6936230063438416, 257.26080322265625, 9.615461349487305)
epoch 200 loss=(1572.80224609375, 0.6936047077178955, 257.26019287109375, 5.83980131149292)
epoch 400 loss=(1569.08837890625, 0.6936105489730835, 257.260498046875, 4.937052249908447)
epoch 600 loss=(1060.85595703125, 0.6936168074607849, 257.26080322265625, 16.621685028076172)
epoch 800 loss=(436.553466796875, 0.6936041712760925, 257.25714111328125, 32.35264205932617)
epoch 1000 loss=(261.0439147949219, 0.6936041712760925, 257.2572326660156, 38.61881637573242)
epoch 1200 loss=(205.9494171142578, 0.693602442741394, 257.2554626464844, 40.3409309387207)
epoch 1400 loss=(171.29598999023438, 0.6935980916023254, 257.2557373046875, 40.86528396606445)
epoch 1600 loss=(146.37510681152344, 0.6935871839523315, 257.2554016113281, 41.766029357910156)
epoch 1800 loss=(129.70065307617188, 0.6935849189758301, 257.2554016113281, 41.91110610961914)
epoch 2000 loss=(115.4173355102539, 0.6935838460922241, 257.252655