Libraries

In [58]:

## Standard libraries
import os
import math
import numpy as np
import time

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

# Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms

## Progress bar
from tqdm.notebook import tqdm


Constants

In [59]:
# Training tag
Train_tag = "first_attempt"
# Whether or not start training from previous checkpoints
Resume_from_checkpoint = False
# Path to the folder where the datasets are stored
DATASET_PATH = "../data"
# Path to the folder where the checkpoints are saved
CHECKPOINT_PATH = "../checkpoints"
# Path to the folder where the training logs are saved
LOG_PATH = "../tensorboard_log"
# Fetching the device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", device)
# Learning rate
Learning_rate = 1e-3
# Max epoch
Max_epoch = 41

Using device cpu


Data

In [60]:
# Convert images from 0-1 to 0-255 (integers). We use the long datatype as we will use the images as labels as well
def discretize(sample):
    return (sample * 255).to(torch.long)

# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor(),
                                discretize])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000])

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=0)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=0)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=0)

In [61]:
first = True
for data in train_loader:
    if(first):
        imgs, tag = data
        print(imgs.shape)
        print(tag.shape)
        first = False

torch.Size([128, 1, 28, 28])
torch.Size([128])


Network Definition

In [62]:
class MaskedConvolution(nn.Module):

    def __init__(self, c_in, c_out, mask, **kwargs):
        """
        Implements a convolution with mask applied on its weights.
        Inputs:
            c_in - Number of input channels
            c_out - Number of output channels
            mask - Tensor of shape [kernel_size_H, kernel_size_W] with 0s where
                   the convolution should be masked, and 1s otherwise.
            kwargs - Additional arguments for the convolution
        """
        super().__init__()
        # For simplicity: calculate padding automatically
        kernel_size = (mask.shape[0], mask.shape[1])
        dilation = 1 if "dilation" not in kwargs else kwargs["dilation"]
        padding = tuple([dilation*(kernel_size[i]-1)//2 for i in range(2)])
        # Actual convolution
        self.conv = nn.Conv2d(c_in, c_out, kernel_size, padding=padding, **kwargs)

        # Mask as buffer => it is no parameter but still a tensor of the module
        # (must be moved with the devices)
        self.register_buffer('mask', mask[None,None])

    def forward(self, x):
        self.conv.weight.data *= self.mask # Ensures zero's at masked positions
        return self.conv(x)

In [63]:
class VerticalStackConvolution(MaskedConvolution):

    def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs):
        # Mask out all pixels below. For efficiency, we could also reduce the kernel
        # size in height, but for simplicity, we stick with masking here.
        mask = torch.ones(kernel_size, kernel_size)
        mask[kernel_size//2+1:,:] = 0

        # For the very first convolution, we will also mask the center row
        if mask_center:
            mask[kernel_size//2,:] = 0

        super().__init__(c_in, c_out, mask, **kwargs)

class HorizontalStackConvolution(MaskedConvolution):

    def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs):
        # Mask out all pixels on the left. Note that our kernel has a size of 1
        # in height because we only look at the pixel in the same row.
        mask = torch.ones(1,kernel_size)
        mask[0,kernel_size//2+1:] = 0

        # For the very first convolution, we will also mask the center pixel
        if mask_center:
            mask[0,kernel_size//2] = 0

        super().__init__(c_in, c_out, mask, **kwargs)

In [64]:
class GatedMaskedConv(nn.Module):

    def __init__(self, c_in, **kwargs):
        """
        Gated Convolution block implemented the computation graph shown above.
        """
        super().__init__()
        self.conv_vert = VerticalStackConvolution(c_in, c_out=2*c_in, **kwargs)
        self.conv_horiz = HorizontalStackConvolution(c_in, c_out=2*c_in, **kwargs)
        self.conv_vert_to_horiz = nn.Conv2d(2*c_in, 2*c_in, kernel_size=1, padding=0)
        self.conv_horiz_1x1 = nn.Conv2d(c_in, c_in, kernel_size=1, padding=0)

    def forward(self, v_stack, h_stack):
        # Vertical stack (left)
        v_stack_feat = self.conv_vert(v_stack)
        v_val, v_gate = v_stack_feat.chunk(2, dim=1)
        v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate)

        # Horizontal stack (right)
        h_stack_feat = self.conv_horiz(h_stack)
        h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat)
        h_val, h_gate = h_stack_feat.chunk(2, dim=1)
        h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate)
        h_stack_out = self.conv_horiz_1x1(h_stack_feat)
        h_stack_out = h_stack_out + h_stack

        return v_stack_out, h_stack_out

In [65]:
class PixelCNN(nn.Module):

    def __init__(self, c_in, c_hidden):
        super().__init__()
        # self.save_hyperparameters()

        # Initial convolutions skipping the center pixel
        self.conv_vstack = VerticalStackConvolution(c_in, c_hidden, mask_center=True)
        self.conv_hstack = HorizontalStackConvolution(c_in, c_hidden, mask_center=True)
        # Convolution block of PixelCNN. We use dilation instead of downscaling
        self.conv_layers = nn.ModuleList([
            GatedMaskedConv(c_hidden),
            GatedMaskedConv(c_hidden, dilation=2),
            GatedMaskedConv(c_hidden),
            GatedMaskedConv(c_hidden, dilation=4),
            GatedMaskedConv(c_hidden),
            GatedMaskedConv(c_hidden, dilation=2),
            GatedMaskedConv(c_hidden)
        ])
        # Output classification convolution (1x1)
        self.conv_out = nn.Conv2d(c_hidden, c_in * 256, kernel_size=1, padding=0)

        # self.example_input_array = train_set[0][0][None]

    def forward(self, x):
        """
        Forward image through model and return logits for each pixel.
        Inputs:
            x - Image tensor with integer values between 0 and 255.
        """
        # Scale input from 0 to 255 back to -1 to 1
        x = (x.float() / 255.0) * 2 - 1

        # Initial convolutions
        v_stack = self.conv_vstack(x)
        h_stack = self.conv_hstack(x)
        # Gated Convolutions
        for layer in self.conv_layers:
            v_stack, h_stack = layer(v_stack, h_stack)
        # 1x1 classification convolution
        # Apply ELU before 1x1 convolution for non-linearity on residual connection
        out = self.conv_out(F.elu(h_stack))

        # Output dimensions: [Batch, Classes, Channels, Height, Width]
        out = out.reshape(out.shape[0], 256, out.shape[1]//256, out.shape[2], out.shape[3])
        return out

    def calc_likelihood(self, x):
        # Forward pass with bpd likelihood calculation
        pred = self.forward(x)
        nll = F.cross_entropy(pred, x, reduction='none')
        bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1))
        return bpd.mean()

    @torch.no_grad()
    def sample(self, img_shape, img=None):
        """
        Sampling function for the autoregressive model.
        Inputs:
            img_shape - Shape of the image to generate (B,C,H,W)
            img (optional) - If given, this tensor will be used as
                             a starting image. The pixels to fill
                             should be -1 in the input tensor.
        """
        # Create empty image
        if img is None:
            img = torch.zeros(img_shape, dtype=torch.long).to(device) - 1
        # Generation loop
        for h in tqdm(range(img_shape[2]), leave=False):
            for w in range(img_shape[3]):
                for c in range(img_shape[1]):
                    # Skip if not to be filled (-1)
                    if (img[:,c,h,w] != -1).all().item():
                        continue
                    # For efficiency, we only have to input the upper part of the image
                    # as all other parts will be skipped by the masked convolutions anyways
                    pred = self.forward(img[:,:,:h+1,:])
                    probs = F.softmax(pred[:,:,c,h,w], dim=-1)
                    img[:,c,h,w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
        return img

    # def configure_optimizers(self):
    #     optimizer = optim.Adam(self.parameters(), lr=1e-3)
    #     scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.99)
    #     return [optimizer], [scheduler]

    # def training_step(self, batch, batch_idx):
    #     loss = self.calc_likelihood(batch[0])
    #     self.log('train_bpd', loss)
    #     return loss

    # def validation_step(self, batch, batch_idx):
    #     loss = self.calc_likelihood(batch[0])
    #     self.log('val_bpd', loss)

    # def test_step(self, batch, batch_idx):
    #     loss = self.calc_likelihood(batch[0])
    #     self.log('test_bpd', loss)

Training

In [66]:
PixelCNN_model = PixelCNN(c_in=1, c_hidden=64)
PixelCNN_model = PixelCNN_model.to(device)
optimizer = optim.Adam(PixelCNN_model.parameters(), lr=Learning_rate)
start_epoch = -1
loss_fn = 


if not os.path.isdir(os.path.join(LOG_PATH, Train_tag)):
        os.mkdir(os.path.join(LOG_PATH, Train_tag))
writer = SummaryWriter(os.path.join(LOG_PATH, Train_tag))
if not os.path.isdir(os.path.join(CHECKPOINT_PATH, Train_tag)):
        os.mkdir(os.path.join(CHECKPOINT_PATH, Train_tag))

# if resume from previous checkpoints
if Resume_from_checkpoint:
    path_checkpoint = CHECKPOINT_PATH 
    checkpoint = torch.load(path_checkpoint)  
    PixelCNN_model.load_state_dict(checkpoint['net']) 
    # optimizer.load_state_dict(checkpoint['optimizer']) 
    start_epoch = checkpoint['epoch']

epoch = start_epoch
train_iter = 0
validation_iter = 0
for epoch_iter in range(Max_epoch):
    epoch = epoch + 1
    
    # train
    for imgs, _ in tqdm(train_loader, desc="epoch {} training ".format(epoch), leave=False):
        outputs = PixelCNN_model(imgs)
        loss = (F.cross_entropy(outputs, imgs, reduction='none')).mean()
        writer.add_scalar("training_loss", loss.item(), train_iter)
        train_iter += 1 
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # validation
    with torch.no_grad():
        for imgs, _ in tqdm(val_loader, desc="epoch {} validating ".format(epoch), leave=False):
            outputs = PixelCNN_model(imgs)
            loss = (F.cross_entropy(outputs, imgs, reduction='none')).mean()
            writer.add_scalar("validation_loss", loss.item(), validation_iter) 
            validation_iter += 1
    
    # save checkpoint every 5 epochs
    if(epoch%5==0):
        checkpoint = {
            "net": PixelCNN_model.state_dict(),
            "optimizer":optimizer.state_dict(),
            "epoch": epoch 
        }
    torch.save(checkpoint, os.path.join(os.path.join(CHECKPOINT_PATH, Train_tag), "epoch_{}.ckpt".format(epoch)))
        
    

epoch 0 training :   0%|          | 0/390 [00:00<?, ?it/s]

epoch 0 validating :   0%|          | 0/79 [00:00<?, ?it/s]

NameError: name 'eepoch' is not defined