In [11]:
# PNC

import pickle
import re
import numpy as np

basis_file = '/home/anton/Documents/Tulane/Research/PNC_Good/AngleBasisNoJit1.pkl'
demodir = '/home/anton/Documents/Tulane/Research/ImageNomer/data/anton/cohorts/PNC/'

basis = pickle.load(open(basis_file, 'rb'))

thetas = []
jitter = []
age = []
sex = []
race = []
fc = []

demo = pickle.load(open(f'{demodir}/demographics.pkl', 'rb'))

for subtask in basis:
    m = re.search('([^-]+)-(.+)', subtask)
    sub = m.group(1)
    task = m.group(2)
    if sub not in demo['Race'] or demo['Race'][sub] not in ['AA', 'EA']:
        continue
    a = demo['age_at_cnb'][sub]
    s = demo['Sex'][sub] == 'M'
    r = demo['Race'][sub] == 'AA'
    age.append(a)
    sex.append(s)
    race.append(r)
    thetas.append(basis[subtask]['thetas'])
    jitter.append(basis[subtask]['jitter'])
    p = np.load(f'{demodir}/fc/{sub}_task-{task}_fc.npy')
    fc.append(p)
    
thetas = np.stack(thetas)
jitter = np.stack(jitter)
age = np.array(age).astype('int')
sex = np.array(sex).astype('int')
race = np.array(race).astype('int')
fc = np.stack(fc)

print([a.shape for a in [thetas, jitter, sex, race, age, fc]])
print(np.mean(1-sex))
print(np.mean(1-race))
print(np.mean(age))

[(3849, 1, 264), (3849, 1, 264), (3849,), (3849,), (3849,), (3849, 34716)]
0.5263704858404781
0.5188360613146272
14.398285268901013


In [2]:
# Jitter Only

import matplotlib.pyplot as plt

def mat2vec(mat):
    a,b = np.triu_indices(264,1)
    return mat[:,a,b]

a,b = np.triu_indices(264,1)
X = np.ones((fc.shape[0],264,264))
X[:,a,b] = fc
X[:,b,a] = fc
w, v = np.linalg.eig(X)
print('Done eigs')

w[:,20:] = 0
aps20 = np.real(np.einsum('nab,nb,ncb->nac',v,w,v))
aps20 = mat2vec(aps20)
print('Done 20')

w[:,15:] = 0
aps15 = np.real(np.einsum('nab,nb,ncb->nac',v,w,v))
aps15 = mat2vec(aps15)
print('Done 15')

w[:,10:] = 0
aps10 = np.real(np.einsum('nab,nb,ncb->nac',v,w,v))
aps10 = mat2vec(aps10)
print('Done 10')

w[:,5:] = 0
aps5 = np.real(np.einsum('nab,nb,ncb->nac',v,w,v))
aps5 = mat2vec(aps5)
print('Done 5')

w[:,3:] = 0
aps3 = np.real(np.einsum('nab,nb,ncb->nac',v,w,v))
aps3 = mat2vec(aps3)
print('Done 3')

w[:,1:] = 0
aps1 = np.real(np.einsum('nab,nb,ncb->nac',v,w,v))
aps1 = mat2vec(aps1)
print('Done 1')

w = None
v = None
X = None

Done eigs
Done 20
Done 15
Done 10
Done 5
Done 3
Done 1


In [12]:
# Angle only or regular AngleBasis

def rmse(yhat, y):
    if isinstance(yhat, np.ndarray) or isinstance(yhat, int):
        f = np.mean
    else:
        f = torch.mean
    return f((y-yhat)**2)**0.5

def tops(thetas, jitter):
    t0 = np.expand_dims(thetas, 2)
    t1 = np.expand_dims(thetas, 3)
    j0 = np.expand_dims(jitter, 2)
    j1 = np.expand_dims(jitter, 3)
    ps = np.cos(t0-t1)*(j0*j1)
    a,b = np.triu_indices(264, 1)
    ps = ps[:,:,a,b]
    return ps

