In [1]:
import os
import numpy as np
import pyvista as pv
from morphomatics.geom import Surface
from morphomatics.manifold import FundamentalCoords, PointDistributionModel, util
from morphomatics.stats import StatisticalShapeModel
import torch 
import torch.nn as nn 
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


# max number of objects (for debugging and dev)
nObjects=1000
dataPath = 'adni_hippos_hackathon/'
# load all data
meshes = []
labels = []
for (dirpath, dirnames, filenames) in os.walk(dataPath+"AD"):
    for file in filenames:
        if file[-3:] == "obj" and nObjects > len(meshes):
            path = os.sep.join([dirpath, file])
            mesh = pv.read(path)
            meshes.append(mesh)
            labels.append(0)

for (dirpath, dirnames, filenames) in os.walk(dataPath+"CN"):
    for file in filenames:
        if file[-3:] == "obj" and nObjects > len(meshes):
            path = os.sep.join([dirpath, file])
            load = mesh_straight = pv.read(path)
            meshes.append(load)
            labels.append(1)


# to Surface type
as_surface = lambda mesh: Surface(mesh.points, mesh.faces.reshape(-1, 4)[:, 1:])
surfaces = [as_surface(m) for m in meshes]

# construct model
SSM = StatisticalShapeModel(lambda ref: FundamentalCoords(ref)) # replace me with PointDistributionModel
SSM.construct(surfaces)

coeffs = SSM.coeffs
#print shape
print("Coeffs Shape: " + str(len(coeffs)) + " " + str(len(coeffs[0])))


|grad|=0.23973914301374424
|grad|=0.0009564107294780184
|grad|=2.4561595037871647e-06
|grad|=2.8440123054478917e-08
|grad|=4.0954989297610907e-10
|grad|=6.125792969023894e-12
|grad|=9.409030725833846e-14
9.377399291627178
|grad|=0.01988499049380428
|grad|=3.228057153144882e-05
|grad|=3.294534810829896e-07
|grad|=4.366021670229755e-09
|grad|=6.611715356867535e-11
|grad|=1.0812418833606873e-12
|grad|=1.850722683051106e-14
0.16063346900603956
|grad|=0.03172870833461714
|grad|=4.5086254228981205e-05
|grad|=4.2012664424070966e-07
|grad|=5.2599583021068625e-09
|grad|=7.62889242044419e-11
|grad|=1.1967226960095914e-12
|grad|=1.967629190672057e-14
0.01440858120403914
|grad|=0.031703121638177886
|grad|=4.47953245676806e-05
|grad|=4.178297588792206e-07
|grad|=5.233306458442942e-09
|grad|=7.584089926324345e-11
|grad|=1.1879419889817283e-12
|grad|=1.9497257126118636e-14
0.0007388793561633242
tol 0.006683843450775681 reached
|grad|=0.2378210044724232
|grad|=0.0008617717956731069
|grad|=2.4179836850

In [2]:
coeffs = np.array(coeffs)
coeffs = torch.from_numpy(coeffs).float()
labels = torch.from_numpy(np.array(labels)).float()

x_test = coeffs[54:60]
y_test = labels[54:60]

x_train = coeffs[:54]
y_train = labels[:54]


x_train_norm = coeffs[60:114]
y_train_norm = labels[60:114]

x_test_norm = coeffs[114:]
y_test_norm = labels[114:]

print(y_train_norm)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])


In [15]:

from FrEIA.framework import InputNode, OutputNode, Node, ReversibleGraphNet, ConditionNode
from FrEIA.modules import GLOWCouplingBlock

device = 'cuda'

DIMENSION = 119
neural_net = nn.Sequential(nn.Linear(DIMENSION,256),nn.ReLU(),
                           nn.Linear(256,256),nn.ReLU(),
                           nn.Linear(256,1)).to(device)

neural_net2 = nn.Sequential(nn.Linear(DIMENSION,256),nn.ReLU(),
                           nn.Linear(256,256),nn.ReLU(),
                           nn.Linear(256,1)).to(device)

optimizer = torch.optim.Adam(neural_net.parameters(), lr = 1e-3)
optimizer2 = torch.optim.Adam(neural_net2.parameters(), lr = 1e-3)

batch_size = 16

def create_INN(num_layers, sub_net_size,dimension=119):
    def subnet_fc(c_in, c_out):
        return nn.Sequential(nn.Linear(c_in, sub_net_size), nn.ReLU(),
                             nn.Linear(sub_net_size, sub_net_size), nn.ReLU(),
                             nn.Linear(sub_net_size,  c_out))
    nodes = [InputNode(dimension, name='input')]
    for k in range(num_layers):
        nodes.append(Node(nodes[-1],
                          GLOWCouplingBlock,
                          {'subnet_constructor':subnet_fc, 'clamp':1.4},
                          name=F'coupling_{k}'))
    nodes.append(OutputNode(nodes[-1], name='output'))

    model = ReversibleGraphNet(nodes, verbose=False).to(device)
    return model




def langevin_step(x, stepsize, neural_net, lang_steps):
        log_det = torch.zeros((x.shape[0], 1), device = device)
        beta=1.
        for i in range(lang_steps):
            x = x.requires_grad_(True)

            eta = torch.randn_like(x, device = device)
            out = neural_net(x)
            grad_x = torch.autograd.grad(out.sum(), x,create_graph=True)[0]
            y = x - stepsize * grad_x + np.sqrt(2*stepsize/beta) * eta

            x = y
        return x

