In [1]:
import pickle
from natsort import natsorted

basedir = '../../ImageNomer/data/anton/cohorts/BSNIP'
demoname = f'{basedir}/demographics.pkl'

with open(demoname, 'rb') as f:
    demo = pickle.load(f)
    
subs = natsorted(list(demo['Age_cal'].keys()))
print(len(subs))

1244


In [6]:
import numpy as np

# y=0 NC y=1 SZ1 (130) y=2 SZ2 (69)

task = 'unk'
x = []
y = []

for sub in subs:
    if demo['DXGROUP_1'][sub] not in ['SZP', 'NC']:
        continue
    p = np.load(f'{basedir}/fc/{sub}_task-{task}_fc.npy')
    x.append(p)
    if demo['DXGROUP_1'][sub] == 'SZP':
        if demo['sz_subtype'][sub] == '1':
            y.append(1)
        elif demo['sz_subtype'][sub] == '2':
            y.append(2)
        else:
            print('Bad sz_subtype')
            raise 'Bad'
    else:
        y.append(0)
    
x = np.stack(x)
y = np.array(y).astype('int')

print(x.shape)
print(y.shape)
print(y[0:5])

(441, 34716)
(441,)
[2 1 0 1 2]


In [15]:
bases = dict()
print(bases)

{}


In [32]:
from sklearn.decomposition import PCA

basis_type = 0

pca = PCA(n_components=20).fit(x[np.where(y == basis_type)[0]])
print(len(pca.components_))
print(sum(pca.explained_variance_ratio_))

bases[basis_type] = pca

20
0.4288267852657342


In [36]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

basis_type = 2
xx = bases[basis_type].transform(x[np.where(y != 2)[0]])
yy = y[np.where(y != 2)[0]]
yy = yy.astype('int')

accs = []

for _ in range(20):
    xtr, xt, ytr, yt = train_test_split(xx, yy, stratify=yy, train_size=0.8)

    lr = LogisticRegression(max_iter=1000).fit(xtr, ytr)
    yhat = lr.predict(xt)
    acc = np.mean(yhat == yt)
    print(acc)
    
    accs.append(acc)
    
print('---')
print(np.mean(accs))
print(np.std(accs))

0.6933333333333334
0.7733333333333333
0.7333333333333333
0.72
0.7733333333333333
0.7466666666666667
0.7066666666666667
0.7866666666666666
0.7733333333333333
0.76
0.7866666666666666
0.7733333333333333
0.7733333333333333
0.7466666666666667
0.72
0.6666666666666666
0.72
0.72
0.7333333333333333
0.84
---
0.7473333333333334
0.03852272056851644


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

basis_type = 0

class Basis(nn.Module):
    def __init__(self, dim):
        super(Basis, self).__init__()
        self.A = nn.Parameter(torch.randn(dim,264,2).float().cuda())
        
    def compute(self, dim=0):
        A = self.A[dim]
        A = A@A.T
        A = A/torch.linalg.norm(A)
        a,b = torch.triu_indices(264,264,offset=1)
        return A[a,b]
    
    def to_img(self, dim=0):
        A = self.A[dim]
        A = A@A.T
        A = A/torch.linalg.norm(A)
        return A.detach().cpu().numpy()
    
    def scramble(self, dim):
        with torch.no_grad():
            self.A[dim] *= 0
            self.A[dim] += torch.randn(264,2).float().cuda()
    
# xtr, xt, ytr, yt = train_test_split(x, y, stratify=y, train_size=0.8)

# xtr = torch.from_numpy(xtr).float().cuda()
# xt = torch.from_numpy(xt).float().cuda()
# ytr = torch.from_numpy(ytr).float().cuda()
# yt = torch.from_numpy(yt).float().cuda()

xtr = torch.from_numpy(x).float().cuda()

# mu = torch.mean(ytr)
# ytr = ytr-mu
# yt = yt-mu
    
w = nn.Parameter(torch.randn(1,xtr.shape[0]).float().cuda())
u = nn.Parameter(torch.randn(1).float().cuda())
    
basis = Basis(3)
optim = torch.optim.Adam(basis.parameters(), lr=1e-1, weight_decay=0)

nepochs = 200
pperiod = 10
eye = torch.eye(1).float().cuda()

def rmse(a,b):
    return torch.mean((a-b)**2)**0.5

for n in range(basis.A.shape[0]):
    tgt = xtr
    cur = None
    print(f'Cur residual')
    with torch.no_grad():
        for m in range(n):
            A = basis.compute(m)
            A = A.unsqueeze(1).detach()
            w,_,_,_ = torch.linalg.lstsq(A.T@A+1e-3*eye, A.T@tgt.T)
            xhat = (A@w).T
            tgt = tgt - xhat
            cur = cur + xhat if cur is not None else xhat
            print(float(rmse(cur, xtr)))
        print(f'Fitting {n}')
    if n > 0:
        start_loss = float(rmse(cur, xtr))
    for epoch in range(nepochs):
        optim.zero_grad()
        A = basis.compute(n)
        A = A.unsqueeze(1)
        w,_,_,_ = torch.linalg.lstsq(A.T@A+1e-3*eye, A.T@tgt.T)
        xhat = (A@w).T
        rloss = rmse(tgt, xhat)
        rloss.backward()
        optim.step()
        if epoch == 0 and n == 0:
            start_loss = float(rloss)
        if epoch % 50 == 49 and float(rloss)/start_loss > 0.9995:
            print('Scrambling')
            basis.scramble(n)
        if epoch % pperiod == 0 or epoch == nepochs:
            print(f'{epoch} {float(rloss)}')
        
print('Complete')