paps = []
pres = []

for i in range(0,fc.shape[0],500):
    ps = tops(thetas[i:i+500], jitter[i:i+500])
    aps = np.mean(ps, axis=1)
    res = fc[i:i+500] - aps
    paps.append(aps)
    pres.append(res)
    print(f'Done {i}')
    
aps = np.concatenate(paps)
res = np.concatenate(pres)

print(ps.shape)
print(aps.shape)
print(res.shape)

aps1 = aps

Done 0
Done 500
Done 1000
Done 1500
Done 2000
Done 2500
Done 3000
Done 3500
(349, 1, 34716)
(3849, 34716)
(3849, 34716)


In [13]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

def tonp(x):
    return x.detach().cpu().numpy()

def totorch(x):
    return torch.from_numpy(x).float().cuda()

def totorchidcs(x):
    return torch.from_numpy(x).long().cuda() #F.one_hot(torch.from_numpy(x)).float().cuda()

def rmse(yt, yhat):
    return torch.mean((yt-yhat)**2)**0.5

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.l1 = nn.Linear(34716,100).float().cuda()
        self.l2 = nn.Linear(100,2).float().cuda()
        
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = self.l2(x)
        return x
    
def fit_mlp(xtr, ytr, verbose=False):
    xtr = totorch(xtr)
    ytr = totorchidcs(ytr)
    
    ce = nn.CrossEntropyLoss()
    mlp = MLP()
    optim = torch.optim.Adam(mlp.parameters(), lr=5e-4, weight_decay=5e-4)

    nepochs = 1000
    pperiod = 100

    for epoch in range(nepochs):
        optim.zero_grad()
        yhat = mlp(xtr)
        loss = ce(yhat, ytr)
        loss.backward()
        optim.step()
        if verbose:
            if epoch % pperiod == 0 or epoch == nepochs-1:
                print(f'{epoch} {float(loss)}')

    if verbose:
        print('Complete')
    
    return mlp
    
def predict_help(model, xt):
    xt = totorch(xt)
    with torch.no_grad():
        yhat = model(xt)
        return tonp(yhat)
    
print('Done')

Done


In [15]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.metrics import roc_auc_score

def cat(x, **kwargs):
    return np.concatenate(x, **kwargs)

def rmse(yhat, y):
    if isinstance(yhat, np.ndarray) or isinstance(yhat, int):
        f = np.mean
    else:
        f = torch.mean
    return f((y-yhat)**2)**0.5

def predict(xtr, xt, ytr, yt, lst):
    reg = fit_mlp(xtr, ytr)
    p = predict_help(reg, xt)
    acc = roc_auc_score(yt, p[:,1])
    print(acc)
    lst.append(acc)
    return p
    
def get_res(fctr, fct, abtr, abt):
    return fctr-abtr, fct-abt

def combine(yt, p0, p1, lst):
    p = (p0+p1)/2
    acc = roc_auc_score(yt, p[:,1])
    print(acc)
    lst.append(acc)

rfc = []

rab1 = []
rab3 = []
rab5 = []
rab10 = []
rab15 = []
rab20 = []

rres1 = []
rres3 = []
rres5 = []
rres10 = []
rres15 = []
rres20 = []

rens1 = []
rens3 = []
rens5 = []
rens10 = []
rens15 = []
rens20 = []

rbest = []

