In [1]:
# Load PNC to make VAE autoencoder

import pickle
import numpy as np

pncdir = '/home/anton/Documents/Tulane/Research/ImageNomer/data/anton/cohorts/PNC'

pncdemo = pickle.load(open(f'{pncdir}/demographics.pkl', 'rb'))

age = []
sex = []
race = []
rest = []
nback = []
emoid = []

for sub in pncdemo['age_at_cnb']:
    try:
        a = pncdemo['age_at_cnb'][sub]
        s = pncdemo['Sex'][sub]
        r = pncdemo['Race'][sub]
        if r not in ['AA', 'EA']:
            continue
        pr = np.load(f'{pncdir}/fc/{sub}_task-rest_fc.npy')
        pn = np.load(f'{pncdir}/fc/{sub}_task-nback_fc.npy')
        pe = np.load(f'{pncdir}/fc/{sub}_task-emoid_fc.npy')
        rest.append(pr)
        nback.append(pn)
        emoid.append(pe)
        age.append(a)
        sex.append(s == 'M')
        race.append(r == 'AA')
    except:
        pass

rest = np.stack(rest)
nback = np.stack(nback)
emoid = np.stack(emoid)
age = np.array(age)
sex = np.array(sex)
race = np.array(race)

print(len(rest), len(nback), len(emoid), len(age), len(sex), len(race))

1193 1193 1193 1193 1193 1193


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

nA = 1000
nB = 500
nC = 10

class FCAE(nn.Module):
    def __init__(self):
        super(FCAE, self).__init__()
        self.fwd1 = nn.Linear(34716,nA).float().cuda()
        self.fwd2 = nn.Linear(nA,nB).float().cuda()
        self.rev1 = nn.Linear(nB,nA).float().cuda()
        self.rev2 = nn.Linear(nA,34716).float().cuda()

    def fwd(self, x):
        x = F.relu(self.fwd1(x))
        x = self.fwd2(x)
        return x

    def rev(self, x):
        x = F.relu(self.rev1(x))
        x = self.rev2(x)
        return x

    def forward(self, x):
        z = self.fwd(x)
        x = self.rev(z)
        return x, z

class CatSim(nn.Module):
    def __init__(self):
        super(CatSim, self).__init__()
        self.fc1 = nn.Linear(nB,nC).float().cuda()
        self.fc2 = nn.Linear(2*nC,4).float().cuda()

    def forward(self, x, y):
        x = F.relu(self.fc1(x))
        y = F.relu(self.fc1(y))
        xy = torch.cat([x,y], dim=-1)
        xy = self.fc2(xy)
        return xy

class RegSim(nn.Module):
    def __init__(self):
        super(RegSim, self).__init__()
        self.fc1 = nn.Linear(nB,nC).float().cuda()
        self.fc2 = nn.Linear(2*nC,1).float().cuda()

    def forward(self, x, y):
        x = F.relu(self.fc1(x))
        y = F.relu(self.fc1(y))
        xy = torch.cat([x,y], dim=-1)
        return self.fc2(xy).squeeze()

nepochs = 5000
pperiod = 50

rest_t = torch.from_numpy(rest).float().cuda()
nback_t = torch.from_numpy(nback).float().cuda()
emoid_t = torch.from_numpy(emoid).float().cuda()
age_t = torch.from_numpy(age).float().cuda()
sex_t = torch.from_numpy(sex).float().cuda()
race_t = torch.from_numpy(race).float().cuda()

fcaerest = FCAE()
fcaenback = FCAE()
fcaeemoid = FCAE()
asim = RegSim()
ssim = CatSim()
rsim = CatSim()

optim = torch.optim.Adam(list(fcaerest.parameters()) 
                         + list(fcaenback.parameters())
                         + list(fcaeemoid.parameters())
                         + list(asim.parameters()) 
                         + list(ssim.parameters()) 
                         + list(rsim.parameters()), lr=1e-4, weight_decay=1e-4)

def rmse(yhat, y):
    return torch.mean((y-yhat)**2)**0.5

def fmt(num):
    return '{0:.2f}'.format(float(num))

ce = nn.CrossEntropyLoss()

