In [5]:
# PNC

import pickle
import re
import numpy as np

basis_file = '/home/anton/Documents/Tulane/Research/PNC_Good/AngleBasis1.pkl'
demodir = '/home/anton/Documents/Tulane/Research/ImageNomer/data/anton/cohorts/PNC/'

basis = pickle.load(open(basis_file, 'rb'))

thetas = []
jitter = []
age = []
sex = []
race = []
fc = []

demo = pickle.load(open(f'{demodir}/demographics.pkl', 'rb'))

for subtask in basis:
    m = re.search('([^-]+)-(.+)', subtask)
    sub = m.group(1)
    task = m.group(2)
    if sub not in demo['Race'] or demo['Race'][sub] not in ['AA', 'EA']:
        continue
    a = demo['age_at_cnb'][sub]
    s = demo['Sex'][sub] == 'M'
    r = demo['Race'][sub] == 'AA'
    age.append(a)
    sex.append(s)
    race.append(r)
    thetas.append(basis[subtask]['thetas'])
    jitter.append(basis[subtask]['jitter'])
    p = np.load(f'{demodir}/fc/{sub}_task-{task}_fc.npy')
    fc.append(p)
    
thetas = np.stack(thetas)
jitter = np.stack(jitter)
age = np.array(age).astype('int')
sex = np.array(sex).astype('int')
race = np.array(race).astype('int')
fc = np.stack(fc)

print([a.shape for a in [thetas, jitter, sex, race, age, fc]])
print(np.mean(1-sex))
print(np.mean(1-race))
print(np.mean(age))

[(3849, 1, 264), (3849, 1, 264), (3849,), (3849,), (3849,), (3849, 34716)]
0.5263704858404781
0.5188360613146272
14.398285268901013


In [6]:
def rmse(yhat, y):
    if isinstance(yhat, np.ndarray) or isinstance(yhat, int):
        f = np.mean
    else:
        f = torch.mean
    return f((y-yhat)**2)**0.5

def tops(thetas, jitter):
    t0 = np.expand_dims(thetas, 2)
    t1 = np.expand_dims(thetas, 3)
    j0 = np.expand_dims(jitter, 2)
    j1 = np.expand_dims(jitter, 3)
    ps = np.cos(t0-t1)*(j0*j1)
    a,b = np.triu_indices(264, 1)
    ps = ps[:,:,a,b]
    return ps

paps = []
pres = []

for i in range(0,3849,500):
    ps = tops(thetas[i:i+500], jitter[i:i+500])
    aps = np.mean(ps, axis=1)
    res = fc[i:i+500] - aps
    paps.append(aps)
    pres.append(res)
    print(f'Done {i}')
    
aps = np.concatenate(paps)
res = np.concatenate(pres)

print(ps.shape)
print(aps.shape)
print(res.shape)

Done 0
Done 500
Done 1000
Done 1500
Done 2000
Done 2500
Done 3000
Done 3500
(349, 1, 34716)
(3849, 34716)
(3849, 34716)


In [7]:
aps1 = aps

print('Done')

Done


In [11]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

def tonp(x):
    return x.detach().cpu().numpy()

def totorch(x):
    return torch.from_numpy(x).float().cuda()

def totorchidcs(x):
    return torch.from_numpy(x).long().cuda() #F.one_hot(torch.from_numpy(x)).float().cuda()

def rmse(yt, yhat):
    return torch.mean((yt-yhat)**2)**0.5

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.l1 = nn.Linear(34716,100).float().cuda()
        self.l2 = nn.Linear(100,2).float().cuda()
        
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = self.l2(x)
        return x
    
def fit_mlp(xtr, ytr):
    xtr = totorch(xtr)
    ytr = totorchidcs(ytr)
    
    ce = nn.CrossEntropyLoss()
    mlp = MLP()
    optim = torch.optim.Adam(mlp.parameters(), lr=1e-3, weight_decay=1e-3)

    nepochs = 1000
    pperiod = 100

    for epoch in range(nepochs):
        optim.zero_grad()
        yhat = mlp(xtr)
        loss = ce(yhat, ytr)
        loss.backward()
        optim.step()
#         if epoch % pperiod == 0 or epoch == nepochs-1:
#             print(f'{epoch} {float(loss)}')

#     print('Complete')
    
    return mlp