for i in range(10):

    x0tr, x0t, x1tr, x1t, x2tr, x2t, x3tr, x3t, x4tr, x4t, x5tr, x5t, x6tr, x6t, ytr, yt = train_test_split(
        fc, aps1, aps3, aps5, aps10, aps15, aps20, sex, stratify=sex, train_size=0.8)
    
    x1atr, x1at = get_res(x0tr, x0t, x1tr, x1t)
    x2atr, x2at = get_res(x0tr, x0t, x2tr, x2t)
    x3atr, x3at = get_res(x0tr, x0t, x3tr, x3t)
    x4atr, x4at = get_res(x0tr, x0t, x4tr, x4t)
    x5atr, x5at = get_res(x0tr, x0t, x5tr, x5t)
    x6atr, x6at = get_res(x0tr, x0t, x6tr, x6t)

    predict(x0tr, x0t, ytr, yt, rfc)
    
    p1 = predict(x1tr, x1t, ytr, yt, rab1)
    p2 = predict(x2tr, x2t, ytr, yt, rab3)
    p3 = predict(x3tr, x3t, ytr, yt, rab5)
    p4 = predict(x4tr, x4t, ytr, yt, rab10)
    p5 = predict(x5tr, x5t, ytr, yt, rab15)
    p6 = predict(x6tr, x6t, ytr, yt, rab20)
    
    p1a = predict(x1atr, x1at, ytr, yt, rres1)
    p2a = predict(x2atr, x2at, ytr, yt, rres3)
    p3a = predict(x3atr, x3at, ytr, yt, rres5)
    p4a = predict(x4atr, x4at, ytr, yt, rres10)
    p5a = predict(x5atr, x5at, ytr, yt, rres15)
    p6a = predict(x6atr, x6at, ytr, yt, rres20)
    
    combine(yt, p1, p1a, rens1)
    combine(yt, p2, p2a, rens3)
    combine(yt, p3, p3a, rens5)
    combine(yt, p4, p4a, rens10)
    combine(yt, p5, p5a, rens15)
    combine(yt, p6, p6a, rens20)

    combine(yt, p6, p1a, rbest)
    
    print('---')
    
print(np.mean(rfc), np.std(rfc))
print(np.mean(rbest), np.std(rbest))

print(np.mean(rab1), np.std(rab1))
print(np.mean(rab3), np.std(rab3))
print(np.mean(rab5), np.std(rab5))
print(np.mean(rab10), np.std(rab10))
print(np.mean(rab15), np.std(rab15))
print(np.mean(rab20), np.std(rab20))

print(np.mean(rres1), np.std(rres1))
print(np.mean(rres3), np.std(rres3))
print(np.mean(rres5), np.std(rres5))
print(np.mean(rres10), np.std(rres10))
print(np.mean(rres15), np.std(rres15))
print(np.mean(rres20), np.std(rres20))

print(np.mean(rens1), np.std(rens1))
print(np.mean(rens3), np.std(rens3))
print(np.mean(rens5), np.std(rens5))
print(np.mean(rens10), np.std(rens10))
print(np.mean(rens15), np.std(rens15))
print(np.mean(rens20), np.std(rens20))

0.9420463385760189
0.7366886521224421
0.7154540842212074
0.8994892609504481
0.9452731270082868
0.947945205479452
0.9451243023845763
0.6974395399966177
0.7520987654320987
0.7483713850837139
0.7544529003889734
0.7334415694233046
0.7302350752579063
0.7315542026044308
0.7758836462032809
0.8596313208185354
0.9231117875866734
0.9304447826822256
0.9392186707255201
0.8880297649247421
---
0.9486622695755116
0.7357957043801793
0.8791002875021139
0.9140267207847117
0.9248165060037206
0.9473431422289871
0.9503263994588195
0.7053948926095045
0.7157179096905124
0.758836462032809
0.7543784880771182
0.7466260781329275
0.7209267715203789
0.7339556908506679
0.8174260104853712
0.8862912227295789
0.9073296127177406
0.9401792660240149
0.9430610519194995
0.8852697446304751
---
0.9439607644173853
0.7140402502959581
0.8544089294774226
0.8990630813461864
0.9264400473532893
0.9424792829359039
0.9463825469304922
0.6568104177236598
0.7136952477591747
0.7392727887705056
0.7899205141214272
0.7287468290208016
0.7225