In this notebook we will practice working with Normalizing Flows.

# Installations and imports

In [None]:
! pip install -q pytorch_lightning
! pip install -q torchvision

In [None]:
import torch
from torch import nn
from torch import distributions
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.uniform import Uniform
from torch.distributions.transforms import SigmoidTransform
from torch.distributions.transforms import AffineTransform
from torch.utils.data import DataLoader, random_split

from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl

import numpy as np
from pylab import rcParams
from sklearn import datasets

import matplotlib.pyplot as plt
%matplotlib inline

rcParams['figure.figsize'] = 10, 8
rcParams['figure.dpi'] = 300

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

# NICE implementation for MNIST

As you remember, we discussed Coupling Flows and, in particular, the NICE (Non-linear Independent Component Estimation) model. The transformation in the formulas looked as follows:

$$y^A = m(x^B) + x^A$$

$$y^B = x^B$$

Let's put such a model on our precious MNIST dataset.

In [None]:
# Data loaders
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=32, num_workers=4)
val_loader = DataLoader(mnist_val, batch_size=32, num_workers=4)

In [None]:
class StandardLogisticDistribution:

    def __init__(self, data_dim=28 * 28, device='cpu'):
        self.m = TransformedDistribution(
            Uniform(torch.zeros(data_dim, device=device),
                    torch.ones(data_dim, device=device)),
            [SigmoidTransform().inv, AffineTransform(torch.zeros(data_dim, device=device),
                                                     torch.ones(data_dim, device=device))]
        )

    def log_pdf(self, z):
        return self.m.log_prob(z).sum(dim=1)

    def sample(self):
        return self.m.sample()

