# Gated PixelCNN

This part is highly inspired by [uvadlc-notebooks](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial12/Autoregressive_Image_Modeling.html)

In [1]:
!nvidia-smi

Sun Oct 23 18:24:01 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   57C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import time
import pathlib
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchvision
import torchvision.transforms as T
from torchvision.datasets import MNIST

In [3]:
## hyperparameters
NUM_EPOCHS = 50
BATCH_SIZE = 128
NUM_CHANNELS = 64
NUM_LAYERS = 7
LEARNING_RATE = 1e-3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def discretize(sample):
    return (sample * 255).to(torch.long)

transform = T.Compose([T.ToTensor(),
                        discretize])


train_dataset = MNIST(root='../datasets/', train=True, download=True, transform=transform)
test_dataset = MNIST(root='../datasets/', train=False, download=False, transform=transform)

In [5]:
train_dataloader = data.DataLoader(dataset=train_dataset, 
                                   batch_size=BATCH_SIZE, 
                                   shuffle=True, 
                                   num_workers=2,
                                   drop_last=True,
                                   pin_memory=True)
test_dataloader = data.DataLoader(dataset=test_dataset, 
                                   batch_size=BATCH_SIZE, 
                                   shuffle=True, 
                                   num_workers=2)

In [6]:
class MaskedConvolution(nn.Module):
    
    def __init__(self, in_channels, out_channels, mask, **kwargs):
        super().__init__()
        kernel_size = mask.shape
        dilation = 1 if "dilation" not in kwargs else kwargs["dilation"]
        padding = tuple([dilation * (size-1)//2 for size in kernel_size])
        self.conv = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels, 
                              kernel_size=kernel_size,
                              padding=padding,
                              **kwargs)
        
        self.register_buffer('mask', mask)
    
    def forward(self, x):
        with torch.no_grad():
          self.conv.weight *= self.mask 
        return self.conv(x)

In [7]:
mask = torch.rand(3, 3)
img = torch.randn(1, 1, 28, 28)
mask_conv = MaskedConvolution(1, 32, mask, dilation=2)
test = mask_conv(img)
# test = test.sum()
# test.backward()

In [8]:
# class MaskedConvolution(nn.Module):

#     def __init__(self, c_in, c_out, mask, **kwargs):
#         """
#         Implements a convolution with mask applied on its weights.
#         Inputs:
#             c_in - Number of input channels
#             c_out - Number of output channels
#             mask - Tensor of shape [kernel_size_H, kernel_size_W] with 0s where
#                    the convolution should be masked, and 1s otherwise.
#             kwargs - Additional arguments for the convolution
#         """
#         super().__init__()
#         # For simplicity: calculate padding automatically
#         kernel_size = (mask.shape[0], mask.shape[1])
#         dilation = 1 if "dilation" not in kwargs else kwargs["dilation"]
#         padding = tuple([dilation*(kernel_size[i]-1)//2 for i in range(2)])
#         # Actual convolution
#         self.conv = nn.Conv2d(c_in, c_out, kernel_size, padding=padding, **kwargs)

#         # Mask as buffer => it is no parameter but still a tensor of the module
#         # (must be moved with the devices)
#         self.register_buffer('mask', mask[None,None])

#     def forward(self, x):
#         self.conv.weight.data *= self.mask # Ensures zero's at masked positions
#         return self.conv(x)


In [9]:
# class VerticalStackConvolution(MaskedConvolution):
#     def __init__(self, in_channels, out_channels, kernel_size=3, mask_type='B', dilation=1):
#         assert mask_type in ['A', 'B']
#         mask = torch.ones(kernel_size, kernel_size)
#         mask[kernel_size//2+1:,:] = 0
#         if mask_type=='A':
#             mask[kernel_size//2,:] = 0
        
#         super().__init__(in_channels, out_channels, mask, dilation=dilation)
        
# class HorizontalStackConvolution(MaskedConvolution):
#     def __init__(self, in_channels, out_channels, kernel_size=3, mask_type='B', dilation=1):
#         assert mask_type in ['A', 'B']
#         mask = torch.ones(1, kernel_size)
#         mask[0, kernel_size//2+1:] = 0
#         if mask_type=='A':
#             mask[0, kernel_size//2] = 0
#         super().__init__(in_channels, out_channels, mask, dilation=dilation)

In [10]:
class VerticalStackConvolution(MaskedConvolution):

    def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs):
        # Mask out all pixels below. For efficiency, we could also reduce the kernel
        # size in height, but for simplicity, we stick with masking here.
        mask = torch.ones(kernel_size, kernel_size)
        mask[kernel_size//2+1:,:] = 0

        # For the very first convolution, we will also mask the center row
        if mask_center:
            mask[kernel_size//2,:] = 0

        super().__init__(c_in, c_out, mask, **kwargs)

class HorizontalStackConvolution(MaskedConvolution):

    def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs):
        # Mask out all pixels on the left. Note that our kernel has a size of 1
        # in height because we only look at the pixel in the same row.
        mask = torch.ones(1,kernel_size)
        mask[0,kernel_size//2+1:] = 0

        # For the very first convolution, we will also mask the center pixel
        if mask_center:
            mask[0,kernel_size//2] = 0

        super().__init__(c_in, c_out, mask, **kwargs)


In [11]:
test = VerticalStackConvolution(1, 32, 3)
print(test.mask)
test = VerticalStackConvolution(1, 32, 3, 'A')
print(test.mask)
test = HorizontalStackConvolution(1, 32, 3)
print(test.mask)
test = HorizontalStackConvolution(1, 32, 3, 'A')
print(test.mask)
del test

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.]])
tensor([[1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.]])
tensor([[1., 1., 0.]])
tensor([[1., 0., 0.]])