def train(x_train,neural_net, optimizer):
    perm = torch.randperm(len(x_train))[:batch_size]
    xs = x_train[perm].to(device)
    optimizer.zero_grad()
    loss = torch.mean(neural_net(xs))
    mcmc_samples = langevin_step(torch.randn(batch_size,DIMENSION, device = device), 1e-4, neural_net, lang_steps = 300)
    loss += -1*torch.mean(neural_net(mcmc_samples))
    loss.backward()
    optimizer.step()
    return loss

def val_step():
    xs = x_val.to(device)
    loss = torch.mean(neural_net(xs))
    mcmc_samples = langevin_step(torch.randn(len(xs),DIMENSION, device = device), 1e-4, neural_net, lang_steps = 300)
    loss += -1*torch.mean(neural_net(mcmc_samples))
    return loss



In [None]:
for i in range(1000):
    loss = train(x_train,neural_net, optimizer)
    print(loss)
    
for i in range(1000):
    loss = train(x_train_norm,neural_net2, optimizer2)
    print(loss)

tensor(-0.0122, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-0.2160, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-0.4676, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-0.6737, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-0.9256, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1.1567, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1.4129, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1.8053, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-2.1806, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-2.2975, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-2.8218, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-3.3223, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-3.6690, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-3.9989, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-4.4614, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-4.9297, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-5.5297, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-5.7758, device='cuda:0'

tensor(-479.8286, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-466.8793, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-488.1471, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-504.1988, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-502.3183, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-501.8760, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-522.2477, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-526.2699, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-538.2120, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-512.1511, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-536.7583, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-545.8595, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-569.8512, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-567.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-591.8408, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-583.5853, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-575.8260, device='cuda:0', grad_fn=<AddBackward0

tensor(-936.5078, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-901.6244, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-860.8546, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-895.2538, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-881.5646, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-919.5294, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-955.3659, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-958.5787, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-895.2845, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-929.6266, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-913.5258, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-942.4113, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-899.0986, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-919.6689, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1011.9750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-959.1148, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-948.1602, device='cuda:0', grad_fn=<AddBackward

tensor(-964.6086, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-959.8680, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-975.3490, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1010.1038, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-985.9919, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-986.5522, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-969.9723, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-956.3768, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-928.1321, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-981.3726, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-953.8739, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-982.3711, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-999.2524, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1019.1527, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-994.6411, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-948.9855, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-962.2110, device='cuda:0', grad_fn=<AddBackwar

tensor(-998.6959, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1013.3466, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-996.3119, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-967.4206, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-938.3942, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-994.3204, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-938.9814, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-999.8709, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1010.6653, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-940.3483, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-964.2704, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1009.4615, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1009.0516, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-934.1050, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-983.4280, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1020.1366, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-955.1414, device='cuda:0', grad_fn=<AddBack

tensor(-961.1126, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1064.5890, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-996.0427, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1008.5964, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-999.2488, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1062.4352, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1027.7898, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-946.1623, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1027.3102, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-983.7078, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1041.3511, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1005.2628, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1015.0450, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-973.4646, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-997.0715, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-959.0812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-971.3987, device='cuda:0', grad_fn=<AddB

tensor(-977.8472, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1016.4179, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1036.1426, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1044.1993, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1049.1528, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-987.5551, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1036.0126, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1037.3893, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-996.4769, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1012.1517, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1075.7494, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1074.0488, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1004.8546, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1042.9478, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-945.6807, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1006.5997, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1059.7645, device='cuda:0', grad_fn=

tensor(-1196.1167, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1175.2965, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1169.2560, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-1219.8737, device='cuda:0', grad_fn=<AddBackward0>)


In [14]:
def estimate_norm_constant(neural_net):
    INN = create_INN(4,256)
    opti_INN = torch.optim.Adam(INN.parameters(),lr = 1e-3)
    for i in range(2000):
        z = torch.randn(32, DIMENSION, device = device)
        out, jac = INN(z)
        loss = torch.mean(neural_net(out))-torch.mean(jac)
        opti_INN.zero_grad()
        loss.backward()
  
        opti_INN.step()
    z = torch.randn(64, DIMENSION, device = device)
    out, jac = INN(z)
    loss =  torch.mean(neural_net(out))- torch.mean(jac)      
    return loss
norm1 = estimate_norm_constant(neural_net)
norm2 = estimate_norm_constant(neural_net2)
print(norm1)
print(norm2)
probs_alz = neural_net(x_test.to(device))+norm1
probs_alz2 = neural_net2(x_test.to(device))+norm2

print(probs_alz)
print(probs_alz2)

probs1 = neural_net(x_test_norm.to(device))+norm1
probs2 = neural_net2(x_test_norm.to(device))+norm2
print(probs1)
print(probs2)

tensor(363.7380, device='cuda:0', grad_fn=<SubBackward0>)
tensor(368.5469, device='cuda:0', grad_fn=<SubBackward0>)
tensor([[382.1083],
        [379.5258],
        [381.2970],
        [384.3909],
        [385.0454],
        [380.8163]], device='cuda:0', grad_fn=<AddBackward0>)
tensor([[393.0344],
        [384.4833],
        [386.8255],
        [389.2781],
        [395.1958],
        [386.6518]], device='cuda:0', grad_fn=<AddBackward0>)
tensor([[385.1948],
        [386.6317],
        [378.1729],
        [378.3347],
        [378.4503],
        [400.5706]], device='cuda:0', grad_fn=<AddBackward0>)
tensor([[386.8570],
        [389.9632],
        [382.5577],
        [385.3625],
        [381.8767],
        [399.1158]], device='cuda:0', grad_fn=<AddBackward0>)
