<a href="https://colab.research.google.com/github/XRater/DUL_2019/blob/hw2/DUL_HW2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import torch
import torch.distributions as D
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import plotly.graph_objs as go

In [0]:
def sample_data():
  count = 100000
  rand = np.random.RandomState(0)
  a = [[-1.5, 2.5]] + rand.randn(count // 3, 2) * 0.2
  b = [[1.5, 2.5]] + rand.randn(count // 3, 2) * 0.2
  c = np.c_[2 * np.cos(np.linspace(0, np.pi, count // 3)),
  -np.sin(np.linspace(0, np.pi, count // 3))]

  c += rand.randn(*c.shape) * 0.2
  data_x = np.concatenate([a, b, c], axis=0)
  data_y = np.array([0] * len(a) + [1] * len(b) + [2] * len(c))
  perm = rand.permutation(len(data_x))
  return data_x[perm], data_y[perm]

In [0]:
def show_points(data, size, range=[-20, 20]):
  xs, ys = data[:size, 0], data[:size, 1]
  data = go.Scatter(x=xs, y=ys, mode='markers')
  fig = go.Figure(data=data, layout=go.Layout(width=700, height=700))
  fig.update_xaxes(range=range)
  fig.update_yaxes(range=range)
  fig.show()

In [0]:
X, y = sample_data()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
X_train, X_test = torch.tensor(X_train), torch.tensor(X_test)

In [158]:
show_points(X, 5000, range=[-5, 5])

In [0]:
class FlowModel(nn.Module):
  def __init__(self, k):
    super(FlowModel, self).__init__()
    self.mu1 = nn.Parameter(torch.randn(k))
    self.mu2 = nn.Parameter(torch.randn(k))
    self.sigma1 = nn.Parameter(torch.randn(k))
    self.sigma2 = nn.Parameter(torch.randn(k))
    self.c1 = nn.Parameter(torch.randn(k))
    self.c2 = nn.Parameter(torch.randn(k))
    self.linear_mu = nn.Linear(1, k)
    self.linear_sigma = nn.Linear(1, k)
    self.linear_pi = nn.Linear(1, k)

  def forward(self, X):
    X = X.float()
    x, y = X[:, 0], X[:,1]

    pi1 = self.c1.softmax(dim=0)
    p1 = torch.stack([D.Normal(mu, sigma).log_prob(x).exp() for mu, sigma in zip(self.mu1, self.sigma1.exp())]).float()
    det1 = pi1 @ p1

    pi2 = self.linear_pi(x.reshape(-1, 1)).softmax(dim=1)
    mu2 = self.linear_mu(x.reshape(-1, 1))
    sigma2 = self.linear_sigma(x.reshape(-1, 1))
    p2 = torch.stack([D.Normal(mu, sigma).log_prob(yi).exp() for mu, sigma, yi in zip(mu2, sigma2.exp(), y)]).float()
    det2 = torch.sum(pi2 * p2, dim=1).T
    
    return -torch.log(det1 * det2)

  def get(self, X):
    X = X.float()
    x, y = X[:, 0], X[:,1]

    pi1 = self.c1.softmax(dim=0)
    cdf1 = torch.stack([D.Normal(mu, sigma).cdf(x) for mu, sigma in zip(self.mu1, self.sigma1.exp())]).float()
    xs = pi1 @ cdf1

    pi2 = self.linear_pi(x.reshape(-1, 1)).softmax(dim=1)
    mu2 = self.linear_mu(x.reshape(-1, 1))
    sigma2 = self.linear_sigma(x.reshape(-1, 1))
    cdf2 = torch.stack([D.Normal(mu, sigma).cdf(yi) for mu, sigma, yi in zip(mu2, sigma2.exp(), y)]).float()
    ys = torch.sum(pi2 * cdf2, dim=1).T

    return torch.stack([xs, ys]).T

In [0]:
def train(model, optimizer, train_loader, test_loader, num_epochs):
  for epoch in range(num_epochs):
    loss, val_loss = 0, 0
    for batch in train_loader:
      optimizer.zero_grad()
      curr_loss = model(batch).mean()
      loss += curr_loss.item()
      curr_loss.backward()
      optimizer.step()

    with torch.no_grad():
      for batch in test_loader:
          val_loss += model(batch).mean()

    print(f"After epoch {epoch} loss is {loss / len(train_loader)} and validation loss is {val_loss / len(test_loader)}")

In [161]:
batch_size = 64
num_epochs = 10

flow_model = FlowModel(5)
optimizer = torch.optim.Adam(flow_model.parameters(), lr=0.0005)
train_loader = torch.utils.data.DataLoader(X_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(X_test, batch_size=batch_size)
train(flow_model, optimizer, train_loader, test_loader, num_epochs)

After epoch 0 loss is 3.7264023294448854 and validation loss is 2.868361473083496
After epoch 1 loss is 2.4184557793617247 and validation loss is 2.145508289337158
After epoch 2 loss is 2.026787812423706 and validation loss is 1.9155350923538208
After epoch 3 loss is 1.8762198101043701 and validation loss is 1.8117215633392334
After epoch 4 loss is 1.8007801845550537 and validation loss is 1.7558103799819946
After epoch 5 loss is 1.753107769203186 and validation loss is 1.7156506776809692
After epoch 6 loss is 1.716541729259491 and validation loss is 1.6822893619537354
After epoch 7 loss is 1.6865291851997375 and validation loss is 1.6562867164611816
After epoch 8 loss is 1.661459656047821 and validation loss is 1.6332951784133911
After epoch 9 loss is 1.6409231266021729 and validation loss is 1.615531086921692


In [0]:
def showModelDensity(model):
    x, y = np.linspace(-4, 4, 200), np.linspace(-4, 4, 200)
    xv, yv = np.meshgrid(x, y)

    points = torch.tensor([(xx, yy) for xx, yy in zip(xv.flatten(), yv.flatten())])
    density = np.exp(-model(points).detach().numpy())
    
    z = density.reshape((200, 200))
    figure = go.Figure(data=go.Heatmap(x=x, y=y, z=z))
    figure.update_layout(height=700, width=700)
    figure.show()

def showModelPoints(model, size):
    show_points(model.get(torch.tensor(X[:size])).detach().numpy(), size, range=[0, 1])

In [184]:
showModelPoints(flow_model, 5000)

In [185]:
showModelDensity(flow_model)

In [0]:
class RealNVP(nn.Module):
    def __init__(self, nets, nett, masks):
        super(RealNVP, self).__init__()
        
        self.n = len(masks)
        self.masks = masks
        self.t = torch.nn.ModuleList([nett() for _ in range(self.n)])
        self.s = torch.nn.ModuleList([nets() for _ in range(self.n)])

    def run(self, Z):
        batch_size = Z.shape[0]
        det_sum = torch.zeros(batch_size).float() 
        for i in reversed(range(self.n)):
            Z_ = self.masks[i] * Z
            s = self.s[i](Z_.float())
            t = self.t[i](Z_.float())
            Z = Z_ + (1 - self.masks[i]) * (Z * torch.exp(s) + t)
            det_sum += (s * (1 - self.masks[i])).sum(dim=1)
        Z = torch.sigmoid(Z)
        det_sum += torch.log(Z * (1 - Z) + 1e-9).sum(dim=1)
        return Z, torch.abs(det_sum)
    
    def forward(self, X):
      return self.run(X)[1]  

    def get(self, X):
      return self.run(X)[0]

In [0]:
n = 3
nets = lambda: nn.Sequential(nn.Linear(2, 64), nn.LeakyReLU(), nn.Linear(64, 64), nn.LeakyReLU(), nn.Linear(64, 2), nn.Tanh())
nett = lambda: nn.Sequential(nn.Linear(2, 64), nn.LeakyReLU(), nn.Linear(64, 64), nn.LeakyReLU(), nn.Linear(64, 2))
masks = torch.from_numpy(np.array([[1, 0], [0, 1]] * n).astype(np.float32))

In [180]:
batch_size = 128
num_epochs = 10
  
realNVP_model = RealNVP(nets, nett, masks)
optimizer = torch.optim.Adam(realNVP_model.parameters(), lr=0.001)
train_loader = torch.utils.data.DataLoader(X_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(X_test, batch_size=batch_size)
train(realNVP_model, optimizer, train_loader, test_loader, num_epochs)

After epoch 0 loss is 1.6879110326766968 and validation loss is 1.4688535928726196
After epoch 1 loss is 1.4675629959106444 and validation loss is 1.471981406211853
After epoch 2 loss is 1.4534997371673584 and validation loss is 1.4218767881393433
After epoch 3 loss is 1.439774608230591 and validation loss is 1.42205810546875
After epoch 4 loss is 1.4411784254074096 and validation loss is 1.4283244609832764
After epoch 5 loss is 1.4335124338150025 and validation loss is 1.4138290882110596
After epoch 6 loss is 1.4266038166046142 and validation loss is 1.402419924736023
After epoch 7 loss is 1.4213792261123657 and validation loss is 1.4147553443908691
After epoch 8 loss is 1.4250451070785521 and validation loss is 1.3974970579147339
After epoch 9 loss is 1.4245886934280396 and validation loss is 1.385642409324646


In [186]:
showModelPoints(realNVP_model, 5000)

In [187]:
showModelDensity(realNVP_model)