for e in range(nepochs):
    for mod, ae in zip([rest_t, nback_t, emoid_t], [fcaerest, fcaenback, fcaeemoid]):
        optim.zero_grad()
        idcs = np.random.permutation(len(rest_t))[:20]
        xb = mod[idcs]
        x1b = rest_t[idcs]
        x2b = nback_t[idcs]
        x3b = emoid_t[idcs]
        _, z = ae(xb)
        x1hat = fcaerest.rev(z)
        x2hat = fcaenback.rev(z)
        x3hat = fcaeemoid.rev(z)
        # ab = age_t[idcs]
        # sb = sex_t[idcs]
        # rb = race_t[idcs]
        # ib = torch.cartesian_prod(torch.arange(10), torch.arange(10))
        # z1b1 = z1[ib[:,0]]
        # z1b2 = z1[ib[:,1]]
        # z2b1 = z2[ib[:,0]]
        # z2b2 = z2[ib[:,1]]
        # z3b1 = z3[ib[:,0]]
        # z3b2 = z3[ib[:,1]]
        # aab = torch.cartesian_prod(ab, ab)
        # aab = aab[:,0]-aab[:,1]
        # ssb = torch.cartesian_prod(sb, sb)
        # ssb = (ssb[:,0]+2*ssb[:,1]).long()
        # rrb = torch.cartesian_prod(rb, rb)
        # rrb = (rrb[:,0]+2*rrb[:,1]).long()
        # ahat1 = asim(z1b1, z1b2)
        # shat1 = ssim(z1b1, z1b2)
        # rhat1 = rsim(z1b1, z1b2)
        # ahat2 = asim(z2b1, z2b2)
        # shat2 = ssim(z2b1, z2b2)
        # rhat2 = rsim(z2b1, z2b2)
        # ahat3 = asim(z3b1, z3b2)
        # shat3 = ssim(z3b1, z3b2)
        # rhat3 = rsim(z3b1, z3b2)
        loss1 = rmse(x1hat, x1b)
        loss2 = rmse(x2hat, x2b)
        loss3 = rmse(x3hat, x3b)
        # aloss = rmse(ahat, aab)
        # sloss = ce(shat, ssb)
        # rloss = ce(rhat, rrb)
        losses = [loss1, loss2, loss3]
        loss = 0
        for ls in losses:
            loss += ls
        loss.backward()
        optim.step()
        if e%pperiod == 0 or e == nepochs-1:
            print(f'{e} ', end='')
            for ls in losses:
                print(f'{fmt(ls)} ', end='')
            print()

print('Complete')

0 0.41 0.38 0.39 
0 0.37 0.34 0.36 
0 0.41 0.36 0.37 
50 0.19 0.19 0.20 
50 0.21 0.16 0.21 
50 0.22 0.19 0.18 
100 0.19 0.19 0.19 
100 0.21 0.16 0.20 
100 0.22 0.20 0.17 
150 0.18 0.19 0.20 
150 0.21 0.17 0.19 
150 0.22 0.18 0.18 
200 0.19 0.19 0.19 
200 0.21 0.16 0.19 
200 0.21 0.17 0.16 
250 0.17 0.17 0.20 
250 0.19 0.17 0.17 
250 0.21 0.18 0.16 
300 0.20 0.19 0.19 
300 0.21 0.17 0.20 
300 0.22 0.17 0.16 
350 0.18 0.17 0.18 
350 0.20 0.15 0.18 
350 0.21 0.18 0.16 
400 0.19 0.19 0.19 
400 0.23 0.17 0.19 
400 0.20 0.17 0.17 
450 0.18 0.19 0.18 
450 0.20 0.16 0.18 
450 0.21 0.17 0.17 
500 0.20 0.21 0.20 
500 0.20 0.16 0.19 
500 0.20 0.18 0.17 
550 0.18 0.16 0.17 
550 0.19 0.17 0.18 
550 0.20 0.18 0.17 
600 0.20 0.18 0.19 
600 0.19 0.17 0.19 
600 0.19 0.17 0.16 
650 0.18 0.17 0.17 
650 0.22 0.16 0.17 
650 0.20 0.17 0.16 
700 0.18 0.17 0.18 
700 0.20 0.16 0.17 
700 0.19 0.17 0.16 
750 0.18 0.19 0.19 
750 0.18 0.15 0.18 
750 0.19 0.17 0.16 
800 0.18 0.18 0.18 
800 0.19 0.16 0.18 
800 0.20 

KeyboardInterrupt: 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

nA = 4000
nB = 1000

class FCAE(nn.Module):
    def __init__(self):
        super(FCAE, self).__init__()
        self.fwd1 = nn.Linear(34716,nA).float().cuda()
        self.fwd2 = nn.Linear(nA,nB).float().cuda()
        self.rev1 = nn.Linear(nB,nA).float().cuda()
        self.rev2 = nn.Linear(nA,34716).float().cuda()

    def fwd(self, x):
        x = F.relu(self.fwd1(x))
        x = self.fwd2(x)
        return x

    def rev(self, x):
        x = F.relu(self.rev1(x))
        x = self.rev2(x)
        return x

    def forward(self, x):
        z = self.fwd(x)
        x = self.rev(z)
        return x, z

def rmse(yhat, y):
    return torch.mean((y-yhat)**2)**0.5

ae = FCAE()
optim = torch.optim.Adam(ae.parameters(), lr=1e-4, weight_decay=1e-4)

x1 = torch.from_numpy(rest).float().cuda()
x2 = torch.from_numpy(nback).float().cuda()

nepochs = 5000
pperiod = 50

for e in range(nepochs):
    optim.zero_grad()
    optim.zero_grad()
    idcs = np.random.permutation(len(x1))[:100]
    x1b = x1[idcs]
    x2b = x2[idcs]
    x2hat, z = ae(x1b)
    loss = rmse(x2hat, x2b)
    loss.backward()
    optim.step()
    if e%pperiod == 0 or e == nepochs-1:
        print(f'{e} {float(loss):.3f}')

print('Complete')

0 0.351
50 0.193
100 0.189
150 0.182
200 0.174
250 0.174
300 0.174
350 0.168
400 0.178
450 0.169
500 0.178
550 0.160
600 0.168
650 0.158
700 0.163
750 0.163
800 0.166
850 0.161
900 0.159
950 0.159
1000 0.161
1050 0.162
1100 0.163
1150 0.158
1200 0.165
1250 0.166


KeyboardInterrupt: 