In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import random

class VAE(nn.Module):
    def __init__(self, fcd, ld, matsz, rank):
        super(VAE, self).__init__()
        self.fcd = fcd
        self.ld = ld
        self.matsz = matsz
        self.rank = rank
        self.enc1 = nn.Linear(fcd, 1000).float().cuda()
        self.enc2 = nn.Linear(1000, ld).float().cuda()
        self.dec1 = nn.Linear(ld+6, 1000).float().cuda()
        self.dec2 = nn.Linear(1000, matsz*rank).float().cuda()

    def enc(self, x):
        x = F.relu(self.enc1(x))
        z = self.enc2(x)
        return z

    def gen(self, n):
        return torch.randn(n, self.ld).float().cuda()/(10**0.5)

    def dec(self, z, age, sex, race, rest, nback, emoid):
        z = torch.cat([z, age.unsqueeze(1), sex.unsqueeze(1), race.unsqueeze(1), 
                       rest.unsqueeze(1), nback.unsqueeze(1), emoid.unsqueeze(1)], dim=1)
        x = F.relu(self.dec1(z))
        x = self.dec2(x)
        x = x.reshape(len(z), self.matsz, self.rank)
        x = torch.einsum('abc,adc->abd', x, x)
        return x

    def vectorize(self, x):
        a,b = np.triu_indices(self.matsz, 1)
        return x[:,a,b]

def rmse(a, b, mean=torch.mean):
    return mean((a-b)**2)**0.5

def pretty(x):
    return f'{round(float(x), 4)}'

vae = VAE(34716, 30, 264, 5)
vae.load_state_dict(torch.load('/home/anton/Documents/Tulane/Research/ImageNomer/data/PNC/vae_1000_z30_cov6_264_rank5.torch'))
vae.eval()

print('Done')

Done


In [2]:
ours2orig = [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 254, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 85,
86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103,
104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 136, 138, 132,
133, 134, 135, 220, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152,
153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 185, 186,
187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201,
202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216,
217, 218, 219, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
233, 137, 234, 235, 236, 237, 238, 239, 240, 241, 250, 251, 255, 256, 257,
258, 259, 260, 261, 262, 263, 242, 243, 244, 245, 0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 83, 84, 131, 139, 140, 141, 181, 182, 183, 184, 246, 247, 248,
249, 252, 253]

def vec2mat(v):
    a,b = np.triu_indices(264,1)
    m = np.zeros((264,264))
    m[a,b] = v
    return m+m.T

def remap(fc, roimap=ours2orig):
    fc = fc[roimap,:]
    fc = fc[:,roimap]
    return fc

print('Complete')

Complete


In [3]:
# Load FC

import pickle
import numpy as np

pncdir = '/home/anton/Documents/Tulane/Research/ImageNomer/data/anton/cohorts/PNC/'
demo = pickle.load(open(f'{pncdir}/demographics.pkl', 'rb'))

rest = []
nback = []
emoid = []
race = []
sex = []
age = []
wrat = []
subids = []

a,b = np.triu_indices(264,1)

for sub in demo['age_at_cnb']:
    try:
        ra = demo['Race'][sub]
        ag = demo['age_at_cnb'][sub]
        se = demo['Sex'][sub]
        if ra not in ['AA', 'EA']:
            continue
        ra = ra == 'AA'
        se = se == 'M'
        r = np.load(f'{pncdir}/fc/{sub}_task-rest_fc.npy')
        n = np.load(f'{pncdir}/fc/{sub}_task-nback_fc.npy')
        e = np.load(f'{pncdir}/fc/{sub}_task-emoid_fc.npy')
        race.append(ra)
        sex.append(se)
        age.append(ag)
        rest.append(remap(vec2mat(r))[a,b])
        nback.append(remap(vec2mat(n))[a,b])
        emoid.append(remap(vec2mat(e))[a,b])
        subids.append(sub)
    except:
        pass

rest = np.stack(rest)
nback = np.stack(nback)
emoid = np.stack(emoid)
race = np.array(race).astype('int')
sex = np.array(sex).astype('int')
age = np.array(age)

# age = (age - np.mean(age)) / np.std(age)

print([a.shape for a in [rest, nback, emoid, race, sex, age]])

[(1193, 34716), (1193, 34716), (1193, 34716), (1193,), (1193,), (1193,)]


In [4]:
rest_t = torch.from_numpy(rest).float().cuda()
nback_t = torch.from_numpy(nback).float().cuda()
emoid_t = torch.from_numpy(emoid).float().cuda()

with torch.no_grad():
    zr = vae.enc(rest_t).detach().cpu().numpy()
    zn = vae.enc(nback_t).detach().cpu().numpy()
    ze = vae.enc(emoid_t).detach().cpu().numpy()

print('Done')

Done


In [23]:
import scipy.stats as stats

