In [1]:
import pickle
from natsort import natsorted

basedir = '../../ImageNomer/data/anton/cohorts/BSNIP'
demoname = f'{basedir}/demographics.pkl'

with open(demoname, 'rb') as f:
    demo = pickle.load(f)
    
subs = natsorted(list(demo['Age_cal'].keys()))
print(len(subs))

1244


In [2]:
import numpy as np

task = 'unk'
x = []
y = []

for sub in subs:
    if demo['DXGROUP_1'][sub] not in ['SZP']:
        continue
    p = np.load(f'{basedir}/fc/{sub}_task-{task}_fc.npy')
    x.append(p)
    y.append(demo['DXGROUP_1'][sub] == 'SZP')
#     y.append(int(demo['Age_cal'][sub]))
#     y.append(demo['sex'][sub] == 'M')
    
x = np.stack(x)
y = np.array(y).astype('int')

print(x.shape)
print(y.shape)
print(y[0:5])

(199, 34716)
(199,)
[1 1 1 1 1]


In [43]:
from sklearn.model_selection import train_test_split

import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F

class Basis(nn.Module):
    def __init__(self, nsub):
        super(Basis, self).__init__()
        self.A = nn.Parameter(torch.randn(3,264,2).float().cuda()) 
        self.m = nn.Parameter(torch.zeros(nsub).float().cuda())
        self.w = nn.Parameter(torch.zeros(3,nsub).float().cuda())
        self.s = nn.Sigmoid()
        
    def get_mask_loss(self):
        m = self.s(self.m)
        loss = -torch.mean(m*torch.log(m)+(1-m)*torch.log(1-m))
        return loss
    
    def get_similarity_loss(self):
        A1 = self.compute(1)
        A2 = self.compute(2)
        return torch.abs(torch.sum(A1*A2))
        
    def compute(self, group):
        A = self.A[group]
        A = A@A.T
        A = A/torch.linalg.norm(A)
        a,b = torch.triu_indices(264,264,offset=1)
        return A[a,b]
    
    def recon(self, apply_mask):
        m = self.s(self.m)
        w = self.w
        A0 = self.compute(0)
        A1 = self.compute(1)
        A2 = self.compute(2)
        x = [
            A0.unsqueeze(0)*w[0].unsqueeze(1),
            A1.unsqueeze(0)*w[1].unsqueeze(1)*m.unsqueeze(1),
            A2.unsqueeze(0)*w[2].unsqueeze(1)*(1-m).unsqueeze(1)
        ]
        return sum(x)
    
# xtr, xt, ytr, yt = train_test_split(x, y, stratify=y, train_size=0.8)

# xtr = x[np.where(y == 1)[0]]
# print(xtr.shape)

res = []

for _ in range(20):
    xtr = x

    xtr = torch.from_numpy(xtr).float().cuda()
    # xt = torch.from_numpy(xt).float().cuda()
    # ytr = torch.from_numpy(ytr).float().cuda()
    # yt = torch.from_numpy(yt).float().cuda()

    # mu = torch.mean(ytr)
    # ytr = ytr-mu
    # yt = yt-mu

    basis = Basis(xtr.shape[0])
    optim = torch.optim.Adam(basis.parameters(), lr=1e-1, weight_decay=0)

    nepochs = 1000
    pperiod = 100

    def rmse(a,b):
        return torch.mean((a-b)**2)**0.5

    for epoch in range(nepochs):
        optim.zero_grad()
        xhat = basis.recon(epoch > 100)
        rloss = rmse(xtr, xhat)
        mloss = 0.1*basis.get_mask_loss() 
        sloss = 0.1*basis.get_similarity_loss()
        (mloss+rloss+sloss).backward()
        optim.step()
        if epoch % pperiod == 0 or epoch == nepochs:
            print(f'{epoch} {float(rloss)} {float(mloss)} {float(sloss)}')

    print('Complete')

    res.append(basis.s(basis.m).detach().cpu().numpy())

0 0.30041778087615967 0.06931471079587936 5.2317696827230975e-05
100 0.27650895714759827 0.0016299727139994502 0.00016258323739748448
200 0.2583323121070862 0.000595811870880425 1.2642681213037577e-05
300 0.24686454236507416 0.0003258798096794635 2.4507649868610315e-05
400 0.2390260100364685 0.0002105140156345442 0.00010135425691260025
500 0.23380713164806366 0.0001492653100285679 9.912601854011882e-06
600 0.2303261160850525 0.00011229042866034433 5.297432835504878e-06
700 0.22801250219345093 8.801214426057413e-05 6.346812733681872e-05
800 0.22647391259670258 7.108746649464592e-05 2.5211867978214286e-05
900 0.22545738518238068 5.873909321962856e-05 6.708497039653594e-06
Complete
0 0.30041778087615967 0.06931471079587936 0.0002032474149018526
100 0.2733308672904968 0.001649963902309537 3.424021997489035e-05
200 0.25229761004447937 0.0006095599965192378 9.37818822421832e-06
300 0.23649589717388153 0.0003344826400279999 1.8237487893202342e-05
400 0.2249564230442047 0.00021605254733003676 

KeyboardInterrupt: 

In [48]:
res[0]

array([4.9205755e-05, 4.6003031e-05, 9.9995482e-01, 5.0445648e-05,
       9.9995601e-01, 9.9995506e-01, 9.9995601e-01, 9.9995852e-01,
       4.7505637e-05, 9.9995518e-01, 9.9995494e-01, 9.9995625e-01,
       9.9995804e-01, 9.9995470e-01, 9.9995852e-01, 9.9995804e-01,
       9.9995697e-01, 9.9995530e-01, 9.9995673e-01, 9.9995673e-01,
       4.7431400e-05, 9.9996006e-01, 9.9995589e-01, 4.6634577e-05,
       9.9995494e-01, 4.7426649e-05, 9.9995506e-01, 9.9995613e-01,
       9.9995470e-01, 9.9995697e-01, 9.9995553e-01, 9.9995530e-01,
       9.9995482e-01, 9.9995840e-01, 9.9995911e-01, 9.9995577e-01,
       4.6670571e-05, 9.9995518e-01, 9.9995470e-01, 9.9995530e-01,
       9.9995458e-01, 9.9996114e-01, 9.9995685e-01, 4.6979032e-05,
       9.9995530e-01, 9.9995863e-01, 9.9996328e-01, 9.9995482e-01,
       9.9995506e-01, 4.7836224e-05, 4.8340848e-05, 9.9995530e-01,
       9.9995875e-01, 4.8886304e-05, 9.9995589e-01, 4.7244779e-05,
       9.9995863e-01, 4.6509911e-05, 9.9995530e-01, 4.6765908e