In [23]:
!pip3 install --upgrade tqdm

Collecting tqdm
[?25l  Downloading https://files.pythonhosted.org/packages/bb/62/6f823501b3bf2bac242bd3c320b592ad1516b3081d82c77c1d813f076856/tqdm-4.39.0-py2.py3-none-any.whl (53kB)
[K     |████████████████████████████████| 61kB 1.9MB/s 
[?25hInstalling collected packages: tqdm
  Found existing installation: tqdm 4.28.1
    Uninstalling tqdm-4.28.1:
      Successfully uninstalled tqdm-4.28.1
Successfully installed tqdm-4.39.0


In [0]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import Adam
from torch.nn import NLLLoss, CrossEntropyLoss
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

import plotly.offline as py
import plotly.graph_objs as go
import plotly.express as px
import plotly.figure_factory as ff

In [0]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# torch.cuda.get_device_name(0)
device = 'cpu'

In [0]:
def sample_data():
    count = 100000
    rand = np.random.RandomState(0)
    a = [[-1.5, 2.5]] + rand.randn(count // 3, 2) * 0.2
    b = [[1.5, 2.5]] + rand.randn(count // 3, 2) * 0.2
    c = np.c_[2 * np.cos(np.linspace(0, np.pi, count // 3)),
    -np.sin(np.linspace(0, np.pi, count // 3))]

    c += rand.randn(*c.shape) * 0.2
    data_x = np.concatenate([a, b, c], axis=0)
    data_y = np.array([0] * len(a) + [1] * len(b) + [2] * len(c))
    perm = rand.permutation(len(data_x))
    return data_x[perm], data_y[perm]

In [0]:
samples_x, samples_y = sample_data()

In [0]:
train_input, dev_input = train_test_split(samples_x, random_state=1, test_size=0.2)
train_input, dev_input = torch.tensor(train_input), torch.tensor(dev_input)

In [0]:
batch_size = 64

train_data = TensorDataset(train_input)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

dev_data = TensorDataset(dev_input)
dev_sampler = SequentialSampler(dev_data)
dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=dev_input.shape[0])

In [0]:
class AutoregressiveFlow(nn.Module):
    def __init__(self, k: int):
        super().__init__()
        self.k = k
        self.pi1 = nn.Parameter(torch.randn(k))
        self.mu1 = nn.Parameter(torch.randn(k))
        self.sigma1 = nn.Parameter(torch.randn(k))

        # self.pi2 = nn.Parameter(torch.randn(k))
        # self.mu2 = nn.Parameter(torch.randn(k))
        # self.sigma2 = nn.Parameter(torch.randn(k))

        self.pi2 = nn.Linear(1, k)
        self.mu2 = nn.Linear(1, k)
        self.sigma2 = nn.Linear(1, k)

        self.params = [self.pi1, 
                       self.mu1, 
                       self.sigma1, 
                       *self.pi2.parameters(), 
                       *self.mu2.parameters(), 
                       *self.sigma2.parameters()]
                       
        for param in [self.pi2, self.mu2, self.sigma2]:
            nn.init.xavier_uniform(param.weight)

    def forward(self, X):
        x1, x2 = X[:,0].float(), X[:,1].float()

        normal1 = [Normal(mu, sigma) for mu, sigma in zip(self.mu1, self.sigma1.exp())]
        d1 = torch.stack([distr1.log_prob(x1).exp() for distr1 in normal1])
        result1 = (F.softmax(self.pi1, dim=0) @ d1).log()
        
        x1 = x1.reshape(-1,1)

        d2 = []
        for xi, mu2, sigma2 in zip(x2, self.mu2(x1), self.sigma2(x1)):
            d2.append(Normal(mu2, sigma2.exp()).log_prob(xi).exp())
        d2 = torch.stack(d2)
        result2 = (F.softmax(self.pi2(x1), dim=1) * d2).sum(dim=1).log()

        return - result1 - result2

    def parameters(self):
        return self.params

    def apply(self, X):
        x1, x2 = X[:,0].float(), X[:,1].float()

        normal1 = [Normal(mu, sigma) for mu, sigma in zip(self.mu1, self.sigma1.exp())]
        f1 = torch.stack([distr1.cdf(x1) for distr1 in normal1])
        result1 = F.softmax(self.pi1, dim=0) @ f1

        x1 = x1.reshape(-1,1)

        f2 = []
        for xi, mu2, sigma2 in zip(x2, self.mu2(x1), self.sigma2(x1)):
            f2.append(Normal(mu2, sigma2.exp()).cdf(xi))
        f2 = torch.stack(f2)
        result2 = (F.softmax(self.pi2(x1), dim=1) * f2).sum(dim=1)

        return torch.stack([result1, result2])

In [0]:
model = AutoregressiveFlow(10)
optimizer = Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999), eps=1e-8)


nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.



In [0]:
epochs = 25

train_nlll, dev_nlll = [], []
for epoch in tqdm(epochs):
    print(f'epoch={epoch}', end='')

    train_nlll.append(0)
    for i, batch in enumerate(train_dataloader):
        X = batch[0].to(device)
        optimizer.zero_grad()

        loss = torch.mean(model(X))
        train_nlll[-1] += loss.item()
        loss.backward()

        optimizer.step()

    dev_nlll.append(0)
    for i, batch in enumerate(dev_dataloader):
        X = batch[0].to(device)

        with torch.no_grad():
            loss = torch.mean(model(X))
            dev_nlll[-1] += loss.item()
    
    train_nlll[-1] /= len(train_dataloader)
    dev_nlll[-1] /= len(dev_dataloader)
    print(f', train_loss={train_nlll[-1]:.5f}, dev_loss={dev_nlll[-1]:.5f}')

  0%|          | 0/25 [00:00<?, ?it/s]

epoch=0

  4%|▍         | 1/25 [00:32<13:03, 32.63s/it]

, train_loss=3.14386, dev_loss=2.27210
epoch=1

  8%|▊         | 2/25 [01:05<12:31, 32.67s/it]

, train_loss=2.04496, dev_loss=1.90435
epoch=2

 12%|█▏        | 3/25 [01:38<11:58, 32.67s/it]

, train_loss=1.81547, dev_loss=1.77515
epoch=3

 16%|█▌        | 4/25 [02:10<11:27, 32.75s/it]

, train_loss=1.73136, dev_loss=1.72085
epoch=4

 20%|██        | 5/25 [02:43<10:55, 32.77s/it]

, train_loss=1.68518, dev_loss=1.68161
epoch=5

 24%|██▍       | 6/25 [03:17<10:25, 32.91s/it]

, train_loss=1.64896, dev_loss=1.64885
epoch=6

 28%|██▊       | 7/25 [03:50<09:53, 32.98s/it]

, train_loss=1.61752, dev_loss=1.61883
epoch=7

 32%|███▏      | 8/25 [04:23<09:22, 33.08s/it]

, train_loss=1.59006, dev_loss=1.59428
epoch=8

 36%|███▌      | 9/25 [04:56<08:49, 33.10s/it]

, train_loss=1.56846, dev_loss=1.57605
epoch=9

 40%|████      | 10/25 [05:29<08:16, 33.13s/it]

, train_loss=1.55250, dev_loss=1.56247
epoch=10

 44%|████▍     | 11/25 [06:03<07:44, 33.16s/it]

, train_loss=1.54027, dev_loss=1.55080
epoch=11

 48%|████▊     | 12/25 [06:36<07:11, 33.17s/it]

, train_loss=1.53025, dev_loss=1.54145
epoch=12

 52%|█████▏    | 13/25 [07:09<06:39, 33.25s/it]

, train_loss=1.52183, dev_loss=1.53358
epoch=13

 56%|█████▌    | 14/25 [07:42<06:05, 33.26s/it]

, train_loss=1.51478, dev_loss=1.52687
epoch=14

 60%|██████    | 15/25 [08:16<05:32, 33.25s/it]

, train_loss=1.50874, dev_loss=1.52036
epoch=15

 64%|██████▍   | 16/25 [08:49<04:59, 33.30s/it]

, train_loss=1.50341, dev_loss=1.51566
epoch=16

 68%|██████▊   | 17/25 [09:22<04:26, 33.29s/it]

, train_loss=1.49868, dev_loss=1.51069
epoch=17

 72%|███████▏  | 18/25 [09:56<03:52, 33.28s/it]

, train_loss=1.49434, dev_loss=1.50613
epoch=18

 76%|███████▌  | 19/25 [10:29<03:19, 33.20s/it]

, train_loss=1.49063, dev_loss=1.50232
epoch=19

 80%|████████  | 20/25 [11:02<02:45, 33.16s/it]

, train_loss=1.48696, dev_loss=1.49912
epoch=20

 84%|████████▍ | 21/25 [11:35<02:12, 33.18s/it]

, train_loss=1.48385, dev_loss=1.49582
epoch=21

 88%|████████▊ | 22/25 [12:08<01:39, 33.20s/it]

, train_loss=1.48077, dev_loss=1.49201
epoch=22

 92%|█████████▏| 23/25 [12:41<01:06, 33.22s/it]

, train_loss=1.47780, dev_loss=1.48934
epoch=23

 96%|█████████▌| 24/25 [13:15<00:33, 33.22s/it]

, train_loss=1.47499, dev_loss=1.48680
epoch=24

100%|██████████| 25/25 [13:48<00:00, 33.26s/it]

, train_loss=1.47245, dev_loss=1.48397





In [0]:
def plot_loss(train_loss, dev_loss):
    x = np.arange(len(train_loss)) + 1
    traces = [
        go.Scatter(x=x, y=train_loss, mode='lines', name='train loss'),
        go.Scatter(x=x, y=dev_loss, mode='lines', name='dev loss')
    ]
    figure = go.Figure(data=traces)
    figure.update_layout(title='losses')
    figure.show()


def plot_points(data_x, data_y):
    data_x = data_x.detach().numpy()[:,:10000]
    data_y = data_y[:10000]

    x0, x1, y = data_x[0], data_x[1], data_y

    points = []
    for c in range(3):
        scatter = go.Scatter(x=x0[y == c], y=x1[y == c], mode='markers', name=f'{c}')
        points.append(scatter)

    figure = go.Figure(data=points)
    figure.update_layout(title='latents', width=1000, height=1000)
    figure.show()

    figure = ff.create_2d_density(x0, x1, hist_color='rgb(255, 237, 222)', point_size=3)
    figure.update_layout(title='density plot', height=1000, width=1000)
    figure.show()


def plot_density(model):
    x, y = np.linspace(-4, 4, 200), np.linspace(-4, 4, 200)
    xv, yv = np.meshgrid(x, y)

    points = torch.tensor([(xx, yy) for xx, yy in zip(xv.flatten(), yv.flatten())])
    density = np.exp(-model(points).detach().numpy())
    
    z = density.reshape((200, 200))
    figure = go.Figure(data=go.Heatmap(x=x, y=y, z=z))
    figure.update_layout(title='points density', height=1000, width=1000)
    figure.show()

In [0]:
data_x, data_y = torch.tensor(samples_x), samples_y
data_x = model.apply(data_x)
plot_loss(train_nlll, dev_nlll)
plot_points(data_x, data_y)
plot_density(model)

In [0]:
class RealNVP(nn.Module):
    def __init__(self, net1, net2, n):
        super(RealNVP, self).__init__()
        
        self.n = n
        self.mask = nn.Parameter(torch.Tensor([[0, 1], [1, 0]] * n), requires_grad=False)
        self.s = torch.nn.ModuleList([net1() for _ in range(2 * n)])
        self.t = torch.nn.ModuleList([net2() for _ in range(2 * n)])

        self.params = []
        for s in self.s:
            self.params += s.parameters()
        for t in self.t:
            self.params += t.parameters()
      
    def forward(self, x):
        x = x.float()
        log_det_J = torch.zeros(x.shape[0]).float()

        for i in range(2 * self.n):
            x_ = self.mask[i] * x.float()
            s = self.s[i](x_)
            t = self.t[i](x_)
            x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
            log_det_J += (s * (1 - self.mask[i])).sum(dim=1)
        x = torch.sigmoid(x)
        return x, torch.abs(log_det_J + torch.log((x * (1 - x) + 1e-8).sum(dim=1)))
        
    def parameters(self):
        return self.params

In [0]:
net1 = lambda: nn.Sequential(nn.Linear(2, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 2), nn.Tanh())
net2 = lambda: nn.Sequential(nn.Linear(2, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 2))

In [0]:
model = RealNVP(net1, net2, 3)
optimizer = Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999), eps=1e-8)

In [73]:
epochs = 3

train_nlll, dev_nlll = [], []
for epoch in tqdm(range(epochs)):
    print(f'epoch={epoch}', end='')

    train_nlll.append(0)
    for i, batch in enumerate(train_dataloader):
        X = batch[0].to(device)
        optimizer.zero_grad()

        Z, loss = model(X)
        loss = torch.mean(loss)
        train_nlll[-1] += loss.item()
        loss.backward()

        optimizer.step()

    dev_nlll.append(0)
    for i, batch in enumerate(dev_dataloader):
        X = batch[0].to(device)

        with torch.no_grad():
            Z, loss = model(X)
            loss = torch.mean(loss)
            dev_nlll[-1] += loss.item()
    
    train_nlll[-1] /= len(train_dataloader)
    dev_nlll[-1] /= len(dev_dataloader)
    print(f', train_loss={train_nlll[-1]:.10f}, dev_loss={dev_nlll[-1]:.10f}')

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

epoch=0, train_loss=0.0249873076, dev_loss=0.0091702072
epoch=1, train_loss=0.0079918285, dev_loss=0.0041614645
epoch=2, train_loss=0.0070780415, dev_loss=0.0043306192



In [0]:
data_x, data_y = torch.tensor(samples_x), samples_y
data_x, _ = model(data_x)
# plot_loss(train_nlll, dev_nlll)
# plot_points(data_x.T, data_y)
# plot_density(model)

In [63]:
print(data_x)

tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        ...,
        [1., 1.],
        [1., 1.],
        [1., 1.]], grad_fn=<SigmoidBackward>)