def corr(z, y):
    z = z - np.mean(z, axis=0, keepdims=True)
    y = y - np.mean(y)
    xx = np.einsum('na,na->a', z, z)
    yy = np.einsum('n,n->', y, y)
    xy = np.einsum('na,n->a', z, y)
    rho = xy/((xx*yy)**0.5)
    n = z.shape[0]
    m = z.shape[1]
    df = n-2
    t = rho*(df/(1-rho**2))**0.5
    t[t < 0] = -t[t < 0]
    p = (1-stats.t.cdf(t, df))*2
    p *= m
    p[p > 1] = 1
    return rho, p

def to_cat(y):
    cats = set()
    for yy in y:
        cats.add(yy)
    cats = list(cats)
    yp = np.zeros(len(y))
    for j,yy in enumerate(y):
        for i,cat in enumerate(cats):
            if yy == cat:
                yp[j] = i
    return yp

n = 0
m = 0
for key in demo.keys():
    z = []
    y = []
    for i,sub in enumerate(subids):
        if sub in demo[key]:
            z.append(rest[i])
            y.append(demo[key][sub])
    if len(z) == 0:
        continue
    n += 1
    nn = len(z)
    z = np.stack(z)
    y = np.array(y)
    if isinstance(y[0], str):
        y = to_cat(y)
    rho, p = corr(z, y)
    if np.any(p < 0.1):
        m += 1
        print(key, 'yes', min(p), nn)
print(n, m)

Race yes 0.0 1193
Sex yes 0.0 1193
age_at_cnb yes 0.0 1193
battery_valid yes 0.09239512441931996 1193
PADT_GENUS yes 0.0014393818898970068 1193
PADT_A yes 1.736640370175735e-06 1181
PADT_SAME_CR yes 0.0002654517811242485 1181
PADT_PC yes 1.1487824131961588e-06 1181
PADT_SAME_PC yes 0.030696678943779254 1181


  rho = xy/((xx*yy)**0.5)


PFMT_TP yes 0.02984431503097884 1189
PFMT_TN yes 0.018339190071529288 1189
PFMT_FP yes 0.018339190071529288 1189
PFMT_FN yes 0.02984431503097884 1189
PFMT_TPRT yes 0.004802570203758627 1189
PFMT_TNRT yes 0.09696825683297572 1186
PFMT_IFAC_TOT yes 0.00010057988248668437 1189
PFMT_IFAC_RTC yes 0.03324954877499664 1189
PEIT_GENUS yes 0.028838423796607948 1193
PEIT_CRT yes 0.01909942636472728 1193
PEITSAD yes 0.04159262360175653 1193
PEITHAPRT yes 0.0011952481287966776 1193
PWMT_TPRT yes 3.450162176488192e-05 1190
PWMT_TNRT yes 0.00410338601311544 1187
PWMT_KIWRD_TOT yes 0.052002551504052974 1190
PWMT_KIWRD_RTC yes 3.900932931344414e-05 1190
PVRT_CR yes 0.0 1188
PVRT_RTCR yes 7.63755686188361e-05 1187
PEDT_GENUS yes 0.0014393818898970068 1193
PEDT_HAP_CR yes 0.00018457016228445866 1187
PEDT_SAD_CR yes 0.0011532987465567857 1187
PEDT_ANG_CR yes 8.116702607097181e-05 1187
PEDT_FEAR_CR yes 0.002935194189203294 1187
PEDT_SAME_CR yes 0.0015971001073538105 1187
PEDT_A yes 0.00021468126883394945 

In [46]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge, LogisticRegression

for alpha in [0.01, 0.1,1,10,100,1000]:
    errs = []
    nulls = []

    for _ in range(10):
        xtr, xt, ytr, yt = train_test_split(ze, race, stratify=race, train_size=0.8)
        
        # reg = Ridge(alpha=alpha).fit(xtr, ytr)
        reg = LogisticRegression(C=alpha, max_iter=100).fit(xtr, ytr)
        yhat = reg.predict(xt)
        err = np.mean(yt == yhat)
        null = np.mean(yt)
        if null < 0.5:
            null = 1-null
        # err = np.mean((yt-yhat)**2)**0.5
        # null = np.mean((np.mean(ytr)-yt)**2)**0.5
        errs.append(err)
        nulls.append(null)

    print(alpha, np.mean(errs), np.std(errs), np.mean(nulls), np.std(nulls))

0.01 0.5313807531380752 0.0018711865920500131 0.5313807531380752 1.1102230246251565e-16
0.1 0.4765690376569037 0.025958288484093893 0.5313807531380752 1.1102230246251565e-16
1 0.4422594142259414 0.027630907081337336 0.5313807531380752 1.1102230246251565e-16
10 0.4577405857740586 0.03471542847312393 0.5313807531380752 1.1102230246251565e-16
100 0.4543933054393305 0.02207683005568689 0.5313807531380752 1.1102230246251565e-16
1000 0.43347280334728033 0.020600056276303222 0.5313807531380752 1.1102230246251565e-16
