In [124]:
import pickle
import numpy as np

bsnipdir = '/home/anton/Documents/Tulane/Research/ImageNomer/data/anton/cohorts/BSNIP/'
bsnipdemo = pickle.load(open(f'{bsnipdir}/demographics.pkl', 'rb'))
bsniptopdir = '/home/anton/Documents/Tulane/Research/Work/ContrastiveLearning/BSNIP/'

fc = []
aps20 = []
aps15 = []
aps10 = []
aps5 = []
aps3 = []
aps1 = []
age = []
sex = []
race = []
sz = []

for sub in bsnipdemo['Age_cal']:
    try:
        a = bsnipdemo['Age_cal'][sub]
        s = bsnipdemo['sex'][sub]
        r = bsnipdemo['Race'][sub]
        d = bsnipdemo['DXGROUP_1'][sub]
        if d not in ['NC', 'SZP']:
            continue
        if r not in ['AA', 'CA']:
            continue
        s = s == 's1.0'
        r = r == 'AA'
        d = d == 'SZP'
        age.append(a)
        sex.append(s)
        race.append(r)
        sz.append(d)
        fc.append(np.load(f'{bsnipdir}/fc/{sub}_task-unk_fc.npy'))
        aps20.append(np.load(f'{bsniptopdir}/Top20/{sub}_task-unktop20_fc.npy'))
        aps3.append(np.load(f'{bsniptopdir}/Top3/{sub}_task-unktop3_fc.npy'))
        aps1.append(np.load(f'{bsniptopdir}/Top1/{sub}_task-unktop1_fc.npy'))
    except:
        pass

age = np.stack(age)
sex = np.stack(sex).astype('int')
race = np.stack(race).astype('int')
sz = np.stack(sz).astype('int')
fc = np.stack(fc)
aps20 = np.stack(aps20)
aps1 = np.stack(aps1)

print([x.shape for x in [age, sex, race, sz, fc, aps20, aps1]])

[(405,), (405,), (405,), (405,), (405, 34716), (405, 34716), (405, 34716)]


In [131]:
import torch
import torch.nn as nn
import torch.nn.functional as Func

def decim(v):
    return '{:.2f}'.format(float(v))

# Assume xb has already had first component removed
def make_pos_samples(xb, aps1):
    x = []
    xpos = []
    for n in range(5):
        idcs = np.random.permutation(len(xb))
        xpart = xb+aps1[idcs]
        x.append(xb)
        xpos.append(xpart)
    x = torch.cat(x)
    xpos = torch.cat(xpos)
    return x, xpos

# First component has not been removed
def make_neg_samples(x):
    xneg = []
    for n in range(5):
        idcs = np.random.permutation(len(x))
        bad = np.arange(len(x))
        idcs[idcs == bad] -= 1
        idcs[idcs < 0] += 2
        xpart = x[idcs]
        xneg.append(xpart)
    xneg = torch.cat(xneg)
    return xneg

def make_samples(xaps20, xaps1):
    x = [f(xaps20)]
    xpos = xaps20-xaps1
    idcs = np.random.permutation(len(xaps1))
    xsamp = [f(xpos+xaps1[idcs])]
    for n in range(5):
        idcs = np.random.permutation(len(xaps20))
        bad = np.arange(len(xaps20))
        idcs[idcs == bad] -= 1
        idcs[idcs < 0] += 2
        xsamp.append(f(xaps20[idcs]))
        x.append(f(xaps20))
    return x,xsamp
    
xaps20 = torch.from_numpy(aps20).float().cuda()
xaps1 = torch.from_numpy(aps1).float().cuda()
y = torch.zeros(6).float().cuda()
y[0] = 1

class F(nn.Module):
    def __init__(self):
        super(F, self).__init__()
        self.fc1 = nn.Linear(34716,20).float().cuda()

    def forward(self, x):
        return self.fc1(x).squeeze()

class Same(nn.Module):
    def __init__(self):
        super(Same, self).__init__()
        self.fc1 = nn.Linear(20*2,50).float().cuda()
        self.fc2 = nn.Linear(50,2).float().cuda()

    def forward(self, x1, x2):
        x = torch.cat([x1, x2], axis=1)
        x = Func.relu(self.fc1(x))
        x = self.fc2(x)
        return x.squeeze()

f = F()
same = Same()
optim = torch.optim.Adam(list(f.parameters()) + list(same.parameters()), lr=1e-4, weight_decay=0)
ce = nn.CrossEntropyLoss()

nepochs = 1000
pperiod = 20

