In [1]:
# Learn systems mask for N subsystems and apply a GCN
# This may require less than the full number of edges: 466 training subjects means each
# training epoch must find (466x465/2)*Nsys edges, and that is just for a single window!
# First do for resting state only
# Later, want to split each subject scan into windows with the same model
# Also implement multi-task GCN

# Load split

import pickle

badIDs = [605515760919, 601983541597]

with open('../../Splits/RegressionAllTasks/split1.bin', 'rb') as f:
    d = pickle.load(f)
    train = []
    trainDirty = d['train']
    test = []
    testDirty = d['test']
    
    # Remove bad subjects
    for subj in trainDirty:
        if subj['ID'] not in badIDs:
            train.append(subj)
            
    for subj in testDirty:
        if subj['ID'] not in badIDs:
            test.append(subj)
    
print(len(train))
print(len(test))
print(train[0])

466
156
{'meta': {'AgeInMonths': 146, 'Gender': 'F', 'Ethnicity': 'AFRICAN', 'AgeGroupID': 2, 'AgeGroupEdge1': 144, 'AgeGroupEdge2': 180}, 'rest': '30', 'nback': '31', 'emoid': '31', 'ID': 600262185931}


In [2]:
# Load data

def loadTimeseries(_id, _dir):
    ts = None
    with open('{:s}/{:d}.bin'.format(_dir, _id), 'rb') as f:
        ts = pickle.load(f)
    return ts

train_rest_ts = [loadTimeseries(int(subj['rest']), '../../rest_fmri_power264/timeseries') for subj in train]
# train_nback_ts = [loadTimeseries(int(subj['nback']), '../../nback_fmri_power264/timeseries') for subj in train]
# train_emoid_ts = [loadTimeseries(int(subj['emoid']), '../../emoid_fmri_power264/timeseries') for subj in train]

test_rest_ts = [loadTimeseries(int(subj['rest']), '../../rest_fmri_power264/timeseries') for subj in test]
# test_nback_ts = [loadTimeseries(int(subj['nback']), '../../nback_fmri_power264/timeseries') for subj in test]
# test_emoid_ts = [loadTimeseries(int(subj['emoid']), '../../emoid_fmri_power264/timeseries') for subj in test]

print('Complete')

Complete


In [3]:
import numpy as np

