In [1]:
import pickle

# Load meta dict

with open('../../PNC/AllSubjectsMeta.bin', 'rb') as f:
    meta = pickle.load(f)
    
# Load rest subject ids and splits

with open('../../Work/Abstract/PaperBin/AllThreeSplit.bin', 'rb') as f:
    splits = pickle.load(f)
    subids = splits['allThreeYesWrat']
    groups = splits['groups']
    
print(len(subids))

593


In [2]:
import numpy as np

subidsNp = np.array(subids)

# Load timeseries

def loadSeries(prefix, para, idx):
    with open('{:}/{:}_fmri_power264/timeseries/{:}.bin'.format(prefix, para, idx), 'rb') as f:
        return pickle.load(f)

rest_ts = [loadSeries('../../PNC', 'rest', meta[subid]['rest']) for subid in subidsNp]
nback_ts = [loadSeries('../../PNC', 'nback', meta[subid]['nback']) for subid in subidsNp]
emoid_ts = [loadSeries('../../PNC', 'emoid', meta[subid]['emoid']) for subid in subidsNp]

print('Loading complete')

Loading complete


In [3]:
import numpy as np

def normalizeSubjects(subjects):
    for i in range(len(subjects)):
        subj = subjects[i]
        subj -= np.mean(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        subj /= np.std(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        if np.sum(np.isnan(subj)) > 0:
            print(i)
        if np.sum(np.isinf(subj)) > 0:
            print(i)

normalizeSubjects(rest_ts)
normalizeSubjects(nback_ts)
normalizeSubjects(emoid_ts)

print('Complete')

Complete


In [4]:
# Calculate pearson matrices

rest_p = np.stack([np.corrcoef(sub) for sub in rest_ts])
nback_p = np.stack([np.corrcoef(sub) for sub in nback_ts])
emoid_p = np.stack([np.corrcoef(sub) for sub in emoid_ts])

print(rest_p.shape)
print('Complete')

(593, 264, 264)
Complete


In [5]:
# Create feature vectors (right now just ages, maleness, and femaless)

males = 0
females = 0

X_all = []
for subid in subidsNp:
    subj = meta[subid]
    maleness = 1 if subj['meta']['Gender'] == 'M' else 0
    femaleness = 1 if maleness == 0 else 0
    feat = np.array([subj['meta']['AgeInMonths'], maleness, femaleness])
    X_all.append(feat)
    if maleness == 1:
        males += 1
    if femaleness == 1:
        females += 1
X_all = np.vstack(X_all)

print(f'{males} {females}')
print(X_all[10:20])
print('Complete')

271 322
[[223   1   0]
 [190   0   1]
 [197   0   1]
 [145   1   0]
 [148   0   1]
 [142   0   1]
 [123   1   0]
 [176   1   0]
 [129   0   1]
 [173   1   0]]
Complete


In [26]:
import torch

def convertTorch(p):
    t = torch.from_numpy(p).float()
    u = []
    for i in range(t.shape[0]):
        u.append(t[i][torch.triu_indices(264,264,offset=1).unbind()])
    return torch.stack(u).cuda()

def normalizeP(p):
    return p - torch.mean(p, dim=1, keepdim=True)

rest_p_t = convertTorch(rest_p)
nback_p_t = convertTorch(nback_p)
emoid_p_t = convertTorch(emoid_p)

# rest_p_t = normalizeP(rest_p_t)
# nback_p_t = normalizeP(nback_p_t)
# emoid_p_t = normalizeP(emoid_p_t)

print(rest_p_t.shape)
print(nback_p_t.shape)
print(emoid_p_t.shape)
print('Complete')

torch.Size([593, 34716])
torch.Size([593, 34716])
torch.Size([593, 34716])
Complete


In [7]:
wratDict = dict()

with open('../../PNC/wrat.csv', 'r') as f:
    lines = f.readlines()[1:]
    for line in lines:
        line = line.strip().split(',')
        wratDict[line[0]] = {'raw': line[2], 'std': line[3]}

wrat = []

for key in subids:
    wrat.append(float(wratDict[str(key)]['std']))
    
wrat = np.array(wrat)
wrat_t = torch.from_numpy(wrat).float().cuda()

print('Complete')

Complete


In [21]:
import torch.nn as nn
import torch.nn.functional as F
from itertools import permutations

def makePoly(ps, nPoly):
    pps = []
    for i in range(ps.shape[0]):
        p = ps[i].flatten()
        pp = nPoly*[None]
        for j in range(nPoly):
            pp[j] = p**(j+1)
        pps.append(torch.stack(pp))
    return torch.stack(pps)

def arith(n):
    return int(n*(n+1)/2)

def mask(e):
    return e - torch.diag(torch.diag(e.detach()))

class PgiDiff(nn.Module):
    def __init__(self, w, nPara, nTgts, dp=0.5, dp2=0.1):
        super(PgiDiff, self).__init__()
        self.nPara = nPara
        self.nTgts = nTgts
        self.masks = []
        if type(w) == int:
            w = (nTgts+1)*[w]
        for i in range(nTgts+1):
            self.masks.append(nn.Parameter(
                1e-4*torch.randn(nPara,arith(263),w[i]).float().cuda()
            ))
        self.dp = nn.Dropout(p=dp)
        self.dp2 = nn.Dropout(p=dp2)
    
    def getLatentsAndEdges(self, x, i, univ):
        if univ:
            y = torch.einsum('abc,bce->ae', x, self.masks[0])
            z = torch.einsum('abc,bce->ae', x, self.masks[i+1])
            e = y@z.T+z@y.T
            return y, z, e
        else:
            y = torch.einsum('abc,bce->ae', x, self.masks[i+1])
            e = y@y.T
            return y, y, e
        
    def forward(self, x, y, testIdcs=None, univ=True):
        x = self.dp(x)
        res = []
        for i in range(self.nTgts):
            _, _, e = self.getLatentsAndEdges(x, i, univ)
            if testIdcs is not None:
                e[:,testIdcs] = 0
            e = self.dp2(e)
            e = mask(e)
            e[e == 0] = float('-inf')
            e = 1.1*F.softmax(e, dim=1)
            e = e*y[i].unsqueeze(0)
            res.append(torch.sum(e, dim=1))
        return res
        
print('Complete')

Complete


In [32]:
ceLoss = torch.nn.CrossEntropyLoss()
mseLoss = torch.nn.MSELoss()
nEpochs = 5000
pPeriod = 200
thresh = torch.Tensor((40,12,20)).float().cuda()

para = [nback_p_t, emoid_p_t]
    
rmse = []

def xform(data, stats=None, fwd=True):
    if stats is None:
        mu = torch.mean(data, dim=0, keepdim=True)
        sd = torch.std(data, dim=0, keepdim=True)
        return (mu, sd)
    elif fwd:
        return (data - stats[0])/stats[1]
    else:
        return data*stats[1] + stats[0]

for i in range(10):
    pgigcn = PgiDiff(4, len(para), 4, 0.5, 0.2)
    optim = torch.optim.Adam(pgigcn.masks, lr=2e-5, weight_decay=2e-5)

    trainIdcs = groups[i][0]
    testIdcs = groups[i][1]
    
    X = torch.stack(para, dim=1)
    X = X[trainIdcs]
    Y = torch.from_numpy(X_all[trainIdcs]).float().cuda()
    
    gen = Y[:,1:]
    wrt = wrat_t[trainIdcs]
    age = Y[:,0]
    
    # Normalize dataset
    statsGen = xform(gen)
    statsWrt = xform(wrt)
    statsAge = xform(age)
    
    # Transformed
    genT = xform(gen, statsGen)
    wrtT = xform(wrt, statsWrt)
    ageT = xform(age, statsAge)
    
    y = torch.cat([ageT.unsqueeze(1), genT, wrtT.unsqueeze(1)], dim=1).T
    
    for epoch in range(nEpochs):
        optim.zero_grad()
        res = pgigcn(X, y, univ=True)
        loss0 = mseLoss(xform(res[0], statsAge, fwd=False), age)
        loss1 = 100*ceLoss(torch.stack([res[1], res[2]], dim=1), gen)
        loss2 = mseLoss(xform(res[3], statsWrt, fwd=False), wrt)
        loss = torch.stack([loss0, loss1, loss2])
        torch.sum(loss).backward()
        optim.step()
        if (epoch % pPeriod == 0 or epoch == nEpochs-1):
            print(f'epoch {epoch} loss={(float(loss0), float(loss1), float(loss2))}')
        if torch.all(loss[0:3] < thresh):
            print('Early stopping')
            break
            
    print('Finished training')
    
    pgigcn.eval()
    
    X = torch.stack(para, dim=1)
    Y = torch.from_numpy(X_all).float().cuda()
        
    gen = Y[:,1:]
    wrt = wrat_t
    age = Y[:,0]
    
    # Transformed
    genT = xform(gen, statsGen)
    wrtT = xform(wrt, statsWrt)
    ageT = xform(age, statsAge)
    
    y = torch.cat([ageT.unsqueeze(1), genT, wrtT.unsqueeze(1)], dim=1).T
    
    with torch.no_grad():
        res = pgigcn(X, y, testIdcs, univ=True)
        loss0 = mseLoss(xform(res[0][testIdcs].detach(), statsAge, fwd=False), age[testIdcs]).cpu().numpy()**0.5
        frac1 = torch.sum(torch.argmax(torch.stack([res[1], res[2]], dim=1).detach(), dim=1)[testIdcs] 
                         == torch.argmax(gen[testIdcs], dim=1))/testIdcs.shape[0]
        loss2 = mseLoss(xform(res[3][testIdcs].detach(), statsWrt, fwd=False), wrt[testIdcs]).cpu().numpy()**0.5

        rmse.append((float(loss0), float(frac1), float(loss2)))
        
    print(i, end=' ')
    print(rmse[-1])

epoch 0 loss=(1574.7066650390625, 69.49566650390625, 257.9539489746094)
epoch 200 loss=(333.02374267578125, 29.247385025024414, 165.604736328125)
epoch 400 loss=(178.16184997558594, 24.014848709106445, 144.730712890625)
epoch 600 loss=(131.04388427734375, 19.453920364379883, 124.00162506103516)
epoch 800 loss=(212.84759521484375, 21.297258377075195, 115.05181121826172)
epoch 1000 loss=(141.8602294921875, 15.439140319824219, 92.44693756103516)
epoch 1200 loss=(57.05058288574219, 15.440918922424316, 71.92483520507812)
epoch 1400 loss=(52.714599609375, 22.166305541992188, 64.61856079101562)
epoch 1600 loss=(68.63125610351562, 14.600007057189941, 56.491050720214844)
epoch 1800 loss=(45.799930572509766, 13.422881126403809, 68.13480377197266)
epoch 2000 loss=(46.37934494018555, 12.660698890686035, 33.74162292480469)
epoch 2200 loss=(43.656288146972656, 12.615516662597656, 49.67647933959961)
epoch 2400 loss=(36.7657585144043, 12.338371276855469, 28.70901107788086)
epoch 2600 loss=(35.93799209

epoch 200 loss=(139.32289123535156, 24.587682723999023, 131.6905975341797)
epoch 400 loss=(132.99002075195312, 18.785064697265625, 81.69926452636719)
epoch 600 loss=(48.246429443359375, 14.26829719543457, 48.420467376708984)
epoch 800 loss=(44.785308837890625, 13.148442268371582, 64.20350646972656)
epoch 1000 loss=(48.429359436035156, 12.28709602355957, 32.330482482910156)
epoch 1200 loss=(47.733154296875, 12.553670883178711, 45.141685485839844)
epoch 1400 loss=(45.571250915527344, 11.799363136291504, 25.61336898803711)
epoch 1600 loss=(36.192745208740234, 11.730323791503906, 17.793481826782227)
Early stopping
Finished training
9 (26.348876953125, 0.7457627058029175, 11.453042984008789)


In [36]:
for a,b,c in rmse:
    print(c)

14.362210273742676
15.599970817565918
16.325719833374023
14.82560920715332
13.630986213684082
14.559538841247559
16.16264533996582
13.279645919799805
14.814011573791504
11.453042984008789


In [61]:
lat,e = pgigcn.getLatentsAndEdges(X,0)
e = mask(e)
e = F.softmax(e)
print(e)

tensor([[0.0000, 0.0761, 0.3380,  ..., 0.1505, 0.6623, 2.2184],
        [0.0761, 0.0000, 0.1209,  ..., 0.4447, 0.1637, 0.0927],
        [0.3380, 0.1209, 0.0000,  ..., 0.5613, 0.2664, 0.3323],
        ...,
        [0.1505, 0.4447, 0.5613,  ..., 0.0000, 0.7420, 0.2554],
        [0.6623, 0.1637, 0.2664,  ..., 0.7420, 0.0000, 0.6312],
        [2.2184, 0.0927, 0.3323,  ..., 0.2554, 0.6312, 0.0000]],
       device='cuda:0', grad_fn=<SubBackward0>)
