# Homework 1. Likelihood-based models

- Seminar (10 points): Autoregressive Transformer
- **Task 1 (10 points): PixelCNN**
    - Unconditional (5 points)
    - **Conditional (5 points)**
- Task 2 (10 points): RealNVP
- \* Bonus (10+++ points)

## Task 1.2. Conditional PixelCNNs on Shapes and MNIST

In this part, implement and train a **class-conditional** PixelCNN on binary MNIST. Condition on a class label by adding a conditional bias in each convolutional layer. More precisely, in the $\ell$th convolutional layer, compute: 
$$W_\ell * x + b_\ell + V_\ell y$$
where $W_\ell * x + b_\ell$ is a masked convolution (as in previous parts), $V$ is a 2D weight matrix, and $y$ is a one-hot encoding of the class label (where the conditional bias is broadcasted spacially and added channel-wise). You may need `nn.Embedding` here

You can use a PixelCNN architecture similar to Task 1.1. Training on the shapes dataset should be quick, and MNIST should take around 10-15 minutes

Feel free to use and modify any sources from Task 1.1

**You will provide these deliverables**


1.   Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves. 
2.   Report the final test set performance of your final model
3. 100 samples from the final trained model

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import math
from sklearn.model_selection import train_test_split
import random

%matplotlib inline

In [None]:
import pickle
from torchvision.utils import make_grid


def show_samples(samples, nrow=10, title='Samples'):
    samples = (torch.FloatTensor(samples)).permute(0, 3, 1, 2)
    grid_img = make_grid(samples, nrow=nrow)
    plt.figure()
    plt.title(title)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.axis('off')

    plt.show()
        

def load_data(fname, binarize=True, include_labels=False):
    with open(fname, 'rb') as data_file:
        data = pickle.load(data_file)
    
    if include_labels:
        return (data['train'] > 127.5), (data['test'] > 127.5), data['train_labels'], data['test_labels']
    
    return (data['train'] > 127.5), (data['test'] > 127.5)


class SimpleDataset(Dataset):
    def __init__(self, X, y):
        super().__init__()
        self.X = X
        self.y = y
        
        assert len(X) == len(y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.y[index]

In [None]:
# For colab users: download data
# ! wget https://github.com/a4-edu/course_gmcv/raw/hw1/module1-likelihood/shapes.pkl
# ! wget https://github.com/a4-edu/course_gmcv/raw/hw1/module1-likelihood/mnist.pkl

In [None]:
########################
# MANY OF YOUR CODE HERE
########################

In [None]:
def train_model(train_data, train_labels, test_data, test_labels, num_classes, model):
    """
    train_data: A (n_train, H, W, 1) uint8 numpy array of binary images with values in {0, 1}
    train_labels: A (n_train, 1) int64 numpy array or image labels with values in {0, 1, ..., num_classes - 1}
    test_data: A (n_test, H, W, 1) uint8 numpy array of binary images with values in {0, 1}
    test_labels: A (n_test, 1) int64 numpy array or image labels with values in {0, 1, ..., num_classes - 1}
    num_classes: int
    model: nn.Model item
    Returns
    - a (# of training iterations,) numpy array of train_losses evaluated every minibatch
    - a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
    - trained model
    """
    ################
    # YOUR CODE HERE
    ###############

### First dataset: **Shapes**

In [None]:
shapes_train, shapes_test, shapes_labels_train, shapes_labels_test = \
    load_data('./shapes.pkl', include_labels=True)

In [None]:
show_samples(shapes_train[:100])

In [None]:
num_classes = shapes_labels_train.max() + 1
num_classes

In [None]:
H, W, _ = shapes_train[0].shape
model = ... 
train_losses, test_losses, shapes_model = train_model(
    shapes_train, shapes_labels_train,
    shapes_test, shapes_labels_test,
    num_classes,
    model)

In [None]:
def show_train_plots(train_losses, test_losses, title):
    plt.figure()
    n_epochs = len(test_losses) - 1
    x_train = np.linspace(0, n_epochs, len(train_losses))
    x_test = np.arange(n_epochs + 1)

    plt.plot(x_train, train_losses, label='train loss')
    plt.plot(x_test, test_losses, label='test loss')
    plt.legend()
    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel('NLL')
    plt.show()

In [None]:
show_train_plots(train_losses, test_losses, 'Shapes')

In [None]:
labels = [0] * 25 + [1] * 25 + [2] * 25 + [3] * 25

In [None]:
samples = shapes_model.sample(100, torch.tensor(labels, dtype=torch.long))
show_samples(samples)

### Second dataset: MNIST

In [None]:
mnist_train, mnist_test, mnist_labels_train, mnist_labels_test = \
    load_data('./mnist.pkl', include_labels=True)

In [None]:
show_samples(mnist_train[:100])

In [None]:
num_classes = mnist_labels_train.max() + 1
num_classes

In [None]:
H, W, _ = mnist_train[0].shape
model = ...
train_losses, test_losses, mnist_model = train_model(
    mnist_train, mnist_labels_train,
    mnist_test, mnist_labels_test,
    num_classes,
    model)

In [None]:
show_train_plots(train_losses, test_losses, 'MNIST')

In [None]:
labels = np.array(list(range(10)) * 10).reshape(10, 10).T

In [None]:
samples = mnist_model.sample(100, torch.tensor(labels, dtype=torch.long))
show_samples(samples)