In [None]:
# This notebook trains & benchmarks flows on a multivariate gaussian problem.
# See flows.py for a library of modules.

In [3]:
import torch
from torch.distributions import *
from flows import *

# - See kl-estimator.ipynb for details & benchmark
def kl_estimate_log(log_px, log_qx, n):
    return (log_px - log_qx).mean()

# dataset is skewed, correlated multivariate gaussian
dim = 2
datapoints = 2500

mu = torch.tensor([4.5, -4.5])
sigma = torch.tensor([[3.0, 2], [2, 3.0]])
dist = MultivariateNormal(mu, sigma)

table = dist.sample_n(datapoints)
dataset = torch.utils.data.TensorDataset(table)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1000)
print(table)

tensor([[ 4.6453, -3.4431],
        [ 3.7798, -3.7737],
        [ 5.5684, -3.9274],
        ...,
        [ 4.9403, -4.7214],
        [ 3.2602, -4.9676],
        [ 3.5042, -6.5901]])


In [4]:
dim = 2
m = 4
datapoints = 2500

from torch.distributions import MixtureSameFamily

mix = Categorical(torch.ones(m,))
comp = Independent(Normal(
             torch.randn(m,2), torch.rand(m,2)), 1)
dist = MixtureSameFamily(mix, comp)

table = dist.sample_n(datapoints)
dataset = torch.utils.data.TensorDataset(table)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1000)

In [5]:
# flows convert to normal dist
normal_flows = Flows(
    DenseTriangularFlow(dim, True),
    DenseTriangularFlow(dim, False),
    DenseTriangularFlow(dim, True),
    DenseTriangularFlow(dim, False),
    SoftsquareFlow(dim),
    DenseTriangularFlow(dim, True),
    DenseTriangularFlow(dim, False),
    DenseTriangularFlow(dim, True),
    DenseTriangularFlow(dim, False),
    SoftsquareFlow(dim),
    DenseTriangularFlow(dim, True),
    DenseTriangularFlow(dim, False),
    DenseTriangularFlow(dim, True),
    DenseTriangularFlow(dim, False),
    SoftsquareFlow(dim),
)

# final flow converts to normal

# train on NLL
net = FlowModule(normal_flows, NegLogLikelihoodLoss(dim))

optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)

In [6]:
final_loss = 0.0
for batch in range(200):  #1oop over the dataset multiple times
    for _, data in enumerate(dataloader, 0):
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        loss = net(data[0])
        loss.backward()
        optimizer.step()
        final_loss = loss
print("final loss: %.03f" % final_loss)

final loss: 2.186


In [7]:
for f in normal_flows.flows:
    print(f)
    for name, param in f.named_parameters():
        print("  %s = %s" % (name, param.data))
    print()

DenseTriangularFlow()
  w = tensor([[ 0.6842, -0.7064],
        [ 0.0000,  0.8338]])
  b = tensor([ 0.0741, -0.2736])

DenseTriangularFlow()
  w = tensor([[ 0.8468,  0.0000],
        [-0.2732,  0.8877]])
  b = tensor([ 0.0494, -0.2820])

DenseTriangularFlow()
  w = tensor([[ 0.8459, -0.5287],
        [ 0.0000,  0.9183]])
  b = tensor([ 0.0527, -0.2920])

DenseTriangularFlow()
  w = tensor([[ 0.9603,  0.0000],
        [-0.2330,  0.9518]])
  b = tensor([ 0.0321, -0.2999])

SoftsquareFlow()
  a = tensor([1.0547, 1.0072])
  b = tensor([0.0022, 0.1885])

DenseTriangularFlow()
  w = tensor([[1.0057, 0.2035],
        [0.0000, 0.9777]])
  b = tensor([0.0586, 0.1998])

DenseTriangularFlow()
  w = tensor([[1.0249, 0.0000],
        [0.2095, 0.9839]])
  b = tensor([0.0544, 0.1973])

DenseTriangularFlow()
  w = tensor([[1.0210, 0.2053],
        [0.0000, 0.9853]])
  b = tensor([0.0577, 0.2055])

DenseTriangularFlow()
  w = tensor([[1.0423, 0.0000],
        [0.2206, 0.9932]])
  b = tensor([0.0529, 0.

In [8]:
n = 50000
points = dist.sample_n(n)
p_points = dist.log_prob(points)

normalized_points, log_det = normal_flows(points)
dist_target = MultivariateNormal(torch.zeros(dim), torch.eye(dim))
p_normalized_points = dist_target.log_prob(normalized_points) + log_det

uniform_with_y = torch.stack([p_points.exp(), p_normalized_points.exp()], dim=1)
print(uniform_with_y)

print()

print("~kl div:", kl_estimate_log(p_points, p_normalized_points, n))

tensor([[0.1997, 0.2417],
        [0.0609, 0.1166],
        [0.2301, 0.1625],
        ...,
        [0.1393, 0.2648],
        [0.0936, 0.2073],
        [2.3091, 0.1859]], grad_fn=<StackBackward>)

~kl div: tensor(0.4877, grad_fn=<MeanBackward0>)


In [9]:
n = 5000
points = dist.sample_n(n)
p_points = dist.log_prob(points)

for ps, log_det, f in zip(*normal_flows.forward_trace(points),normal_flows.flows):
    dist_target = MultivariateNormal(torch.zeros(dim), torch.eye(dim))
    p_normalized_points = dist_target.log_prob(ps) + log_det
    approx_kl_divergence = kl_estimate_log(p_points, p_normalized_points, n).item()
    params = list(f.parameters())
    print("%1.4f - %s:" % (approx_kl_divergence, f))

1.4280 - DenseTriangularFlow():
1.8065 - DenseTriangularFlow():
2.5409 - DenseTriangularFlow():
3.1455 - DenseTriangularFlow():
3.9743 - SoftsquareFlow():
3.2066 - DenseTriangularFlow():
2.6349 - DenseTriangularFlow():
2.1770 - DenseTriangularFlow():
1.8212 - DenseTriangularFlow():
1.8154 - SoftsquareFlow():
1.3888 - DenseTriangularFlow():
1.0681 - DenseTriangularFlow():
0.8200 - DenseTriangularFlow():
0.6772 - DenseTriangularFlow():
0.4593 - SoftsquareFlow():
