In [1]:
import pickle

para = 'Rest'

# Load meta dict

with open('../../AllSubjectsMeta.bin', 'rb') as f:
    meta = pickle.load(f)
    
# Load rest subject ids and splits

with open('../../../BrainAgePredictWrat/' + para + 'Only10FoldSplit.bin', 'rb') as f:
    splits = pickle.load(f)
    subids = splits['cand' + para + 'YesWrat']
    groups = splits['groups']
    
print(len(subids))

804


In [2]:
import numpy as np

paradigm = 'rest'
subidsNp = np.array(subids)

# Load timeseries

def loadSeries(prefix, para, idx):
    with open('{:}/{:}_fmri_power264/timeseries/{:}.bin'.format(prefix, para, idx), 'rb') as f:
        return pickle.load(f)

all_ts = [loadSeries('../..', paradigm, meta[subid][paradigm]) for subid in subidsNp]

print('Loading complete')

Loading complete


In [3]:
# Normalize timeseries

def normalizeSubjects(subjects):
    for i in range(len(subjects)):
        subj = subjects[i]
        subj -= np.mean(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        subj /= np.std(subj, axis=1, keepdims=True)@np.ones([1,subj.shape[1]])
        if np.sum(np.isnan(subj)) > 0:
            print('nan {:}'.format(i))
        if np.sum(np.isinf(subj)) > 0:
            print('inf {:}'.format(i))
            
normalizeSubjects(all_ts)

print('Complete')

Complete


In [4]:
# Create feature vectors (right now just ages, maleness, and femaless)

X_all = []
for subid in subidsNp:
    subj = meta[subid]
    maleness = 1 if subj['meta']['Gender'] == 'M' else 0
    femaleness = 1 if maleness == 0 else 0
    feat = np.array([subj['meta']['AgeInMonths'], maleness, femaleness])
    X_all.append(feat)
X_all = np.vstack(X_all)

print(X_all[10:20])
print('Complete')

[[223   1   0]
 [190   0   1]
 [197   0   1]
 [145   1   0]
 [148   0   1]
 [142   0   1]
 [123   1   0]
 [176   1   0]
 [129   0   1]
 [173   1   0]]
Complete


In [5]:
# Calculate pearson matrices

all_p = np.stack([np.corrcoef(ts) for ts in all_ts])

print(all_p.shape)

(804, 264, 264)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

nRoi = 264

class Trans(nn.Module):
    def __init__(self):
        super(Trans, self).__init__()
        self.trans = nn.Transformer(d_model=nRoi, nhead=8, 
                                    num_encoder_layers=2, num_decoder_layers=2, 
                                    dim_feedforward=100, dropout=0.1).cuda()
        self.fc1 = nn.Linear(264,50).cuda()
        self.fc2 = nn.Linear(50,1).cuda()
        
    def forward(self, x):
        nB = x.shape[0]
#         x = torch.cat([x, torch.ones(nB, nRoi, 1).cuda()], dim=2)
        x = x.permute(2,0,1)
        x = self.trans(x,x)#,src_mask=mask,tgt_mask=mask)
        x = x[0,:,:]
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

all_p_t = torch.from_numpy(all_p).float()
agesPredTrans = torch.zeros(all_ts_t.shape[0])
# mask = (torch.arange(0,232).float().unsqueeze(1)@torch.ones(1,232)).cuda()

import random

# Loop over splits 
splitIdx = 0
for trainIdcs, testIdcs in groups:
    trans = Trans()
    optim = torch.optim.Adam(trans.parameters(), lr=2e-4)

    train_p = all_p_t[trainIdcs].cuda()
    test_p = all_p_t[testIdcs]

    X_train = X_all[trainIdcs,0]

    running = 0
    nEpoch = 3500
    pPeriod = 100
    nB = 10
    N = train_p.shape[0]

    for epoch in range(nEpoch):
        batch = []
        truth = torch.zeros(nB).float().cuda()
        for i in range(nB):
            idx = random.randint(0,N-1)
            batch.append(train_p[idx])
            truth[i] = X_train[idx]
        batch = torch.stack(batch)
        optim.zero_grad()
        pred = trans(batch).flatten()
        loss = torch.sum((pred-truth)**2)
        loss.backward()
        running += loss.detach().cpu()
        optim.step()
        if epoch % pPeriod == 0 or epoch == nEpoch-1:
            if epoch != 0:
                if epoch % pPeriod != 0:
                    running /= epoch % pPeriod
                else:
                    running /= pPeriod
            print('epoch {:d} loss={:f}'.format(epoch, ((running/nB)**0.5)/12))
            running = 0

    print('Finished training')
    
    for i in range(int(test_p.shape[0]/10)+1):
        st = i*10
        end = st+10
        if st >= test_p.shape[0]:
            break
        if end > test_p.shape[0]:
            end = test_p.shape[0]
        test_p_b = test_p[st:end].cuda()
        agesPredTrans[testIdcs[st:end]] = trans(test_p_b).flatten().detach().cpu()
    
    print('Completed split {:}'.format(splitIdx))
    splitIdx += 1

print('All complete')