In [12]:
class GatedMaskedConv(nn.Module):

    def __init__(self, c_in, **kwargs):
        """
        Gated Convolution block implemented the computation graph shown above.
        """
        super().__init__()
        self.conv_vert = VerticalStackConvolution(c_in, c_out=2*c_in, **kwargs)
        self.conv_horiz = HorizontalStackConvolution(c_in, c_out=2*c_in, **kwargs)
        self.conv_vert_to_horiz = nn.Conv2d(2*c_in, 2*c_in, kernel_size=1, padding=0)
        self.conv_horiz_1x1 = nn.Conv2d(c_in, c_in, kernel_size=1, padding=0)

    def forward(self, v_stack, h_stack):
        # Vertical stack (left)
        v_stack_feat = self.conv_vert(v_stack)
        v_val, v_gate = v_stack_feat.chunk(2, dim=1)
        v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate)

        # Horizontal stack (right)
        h_stack_feat = self.conv_horiz(h_stack)
        h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat)
        h_val, h_gate = h_stack_feat.chunk(2, dim=1)
        h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate)
        h_stack_out = self.conv_horiz_1x1(h_stack_feat)
        h_stack_out = h_stack_out + h_stack

        return v_stack_out, h_stack_out


In [13]:
# # test gated convolution
# gated = GatedConvolution(NUM_CHANNELS, dilation=2)
# v = torch.randn(1, NUM_CHANNELS, 28, 28)
# h = torch.randn(1, NUM_CHANNELS, 28, 28)
# v, h = gated(v, h)
# print(v.shape, h.shape)
# del gated, v, h

In [14]:
# class GatedPixelCNN(nn.Module):
#     def __init__(self, hidden_channels, num_gated=7):
#         super().__init__()
#         self.v = VerticalStackConvolution(in_channels=1,
#                                           out_channels=hidden_channels,
#                                           kernel_size=7,
#                                           mask_type='A')
#         self.h = HorizontalStackConvolution(in_channels=1,
#                                           kernel_size=7,
#                                           out_channels=hidden_channels,
#                                           mask_type='A')
        
#         self.gated_convolutions = nn.ModuleList([GatedConvolution(hidden_channels, kernel_size=7) 
#                                                  for _ in range(num_gated)])
#         # we apply a 256 way softmax
#         self.output = nn.Conv2d(in_channels=hidden_channels, 
#                                 out_channels=256,
#                                 kernel_size=1)
#     def forward(self, x):
#         x = (x.float() / 255.0) * 2 - 1

#         v = self.v(x)
#         h = self.h(x)
        
#         for gated_layer in self.gated_convolutions:
#             v, h = gated_layer(v, h)
#         out = self.output(F.elu(h))
#         # from Batch, Classes, Height, Width to Batch, Classes, Channel, Height, Width
#         out = out.unsqueeze(dim=2)
#         return out

