In [1]:
# evaluation file

In [22]:
#inverse kinematics test
import torch
import torch.nn as nn
import numpy as np
from FrEIA.framework import InputNode, OutputNode, Node, ReversibleGraphNet
from FrEIA.modules import GLOWCouplingBlock, PermuteRandom
import random

In [33]:
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

# use no padding
ndim_tot = 4
ndim_x = 4
ndim_y = 2
ndim_z = 2

def subnet_fc(c_in, c_out):
    return nn.Sequential(nn.Linear(c_in, 256), nn.ReLU(),
                         nn.Linear(256, 256), nn.ReLU(),
                         nn.Linear(256,  c_out))

nodes = [InputNode(ndim_tot, name='input')]

for k in range(7):
    nodes.append(Node(nodes[-1],
                      GLOWCouplingBlock,
                      {'subnet_constructor':subnet_fc, 'clamp': 0.7},
                      name=F'coupling_{k}'))
    nodes.append(Node(nodes[-1],
                      PermuteRandom,
                      {'seed':k},
                      name=F'permute_{k}'))

nodes.append(OutputNode(nodes[-1], name='output'))

model = ReversibleGraphNet(nodes, verbose=False).to(device)

# number of gaussians to be used
no_gaussians = 2 
scale = 0.3
means = torch.tensor([[-2.,0],[2.,0]], device = device)



prob_net = nn.Sequential(nn.Linear(ndim_y,64), nn.ReLU(),nn.Linear(64,64), nn.ReLU(),
                        nn.Linear(64,no_gaussians), nn.Softmax(dim = 1)).to(device)

model.load_state_dict(torch.load('Gumbel_new.pt'))
prob_net.load_state_dict(torch.load('Gumbel_probs_new.pt'))

def f(x):
    y1 = x[:,0] + 0.5*torch.sin(x[:,2]-x[:,1])+0.5*torch.sin(x[:,1])+torch.sin(x[:,3]-x[:,2]-x[:,1])
    y2 = 0.5*torch.cos(x[:,2]-x[:,1])+0.5*torch.cos(x[:,1])+torch.cos(x[:,3]-x[:,2]-x[:,1])
    y1 = y1.view(1,1)
    y2 = y2.view(1,1)
    return torch.cat((y1,y2),1)

In [34]:
def MMD_multiscale(x, y):
    xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())

    rx = (xx.diag().unsqueeze(0).expand_as(xx))
    ry = (yy.diag().unsqueeze(0).expand_as(yy))

    dxx = rx.t() + rx - 2.*xx
    dyy = ry.t() + ry - 2.*yy
    dxy = rx.t() + ry - 2.*zz

    XX, YY, XY = (torch.zeros(xx.shape).to(device),
                  torch.zeros(xx.shape).to(device),
                  torch.zeros(xx.shape).to(device))

    for a in [0.05, 0.2, 0.9]:
        XX += a**2 * (a**2 + dxx)**-1
        YY += a**2 * (a**2 + dyy)**-1
        XY += a**2 * (a**2 + dxy)**-1

    return torch.mean(XX + YY - 2.*XY)

no_runs = 50

In [35]:
mean_mmd = 0
mean_resim = 0

tens = np.load('rejection_samples_multi.npy', allow_pickle = True)
tens = torch.tensor(tens).to(device).float()

for i in range(no_runs):

    y = torch.zeros(len(tens),2, device = device)
    y[:,0] = 0.
    y[:,1] = 1.5
    z = scale*torch.randn(len(tens),2, device = device)
    # update the means according to the probabilities
    probs = prob_net(y.to(device))
    for i in range(len(z)):
        num = random.random()
        if num < probs[0][0]:
            z[i] = z[i] + means[0]
        else: 
            z[i] = z[i] + means[1]     
    inp = torch.cat((z,y),1).to(device)
    out = model(inp, rev = True)

    s = 0 
    for j in range(len(out)):
        # resimulation error
        fx = f(out[j].view(1,4))[0]
        s += torch.sum((fx-torch.tensor([0,1.5], device = device))**2)  
    
    mean_resim += s/len(out)

    mean_mmd += MMD_multiscale(tens,out)

