# PixelRNN

In [8]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T

In [3]:
# parameters from the paper
batch_size = 16
num_layers = 7
hidden_dim = 16

The architecture is designed for MNIST only and for static inputs.

In [4]:
# architecture
# 1. 7 x 7 with conv mask A
# 2. Row LSTM with residual blocks
#    a. i-s: 3x1 mask with conv mask B
#    b. s-s: 3x1 no mask
# 3. ReLU + 1x1 conv layer with mask B (2 layers)
# 4. 256 Ways softmax

In [5]:
# Type A masked Conv2d 
class MaskedConv2d(nn.Module):
    def __init__(self, in_channels=1, 
                 out_channels=hidden_dim,
                 kernel_size=7,
                 padding=3):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=in_channels,
                             out_channels=out_channels,
                             kernel_size=kernel_size,
                             padding=padding)
        
        mask_idxs = torch.arange(0, kernel_size**2).view(kernel_size, kernel_size)
        mask = (mask_idxs < padding * kernel_size + padding).float()
        self.register_buffer('mask', mask)
    
    def forward(self, x):
        # apply mask, see below link for more info
        # https://discuss.pytorch.org/t/applying-custom-mask-on-kernel-for-cnn/87099
        with torch.no_grad():
            self.conv.weight = nn.Parameter(self.conv.weight * self.mask)
        return self.conv(x)

In [6]:
test2dLayer = MaskedConv2d()
testdata = torch.randn(1, 1, 28, 28)
test2dLayer(testdata).shape

torch.Size([1, 16, 28, 28])

In [48]:
# Type B masked Conv1d
class MaskedConv1d(nn.Module):
    def __init__(self, 
                 in_channels=hidden_dim,
                 out_channels=hidden_dim,
                 kernel_size=3,
                 padding=1):
        super().__init__()
        self.conv = nn.Conv1d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              padding=padding)
        
        mask = torch.ones(kernel_size)
        mask[2] = 0
        self.register_buffer('mask', mask)
        
    def forward(self, x):
        with torch.no_grad():
            self.conv.weight = nn.Parameter(self.conv.weight * self.mask)
        return self.conv(x)

In [53]:
test1dLayer = MaskedConv1d()
# 16 -> hidden size (number of feature maps), 28 -> width
# the layer processes one row at a time
testdata = torch.randn(1, 16, 28)
test1dLayer(testdata).shape

torch.Size([1, 16, 28])

In [9]:
class RowLSTMCell(nn.Module):
    def __init__(self):
        pass
    
    def forward(self, x):
        pass

In [10]:
class RowLSTM(nn.Module):
    def __init__(self):
        pass
    
    def forward(self):
        pass

In [11]:
train_dataset = torchvision.datasets.MNIST(root='../datasets/', 
                                           train=True, 
                                           transform=T.ToTensor(), 
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../datasets/', 
                                           train=False, 
                                           transform=T.ToTensor(), 
                                           download=False)