In [1]:
import pickle
from natsort import natsorted

basedir = '../../ImageNomer/data/anton/cohorts/BSNIP'
demoname = f'{basedir}/demographics.pkl'

with open(demoname, 'rb') as f:
    demo = pickle.load(f)
    
subs = natsorted(list(demo['Age_cal'].keys()))
print(len(subs))

1244


In [2]:
import numpy as np

task = 'unk'
x = []
y = []
sex = []
race = []
lookup = dict(SZP=0, SZR=1, SADP=2, SADR=3, BPP=4, BPR=5, NC=6)
lookup_race = dict(AA=-1, CA=1)
lookup_sex = {'s1.0': -1, 's2.0': 1}

for sub in subs:
    if demo['DXGROUP_2'][sub] not in lookup.keys():
        continue
    p = np.load(f'{basedir}/fc/{sub}_task-{task}_fc.npy')
    x.append(p)
    y.append(lookup[demo['DXGROUP_2'][sub]])
    r = demo['Race'][sub]
    r = lookup_race[r] if r in lookup_race else 0
    race.append(r)
    s = demo['sex'][sub]
    s = lookup_sex[s] if s in lookup_sex else 0
    sex.append(s)
    
x = np.stack(x)
y = np.array(y).astype('int')
race = np.array(race).astype('int')
sex = np.array(sex).astype('int')

print(x.shape)
print(y.shape)
print(y)

(1244, 34716)
(1244,)
[0 0 6 ... 1 1 3]


In [50]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

def cat(lst):
    return np.concatenate(lst)

def rand_idcs(n):
    return np.random.permutation(n)   

def get_samples(x, y, r, s, group):
    return x[y == group], r[y == group], s[y == group]

def split(x0, r0, s0, x1, r1, s1, i0, i1, ntr, nt):
    res_tr = []
    res_t = []
    for a,b in zip([x0, r0, s0], [x1, r1, s1]):
        tr = cat([a[i0][:ntr], b[i1][:ntr]])
        t = cat([a[i0][ntr:], b[i1][ntr:]])
        res_tr.append(tr)
        res_t.append(t)
    res_tr.append(cat([np.zeros(ntr), np.ones(ntr)]))
    res_t.append(cat([np.zeros(nt), np.ones(nt)]))
    return res_tr, res_t

def get_logits(y, ysel, grpsel):
    y = y.detach().cpu().numpy()
    return y[np.all([ysel, grpsel], axis=0)]

accs = []

class Logistic(nn.Module):
    def __init__(self, dim):
        super(Logistic, self).__init__()
        self.fc = nn.Linear(dim,2).float().cuda()
        
    def reg_loss(self, C):
        return rmse(self.fc.weight, 0)
        
    def forward(self, x):
        return F.softmax(self.fc(x).squeeze(), dim=-1)

x0, r0, s0 = get_samples(x, y, race, sex, 0)
x6, r6, s6 = get_samples(x, y, race, sex, 6)

n = 150
ntr = 120
aa = []
bb = []

for _ in range(50):
    i0 = rand_idcs(n)
    i1 = rand_idcs(n)

    tr, t = split(x0, r0, s0, x6, r6, s6, i0, i1, ntr, n-ntr)
    
    xtr, xt, ytr = [torch.from_numpy(a).float().cuda() for a in [tr[0], t[0], tr[-1]]]
    yt = t[-1]
    ytr = F.one_hot(ytr.long()).float()
    
    clf = Logistic(xtr.shape[-1])
    optim = torch.optim.Adam(clf.parameters(), lr=1e-4, weight_decay=1e-3)

    ce = nn.CrossEntropyLoss()

    for epoch in range(1000):
        optim.zero_grad()
        yhat = clf(xtr)
        loss = ce(yhat, ytr)
        loss.backward()
        optim.step()
#         if epoch % 200 == 0:
#             print(f'{float(loss)}')

    # Training
    yhat = clf(xt)
    yhat2 = torch.argmax(yhat, axis=-1)
    yhat2 = yhat2.detach().cpu().numpy()
    acc = np.mean(yhat2 == yt)
    log = get_logits(yhat, t[-1] == 1, t[1] == 1)
    a = np.mean(log[:,0])
    log = get_logits(yhat, t[-1] == 1, t[1] == -1)
    b = np.mean(log[:,0])
    print(a,b)
    print(acc)
    accs.append(acc)
    aa.append(a)
    bb.append(b)
    
print('---')
print(np.mean(accs))
print(np.std(accs))
print(np.mean(aa), np.mean(bb))
print(np.std(aa), np.std(bb))

0.27258912 0.39437982
0.6166666666666667
0.27656633 0.34406817
0.65
0.28730083 0.47188306
0.6666666666666666
0.19901669 0.45719406
0.75
0.23907362 0.1647722
0.7
0.18535806 0.3099959
0.7
0.33231685 0.43909588
0.7
0.37259758 0.14935902
0.6
0.27438286 0.22343542
0.75
0.20783554 0.34143847
0.7833333333333333
0.31959328 0.5498816
0.7333333333333333
0.3633684 0.32986626
0.7166666666666667
0.20636982 0.065543145
0.8833333333333333
0.28587162 0.33762532
0.6
0.52371746 0.4200277
0.7
0.49378258 0.38847408
0.65
0.3026185 0.6254502
0.6833333333333333
0.27507275 0.30769762
0.75
0.27232447 0.07331016
0.6666666666666666
0.18075135 0.34506792
0.7333333333333333
0.38937417 0.25230354
0.65
0.38300297 0.2743231
0.75
0.3073355 0.37855646
0.7333333333333333
0.27272272 0.29405078
0.7333333333333333
0.1201901 0.42178914
0.8
0.20501196 0.4042573
0.7333333333333333
0.16785328 0.2425141
0.7666666666666667
0.20344013 0.35566637
0.7833333333333333
0.25004265 0.27412733
0.8
0.25630873 0.21490826
0.7333333333333333