In [None]:
class NICE(pl.LightningModule):
    def __init__(self, distribution, data_dim=28 * 28, hidden_dim=1000, n_transformations=4):
        super().__init__()
        

        self.distribution = distribution
        self.n_transformations = n_transformations

        # NN-transformations
        self.m = torch.nn.ModuleList([nn.Sequential(
            nn.Linear(data_dim // 2, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, data_dim // 2)
        ) for i in range(n_transformations)])

        self.s = torch.nn.Parameter(torch.randn(data_dim))
        
        # we will alternate the indices to transform,
        # that is, one half of the input will be transformed in an even transformation
        # and the other half in an odd transformation
        self.idxs_even = np.full(data_dim, False)
        self.idxs_even[::2] = True
        self.idxs_odd = np.full(data_dim, False)
        self.idxs_odd[1::2] = True        
        
    def forward(self, x):
        
        _x = x.clone()
        
        # normalizing flow
        for i, m in enumerate(self.m):
            
            # split data
            idxs_a = self.idxs_even if (i % 2 == 0) else self.idxs_odd
            idxs_b = self.idxs_odd if (i % 2 == 0) else self.idxs_even
            _x_a = _x[:, idxs_a]
            _x_b = _x[:, idxs_b]
            
            # here is our formula from the lecture
            _y_a = _x_a + m(_x_b)
            _y_b = _x_b
            
            # trainsformation output
            _x = torch.empty(_x.shape, device=_x.device)
            _x[:, idxs_a] = _y_a
            _x[:, idxs_b] = _y_b
            
        # flow output
        y = torch.exp(self.s) * _x
        log_det_J = torch.sum(self.s)
            
        return y, log_det_J

    def invert(self, y):
        
        _y = y.clone() / torch.exp(self.s)
        
        # inversion
        for i in range(len(self.m) - 1, -1, -1):
            
            # split data
            idxs_a = self.idxs_even if (i % 2) == 0 else self.idxs_odd
            idxs_b = self.idxs_odd if (i % 2) == 0 else self.idxs_even

            _y_a = _y[:, idxs_a]
            _y_b = _y[:, idxs_b]
            
            # here is our formula from the lecture
            _x_a = _y_a - self.m[i](_y_b)
            _x_b = _y_b
            
            # make output
            _y = torch.empty(y.shape, device=y.device)
            _y[:, idxs_a] = _x_a
            _y[:, idxs_b] = _x_b
            
        # inverse output
        x = _y
            
        return x
    
    def compute_loss(self, batch, batch_idx):
        x, y = batch
        x = x.view(-1, 28 * 28)
        y, log_det_J = self.forward(x)
        log_likelihood = self.distribution.log_pdf(y) + log_det_J
        loss = -log_likelihood.sum()
        return loss
        
    def training_step(self, batch, batch_idx):
        loss = self.compute_loss(batch, batch_idx)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.compute_loss(batch, batch_idx)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        loss = self.compute_loss(batch, batch_idx)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [None]:
model = NICE(distribution=StandardLogisticDistribution(device=DEVICE))
model = model.to(DEVICE)

In [None]:
trainer_kwargs = {
    'gpus': 1,
    'max_epochs': 5,
    'precision': 16,
    'progress_bar_refresh_rate': 5,
    'weights_summary': "full"
}

In [None]:
trainer = pl.Trainer(**trainer_kwargs)
trainer.fit(model, train_loader, val_loader)

In [None]:
model.eval().to(DEVICE)
nb_data = 10
fig, axs = plt.subplots(nb_data, nb_data, figsize=(10, 10))
logistic_distribution = StandardLogisticDistribution(device=DEVICE)
for i in range(nb_data):
    for j in range(nb_data):
        x = model.invert(logistic_distribution.sample().unsqueeze(0)).data.cpu().numpy()
        axs[i, j].imshow(x.reshape(28, 28).clip(0, 1), cmap='gray')
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
plt.show()

# Task (6/6 points)

Make RealNVP. Recap:

* There are two neural networks: $s$ for scaling and $t$ for shifting;

* Forward:

$$ y^A = x^A \times \exp(s(x^B)) + t(x^B) $$

$$ y^B = x^B $$

* Jacobian determinant:

$$ \det(J) = \exp \Big( \sum_{j=1:k}^d s(x^B)_j \Big) $$

* Inverse:

$$ x^A = f^{-1}(y^A) = (y^A - t(x^B)) \times \exp(-s(x^B)) $$

$$ x^B = f^{-1}(y^B) = y^B $$

In [None]:
class RealNVP(NICE):
    def __init__(self, distribution, data_dim=28 * 28, hidden_dim=1000, n_transformations=4):
        super().__init__(distribution)

        self.distribution = distribution
        self.n_transformations = n_transformations

        # NN-transformations
        self.t = torch.nn.ModuleList([nn.Sequential(
            nn.Linear(data_dim // 2, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, data_dim // 2)
        ) for i in range(n_transformations)])
        self.s = torch.nn.ModuleList([nn.Sequential(
            nn.Linear(data_dim // 2, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, data_dim // 2), nn.Tanh()
        ) for i in range(n_transformations)])
        
    def forward(self, x):
        
        _x = x.clone()
        log_det_J = x.new_zeros(x.shape[0])
        
        # normalizing flow
        for i in range(len(self.s)):
            
            # split data
            idxs_a = self.idxs_even if (i % 2 == 0) else self.idxs_odd
            idxs_b = self.idxs_odd if (i % 2 == 0) else self.idxs_even
            _x_a = _x[:, idxs_a]
            _x_b = _x[:, idxs_b]
            
            # here is our formula from the lecture
            s = <<your code here>>
            t = <<your code here>>
            _y_a = <<your code here>>
            _y_b = <<your code here>>
            log_det_J += torch.sum(s, dim=1) # log(exp(x)) = x, so we use just sum
            
            # trainsformation output
            _x = torch.empty(_x.shape, device=_x.device)
            _x[:, idxs_a] = _y_a
            _x[:, idxs_b] = _y_b
            
        # flow output
        y = _x
            
        return y, log_det_J

    def invert(self, y):
        
        _y = y.clone()
        
        # inversion
        for i in range(len(self.s) - 1, -1, -1):
            
            # split data
            idxs_a = self.idxs_even if (i % 2) == 0 else self.idxs_odd
            idxs_b = self.idxs_odd if (i % 2) == 0 else self.idxs_even

            _y_a = _y[:, idxs_a]
            _y_b = _y[:, idxs_b]
            
            # here is our formula from the lecture
            s = <<your code here>>
            t = <<your code here>>
            _x_a = <<your code here>>
            _x_b = <<your code here>>
            
            # make output
            _y = torch.empty(y.shape, device=y.device)
            _y[:, idxs_a] = _x_a
            _y[:, idxs_b] = _x_b
            
        # inverse output
        x = _y
            
        return x

In [None]:
your_model = RealNVP(distribution=StandardLogisticDistribution(device=DEVICE))
your_model = your_model.to(DEVICE)

As we discussed in the lecture, normalizing flows are bijective transformations. Let us check this:

In [None]:
# test
y = logistic_distribution.sample().unsqueeze(0)
x = your_model.invert(y)
y_reconstructed = your_model.forward(x)[0]

In [None]:
assert torch.allclose(y, y_reconstructed, atol=1e-02)
print("you seem to have written the right code")

Let's train this:

In [None]:
trainer = pl.Trainer(**trainer_kwargs)
trainer.fit(your_model, train_loader, val_loader)

In [None]:
# another test
test_loss = trainer.test(your_model, val_loader)[0]['test_loss']
assert test_loss < -10000
print("it looks even more like you wrote the right code")

In [None]:
your_model.eval().to(DEVICE)
nb_data = 10
fig, axs = plt.subplots(nb_data, nb_data, figsize=(10, 10))
logistic_distribution = StandardLogisticDistribution(device=DEVICE)
for i in range(nb_data):
    for j in range(nb_data):
        x = your_model.invert(logistic_distribution.sample().unsqueeze(0)).data.cpu().numpy()
        axs[i, j].imshow(x.reshape(28, 28).clip(0, 1), cmap='gray')
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
plt.show()