In [2]:
# Bring in (new) PNC cohort

import pickle
from pathlib import Path
import numpy as np

newdir = '/home/anton/Documents/Tulane/Research/ImageNomer/data/anton/cohorts/BSNIP'
newdemo = pickle.load(open(f'{newdir}/demographics.pkl', 'rb'))

newfc = dict()
age = []
sex = []
race = []

for sub in newdemo['Age_cal']:
    f = f'{newdir}/fc/{sub}_task-unk_fc.npy'
    if not Path(f).exists():
        continue
    if sub not in newdemo['Race'] or sub not in newdemo['sex'] or newdemo['Race'][sub] not in ['AA', 'CA']:
        continue
    p = np.load(f)
    a = newdemo['Age_cal'][sub]
    s = newdemo['sex'][sub] == 's1.0'
    r = newdemo['Race'][sub] == 'AA'
    newfc[f'{sub}'] = p
    age.append(a)
    sex.append(s)
    race.append(r)
        
age = np.array(age)
sex = np.array(sex)
race = np.array(race)

print(age.shape)
print(np.mean(age))
print(np.mean(sex))
print(np.mean(race))
print(len(newfc))

(1165,)
38.49871244635193
0.42660944206008583
0.3321888412017167
1165


In [3]:
# Angle estimate

import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

pi = 3.14
pi2 = 2*pi

def tonp(t):
    return t.detach().cpu().numpy()

class AngleBasis(nn.Module):
    def __init__(self, mixn):
        super(AngleBasis, self).__init__()
        self.mixn = mixn
        self.thetas = nn.Parameter((pi*torch.rand(self.mixn,264)+pi/2).float().cuda())
        self.jitter = nn.Parameter(torch.ones(self.mixn,264).float().cuda())
        
    def project(self):
        with torch.no_grad():
            self.jitter[self.jitter < 0] = 0
            self.jitter[self.jitter > 1] = 1
        
    def phases(self):
        t0 = self.thetas.unsqueeze(2)
        t1 = self.thetas.unsqueeze(1)
        return t0-t1
    
    def jit(self):
        j0 = self.jitter.unsqueeze(2)
        j1 = self.jitter.unsqueeze(1)
        return j0*j1
    
    def ps(self, jitter=True):
        t = self.phases()
        p = torch.cos(t)
        if jitter:
            j = self.jit()
            p = j*p
        return p
    
    def dump(self):
        return dict(mixn=self.mixn, thetas=tonp(self.thetas), jitter=tonp(self.jitter))
    
    def psum(self):
        return torch.mean(self.ps(), axis=0)
    
    def pvec(self):
        a,b = np.triu_indices(264,1)
        p = self.psum()
        return p[a,b]
    
print('Complete')

Complete


In [5]:
nepochs = 5000
pperiod = 500

def fromnp(x):
    return torch.from_numpy(x).float().cuda()

def rmse(yhat, y):
    if isinstance(yhat, np.ndarray):
        f = np.mean
    else:
        f = torch.mean
    return f((y-yhat)**2)**0.5

bases = dict()
i = 0

for subtask in newfc:
    x = fromnp(newfc[subtask])

    min_loss = 10
    sav = None

    for _ in range(2):
        basis = AngleBasis(5)
        optim = torch.optim.Adam(basis.parameters(), lr=1e-2, weight_decay=0)

        for e in range(nepochs):
            optim.zero_grad()
            xhat = basis.pvec()
            loss = rmse(xhat, x)
            loss.backward()
            optim.step()
            basis.project()
            if e == nepochs-1 or e % pperiod == 1:
                print(f'{e} {float(loss)}')

        if loss < min_loss:
            print('Saving')
            min_loss = float(loss)
            sav = basis
    
    bases[subtask] = sav.dump()
    i += 1
    
    print(f'Finished {i} {subtask}')
        
print('Complete')

1 0.38084694743156433
501 0.07718787342309952
1001 0.07378187030553818
1501 0.07312768697738647


KeyboardInterrupt: 

In [25]:
import pickle

pickle.dump(bases, open('/home/anton/Documents/Tulane/Research/BSNIP/AngleBasisLong5.pkl', 'wb'))

print('Complete')

Complete
