# PixelRNN

Partially inspired by https://github.com/heechan95/PixelRNN-pytorch/blob/master/PixelRNN%20pytorch.ipynb

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T

In [2]:
# parameters from the paper
batch_size = 16
num_layers = 7
hidden_dim = 16

The architecture is designed for MNIST only and for static inputs.

In [3]:
# architecture
# 1. 7 x 7 with conv mask A
# 2. Row LSTM with residual blocks
#    a. i-s: 3x1 mask with conv mask B
#    b. s-s: 3x1 no mask
# 3. ReLU + 1x1 conv layer with mask B (2 layers)
# 4. 256 Ways softmax

In [4]:
# Type A masked Conv2d 
class MaskedConv2d(nn.Module):
    def __init__(self, in_channels=1, 
                 out_channels=hidden_dim,
                 kernel_size=7,
                 padding=3):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=in_channels,
                             out_channels=out_channels,
                             kernel_size=kernel_size,
                             padding=padding)
        
        mask_idxs = torch.arange(0, kernel_size**2).view(kernel_size, kernel_size)
        mask = (mask_idxs < padding * kernel_size + padding).float()
        self.register_buffer('mask', mask)
    
    def forward(self, x):
        # apply mask, see below link for more info
        # https://discuss.pytorch.org/t/applying-custom-mask-on-kernel-for-cnn/87099
        with torch.no_grad():
            self.conv.weight = nn.Parameter(self.conv.weight * self.mask)
        return self.conv(x)

In [5]:
test2dLayer = MaskedConv2d()
testdata = torch.randn(1, 1, 28, 28)
test2dLayer(testdata).shape

torch.Size([1, 16, 28, 28])

In [6]:
# Type B masked Conv1d
class MaskedConv1d(nn.Module):
    def __init__(self, 
                 in_channels=hidden_dim,
                 out_channels=hidden_dim,
                 kernel_size=3,
                 padding=1):
        super().__init__()
        self.conv = nn.Conv1d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              padding=padding)
        
        mask = torch.ones(kernel_size)
        mask[2] = 0
        self.register_buffer('mask', mask)
        
    def forward(self, x):
        with torch.no_grad():
            self.conv.weight = nn.Parameter(self.conv.weight * self.mask)
        return self.conv(x)

In [12]:
test1dLayer = MaskedConv1d()
# 16 -> hidden size (number of feature maps), 28 -> width
# the layer processes one row at a time
testdata = torch.randn(1, 16, 28)
test1dLayer(testdata).shape

torch.Size([1, 16, 28])

The LSTM cell receives one cell at a time.

In [51]:
class RowLSTMCell(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv_i_s = nn.Conv1d(in_channels=hidden_dim, 
                                  out_channels=4 * hidden_dim,
                                  kernel_size=3,
                                  padding=1)
        
        self.conv_s_s = MaskedConv1d(in_channels=hidden_dim, 
                                        out_channels=4 * hidden_dim,
                                        kernel_size=3,
                                        padding=1)
        
    
    def forward(self, x, h_prev, c_prev):
        i_s = self.conv_i_s(x)
        s_s = self.conv_s_s(h_prev)
        
        o, f, i, g = torch.split(i_s + s_s, hidden_dim, 1)
        o = torch.sigmoid(o)
        f = torch.sigmoid(f)
        i = torch.sigmoid(i)
        g = torch.tanh(g)
        
        c = f * c_prev + i * g
        h = o * torch.tanh(c)
        return h, c

In [55]:
# test Cell
row = torch.randn(1, 16, 28)
prev_h = torch.randn(1, 16, 28)
prev_c = torch.randn(1, 16, 28)

cell = RowLSTMCell()
h, c = cell(row, prev_h, prev_c)
print(h.shape, c.shape)

torch.Size([1, 16, 28]) torch.Size([1, 16, 28])


In [93]:
class RowLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.cell = RowLSTMCell()

    
    def forward(self, x):
        _, _, num_rows, _ = x.shape
        h_prev = torch.zeros(batch_size, hidden_dim, 28)
        c_prev = torch.zeros(batch_size, hidden_dim, 28)

        rows = []
        for row_idx in range(num_rows):
            image_row = x[:, :, row_idx, :]
            h_prev, c_prev = self.cell(image_row, h_prev, c_prev)
            rows.append(h_prev.unsqueeze(dim=2))
        return torch.cat(rows, dim=2)

In [94]:
# test row lstm
image = torch.randn(batch_size, 16, 28, 28)
row_lstm = RowLSTM()
row_lstm(image).shape

torch.Size([16, 16, 28, 28])

In [None]:
class PixelRNN(nn.Module):
    
    def __init__(self):
        pass
    
    def forward(self, x):
        pass

In [10]:
train_dataset = torchvision.datasets.MNIST(root='../datasets/', 
                                           train=True, 
                                           transform=T.ToTensor(), 
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../datasets/', 
                                           train=False, 
                                           transform=T.ToTensor(), 
                                           download=False)