class BrainNetCNN(nn.Module):
    def __init__(self):
        super(BrainNetCNN, self).__init__()
        self.cnn1 = nn.Conv1d(264,10,5).float().cuda()
        self.cnn2 = nn.Conv1d(10,10,5).float().cuda()
        self.pool = nn.AvgPool1d(200,ceil_mode=True).float().cuda()
        self.l1 = nn.Linear(10,2).float().cuda()
        
    def forward(self, x):
        x = F.relu(self.cnn1(x))
        x = F.relu(self.cnn2(x))
        x = self.pool(x).squeeze()
        x = self.l1(x)
        return x
      
def fit_cnn(xtr, ytr):
    xtr = totorch(xtr)
    ytr = totorchidcs(ytr)
    
    ce = nn.CrossEntropyLoss()
    bnc = BrainNetCNN()
    optim = torch.optim.Adam(bnc.parameters(), lr=1e-3, weight_decay=1e-3)

    nepochs = 1000
    pperiod = 100

    for epoch in range(nepochs):
        optim.zero_grad()
        yhat = bnc(xtr)
        loss = ce(yhat, ytr)
        loss.backward()
        optim.step()
#         if epoch % pperiod == 0 or epoch == nepochs-1:
#             print(f'{epoch} {float(loss)}')

#     print('Complete')
    
    return bnc
    
def predict_help(model, xt):
    xt = totorch(xt)
    with torch.no_grad():
        yhat = model(xt)
        return tonp(yhat)
    
print('Done')

Done


In [13]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.metrics import roc_auc_score

def cat(x, **kwargs):
    return np.concatenate(x, **kwargs)

def rmse(yhat, y):
    if isinstance(yhat, np.ndarray) or isinstance(yhat, int):
        f = np.mean
    else:
        f = torch.mean
    return f((y-yhat)**2)**0.5

def predict(xtr, xt, ytr, yt, lst):
    reg = fit_mlp(xtr, ytr)
    p = predict_help(reg, xt)
    acc = roc_auc_score(yt, p[:,1])
    print(acc)
    lst.append(acc)
    return p
    
def get_res(fctr, fct, abtr, abt):
    return fctr-abtr, fct-abt

def combine(yt, p0, p1, lst):
    p = (p0+p1)/2
    acc = roc_auc_score(yt, p[:,1])
    print(acc)
    lst.append(acc)

rfc = []

rab1 = []
rab20 = []

rres1 = []
rres20 = []

rbest = []

for i in range(5):

    x0tr, x0t, x1tr, x1t, x2tr, x2t, ytr, yt = train_test_split(
        fc, aps1, aps20, sex, stratify=sex, train_size=0.8)
    
#     mu = np.mean(ytr)
#     ytr = ytr-mu
#     yt = yt-mu
    
    x1atr, x1at = get_res(x0tr, x0t, x1tr, x1t)
    x2atr, x2at = get_res(x0tr, x0t, x2tr, x2t)

    predict(x0tr, x0t, ytr, yt, rfc)
    
    p1 = predict(x1tr, x1t, ytr, yt, rab1)
    p2 = predict(x2tr, x2t, ytr, yt, rab20)
    
    p1a = predict(x1atr, x1at, ytr, yt, rres1)
    p2a = predict(x2atr, x2at, ytr, yt, rres20)

    combine(yt, p2, p1a, rbest)
    
    print('---')
    
print(np.mean(rfc), np.std(rfc))
print(np.mean(rbest), np.std(rbest))

print(np.mean(rab1), np.std(rab1))
print(np.mean(rab20), np.std(rab20))

print(np.mean(rres1), np.std(rres1))
print(np.mean(rres20), np.std(rres20))

0.9577608658887198
0.8117774395399966
0.9068560798241163
0.9569220361914426
0.9252359208523592
0.9658785726365635
---
0.8150076103500761
0.7846101809572129
0.9424792829359039
0.9686656519533232
0.9199594114662608
0.9766480635887028
---
0.9487096228648739
0.6993066125486216
0.9028310502283106
0.9513478775579233
0.9159884999154405
0.9633891425672247
---
0.947911381701336
0.8195569085066802
0.9402266193133774
0.9541011330965669
0.9178893962455605
0.9659868087265346
---
0.8222695755115845
0.792673769660071
0.9242279722645019
0.9520446473871131
0.9245594452900389
0.9612108912565533
---
0.898331811263318 0.06520162023390498
0.9666226957551158 0.0053143121074506776
0.7815849822425165 0.04302393476894367
0.923324200913242 0.01639874357946309
0.9566162692372739 0.006328707676963734
0.920726534753932 0.0036362775842176285
