#Run pretrained PixelCNNs in jupyter

This code is made for loading our trained PixelCNN models in jyputer notebook and displaying some general results. 



In [2]:
#Load dependencies

import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim, cuda, backends
from torch.autograd import Variable
from torch.utils import data
from torchvision import datasets, transforms, utils
from pixel_functions import *

In [3]:
#Load datasets (MNIST)

trans = transforms.Compose([transforms.ToTensor(), DynamicBinarization()])

tr_bin = data.DataLoader(datasets.MNIST('../data', train=True, download=True, transform=trans),
                     batch_size=128, shuffle=True, num_workers=0, pin_memory=True)
te_bin = data.DataLoader(datasets.MNIST('../data', train=False, download=True, transform=trans),
                     batch_size=128, shuffle=False, num_workers=0, pin_memory=True)
tr_256 = data.DataLoader(datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()),
                     batch_size=128, shuffle=True, num_workers=0, pin_memory=True)
te_256 = data.DataLoader(datasets.MNIST('../data', train=False, download=True, transform=transforms.ToTensor()),
                     batch_size=128, shuffle=False, num_workers=0, pin_memory=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
#Define models and load pytorch state dicts

net_pixel_bin  = PixelCNN(num_filters=64, color_levels=1)
net_pixel_bin.cuda()
checkpoint = torch.load("Standard_Binary_e25.pth")
net_pixel_bin.load_state_dict(checkpoint)

net_pixel_256 = PixelCNN(num_filters=64, color_levels=256)
net_pixel_256.cuda()
checkpoint = torch.load("Standard_256_e49.pth")
net_pixel_256.load_state_dict(checkpoint)

net_gated_bin = GatedPixelCNN(hidden_fmaps=64, causal_ksize=7, hidden_ksize=3, num_layers=12, 
                    out_hidden_fmaps=256, color_levels=1)
net_gated_bin.cuda()
checkpoint = torch.load('Gated_Binary_e50.pth')
net_gated_bin.load_state_dict(checkpoint)

net_gated_256 = GatedPixelCNN(hidden_fmaps=64, causal_ksize=7, hidden_ksize=3, num_layers=12, 
                    out_hidden_fmaps=256, color_levels=256)
net_gated_256.cuda()
checkpoint = torch.load('Gated_256_e50.pth')
net_gated_256.load_state_dict(checkpoint)

net_space_256 = GatedPixelCNN_space(hidden_fmaps=64, causal_ksize=7, hidden_ksize=3, num_layers=12, 
                    out_hidden_fmaps=256, color_levels=256)
net_space_256.cuda()
checkpoint = torch.load('Spacial_Gated_256_e60.pth')
net_space_256.load_state_dict(checkpoint)

net_space_bin = GatedPixelCNN_space(hidden_fmaps=64, causal_ksize=7, hidden_ksize=3, num_layers=12, 
                    out_hidden_fmaps=256, color_levels=1)
net_space_bin.cuda()
checkpoint = torch.load('Spacial_Gated_Binary_e30.pth')
net_space_bin.load_state_dict(checkpoint)


<All keys matched successfully>

In [None]:
# Generate a (ns x ns) grid of samples from scratch ~5 min

ns = 8
s_pixel_bin = sample_images(net_pixel_bin,num_colors=1,num_samples=ns,label_bool=False)
s_pixel_256 = sample_images(net_pixel_256,num_colors=256,num_samples=ns,label_bool=False)
s_gated_bin = sample_images(net_gated_bin,num_colors=1,num_samples=ns,label_bool=True)
s_gated_256 = sample_images(net_gated_256,num_colors=256,num_samples=ns,label_bool=True)
s_space_bin = sample_images(net_space_bin,num_colors=1,num_samples=ns,label_bool=True)
s_space_256 = sample_images(net_space_256,num_colors=256,num_samples=ns,label_bool=True)

utils.save_image(s_pixel_bin, 'pixel_bin.png', nrow=8, padding=0)
utils.save_image(s_pixel_256, 'pixel_256.png', nrow=8, padding=0)
utils.save_image(s_gated_bin, 'gated_bin.png', nrow=8, padding=0)
utils.save_image(s_gated_256, 'gated_256.png', nrow=8, padding=0)
utils.save_image(s_space_bin, 'space_bin.png', nrow=8, padding=0)
utils.save_image(s_space_256, 'space_256.png', nrow=8, padding=0)

In [None]:
#Get data for half images

input, labels = next(iter(te_256))
label = Variable(labels.cuda().long())[[[torch.where(labels==i)[0][0] for i in range(10)]]]
label = torch.repeat_interleave(label, 10, dim=0)
input_bin = Variable(Binarize(input.cuda()))[[[torch.where(labels==i)[0][0] for i in range(10)]]]
input = Variable(input.cuda())[[[torch.where(labels==i)[0][0] for i in range(10)]]]
input = torch.repeat_interleave(input, 10, dim=0)
input_bin = torch.repeat_interleave(input_bin, 10, dim=0)

In [None]:
#Finish half images for all models (10 times for each model and each digit) ~5 min

utils.save_image(input_bin, 'Half_images_GT_bin.png', nrow=10, padding=0)
utils.save_image(input, 'Half_images_GT.png', nrow=10, padding=0)

sample_pixel_bin = plot_half_boys(net = net_pixel_bin,input=input_bin,num_colors=1,label = None)
utils.save_image(sample_pixel_bin, 'Half_images_pixel_bin_1.png', nrow=10, padding=0)
sample_pixel_256 = plot_half_boys(net = net_pixel_256,input=input,num_colors=256,label = None)
utils.save_image(sample_pixel_256, 'Half_images_pixel_256_1.png', nrow=10, padding=0)

sample_gated_bin = plot_half_boys(net = net_gated_bin,input=input_bin,num_colors=1,label = label)
utils.save_image(sample_gated_bin, 'Half_images_gated_bin_1.png', nrow=10, padding=0)
sample_gated_256 = plot_half_boys(net = net_gated_256,input=input,num_colors=256,label = label)
utils.save_image(sample_gated_256, 'Half_images_gated_256_1.png', nrow=10, padding=0)

sample_space_bin = plot_half_boys(net = net_space_bin,input=input_bin,num_colors=1,label = label)
utils.save_image(sample_space_bin, 'Half_images_space_bin_1.png', nrow=10, padding=0)
sample_space_256 = plot_half_boys(net = net_space_256,input=input,num_colors=256,label = label)
utils.save_image(sample_space_256, 'Half_images_space_256_1.png', nrow=10, padding=0)

In [5]:
#get test losses for the models

loss_pixel_256 = []
loss_gated_256 = []
loss_space_256 = []
loss_pixel_bin = []
loss_gated_bin = []
loss_space_bin = []

criterion_256 = F.cross_entropy
criterion_bin = nn.BCEWithLogitsLoss(reduction='mean')

with torch.no_grad():
        for input, label in te_256:#tr_256 for training set instead
            label = Variable(label.cuda().long())
            target = Variable((input.data[:,0]*255).long().cuda())
            input = Variable((input*2-1).cuda())
            loss_pixel_256.append(criterion_256(net_pixel_256(input), target))
            loss_gated_256.append(criterion_256(net_gated_256(input, label), target))
            loss_space_256.append(criterion_256(net_space_256(input, label), target))

        for input, label in te_bin:#tr_bin for training set instead
            label = Variable(label.cuda().long())
            input = Variable(input.cuda())
            target = Variable(input.data[:,0].unsqueeze(1))/2+1/2*torch.ones_like(input)
            loss_pixel_bin.append(criterion_bin(net_pixel_bin(input), target))
            loss_gated_bin.append(criterion_bin(net_gated_bin(input, label), target))
            loss_space_bin.append(criterion_bin(net_space_bin(input, label), target))

print("Mean loss for standard 256: ",torch.tensor(loss_pixel_256).mean().cpu().numpy())
print("Mean loss for gated 256: ",torch.tensor(loss_gated_256).mean().cpu().numpy())
print("Mean loss for spacial 256: ",torch.tensor(loss_space_256).mean().cpu().numpy())
print("Mean loss for standard bin: ",torch.tensor(loss_pixel_bin).mean().cpu().numpy())
print("Mean loss for gated bin: ",torch.tensor(loss_gated_bin).mean().cpu().numpy())
print("Mean loss for spacial bin: ",torch.tensor(loss_space_bin).mean().cpu().numpy())

Mean loss for standard 256:  0.7208883
Mean loss for gated 256:  0.5989351
Mean loss for spacial 256:  0.6005826
Mean loss for standard bin:  0.10873116
Mean loss for gated bin:  0.0998669
Mean loss for spacial bin:  0.09946648
