# PixelRNN

Partially inspired by https://github.com/heechan95/PixelRNN-pytorch/blob/master/PixelRNN%20pytorch.ipynb

In [2]:
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T

In [3]:
# PARAMETERS
NUM_EPOCHS = 10
BATCH_SIZE = 16
HIDDEN_DIM = 16
LR = 1e-3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Masked Convolutions

In [None]:
class MaskedConvolution(nn.Module):
    def __init__(self,
                in_channels,
                out_channels,
                kernel_size=(3,3),
                mask_type='B'):
        super().__init__()
        
        # determine the mask
        assert mask_type in ['A', 'B']
        mask = torch.zeros(kernel_size)
        mask[: kernel_size[0] // 2, :] = 1
        if mask_type == 'A':
            mask[kernel_size[0] // 2, : kernel_size[1] // 2] = 1
        elif mask_type == 'B':
            mask[kernel_size[0] // 2, : kernel_size[1] // 2 + 1] = 1
        self.register_buffer('mask', mask)
        
        # add conv2d layer
        padding = tuple([(size-1)//2 for size in kernel_size])
        self.conv = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              padding=padding)
        
    def forward(self, x):
        with torch.inference_mode():
            self.conv.weight *= self.mask 
        return self.conv(x)

In [None]:
# test the MaskedConvolution layer
mc = MaskedConvolution(in_channels=1, out_channels=16, kernel_size=(3, 3), mask_type='B').to(DEVICE)
data = torch.randn(1, 1, 28, 28, device=DEVICE)
mc(data).shape

## RowLSMT

The LSTM cell receives one cell at a time.

In [None]:
class RowLSTMCell(nn.Module):
    
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.conv_s_s = nn.Conv1d(in_channels=hidden_dim, 
                                out_channels=4 * hidden_dim,
                                kernel_size=3,
                                padding=1)

    def forward(self, i_s, h_prev, c_prev):
        s_s = self.conv_s_s(h_prev)
        
        o, f, i, g = torch.split(i_s + s_s, self.hidden_dim, 1)
        o = torch.sigmoid(o)
        f = torch.sigmoid(f)
        i = torch.sigmoid(i)
        g = torch.tanh(g)
        
        c = f * c_prev + i * g
        h = o * torch.tanh(c)
        return h, c

In [None]:
# test Cell
row = torch.randn(1, 16, 28, device=DEVICE)
i_s = torch.randn(1, 16*4, 28, device=DEVICE)
prev_h = torch.randn(1, 16, 28, device=DEVICE)
prev_c = torch.randn(1, 16, 28, device=DEVICE)

cell = RowLSTMCell(16).to(DEVICE)
h, c = cell(i_s, prev_h, prev_c)
print(h.shape, c.shape)

In [None]:
class RowLSTM(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.cell = RowLSTMCell(hidden_dim)
        self.i_s = MaskedConvolution(in_channels=hidden_dim*2,
                                    out_channels=hidden_dim*4,
                                    kernel_size=(1, 3))
        self.conv = nn.Conv2d(in_channels=hidden_dim,
                              out_channels=hidden_dim*2,
                              kernel_size=1)
        self.h_prev = torch.randn(1, 16, 28, device=DEVICE)
        self.c_prev = torch.randn(1, 16, 28, device=DEVICE)

    def forward(self, x):
        i_s = self.i_s(x)
        num_rows = 28
        rows = []
        
        h_prev = self.h_prev
        c_prev = self.c_prev
        for row_idx in range(num_rows):
            # batch_size, channels, height, width
            is_row = i_s[:, :, row_idx, :]            
            h_prev, c_prev = self.cell(is_row, h_prev, c_prev)
            rows.append(h_prev.unsqueeze(dim=2))
        out = torch.cat(rows, dim=2)
        out = self.conv(out)
        # skip connection
        out += x
        return out

In [None]:
# test row lstm
image = torch.randn(1, 16*2, 28, 28, device=DEVICE)
h_prev = torch.randn(1, 16, 28, device=DEVICE)
c_prev = torch.randn(1, 16, 28, device=DEVICE)
row_lstm = RowLSTM(16).to(DEVICE)
row_lstm(image).shape

In [None]:
class PixelRNN(nn.Module):
    
    def __init__(self, hidden_dim):
        super().__init__()
        self.layers = nn.Sequential(
                MaskedConvolution(in_channels=1, 
                                  out_channels=hidden_dim*2,
                                  kernel_size=(7, 7),
                                  mask_type='A'),
                nn.ReLU(),
                RowLSTM(hidden_dim),
                RowLSTM(hidden_dim),
                RowLSTM(hidden_dim),
                RowLSTM(hidden_dim),
                RowLSTM(hidden_dim),
                RowLSTM(hidden_dim),
                RowLSTM(hidden_dim),
                nn.ReLU(),
                nn.Conv2d(hidden_dim*2, hidden_dim*2, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(hidden_dim*2, 256, kernel_size=1)
        )
    
    def forward(self, x):
        x = self.layers(x)
        x = x.view(BATCH_SIZE, 256, 1, 28, 28)
        return x

In [None]:
test_images = torch.randn(BATCH_SIZE, 1, 28, 28, device=DEVICE)
model = PixelRNN(16).to(DEVICE)
model(test_images).shape

In [None]:
train_dataset = torchvision.datasets.MNIST(root='../datasets/', 
                                           train=True, 
                                           transform=T.PILToTensor(), 
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../datasets/', 
                                           train=False, 
                                           transform=T.PILToTensor(), 
                                           download=False)

In [None]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2)

In [None]:
model = PixelRNN(HIDDEN_DIM).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)

In [None]:
def train():
    for epoch in range(1, NUM_EPOCHS+1):
        train_losses = []
        test_losses = []
        for features, _ in tqdm(train_dataloader, leave=False):
            features = features.to(DEVICE)
            logits = model(features.float() / 255)
            loss = criterion(logits, features.long())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.cpu().item())
        
        # inference
        with torch.inference_mode():
            for features, _ in test_dataloader:
                features = features.to(DEVICE)
                logits = model(features.float() / 255)
                loss = criterion(logits, features.long())
                test_losses.append(loss.cpu().item())

        ce_train = sum(train_losses)/len(train_losses)
        ce_test = sum(test_losses)/len(test_losses)
        print(f'Epoch: {epoch}/{NUM_EPOCHS}, Cross Entropy Train: {ce_train:.4f}, Cross Entropy Test: {ce_test:.4f}')

In [None]:
train()

## PixelCNN