In [0]:
import numpy as np
import torch
import torch.distributions as D
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import plotly.graph_objs as go

In [0]:
def sample_data():
  count = 100000
  rand = np.random.RandomState(0)
  a = [[-1.5, 2.5]] + rand.randn(count // 3, 2) * 0.2
  b = [[1.5, 2.5]] + rand.randn(count // 3, 2) * 0.2
  c = np.c_[2 * np.cos(np.linspace(0, np.pi, count // 3)),
  -np.sin(np.linspace(0, np.pi, count // 3))]

  c += rand.randn(*c.shape) * 0.2
  data_x = np.concatenate([a, b, c], axis=0)
  data_y = np.array([0] * len(a) + [1] * len(b) + [2] * len(c))
  perm = rand.permutation(len(data_x))
  return data_x[perm], data_y[perm]

In [0]:
X, y = sample_data()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
X_train, X_test = torch.tensor(X_train), torch.tensor(X_test)

In [0]:
class FlowModel(nn.Module):
  def __init__(self, k):
    super(FlowModel, self).__init__()
    self.mu1 = nn.Parameter(torch.randn(k))
    self.mu2 = nn.Parameter(torch.randn(k))
    self.sigma1 = nn.Parameter(torch.randn(k))
    self.sigma2 = nn.Parameter(torch.randn(k))
    self.c1 = nn.Parameter(torch.randn(k))
    self.c2 = nn.Parameter(torch.randn(k))

  def forward(self, X):
    x, y = X[:, 0], X[:,1]
    pi1 = F.softmax(self.c1)
    pi2 = F.softmax(self.c2)
    p1 = torch.stack([D.Normal(mu, sigma).log_prob(x).exp() for mu, sigma in zip(self.mu1, self.sigma1.exp())]).float()
    p2 = torch.stack([D.Normal(mu, sigma).log_prob(x).exp() for mu, sigma in zip(self.mu2, self.sigma2.exp())]).float()
    det1 = pi1 @ p1
    det2 = pi2 @ p2
    return -torch.log(det1 * det2)

  def get(self, X):
    x, y = X[:, 0], X[:, 1]
    pi1 = F.softmax(self.c1)
    pi2 = F.softmax(self.c2)
    f2 = torch.stack([D.Normal(mu, sigma).cdf(x) for mu, sigma in zip(self.mu2, self.sigma2.exp())]).float()
    f1 = torch.stack([D.Normal(mu, sigma).cdf(x) for mu, sigma in zip(self.mu1, self.sigma1.exp())]).float()
    xs = pi1 @ f1
    ys = pi2 @ f2
    return xs, ys 

In [0]:
def train(model, optimizer, train_loader, test_loader, num_epochs):
  for epoch in range(num_epochs):
    loss, val_loss = 0, 0
    for batch in train_loader:
      optimizer.zero_grad()
      curr_loss = model(batch).mean()
      loss += curr_loss.item()
      curr_loss.backward()
      optimizer.step()

    with torch.no_grad():
      for batch in test_loader:
          val_loss += model(batch).mean()

    print(f"After epoch {epoch} loss is {loss / len(train_loader)} and validation loss is {val_loss / len(test_loader)}")

In [130]:
batch_size = 32
num_epochs = 10

flow_model = FlowModel(10)
optimizer = torch.optim.Adam(flow_model.parameters(), lr=0.0005)
train_loader = torch.utils.data.DataLoader(X_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(X_test, batch_size=batch_size)
train(flow_model, optimizer, train_loader, test_loader, num_epochs)


Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.


Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.



After epoch 0 loss is 3.0572311900615694 and validation loss is 2.4002645015716553
After epoch 1 loss is 2.1947801872253416 and validation loss is 2.127822160720825
After epoch 2 loss is 2.127400024986267 and validation loss is 2.1206793785095215
After epoch 3 loss is 2.124851322746277 and validation loss is 2.1190555095672607
After epoch 4 loss is 2.124057114458084 and validation loss is 2.1189491748809814
After epoch 5 loss is 2.123489771604538 and validation loss is 2.1183300018310547
After epoch 6 loss is 2.1231123933315277 and validation loss is 2.1179161071777344
After epoch 7 loss is 2.1228314596652984 and validation loss is 2.117908239364624
After epoch 8 loss is 2.1224295464038847 and validation loss is 2.116994619369507
After epoch 9 loss is 2.1220996685028077 and validation loss is 2.117205858230591


In [0]:
def showDensity(model, k):
  x = np.linspace(-4, 4, k)
  y = np.linspace(-4, 4, k)

  xx, yy = np.meshgrid(x, y)
  points = np.vstack([xx.reshape(-1), yy.reshape(-1)]).T
  res = model(torch.tensor(points)).detach().numpy()
  z = np.exp(-res)
  plt.pcolormesh(xx, yy, np.exp(-res).reshape(k, k))

In [0]:
def show(model, data):
  xs, ys = model.get(torch.tensor(X))
  data = go.Scatter(x=xs.detach().numpy(), y=ys.detach().numpy(), mode='markers')
  fig = go.Figure(data=data, layout=go.Layout())
  fig.show()

In [133]:
show(flow_model, X[:2000])


Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.


Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.



NameError: ignored

In [0]:
showDensity(flow_model, 200)

In [0]:
class RealNVP(nn.Module):
    def __init__(self, nets, nett, masks):
        super(RealNVP, self).__init__()
        
        self.n = len(masks)
        self.masks = masks
        self.t = torch.nn.ModuleList([nett() for _ in range(self.n)])
        self.s = torch.nn.ModuleList([nets() for _ in range(self.n)])

    def run(self, Z):
        batch_size = Z.shape[0]
        det_sum = torch.zeros(batch_size).float() 
        for i in reversed(range(self.n)):
            Z_ = self.masks[i] * Z
            s = self.s[i](Z_.float())
            t = self.t[i](Z_.float())
            Z = Z_ + (1 - self.masks[i]) * (Z * torch.exp(s) + t)
            det_sum += (s * (1 - self.masks[i])).sum(dim=1)
        Z = torch.sigmoid(Z)
        det_sum += torch.log(Z * (1 - Z) + 1e-9).sum(dim=1)
        return Z, torch.abs(det_sum)
    
    def forward(self, X):
      return self.run(X)[1]  

    def get(self, X):
      Z = self.run(X)[0]
      return Z[:, 0], Z[:, 1]        

In [0]:
n = 3
nets = lambda: nn.Sequential(nn.Linear(2, 64), nn.LeakyReLU(), nn.Linear(64, 64), nn.LeakyReLU(), nn.Linear(64, 2), nn.Sigmoid())
nett = lambda: nn.Sequential(nn.Linear(2, 64), nn.LeakyReLU(), nn.Linear(64, 64), nn.LeakyReLU(), nn.Linear(64, 2))
masks = torch.from_numpy(np.array([[0, 0], [0, 0]] * n).astype(np.float32))

In [159]:
batch_size = 128
num_epochs = 10

realNVP_model = RealNVP(nets, nett, masks)
optimizer = torch.optim.Adam(realNVP_model.parameters(), lr=0.001)
train_loader = torch.utils.data.DataLoader(X_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(X_test, batch_size=batch_size)
train(realNVP_model, optimizer, train_loader, test_loader, num_epochs)

After epoch 0 loss is 15.065730276489258 and validation loss is 14.485151290893555
After epoch 1 loss is 14.511493202209472 and validation loss is 14.47329044342041
After epoch 2 loss is 14.513773489379883 and validation loss is 14.474895477294922
After epoch 3 loss is 14.51212767944336 and validation loss is 14.477805137634277
After epoch 4 loss is 14.511634841918946 and validation loss is 14.478279113769531
After epoch 5 loss is 14.510469007873535 and validation loss is 14.481213569641113


KeyboardInterrupt: ignored

In [149]:
show(realNVP_model, X[:1000])