In [1]:
import sys

sys.path.append('..')

from LoadData2 import loadNbackEmoidAgesScansAndGenders, loadMeta

pncDir = '../../PNC_Good'

keys, nbackTs, emoidTs, ages = loadNbackEmoidAgesScansAndGenders(loadMeta(f'{pncDir}/MegaMeta3.pkl'))

print(nbackTs.shape)
print(emoidTs.shape)
print(ages.shape)
print(ages[0:10])

(650, 264, 231)
(650, 264, 210)
(650,)
[ 9.66666667 15.58333333 15.83333333 20.16666667 11.41666667 11.58333333
  9.41666667 17.5        15.25        9.58333333]


In [2]:
# Get FC and convert to torch

from LoadData2 import getFC
import torch

nbackP = getFC(nbackTs)
emoidP = getFC(emoidTs)

nbackP_t = torch.from_numpy(nbackP).reshape(650,264*264).float().cuda()
emoidP_t = torch.from_numpy(emoidP).reshape(650,264*264).float().cuda()
feat_t = torch.cat([nbackP_t, emoidP_t], dim=1)
ages_t = torch.from_numpy(ages).float().cuda()

print(nbackP_t.shape)
print(emoidP_t.shape)
print(ages_t.shape)

torch.Size([650, 69696])
torch.Size([650, 69696])
torch.Size([650])


In [3]:
def normalize(A):
    if A.shape[0] != A.shape[1]:
        raise Exception("Bad A shape")
    d = torch.sum(A,dim=1)**0.5
    return ((A/d).T/d).T

# Cosine similarity

def cosineSim(a, b):
    nB = a.shape[0]
    e = torch.einsum('ai,bi->ab',a,b)
    aa = torch.einsum('ai,ai->a',a,a)**0.5
    bb = torch.einsum('bi,bi->b',b,b)**0.5
    e /= aa.unsqueeze(1)
    e /= bb.unsqueeze(1).T
    return normalize(e-torch.eye(nB).float().cuda())

A = cosineSim(feat_t, feat_t)

print(A.shape)

torch.Size([650, 650])


In [4]:
# Load model

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.dense import DenseGCNConv

class GCN(nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.gc1 = DenseGCNConv(2*264*264,100).float().cuda()
        self.gc2 = DenseGCNConv(100,1).float().cuda()
        self.loss = nn.MSELoss()

    def forward(self, x):
        A = x[0]
        z = x[1]
        a = F.relu(self.gc1(z,A))
        a = self.gc2(a,A)
        return a
    
gcn = GCN()
optim = torch.optim.Adam(gcn.parameters(), lr=1e-5, weight_decay=0.2)

nEpoch = 4000
pPrint = 200

for epoch in range(nEpoch):
    optim.zero_grad()
    pred = gcn([A, feat_t]).flatten()
    loss = gcn.loss(pred, ages_t)
    loss.backward()
    optim.step()
    if epoch % pPrint == 0 or epoch == nEpoch-1:
        print(f'epoch {epoch} loss={loss}')

print('Completed GCN')

epoch 0 loss=236.02418518066406
epoch 200 loss=9.154472351074219
epoch 400 loss=5.707479000091553
epoch 600 loss=4.358497619628906
epoch 800 loss=3.577888011932373
epoch 1000 loss=3.0755817890167236
epoch 1200 loss=2.712596893310547
epoch 1400 loss=2.435807943344116
epoch 1600 loss=2.216533660888672
epoch 1800 loss=2.0332930088043213
epoch 2000 loss=1.8760297298431396
epoch 2200 loss=1.8268015384674072
epoch 2400 loss=1.556636095046997
epoch 2600 loss=1.4357376098632812
epoch 2800 loss=1.325613021850586
epoch 3000 loss=1.2247850894927979
epoch 3200 loss=1.1322771310806274
epoch 3400 loss=1.0468876361846924
epoch 3600 loss=0.9652971029281616
epoch 3800 loss=0.8902499675750732
epoch 3999 loss=0.8226085305213928
Completed GCN


In [5]:
pred = gcn([A, feat_t]).flatten()
loss = gcn.loss(pred, ages_t)
                
print(loss)

tensor(0.8223, device='cuda:0', grad_fn=<MseLossBackward0>)


In [6]:
# Save model for explainer

torch.save(gcn.state_dict(), '../../Work/Explainer/GCN_NbackEmoid3.pyt')

print('Complete')

Complete