for e in range(nepochs):
    optim.zero_grad()
    x, xsamp = make_samples(xaps20, xaps1)
    yhat = []
    for i in range(6):
        yhat.append(torch.sum(x[i]*xsamp[i]))
    yhat = torch.stack(yhat)
    yhat = Func.softmax(yhat, dim=0)
    loss = ce(yhat, y)
    loss.backward()
    optim.step()
    if e % pperiod == 0 or e == nepochs-1:
        print(f'{e} {decim(loss)}')

print('Done')

0 1.04
20 1.04
40 1.04
60 1.04
80 1.04
100 1.04
120 1.04
140 1.04
160 1.04
180 1.04
200 1.04
220 1.04
240 1.04
260 1.04
280 1.04


KeyboardInterrupt: 

In [123]:
# Train using features

from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge, LogisticRegression

def rmse(yhat, y):
    return np.mean((yhat-y)**2)**0.5

def make_pos_samples(xb, aps1, yb):
    x = []
    xpos = []
    y = []
    for n in range(1):
        idcs = np.random.permutation(len(aps1))
        xpart = xb+aps1[idcs][:len(xb)]
        x.append(xb)
        xpos.append(xpart)
        y.append(yb)
    x = np.concatenate(x)
    xpos = np.concatenate(xpos)
    y = np.concatenate(y)
    return x, xpos, y

losses = []

for _ in range(20):
    x1tr, x1t, x20tr, x20t, ytr, yt = train_test_split(aps1, fc, age, train_size=0.8)

    _, xtr, ytr = make_pos_samples(x20tr-x1tr, x1tr, ytr)
    # _, xt, yt = make_pos_samples(x20t-x1t, x1tr, yt)
    xt, yt = x20t, yt
    
    # with torch.no_grad():
    #     x1tr = torch.from_numpy(x1tr).float().cuda()
    #     x1t = torch.from_numpy(x1t).float().cuda()
    #     x20tr = torch.from_numpy(x20tr).float().cuda()
    #     x20t = torch.from_numpy(x20t).float().cuda()

    #     aa, bb = make_pos_samples(x20tr-x1tr, x1tr)
    #     bb = bb.detach().cpu().numpy()
        
    reg = Ridge(alpha=1).fit(xtr, ytr)
    yhat = reg.predict(xt)
    loss = rmse(yhat, yt)
    null = rmse(np.mean(yt), yt)
    print(loss, null)
    losses.append(loss)

    # reg = LogisticRegression(C=1, max_iter=300).fit(xtr, ytr)
    # yhat = reg.predict(xt)
    # loss = np.mean(yhat == yt)
    # null = np.mean(yt)
    # if null < 0.5:
    #     null = 1-null
    # print(loss, null)
    # losses.append(loss)

print(np.mean(losses))
    

10.583522000308596 12.472301272811706
10.250557171784562 12.36864530491759
8.830976716847966 11.708404015264717
10.703622099723177 12.612076430123894
10.766776331953832 11.96929975247291
10.639522709654642 11.100488749605114
11.211795314966448 12.40845766699171
11.118532887069366 11.939261005638839
10.647001176635744 11.329432610951196
11.08030136403556 12.420612099984107
10.156131023177883 11.903603575498023
10.918270747328005 12.535941317740809
10.191922407513177 12.27606298038708
10.533554709972426 13.216716816748162
11.200853042776817 12.157178215661418
10.513380404240316 11.240169543659828
11.911739727165163 13.669928347108817
10.25840283107512 12.16918283293998
9.799373465674545 12.840693202488678
10.61628461764953 12.718446376023431
10.596626037477645


In [11]:
ours2orig = [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 254, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 85,
86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103,
104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 136, 138, 132,
133, 134, 135, 220, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152,
153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 185, 186,
187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201,
202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216,
217, 218, 219, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
233, 137, 234, 235, 236, 237, 238, 239, 240, 241, 250, 251, 255, 256, 257,
258, 259, 260, 261, 262, 263, 242, 243, 244, 245, 0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 83, 84, 131, 139, 140, 141, 181, 182, 183, 184, 246, 247, 248,
249, 252, 253]

def vec2mat(v):
    a,b = np.triu_indices(264,1)
    m = np.zeros((264,264))
    m[a,b] = v
    return m+m.T

def remap(fc, roimap=ours2orig):
    fc = fc[roimap,:]
    fc = fc[:,roimap]
    return fc

def numpy(x):
    return x.detach().cpu().numpy()

print('Complete')

Complete