def normalizeSubjects(subjects):
    for i in range(len(subjects)):
        subj = subjects[i]
        subj -= np.mean(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        subj /= np.std(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        if np.sum(np.isnan(subj)) > 0:
            print(i)
        if np.sum(np.isinf(subj)) > 0:
            print(i)

normalizeSubjects(train_rest_ts)
# normalizeSubjects(train_nback_ts)
# normalizeSubjects(train_emoid_ts)

normalizeSubjects(test_rest_ts)
# normalizeSubjects(test_nback_ts)
# normalizeSubjects(test_emoid_ts)

print('Complete')

Complete


In [4]:
# Calculare pearson matrices

train_p = [np.corrcoef(sub) for sub in train_rest_ts]
test_p = [np.corrcoef(sub) for sub in test_rest_ts]

print(train_p[0].shape)
print('Complete')

(264, 264)
Complete


In [5]:
# Create feature vectors (right now just ages, maleness, and femaless)

X_train = []
for subj in train:
    maleness = 1 if subj['meta']['Gender'] == 'M' else 0
    femaleness = 1 if maleness == 0 else 0
    feat = np.array([subj['meta']['AgeInMonths'], maleness, femaleness])
    X_train.append(feat)
    
X_test = []
for subj in test:
    maleness = 1 if subj['meta']['Gender'] == 'M' else 0
    femaleness = 1 if maleness == 0 else 0
    feat = np.array([subj['meta']['AgeInMonths'], maleness, femaleness])
    X_test.append(feat)
    
X_train = np.vstack(X_train)
X_test = np.vstack(X_test)

print(X_train[10:20])
print('Complete')

[[193   1   0]
 [217   1   0]
 [233   1   0]
 [176   1   0]
 [116   0   1]
 [246   0   1]
 [164   1   0]
 [167   0   1]
 [202   0   1]
 [108   0   1]]
Complete


In [19]:
# Pearson matrices to pyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F

train_p_torch = [torch.from_numpy(p).float() for p in train_p]
test_p_torch = [torch.from_numpy(p).float() for p in test_p]

print('Complete')

Complete


In [11]:
import torch

torch.cuda.empty_cache()

t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0) 
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved

print('Total: {:f} Reserved: {:f} Allocated: {:f} Free: {:f}'.format(t,r,a,f))

Total: 4294967296.000000 Reserved: 1484783616.000000 Allocated: 1305116160.000000 Free: 179667456.000000


In [7]:
# Test reshape

a = torch.randn(3,3)
print(a)
a = a.reshape(1,3*3)
print(a)
a = a.expand(3,9)
print(a)

tensor([[-0.2910, -0.6650, -0.3234],
        [ 0.8243,  0.1680, -0.2487],
        [ 0.7814,  0.6550,  1.7756]])
tensor([[-0.2910, -0.6650, -0.3234,  0.8243,  0.1680, -0.2487,  0.7814,  0.6550,
          1.7756]])
tensor([[-0.2910, -0.6650, -0.3234,  0.8243,  0.1680, -0.2487,  0.7814,  0.6550,
          1.7756],
        [-0.2910, -0.6650, -0.3234,  0.8243,  0.1680, -0.2487,  0.7814,  0.6550,
          1.7756],
        [-0.2910, -0.6650, -0.3234,  0.8243,  0.1680, -0.2487,  0.7814,  0.6550,
          1.7756]])


In [15]:
# Test efficiency

import torch
import torch.nn as nn
import torch.nn.functional as F

Nsub = 450
Nroi = 264
Nmask = 8
Nfeat = 4
Nhid = 40
Nlat = 20

print('Initing tensors...')
ps = torch.randn(Nsub, Nroi, Nroi)
masks = torch.randn(Nroi, Nmask, requires_grad=True)
feat = torch.randn(Nsub, Nfeat)
fc1 = nn.Linear(Nsub*(Nmask+Nfeat),Nhid)
fc2 = nn.Linear(Nhid,Nlat)
fc3 = nn.Linear(Nsub*(Nmask+Nlat),Nhid)
fc4 = nn.Linear(Nhid,1)

print('Einsum...')
#edges = torch.einsum('aij, bij, im, jm -> abm', ps, ps, masks, masks)
mp = torch.einsum('aij,im,jm->aijm',ps,masks,masks)
edges = torch.einsum('aijm,bijm->abm',mp,mp)

print('Features and message...')
featRep = feat.reshape(1,Nsub*Nfeat).expand(Nsub,Nsub*Nfeat)
msg = edges.reshape(Nsub,Nsub*Nmask)
totalMsg = torch.cat((msg, featRep), dim=1)

print('Layer 1')
x = F.relu(fc1(totalMsg))
h = fc2(x)

print('Layer 2')
h = h.reshape(1,Nsub*Nlat).expand(Nsub,Nsub*Nlat)
h = torch.cat((msg, h), dim=1)
h = F.relu(fc3(h))
age = fc4(h)

print('Loss...')
loss = 1-torch.sum(age)
loss += torch.sum(masks)
for i in range(Nmask):
    for j in range(i+1,Nmask):
        loss += torch.sum(masks[:,i]*masks[:,j])

print('Backward...')
loss.backward()

print('Complete')


Initing tensors...
Einsum...


RuntimeError: CUDA out of memory. Tried to allocate 958.00 MiB (GPU 0; 4.00 GiB total capacity; 2.93 GiB already allocated; 101.19 MiB free; 2.95 GiB reserved in total by PyTorch)

In [20]:
# GCN model

import random

def updateListDict(d, key, item):
    if key not in d.keys():
        d[key] = []
    d[key].append(item)

Nbatch = 10
Nsys = 4
Nroi = train_p_torch[0].shape[0]
Ne = 2000
NeS = 30
NmaxE = 50
Nh = 10

class SysGcn(nn.Module):
    def __init__(self):
        super(SysGcn, self).__init__()
        self.sysM = []
        for i in range(Nsys):
            self.sysM.append(nn.Parameter(torch.eye(Nroi)))
        self.sysMp = nn.ParameterList(self.sysM)
        self.fc11 = nn.Linear(NmaxE*7,20)
        self.fc12 = nn.Linear(20,Nh)
        self.fc91 = nn.Linear(Nh,10)
        self.fc92 = nn.Linear(10,1)
        
    def forward(self, x):
        # Pearson matrices and features are set outside the loop
        ps = self.ps
        feat = self.feat
        # Mask all matrices
        mSums = []
        psM = []
        for i in range(Nsys):
            # Squared sum
            mSums.append(torch.sum(self.sysM[i])**2)
            psM.append([])
            for p in ps:
                psM[i].append(self.sysM[i]@p@self.sysM[i])
        # Sample out of all possible edges
        es = {}
        for ei in range(Ne):
            iSys = random.randint(0,Nsys-1)
            i = random.randint(0,len(ps)-1)
            j = random.randint(0,len(ps)-1)
            pi = psM[iSys][i]
            pj = psM[iSys][j]
            # Covariance norm
            n = torch.sum(pi*pj)/mSums[iSys]
            updateListDict(es, i, [n, iSys, j])
            updateListDict(es, j, [n, iSys, i])
        # Batching
        pred = torch.zeros(Nbatch)
        for a in range(len(x)):
            print('Done {0:d}'.format(a))
            p = x[a][0]
            age = x[a][1]
            m = x[a][2]
            f = x[a][3]
            # Calc edges to NeS other subjects
            for ei in range(NeS):
                iSys = random.randint(0,Nsys-1)
                i = random.randint(0,len(ps)-1)
                pi = psM[iSys][i]
                p_ = self.sysM[iSys]@p@self.sysM[iSys]
                # Covariance norm
                n = torch.sum(pi*p_)/mSums[iSys]
                updateListDict(es, i, [n, iSys, -1])
                updateListDict(es, -1, [n, iSys, i])
            # Graph convolutions
            # Hidden state
            h = torch.zeros(feat.shape[0]+1,10)
            # Layer 1
            for ei in es.keys():
                # Aggregate messages
                msg = torch.zeros(NmaxE, 7)
                for eii in range(len(es[ei])):
                    if eii >= NmaxE:
                        break
                    e = es[ei][eii]
                    msg[eii,0] = e[0]
                    msg[eii,1] = e[1]
                    if e[2] == -1:
                        msg[eii,2] = 0
                        msg[eii,3] = m
                        msg[eii,4] = f
                    else:
                        msg[eii,2] = self.feat[e[2],0]
                        msg[eii,3] = self.feat[e[2],1]
                        msg[eii,4] = self.feat[e[2],2]
                    if ei == -1:
                        msg[eii,5] = m
                        msg[eii,6] = f
                    else:
                        msg[eii,5] = self.feat[ei,1]
                        msg[eii,6] = self.feat[ei,2]
                # Permute rows
                msg = msg[torch.randperm(msg.size()[0])]
                # Flatten message
                msg = msg.flatten()
                # Update hidden state
                y = F.relu(self.fc11(msg))
                y = F.relu(self.fc12(y))
                h[ei,:] = y
            # Predict age based on embedding
            y = F.relu(self.fc91(h[-1,:]))
            pred[a] = self.fc92(y)
        return pred, self.sysM
    
sysgcn = SysGcn()
optim = torch.optim.Adam(sysgcn.parameters(), lr=1e-3)

print('Complete')

Complete


In [21]:
# Train Gcn

import random

N = len(train_p_torch)
running = 0
runningM1 = 0
runningM2 = 0
nEpoch = 1500
pPeriod = 1

# Want both discreteness and non-overlapping
def maskLoss(masks):
    disc = 0
    nono = 0
    for i in range(len(masks)):
        disc += torch.sum(masks[i])
        for j in range(len(masks)):
            if i != j:
                nono += torch.sum(masks[i]*masks[j])
    return disc, nono

sysgcn.ps = train_p_torch
sysgcn.feat = X_train

for epoch in range(nEpoch):
    batch = []
    truth = torch.zeros(Nbatch)
    for i in range(Nbatch):
        idx = random.randint(0,N-1)
        p = train_p_torch[idx]
        batch.append([p, X_train[idx,0], X_train[idx,1], X_train[idx,2]])
        truth[i] = X_train[idx, 0]
    optim.zero_grad()
    pred, masks = sysgcn(batch)          
    l1 = torch.sum((truth-pred)**2)
    l2, l3 = maskLoss(masks)
    running += l1.detach()
    runningM1 += l2.detach()
    runningM2 += l3.detach()
    loss = l1 + l2 + l3
    print('Going backwards')
    loss.backward()
    optim.step()
    if epoch % pPeriod == 0 or epoch == nEpoch-1:
        if epoch != 0:
            if epoch % pPeriod != 0:
                running /= epoch % pPeriod
                runningM1 /= epoch % pPeriod
                runningM2 /= epoch % pPeriod
            else:
                running /= pPeriod
                runningM1 /= pPeriod
                runningM2 /= pPeriod
        print('epoch {:d} loss={:f}'.format(epoch, running))
        running = 0
        runningM1 = 0
        runningM2 = 0

print('Finished training')

Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7
Done 8
Done 9
Going backwards
epoch 0 loss=371961.625000
Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7
Done 8
Done 9
Going backwards
epoch 1 loss=295143.625000
Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7
Done 8
Done 9
Going backwards
epoch 2 loss=314101.812500
Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7
Done 8
Done 9
Going backwards
epoch 3 loss=244405.062500
Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7
Done 8
Done 9
Going backwards
epoch 4 loss=289488.375000
Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7
Done 8
Done 9
Going backwards
epoch 5 loss=389040.531250
Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7
Done 8
Done 9
Going backwards
epoch 6 loss=393300.312500
Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7
Done 8
Done 9
Going backwards
epoch 7 loss=323468.656250
Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7
Done 8
Done 9
Going backwards
epoch 8 lo

KeyboardInterrupt: 