In [1]:
import pickle

# Load meta dict

with open('../../PNC/AllSubjectsMeta.bin', 'rb') as f:
    meta = pickle.load(f)
    
# Load rest subject ids and splits

with open('../../Work/Abstract/PaperBin/AllThreeSplit.bin', 'rb') as f:
    splits = pickle.load(f)
    subids = splits['allThreeYesWrat']
    groups = splits['groups']
    
print(len(subids))

593


In [2]:
import numpy as np

subidsNp = np.array(subids)

# Load timeseries

def loadSeries(prefix, para, idx):
    with open('{:}/{:}_fmri_power264/timeseries/{:}.bin'.format(prefix, para, idx), 'rb') as f:
        return pickle.load(f)

rest_ts = [loadSeries('../../PNC', 'rest', meta[subid]['rest']) for subid in subidsNp]
nback_ts = [loadSeries('../../PNC', 'nback', meta[subid]['nback']) for subid in subidsNp]
emoid_ts = [loadSeries('../../PNC', 'emoid', meta[subid]['emoid']) for subid in subidsNp]

print('Loading complete')

Loading complete


In [3]:
import numpy as np

def normalizeSubjects(subjects):
    for i in range(len(subjects)):
        subj = subjects[i]
        subj -= np.mean(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        subj /= np.std(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        if np.sum(np.isnan(subj)) > 0:
            print(i)
        if np.sum(np.isinf(subj)) > 0:
            print(i)

normalizeSubjects(rest_ts)
normalizeSubjects(nback_ts)
normalizeSubjects(emoid_ts)

print('Complete')

Complete


In [5]:
# Convert raw timeseries to torch

import torch

rest_t = torch.from_numpy(np.stack(rest_ts)).float().cuda()
nback_t = torch.from_numpy(np.stack(nback_ts)).float().cuda()
emoid_t = torch.from_numpy(np.stack(emoid_ts)).float().cuda()

print(emoid_t.shape)
print('Complete')

torch.Size([593, 264, 210])
Complete


In [6]:
# Create feature vectors (right now just ages, maleness, and femaless)

males = 0
females = 0

X_all = []
for subid in subidsNp:
    subj = meta[subid]
    maleness = 1 if subj['meta']['Gender'] == 'M' else 0
    femaleness = 1 if maleness == 0 else 0
    feat = np.array([subj['meta']['AgeInMonths'], maleness, femaleness])
    X_all.append(feat)
    if maleness == 1:
        males += 1
    if femaleness == 1:
        females += 1
X_all = np.vstack(X_all)

print(f'{males} {females}')
print(X_all[10:20])
print('Complete')

271 322
[[223   1   0]
 [190   0   1]
 [197   0   1]
 [145   1   0]
 [148   0   1]
 [142   0   1]
 [123   1   0]
 [176   1   0]
 [129   0   1]
 [173   1   0]]
Complete


In [7]:
wratDict = dict()

with open('../../PNC/wrat.csv', 'r') as f:
    lines = f.readlines()[1:]
    for line in lines:
        line = line.strip().split(',')
        wratDict[line[0]] = {'raw': line[2], 'std': line[3]}

wrat = []

for key in subids:
    wrat.append(float(wratDict[str(key)]['std']))
    
wrat = np.array(wrat)
wrat_t = torch.from_numpy(wrat).float().cuda()

print('Complete')

Complete


In [61]:
import torch.nn as nn
import torch.nn.functional as F

def arith(n):
    return int(n*(n+1)/2)

def mask(e):
    return e - torch.diag(torch.diag(e))

class MiniPgi(nn.Module):
    def __init__(self, w, nRoi, nTgts, dp=0.5, relu=0.1):
        super(MiniPgi, self).__init__()
        if type(w) == int:
            w = nTgts*[w]
        self.masks = []
        self.relu = []
        for i in range(nTgts):
            self.masks.append(nn.Parameter(
                0.0001*torch.ones(nRoi,w[i]).float().cuda()
                +0.00001*torch.randn(nRoi,w[i]).float().cuda()
            ))
            rel = relu if type(relu) == float or type(relu) == int else relu[i]
            self.relu.append(nn.LeakyReLU(negative_slope=rel))
        self.dp = nn.Dropout(p=dp)
        
    def getLatentsAndEdges(self, x, idx):
        y = torch.einsum('abc,bd->acd', x, self.masks[i])
        y = torch.einsum('abc,abd->acd', y, y)
        y = self.relu[idx](y)
        y = y.reshape(y.shape[0],-1)
        e = y@y.T
        return y,e
    
    def forward(self, x, age=None, gender=None, wrat=None):
        x = self.dp(x)
        lbls = [age, gender, wrat]
        res = []
        for i in range(len(self.masks)):
            _, e = self.getLatentsAndEdges(x, i)
            idcs = torch.logical_not(torch.any(lbls[i], dim=1))
            e[:,idcs] = 0
            e = mask(e)
#             s = torch.sum(e, dim=1)
#             e = e/s.unsqueeze(1)
            res.append(e@lbls[i])
        return res
        
print('Complete')

Complete


In [63]:
ceLoss = torch.nn.CrossEntropyLoss()
mseLoss = torch.nn.MSELoss()
nEpochs = 10000
pPeriod = 200
thresh = torch.Tensor((40,3.2e-1,20)).float().cuda()
    
rmse = []

for i in range(10):
    pgigcn = MiniPgi((3, 3, 3), 264, 3, 0.5, (0, 0, 0))
    optim = torch.optim.Adam(pgigcn.masks, lr=2e-5, weight_decay=2e-5)

    trainIdcs = groups[i][0]
    testIdcs = groups[i][1]
    
    X = emoid_t
    X = X[trainIdcs]
    Y = torch.from_numpy(X_all[trainIdcs]).float().cuda()
    
    gen = Y[:,1:]
    wrt = wrat_t[trainIdcs].unsqueeze(1)
    age = Y[:,0].unsqueeze(1)
    
    for epoch in range(nEpochs):
        optim.zero_grad()
        res = pgigcn(X, age=age, gender=gen, wrat=wrt)
        loss0 = mseLoss(res[0], age)
        loss1 = ceLoss(res[1], gen)
        loss2 = mseLoss(res[2], wrt)
        loss = torch.stack([loss0, loss1, loss2])
        torch.sum(loss).backward()
        optim.step()
        if (epoch % pPeriod == 0 or epoch == nEpochs-1):
            print(f'epoch {epoch} loss={(float(loss0), float(loss1), float(loss2))}')
        if torch.all(loss < thresh):
            print('Early stopping')
            break
            
    print('Finished training')
    
    pgigcn.eval()
    
    X = emoid_t
    Y = torch.from_numpy(X_all).float().cuda()
        
    gen = Y[:,1:]
    wrt = wrat_t.unsqueeze(1)
    age = Y[:,0].unsqueeze(1)

    gen0 = gen.clone().detach()
    gen0[testIdcs] = 0
    wrt0 = wrt.clone().detach()
    wrt0[testIdcs] = 0
    age0 = age.clone().detach()
    age0[testIdcs] = 0
    
    res = pgigcn(X, age=age0, gender=gen0, wrat=wrt0)
    loss0 = mseLoss(res[0][testIdcs].detach(), age[testIdcs]).cpu().numpy()**0.5
    frac1 = torch.sum(torch.argmax(res[1].detach(), dim=1)[testIdcs] 
                     == torch.argmax(gen[testIdcs], dim=1))/testIdcs.shape[0]
    loss2 = mseLoss(res[2][testIdcs].detach(), wrt[testIdcs]).cpu().numpy()**0.5
    
    rmse.append((float(loss0), float(frac1), float(loss2)))
    print(i, end=' ')
    print(rmse[-1])

epoch 0 loss=(1862026.75, 0.7026510238647461, 629329.875)
epoch 200 loss=(7189.4296875, 0.6940171122550964, 1631.5980224609375)
epoch 400 loss=(6163.46826171875, 0.6938968896865845, 1364.99560546875)
epoch 600 loss=(3810.91796875, 0.6932421922683716, 807.3846435546875)
epoch 800 loss=(2073.426513671875, 0.6917603015899658, 504.9845275878906)
epoch 1000 loss=(1933.379638671875, 0.6910960078239441, 516.8662109375)
epoch 1200 loss=(1842.461669921875, 0.6909302473068237, 515.888671875)
epoch 1400 loss=(1715.385498046875, 0.6912573575973511, 513.0823364257812)
epoch 1600 loss=(1623.427734375, 0.691146969795227, 514.0126342773438)
epoch 1800 loss=(1539.9017333984375, 0.6911959052085876, 523.5892333984375)
epoch 2000 loss=(1495.183837890625, 0.6914143562316895, 524.4559326171875)
epoch 2200 loss=(1421.0830078125, 0.6914901733398438, 517.8359375)
epoch 2400 loss=(1383.4007568359375, 0.6915210485458374, 494.3544921875)
epoch 2600 loss=(1302.918212890625, 0.6916968822479248, 515.7525634765625)
e

epoch 1400 loss=(1747.8564453125, 0.685832679271698, 520.3629150390625)
epoch 1600 loss=(1700.1900634765625, 0.6860663294792175, 504.24285888671875)
epoch 1800 loss=(1581.690673828125, 0.6864323019981384, 499.7524719238281)
epoch 2000 loss=(1479.15087890625, 0.6862672567367554, 489.9795837402344)
epoch 2200 loss=(1409.0084228515625, 0.6867790222167969, 505.6539611816406)
epoch 2400 loss=(1336.0916748046875, 0.6863257884979248, 526.20654296875)
epoch 2600 loss=(1290.2879638671875, 0.686801016330719, 494.7808532714844)
epoch 2800 loss=(1253.623779296875, 0.6869903802871704, 487.48382568359375)
epoch 3000 loss=(1179.2041015625, 0.6867012977600098, 506.86468505859375)
epoch 3200 loss=(1134.618896484375, 0.686865508556366, 506.1454772949219)
epoch 3400 loss=(1192.4969482421875, 0.6875489354133606, 502.30029296875)
epoch 3600 loss=(1159.9266357421875, 0.6878929138183594, 501.081787109375)
epoch 3800 loss=(1066.6346435546875, 0.6877288818359375, 516.8304443359375)
epoch 4000 loss=(1095.240112

KeyboardInterrupt: 

In [49]:
y,e = pgigcn.getLatentsAndEdges(X, 0)
print(y)
print(e)

tensor([[ 0.0063,  0.0046,  0.0036,  ..., -0.0013, -0.0013,  0.0039],
        [ 0.0040,  0.0033,  0.0033,  ...,  0.0008,  0.0007,  0.0035],
        [ 0.0085,  0.0049,  0.0029,  ..., -0.0046, -0.0045,  0.0043],
        ...,
        [ 0.0046,  0.0044,  0.0043,  ...,  0.0016,  0.0017,  0.0056],
        [ 0.0064,  0.0050,  0.0040,  ..., -0.0008, -0.0006,  0.0068],
        [ 0.0094,  0.0060,  0.0040,  ..., -0.0033, -0.0034,  0.0041]],
       device='cuda:0', grad_fn=<ReshapeAliasBackward0>)
tensor([[0.0010, 0.0008, 0.0011,  ..., 0.0011, 0.0013, 0.0012],
        [0.0008, 0.0009, 0.0007,  ..., 0.0013, 0.0013, 0.0008],
        [0.0011, 0.0007, 0.0016,  ..., 0.0008, 0.0013, 0.0016],
        ...,
        [0.0011, 0.0013, 0.0008,  ..., 0.0018, 0.0017, 0.0010],
        [0.0013, 0.0013, 0.0013,  ..., 0.0017, 0.0019, 0.0014],
        [0.0012, 0.0008, 0.0016,  ..., 0.0010, 0.0014, 0.0016]],
       device='cuda:0', grad_fn=<MmBackward0>)
