# Homework 1. Likelihood-based models

- Seminar (10 points): Autoregressive Transformer
- **Task 1 (10 points): PixelCNN**
    - **Unconditional (5 points)**
    - Conditional (5 points)
- Task 2 (10 points): RealNVP
- \* Bonus (10+++ points)

## Task 1.1 PixelCNN on Shapes and MNIST

In this part, implement a simple PixelCNN architecture to model binary MNIST and shapes images

Recap:

$$Mask_a
=
\begin{bmatrix}
1 & 1 & 1 \\
1 & 0 & 0 \\
0 & 0 & 0 \\
\end{bmatrix}$$

$$Mask_b
=
\begin{bmatrix}
1 & 1 & 1 \\
1 & 1 & 0 \\
0 & 0 & 0 \\
\end{bmatrix}$$

We recommend the following network design:
* A $7 \times 7$ masked type A convolution
* $5$ $7 \times 7$ masked type B convolutions
* $2$ $1 \times 1$ masked type B convolutions
* Appropriate nonlinearities in-between
* 64 convolutional filters
* Use normalization carefully: remember about autoregressive property. LayerNorm on channels dimension is definitely OK

And the following hyperparameters:
* Batch size 128
* Learning rate $10^{-3}$
* 10 epochs
* AdamW Optimizer (this applies to all PixelCNN models trained in future parts)

Your model should output logits, after which you could apply a sigmoid over 1 logit, or a softmax over two logits (either is fine). It may also help to scale your input to $[-1, 1]$ before running it through the network. 

Training on the shapes dataset should be quick, and MNIST should take around 10 minutes

**You will provide these deliverables**

1.   Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves. 
2.   Report the final test set performance of your final model
3. 100 samples from the final trained model


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import math
from sklearn.model_selection import train_test_split
import random

%matplotlib inline

In [None]:
import pickle
from torchvision.utils import make_grid


def show_samples(samples, nrow=10, title='Samples'):
    samples = (torch.FloatTensor(samples)).permute(0, 3, 1, 2)
    grid_img = make_grid(samples, nrow=nrow)
    plt.figure()
    plt.title(title)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.axis('off')

    plt.show()


def load_data(fname, binarize=True, include_labels=False):
    with open(fname, 'rb') as data_file:
        data = pickle.load(data_file)
    
    if include_labels:
        return (data['train'] > 127.5), (data['test'] > 127.5), data['train_labels'], data['test_labels']
    
    return (data['train'] > 127.5), (data['test'] > 127.5)