print(mean_mmd/no_runs)
print(mean_resim/no_runs)

tensor(0.0291, grad_fn=<DivBackward0>)
tensor(0.0003, grad_fn=<DivBackward0>)


In [36]:
mean_mmd = 0
mean_resim = 0

tens = np.load('rejection_samples_uni.npy', allow_pickle = True)
tens = torch.tensor(tens).to(device).float()

for i in range(no_runs):
    y = torch.zeros(len(tens),2, device = device)
    y[:,0] = 0.5
    y[:,1] = 1.5
    z = scale*torch.randn(len(tens),2, device = device)
    # update the means according to the probabilities
    probs = prob_net(y.to(device))
    for i in range(len(z)):
        num = random.random()
        if num < probs[0][0]:
            z[i] = z[i] + means[0]
        else: 
            z[i] = z[i] + means[1]     
    inp = torch.cat((z,y),1).to(device)
    out = model(inp, rev = True)

    s = 0 
    for j in range(len(out)):
        # resimulation error
        fx = f(out[j].view(1,4))[0]
        s += torch.sum((fx-torch.tensor([0.5,1.5], device = device))**2)  
    
    mean_resim += s/len(out)

    mean_mmd += MMD_multiscale(tens,out)

print(mean_mmd/no_runs)
print(mean_resim/no_runs)

tensor(0.0105, grad_fn=<DivBackward0>)
tensor(0.0014, grad_fn=<DivBackward0>)


In [37]:
def subnet_fc(c_in, c_out):
    return nn.Sequential(nn.Linear(c_in, 256), nn.ReLU(),
                         nn.Linear(256, 256), nn.ReLU(),
                         nn.Linear(256,  c_out))

nodes = [InputNode(ndim_tot, name='input')]

for k in range(7):
    nodes.append(Node(nodes[-1],
                      GLOWCouplingBlock,
                      {'subnet_constructor':subnet_fc, 'clamp':2.0},
                      name=F'coupling_{k}'))
    nodes.append(Node(nodes[-1],
                      PermuteRandom,
                      {'seed':k},
                      name=F'permute_{k}'))

nodes.append(OutputNode(nodes[-1], name='output'))

model = ReversibleGraphNet(nodes, verbose=False).to(device)

model.load_state_dict(torch.load('inn.pt'))


<All keys matched successfully>

In [38]:
mean_mmd = 0
mean_resim = 0

tens = np.load('rejection_samples_multi.npy', allow_pickle = True)
tens = torch.tensor(tens, device = device).float()

for i in range(no_runs):

    y = torch.zeros(len(tens),2, device = device)
    y[:,0] = 0.
    y[:,1] = 1.5
    z = torch.randn(len(tens),2, device = device)

    inp = torch.cat((z,y),1).to(device)
    out = model(inp, rev = True)

    s = 0 
    for j in range(len(out)):
        # resimulation error
        fx = f(out[j].view(1,4))[0]
        s += torch.sum((fx-torch.tensor([0,1.5], device = device))**2)  
    mean_resim += (s/len(out))
    mean_mmd += (MMD_multiscale(tens,out))
    
print(mean_resim/no_runs)
print(mean_mmd/no_runs)

tensor(0.0084, grad_fn=<DivBackward0>)
tensor(0.0361, grad_fn=<DivBackward0>)


In [None]:
mean_mmd = 0
mean_resim = 0

tens = np.load('rejection_samples_uni.npy', allow_pickle = True)
tens = torch.tensor(tens, device = device).float()


for i in range(no_runs):
    y = torch.zeros(len(tens),2, device = device)
    y[:,0] = 0.5
    y[:,1] = 1.5
    z = torch.randn(len(tens),2, device = device)

    inp = torch.cat((z,y),1).to(device)
    out = model(inp, rev = True)
    s = 0 
    for j in range(len(out)):
        # resimulation error
        fx = f(out[j].view(1,4))[0]
        s += torch.sum((fx-torch.tensor([0.5,1.5], device = device))**2)  
    mean_resim += (s/len(out))
    mean_mmd += (MMD_multiscale(tens,out))
    
print(mean_resim/no_runs)
print(mean_mmd/no_runs)