In [1]:
import pickle

# Load meta dict

with open('../../PNC/AllSubjectsMeta.bin', 'rb') as f:
    meta = pickle.load(f)
    
# Load rest subject ids and splits

with open('../../Work/Abstract/PaperBin/AllThreeSplit.bin', 'rb') as f:
    splits = pickle.load(f)
    subids = splits['allThreeYesWrat']
    groups = splits['groups']
    
print(len(subids))

593


In [2]:
import numpy as np

subidsNp = np.array(subids)

# Load timeseries

def loadSeries(prefix, para, idx):
    with open('{:}/{:}_fmri_power264/timeseries/{:}.bin'.format(prefix, para, idx), 'rb') as f:
        return pickle.load(f)

rest_ts = [loadSeries('../../PNC', 'rest', meta[subid]['rest']) for subid in subidsNp]
nback_ts = [loadSeries('../../PNC', 'nback', meta[subid]['nback']) for subid in subidsNp]
emoid_ts = [loadSeries('../../PNC', 'emoid', meta[subid]['emoid']) for subid in subidsNp]

print('Loading complete')

Loading complete


In [3]:
import numpy as np

def normalizeSubjects(subjects):
    for i in range(len(subjects)):
        subj = subjects[i]
        subj -= np.mean(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        subj /= np.std(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        if np.sum(np.isnan(subj)) > 0:
            print(i)
        if np.sum(np.isinf(subj)) > 0:
            print(i)

normalizeSubjects(rest_ts)
normalizeSubjects(nback_ts)
normalizeSubjects(emoid_ts)

print('Complete')

Complete


In [4]:
# Calculate pearson matrices

rest_p = np.stack([np.corrcoef(sub) for sub in rest_ts])
nback_p = np.stack([np.corrcoef(sub) for sub in nback_ts])
emoid_p = np.stack([np.corrcoef(sub) for sub in emoid_ts])

print(rest_p.shape)
print('Complete')

(593, 264, 264)
Complete


In [5]:
# Create feature vectors (right now just ages, maleness, and femaless)

males = 0
females = 0

X_all = []
for subid in subidsNp:
    subj = meta[subid]
    maleness = 1 if subj['meta']['Gender'] == 'M' else 0
    femaleness = 1 if maleness == 0 else 0
    feat = np.array([subj['meta']['AgeInMonths'], maleness, femaleness])
    X_all.append(feat)
    if maleness == 1:
        males += 1
    if femaleness == 1:
        females += 1
X_all = np.vstack(X_all)

print(f'{males} {females}')
print(X_all[10:20])
print('Complete')

271 322
[[223   1   0]
 [190   0   1]
 [197   0   1]
 [145   1   0]
 [148   0   1]
 [142   0   1]
 [123   1   0]
 [176   1   0]
 [129   0   1]
 [173   1   0]]
Complete


In [6]:
import torch

def convertTorch(p):
    t = torch.from_numpy(p).float()
    u = []
    for i in range(t.shape[0]):
        u.append(t[i][torch.triu_indices(264,264,offset=1).unbind()])
    return torch.stack(u).cuda()

rest_p_t = convertTorch(rest_p)
nback_p_t = convertTorch(nback_p)
emoid_p_t = convertTorch(emoid_p)

print(rest_p_t.shape)
print(nback_p_t.shape)
print(emoid_p_t.shape)
print('Complete')

torch.Size([593, 34716])
torch.Size([593, 34716])
torch.Size([593, 34716])
Complete


In [7]:
wratDict = dict()

with open('../../PNC/wrat.csv', 'r') as f:
    lines = f.readlines()[1:]
    for line in lines:
        line = line.strip().split(',')
        wratDict[line[0]] = {'raw': line[2], 'std': line[3]}

wrat = []

for key in subids:
    wrat.append(float(wratDict[str(key)]['std']))
    
wrat = np.array(wrat)
wrat_t = torch.from_numpy(wrat).float().cuda()

print('Complete')

Complete


In [24]:
import torch.nn as nn
import torch.nn.functional as F

def makePoly(ps):
    pps = []
    for i in range(ps.shape[0]):
        p = ps[i].flatten()
        pp = nPoly*[None]
        for j in range(nPoly):
            pp[j] = p**(j+1)
        pps.append(torch.stack(pp))
    return torch.stack(pps)

def arith(n):
    return int(n*(n+1)/2)

def mask(e):
    return e - torch.diag(torch.diag(e))

class MiniPgi(nn.Module):
    def __init__(self, w, nPara, nPoly, nTgts, dp=0.5, relu=0.1):
        super(MiniPgi, self).__init__()
        self.masks = []
        if type(w) == int:
            w = nTgts*[w]
        for i in range(nTgts):
            self.masks.append(nn.Parameter(
                0.01*torch.ones(nPara,nPoly,arith(263),w[i]).float().cuda()
                +0.001*torch.randn(nPara,nPoly,arith(263),w[i]).float().cuda()
            ))
        self.dp = nn.Dropout(p=dp)
        self.relu = []
        for i in range(nTgts):
            rel = relu if type(relu) == float or type(relu) == int else relu[i]
            self.relu.append(nn.LeakyReLU(negative_slope=rel))
    
    def getLatentsAndEdges(self, x, idx):
        y = torch.einsum('abcd,bcde->ae', x, self.masks[idx])
        e = y@y.T
        return y, e
        
    def forward(self, x, age=None, gender=None, wrat=None):
        x = self.dp(x)
        lbls = [age, gender, wrat]
        res = []
        for i,lbl in enumerate(lbls):
            _, e = self.getLatentsAndEdges(x, i)
            idcs = torch.logical_not(torch.any(lbl, dim=1))
            e[:,idcs] = 0
            e = self.relu[i](mask(e))
            s = torch.sum(e, dim=1)
            e = e/s.unsqueeze(1)
            res.append(e@lbl)
        return res
        
print('Complete')

Complete


In [26]:
ceLoss = torch.nn.CrossEntropyLoss()
mseLoss = torch.nn.MSELoss()
nEpochs = 5000
pPeriod = 200
thresh = torch.Tensor((40,3.2e-1,20)).float().cuda()

nPoly = 1
para = [makePoly(nback_p_t), makePoly(emoid_p_t)]
    
rmse = []

for i in range(10):
    pgigcn = MiniPgi((2, 5, 10), len(para), nPoly, 3, 0.5, (0.2, 0.01, 1))
    optim = torch.optim.Adam(pgigcn.masks, lr=2e-5, weight_decay=2e-5)

    trainIdcs = groups[i][0]
    testIdcs = groups[i][1]
    
    X = torch.stack(para, dim=1)
    X = X[trainIdcs]
    Y = torch.from_numpy(X_all[trainIdcs]).float().cuda()
    
    gen = Y[:,1:]
    wrt = wrat_t[trainIdcs].unsqueeze(1)
    age = Y[:,0].unsqueeze(1)
    
    for epoch in range(nEpochs):
        optim.zero_grad()
        res = pgigcn(X, age=age, gender=gen, wrat=wrt)
        loss0 = mseLoss(res[0], age)
        loss1 = ceLoss(res[1], gen)
        loss2 = mseLoss(res[2], wrt)
        loss = torch.stack([loss0, loss1, loss2])
        torch.sum(loss).backward()
        optim.step()
        if (epoch % pPeriod == 0 or epoch == nEpochs-1):
            print(f'epoch {epoch} loss={(float(loss0), float(loss1), float(loss2))}')
        if torch.all(loss < thresh):
            print('Early stopping')
            break
            
    print('Finished training')
    
    pgigcn.eval()
    
    X = torch.stack(para, dim=1)
    Y = torch.from_numpy(X_all).float().cuda()
        
    gen = Y[:,1:]
    wrt = wrat_t.unsqueeze(1)
    age = Y[:,0].unsqueeze(1)

    gen0 = gen.clone().detach()
    gen0[testIdcs] = 0
    wrt0 = wrt.clone().detach()
    wrt0[testIdcs] = 0
    age0 = age.clone().detach()
    age0[testIdcs] = 0
    
    res, ss = pgigcn(X, age=age0, gender=gen0, wrat=wrt0)
    loss0 = mseLoss(res[0][testIdcs].detach(), age[testIdcs]).cpu().numpy()**0.5
    frac1 = torch.sum(torch.argmax(res[1].detach(), dim=1)[testIdcs] 
                     == torch.argmax(gen[testIdcs], dim=1))/testIdcs.shape[0]
    loss2 = mseLoss(res[2][testIdcs].detach(), wrt[testIdcs]).cpu().numpy()**0.5
    
    rmse.append((float(loss0), float(frac1), float(loss2)))
    print(i, end=' ')
    print(rmse[-1])

epoch 0 loss=(1604.88671875, 0.6920773386955261, 257.2871398925781, 0.0)
epoch 200 loss=(1599.9012451171875, 0.4183705747127533, 257.28729248046875, 0.0)
epoch 400 loss=(1605.887451171875, 0.36124739050865173, 255.22109985351562, 0.0)
epoch 600 loss=(1605.2891845703125, 0.34499526023864746, 257.00469970703125, 0.0)
Early stopping
Finished training
0 (34.39707565307617, 0.8333333730697632, 14.658626556396484)
epoch 0 loss=(2043.017822265625, 0.6903658509254456, 267.3212890625, 0.0)
epoch 200 loss=(1547.0286865234375, 0.45039913058280945, 250.3934326171875, 0.0)
epoch 400 loss=(1544.4755859375, 0.35791757702827454, 250.06971740722656, 0.0)
epoch 600 loss=(1548.4580078125, 0.34072282910346985, 246.99774169921875, 0.0)
Early stopping
Finished training
1 (39.55796432495117, 0.8166667222976685, 16.635272979736328)
epoch 0 loss=(1574.7952880859375, 0.6911529898643494, 537.5813598632812, 0.0)
epoch 200 loss=(1573.1409912109375, 0.416223406791687, 246.692138671875, 0.0)
epoch 400 loss=(1558.109

In [27]:
for a,b,c in rmse:
    print(b)

0.8333333730697632
0.8166667222976685
0.8000000715255737
0.8305084705352783
0.7796609997749329
0.7966101765632629
0.8305084705352783
0.7796609997749329
0.7118644118309021
0.7457627058029175


In [29]:
import torch.linalg

es = []
for i in range(3):
    _,e = pgigcn.getLatentsAndEdges(X, i)
    es.append(e.detach())

E = sum(es)

nSub = E.shape[0]
w = 11

U = torch.rand(nSub, w).float().cuda()

for i in range(2):
    A, res0, rank, sigma = torch.linalg.lstsq(U,es[0])
    B, res1, rank, sigma = torch.linalg.lstsq(U,es[1])
    C, res2, rank, sigma = torch.linalg.lstsq(U,es[2])
    print(f'{(res0[0], res1[0], res2[0])}')
    U, res, rank, sigma = torch.linalg.lstsq((A+B+C).T,E.T)
    U = U.T
    print(res[0])

(tensor(6094822.5000, device='cuda:0'), tensor(699011.1250, device='cuda:0'), tensor(12366456., device='cuda:0'))
tensor(22.1522, device='cuda:0')
(tensor(0.0013, device='cuda:0'), tensor(0.1015, device='cuda:0'), tensor(0.3642, device='cuda:0'))
tensor(0.0053, device='cuda:0')


In [46]:
UA = U@B
idcs = torch.logical_not(torch.any(gen0, dim=1))
UA[:,idcs] = 0
UA = pgigcn.relu[1](mask(UA))
s = torch.sum(UA, dim=1)
UA /= s.unsqueeze(1)
res = UA@gen0
# loss = mseLoss(res[testIdcs].detach(), wrt[testIdcs]).cpu().numpy()**0.5
frac1 = torch.sum(torch.argmax(res.detach(), dim=1)[testIdcs] 
                     == torch.argmax(gen[testIdcs], dim=1))/testIdcs.shape[0]
print(frac1)

tensor(0.8667, device='cuda:0')


In [53]:
Xf = X.reshape(X.shape[0],-1)

MU, res, rank, sigma = torch.linalg.lstsq(Xf,U)
print(res)
MC, res, rank, sigma = torch.linalg.lstsq(Xf,C.T)
print(res)

tensor([], device='cuda:0')
tensor([], device='cuda:0')


In [57]:
print(Xf@MC)
print(C)

tensor([[ 35.7946,  52.9134,  30.6744,  ...,  49.7341,  63.4010,  45.6183],
        [  6.1951,  39.1063,   7.1359,  ...,  73.0908,  28.1651, 107.8551],
        [ 84.5367,  53.7470,  70.6902,  ...,  21.5250,  76.6547, -12.5471],
        ...,
        [ 96.5503,  52.9031,  59.2112,  ...,  34.4210,  96.1620,   1.3936],
        [ 38.5524,  69.2021,  31.6161,  ...,  75.1384,  77.7418,  81.0949],
        [ 80.8732,  42.9717,  57.8032,  ...,  11.5973,  76.9860, -27.9158]],
       device='cuda:0')
tensor([[ 35.7946,   6.1951,  84.5368,  ...,  96.5504,  38.5524,  80.8733],
        [ 52.9133,  39.1062,  53.7470,  ...,  52.9031,  69.2021,  42.9717],
        [ 30.6744,   7.1359,  70.6902,  ...,  59.2112,  31.6161,  57.8032],
        ...,
        [ 49.7340,  73.0908,  21.5250,  ...,  34.4210,  75.1383,  11.5973],
        [ 63.4010,  28.1651,  76.6547,  ...,  96.1620,  77.7418,  76.9860],
        [ 45.6183, 107.8551, -12.5472,  ...,   1.3936,  81.0949, -27.9158]],
       device='cuda:0')


In [17]:
print(torch.sum(E[0]**2)**0.5)

tensor(7656.2363, device='cuda:0')