class SimpleDataset(Dataset):
    def __init__(self, X, y):
        super().__init__()
        self.X = X
        self.y = y
        
        assert len(X) == len(y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.y[index]

First of all, we need to modify `Conv2d` with masking

In [None]:
class MaskedConv2D(nn.Conv2d):
    def __init__(self, mask_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_buffer('mask', torch.zeros_like(self.weight))
        self.create_mask(mask_type)
    
    def forward(self, x):
        """
        x: (N, C_in, H_in, W_in) torch.Tensor
        Returns
          - out (N, C_out, H_out, W_out) should be conv2d(x, weight * mask) + bias 
        """
        ################
        # YOUR CODE HERE
        ###############
    
    def create_mask(self, mask_type):
        assert mask_type == 'A' or mask_type == 'B'
        k = self.kernel_size[0]
        ################
        # YOUR CODE HERE
        ###############

In [None]:
convA = MaskedConv2D('A', 1, 1, kernel_size=3)
convB = MaskedConv2D('B', 1, 1, kernel_size=3)
assert np.allclose(convA.mask.view(-1), [1., 1., 1., 1., 0., 0., 0., 0., 0.], atol=1e-6)
assert np.allclose(convB.mask.view(-1), [1., 1., 1., 1., 1., 0., 0., 0., 0.], atol=1e-6)

In [None]:
class PixelCNN(nn.Module):
    def __init__(self, input_shape, n_colors=2, n_filters=64,
               kernel_size=7, n_layers=5):
        super().__init__()
        assert n_layers >= 2
        n_channels = input_shape[0]
        
        self.input_shape = input_shape
        self.n_channels = n_channels
        self.n_colors = n_colors
        
        ################
        # YOUR CODE HERE
        ###############
        
    def forward(self, x, cond=None):
        batch_size = x.shape[0]
        x = (x.float() / (self.n_colors - 1) - 0.5) / 0.5
        ################
        # YOUR CODE HERE
        ###############

    def loss(self, x):
        return F.cross_entropy(self.forward(x), x.long())

    def sample(self, n):
        samples = torch.zeros(n, *self.input_shape).cuda()
        with torch.no_grad():
            for r in range(self.input_shape[1]):
                for c in range(self.input_shape[2]):
                    for k in range(self.n_channels):
                        logits = self.forward(samples)[:, :, k, r, c]
                        probs = F.softmax(logits, dim=1)
                        samples[:, k, r, c] = torch.multinomial(probs, 1).squeeze(-1)
        return samples.permute(0, 2, 3, 1).cpu().numpy()

Feel free to use and modify this train loop. You may want to show some logs or sampling results during training

In [None]:
def train(model, train_loader, optimizer):
    model.train()
    train_losses = []
    for x, _ in train_loader:
        x = x.cuda()
        loss = model.loss(x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
    return train_losses


def eval_loss(model, data_loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, _ in data_loader:
            x = x.cuda()
            loss = model.loss(x)
            total_loss += loss * x.shape[0]
        avg_loss = total_loss / len(data_loader.dataset)
    return avg_loss.item()


def train_epochs(model, train_loader, test_loader, train_args):
    epochs, lr = train_args['epochs'], train_args['lr']
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    train_losses = []
    test_losses = [eval_loss(model, test_loader)]
    for epoch in range(epochs):
        print(f'epoch {epoch} started')
        model.train()
        train_losses.extend(train(model, train_loader, optimizer))
        test_loss = eval_loss(model, test_loader)
        test_losses.append(test_loss)
        print('train loss: {}, test_loss: {}'.format(np.mean(train_losses[-1000:]), 
                                                     test_losses[-1]))

    return train_losses, test_losses


def train_model(train_data, test_data, model):
    """
    train_data: A (n_train, H, W, 1) uint8 numpy array of binary images with values in {0, 1}
    test_data: A (n_test, H, W, 1) uint8 numpy array of binary images with values in {0, 1}
    model: nn.Model item, should contain function loss and accept
    Returns
    - a (# of training iterations,) numpy array of train_losses evaluated every minibatch
    - a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
    - trained model
    """
    ################
    # YOUR CODE HERE
    ###############

### First dataset: **Shapes**

In [None]:
# For colab users: download file
# ! wget https://github.com/a4-edu/course_gmcv/raw/hw1/module1-likelihood/shapes.pkl

In [None]:
shapes_train, shapes_test = load_data('./shapes.pkl')

In [None]:
show_samples(shapes_train[:100])

In [None]:
H, W, _ = shapes_train[0].shape
model = PixelCNN((1, H, W))
train_losses, test_losses, shapes_model = train_model(shapes_train, shapes_test, model)

In [None]:
def show_train_plots(train_losses, test_losses, title):
    plt.figure()
    n_epochs = len(test_losses) - 1
    x_train = np.linspace(0, n_epochs, len(train_losses))
    x_test = np.arange(n_epochs + 1)

    plt.plot(x_train, train_losses, label='train loss')
    plt.plot(x_test, test_losses, label='test loss')
    plt.legend()
    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel('NLL')
    plt.show()

In [None]:
show_train_plots(train_losses, test_losses, 'Shapes')

In [None]:
samples = shapes_model.sample(100)
show_samples(samples)

### Second dataset: **MNIST**

In [None]:
# For colab users: download file
# ! wget https://github.com/a4-edu/course_gmcv/raw/hw1/module1-likelihood/shapes.pkl

In [None]:
mnist_train, mnist_test = load_data('./mnist.pkl')

In [None]:
show_samples(mnist_train[:100])

In [None]:
H, W, _ = mnist_train[0].shape
model = PixelCNN((1, H, W))
train_losses, test_losses, mnist_model = train_model(mnist_train, mnist_test, model)

In [None]:
show_train_plots(train_losses, test_losses, 'MNIST')

In [None]:
samples = mnist_model.sample(100)
show_samples(samples)