## Normalizing Flows

Your task is to implement a version of Normalizing Flow for image generation. Our implementation will be based on RealNVP (https://arxiv.org/pdf/1605.08803.pdf) and we will be training on one class from MNIST. Your task is to read the paper in details and implement simple version of the algorithm from the paper:


1. Implement simple CouplingLayers (see RealNVP paper) with neural networks using a few fully connected layers with hidden activations of your choice. More on the CouplingLayers can be also found in https://arxiv.org/pdf/1410.8516.pdf. Remember to implement properly logarithm of a Jacobian determinant calculation. Implement only single scale architecture, ignore multiscale architecture with masked convolution and batch normalization. (2 points)
2. Implement RealNVP class combining many CouplingLayers with proper masking pattern (rememeber to alternate between unmodified pixels) with forward and inverse flows. (1 points)
3. Implement a loss function `nf_loss` (data log-likelihood) for the model. Hint: check `torch.distributions` (1 point)
4. Train your model to achieve good looking samples (similar to training set - similar to that appended to assignmenmt on moodle). The training process should take between 5-10 minutes. (2 points)
5. Sample from your model and pick 2 images (as visually different as possible) from your samples and plot 10 images that correspond to equally spaced linear interpolations in latent space between those images you picked. (1 point)
6. Use method from section 5.2 from https://arxiv.org/pdf/1410.8516.pdf with trained model and inpaint 5 sampled images with different random parts of your image occluded (50% of the image must be occluded). (2 point)
7. Write a report describing your solution, add loss plots and samples from the model. Write which hyperparameter sets worked for you and which did not. (1 point)

In [27]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [29]:
def plot(samples):
    length = len(samples)
    fig, ax = plt.subplots(1, length, figsize=(2*length, 2))
    fig.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    for j in range(length):
        ax[j].imshow(samples[j].cpu().numpy())
        ax[j].axis('off')
    plt.show()

def get_mask1d(length=784, odd=True):
    mask = torch.zeros(length).to(device)
    for i in range(length):
        if i % 2 == int(odd):
            mask[i] = 1
    return mask


class CombScaleTranslationModule(nn.Module):
    def __init__(self, in_features, depth=2):
        super().__init__()

        self.depth = depth

        self.layers = nn.ModuleList(
            [
             nn.Linear(in_features, 2*in_features) if i == 0 else nn.Linear(2*in_features, 2*in_features)
             for i in range(depth)
            ]
        )
        
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i != self.depth-1:
                x = self.relu(x)

        scale, translation = x.chunk(2, 1)
        return self.tanh(scale), translation


class Coupling_layer(nn.Module):
    odd_mask = get_mask1d()

    def __init__(self, *args, **kwargs):
        super().__init__()

        # based on layer_id mask odd and even values will oscillate
        self.layer_id = args[0]
        self.depth = 2
        if len(args) > 1:
            self.depth = args[1]

        self.mask = self.odd_mask if self.layer_id % 2 == 1 else 1 - self.odd_mask

        # neural nets don't change dimentions
        self.nn_comb_scale_translation = CombScaleTranslationModule(784, self.depth)

        # learnable parameters for scale to reduce NaNs
        # self.scale_par = nn.Parameter(torch.ones(1), requires_grad=True)
 
    def forward_flow(self, x):
        untouched_x = self.mask * x # part of the vector that will be passed on unmodified
        
        scale, translation = self.nn_comb_scale_translation(untouched_x)

        # scale = self.scale_par * scale
        # affine transformation
        y = (1 - self.mask) * (x * torch.exp(scale) + translation)
        return (
            untouched_x + y, # we merge vectors that were split using masks
            ((1 - self.mask) * scale).sum(-1) # logarithm of jacobian determinant
            )

    def inverse_flow(self, z):
        untouched_x = self.mask * z # part of the vector that was passed on umodified

        scale, translation = self.nn_comb_scale_translation(untouched_x)

        # scale = self.scale_par * scale
        # inverse affine transformation
        touched_x = (1 - self.mask) * (z - translation) * torch.exp(-scale)
        return untouched_x + touched_x # we merge vectors that were split using masks
    
    
class RealNVP(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

        self.layer_amount = args[0]
        self.coupling_layer_depth = 2
        if len(args) > 1:
            self.coupling_layer_depth = args[1]
        self.coupling_layers = nn.ModuleList(
            [
             Coupling_layer(idx, self.coupling_layer_depth)
             for idx in range(self.layer_amount)
            ]
        )
    
    def forward_flow(self, x):
        log_det_j = torch.zeros(x.shape[0])

        for coupling_layer in self.coupling_layers:
            x, log_det_j = coupling_layer.forward_flow(x)
            log_det_j += log_det_j

        return x, log_det_j
    
    def inverse_flow(self, z):
        for coupling_layer in self.coupling_layers:
            z = coupling_layer.inverse_flow(z)

        return z
        
        
def nf_loss(z, logdetJ):
    pi = torch.tensor(np.pi).to(device)
    # minimalizuje wyrażenie z minusem czyli maksymalizuje log likelihood

    # tutaj to poprawić z tego rozkładu z pytorch
    m = torch.distributions.normal.Normal(torch.tensor([0.0]).to(device), torch.tensor([1.0]).to(device))

    # return -(torch.log(0.5 * torch.log(2 * pi) + 0.5 * z**2).sum(1) + logdetJ).mean()
    return -(m.log_prob(z).sum(1) + logdetJ).mean()

In [30]:
dataset = torchvision.datasets.MNIST(root=r'./mnist/', 
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)
x = (dataset.data.float() / 255. - 0.5)
y = dataset.targets
x = x[y == 5]

dataloader = DataLoader(x, batch_size=128, shuffle=True)

In [34]:
n_epochs = 3000
lr = 5e-4
eps = 1e-3

model = RealNVP(10, 3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=eps)

In [35]:
sample_noise = torch.randn(5,28,28).to(device)

for i in range(n_epochs):
    model.train()
    loss_acc = 0
    for j, x in enumerate(dataloader):
        optimizer.zero_grad()
        x = (x.float() + torch.randn(x.shape) / 64.).to(device)
        x = x.view(-1, 28**2)
        z, logdetJ = model.forward_flow(x)
        loss = nf_loss(z, logdetJ)
        loss_acc += loss.item()
        loss.backward()
        optimizer.step()
    
    if i%10 == 0:
        print(f'Epoch: {i + 1}/{n_epochs} Loss: {(loss_acc / (j+1)):.4f}')
        with torch.no_grad():
            model.eval()
            sample_noise = sample_noise.view(-1, 28**2)
            samples = model.inverse_flow(sample_noise)
            plot(samples.view(-1, 28, 28))

Output hidden; open in https://colab.research.google.com to view.

In [None]:
test_img = z.view(-1, 28, 28)
plot(test_img.detach())