In [15]:
class PixelCNN(nn.Module):

    def __init__(self, c_in, c_hidden):
        super().__init__()

        # Initial convolutions skipping the center pixel
        self.conv_vstack = VerticalStackConvolution(c_in, c_hidden, mask_center=True)
        self.conv_hstack = HorizontalStackConvolution(c_in, c_hidden, mask_center=True)
        # Convolution block of PixelCNN. We use dilation instead of downscaling
        self.conv_layers = nn.ModuleList([
            GatedMaskedConv(c_hidden),
            GatedMaskedConv(c_hidden, dilation=2),
            GatedMaskedConv(c_hidden),
            GatedMaskedConv(c_hidden, dilation=4),
            GatedMaskedConv(c_hidden),
            GatedMaskedConv(c_hidden, dilation=2),
            GatedMaskedConv(c_hidden)
        ])
        # Output classification convolution (1x1)
        self.conv_out = nn.Conv2d(c_hidden, c_in * 256, kernel_size=1, padding=0)


    def forward(self, x):
        """
        Forward image through model and return logits for each pixel.
        Inputs:
            x - Image tensor with integer values between 0 and 255.
        """
        # Scale input from 0 to 255 back to -1 to 1
        x = (x.float() / 255.0) * 2 - 1

        # Initial convolutions
        v_stack = self.conv_vstack(x)
        h_stack = self.conv_hstack(x)
        # Gated Convolutions
        for layer in self.conv_layers:
            v_stack, h_stack = layer(v_stack, h_stack)
        # 1x1 classification convolution
        # Apply ELU before 1x1 convolution for non-linearity on residual connection
        out = self.conv_out(F.elu(h_stack))

        # Output dimensions: [Batch, Classes, Channels, Height, Width]
        out = out.reshape(out.shape[0], 256, out.shape[1]//256, out.shape[2], out.shape[3])
        return out

    def calc_likelihood(self, x):
        # Forward pass with bpd likelihood calculation
        pred = self.forward(x)
        nll = F.cross_entropy(pred, x, reduction='none')
        bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1))
        return bpd.mean()

    @torch.no_grad()
    def sample(self, img_shape, img=None):
        """
        Sampling function for the autoregressive model.
        Inputs:
            img_shape - Shape of the image to generate (B,C,H,W)
            img (optional) - If given, this tensor will be used as
                             a starting image. The pixels to fill
                             should be -1 in the input tensor.
        """
        # Create empty image
        if img is None:
            img = torch.zeros(img_shape, dtype=torch.long).to(device) - 1
        # Generation loop
        for h in tqdm(range(img_shape[2]), leave=False):
            for w in range(img_shape[3]):
                for c in range(img_shape[1]):
                    # Skip if not to be filled (-1)
                    if (img[:,c,h,w] != -1).all().item():
                        continue
                    # For efficiency, we only have to input the upper part of the image
                    # as all other parts will be skipped by the masked convolutions anyways
                    pred = self.forward(img[:,:,:h+1,:])
                    probs = F.softmax(pred[:,:,c,h,w], dim=-1)
                    img[:,c,h,w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
        return img

In [16]:
# # test pixelcnn
# gated_pixel_cnn = GatedPixelCNN(NUM_CHANNELS)
# #print(gated_pixel_cnn)
# img = torch.randn(1, 1, 28, 28)
# labels = (torch.rand(1, 1, 28, 28) * 256).long()
# output = gated_pixel_cnn(img)
# print(output.shape, labels.shape)
# test = F.cross_entropy(output, labels, reduction='mean')
# # print(test)
# del gated_pixel_cnn, img, labels, output, test

In [17]:
model = PixelCNN(c_in=1, c_hidden=64).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)

In [18]:
from google.colab import drive
drive.mount('/content/gdrive')
drive_path = pathlib.Path('/content/gdrive')
weights_path = drive_path / 'MyDrive/weights'

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [19]:
# training loop
best_loss = float("inf")
scaler = torch.cuda.amp.GradScaler()
for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    train_losses = []
    test_losses = []
    for features, _ in train_dataloader:
        optimizer.zero_grad()

        features = features.to(DEVICE)

        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
          loss = model.calc_likelihood(features)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
      
        train_losses.append(loss.cpu().item())

        
    # evaluate on the test dataset
    with torch.inference_mode():
        for features, _ in test_dataloader:
            features = features.to(DEVICE)
            with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
              loss = model.calc_likelihood(features)
            test_losses.append(loss.cpu().item())

    end_time = time.time()
    bpd_train = sum(train_losses)/len(train_losses)
    bpd_test = sum(test_losses)/len(test_losses)
    print(f'Epoch: {epoch+1}/{NUM_EPOCHS}, BPD Train: {bpd_train:.4f}, BPD Test: {bpd_test:.4f}, Elapsed Time: {end_time-start_time:.2f}sec')
    
    if bpd_test < best_loss:
      print("Saving Weights")
      best_loss = bpd_test
      torch.save(model.state_dict(), f=weights_path / 'pixel_cnn.pt')
    scheduler.step(bpd_test)

Epoch: 1/50, BPD Train: 1.3411, BPD Test: 1.1581, Elapsed Time: 68.71sec
Saving Weights
Epoch: 2/50, BPD Train: 1.1454, BPD Test: 1.1353, Elapsed Time: 66.82sec
Saving Weights
Epoch: 3/50, BPD Train: 1.1207, BPD Test: 1.1120, Elapsed Time: 66.42sec
Saving Weights
Epoch: 4/50, BPD Train: 1.1041, BPD Test: 1.0898, Elapsed Time: 66.45sec
Saving Weights
Epoch: 5/50, BPD Train: 1.0919, BPD Test: 1.0829, Elapsed Time: 66.57sec
Saving Weights
Epoch: 6/50, BPD Train: 1.0730, BPD Test: 1.0610, Elapsed Time: 67.33sec
Saving Weights
Epoch: 7/50, BPD Train: 1.0664, BPD Test: 1.0527, Elapsed Time: 66.76sec
Saving Weights


KeyboardInterrupt: ignored

In [None]:
# sample an image
@torch.no_grad()
def sample(model, img_shape):
    img = torch.zeros(img_shape, device=DEVICE) - 1
    # Generation loop
    _, channel, height, width = img_shape
    for h in range(height):
        for w in range(width):
            for c in range(channel):
                pred = model(img[:,:,:h+1,:])
                probs = F.softmax(pred[:,:,c,h,w], dim=-1)
                img[:,c,h,w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
    return img

In [None]:
img = sample(model, img_shape=(1, 1, 28, 28))
print(img.shape)
plt.imshow(img.cpu().long().squeeze(), cmap="gray")
plt.show()

## Conditional PixelCNN