In [1]:
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision
from torchvision.utils import save_image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import math
import matplotlib.pyplot as plt
import numpy as np

In [2]:
class MaskedConv2d(nn.Conv2d):
    def __init__(self,mask_type,c_in,c_out,k_size,stride,pad):
        super(MaskedConv2d,self).__init__(c_in, c_out, k_size, stride, pad, bias=False)
        self.mask_type = mask_type
        ch_out, ch_in, height, width = self.weight.size()
        mask  = torch.ones(ch_out, ch_in, height, width)
        if self.mask_type == 'A':
            mask[:, :, height // 2, width // 2:] = 0
            mask[:, :, height // 2 + 1:] = 0
        else:
            mask[:, :, height // 2, width // 2 + 1:] = 0
            mask[:, :, height // 2] = 0
        self.register_buffer('mask', mask)
    
    def forward(self,x):
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)

def MaskAConv2d(c_in=3, c_out=256, k_size=7, stride=1, pad=3):
    return nn.Sequential(MaskedConv2d('A', c_in, c_out, k_size, stride, pad),nn.BatchNorm2d(c_out))

class MaskBConvBlock(nn.Module):
    def __init__(self, h=128, k_size=3, stride=1, pad=1):
        super(MaskBConvBlock, self).__init__()
        self.net = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(2 * h, h, 1),  # 1x1
            nn.BatchNorm2d(h),
            nn.ReLU(),
            MaskedConv2d('B', h, h, k_size, stride, pad),
            nn.BatchNorm2d(h),
            nn.ReLU(),
            nn.Conv2d(h, 2 * h, 1),  # 1x1
            nn.BatchNorm2d(2 * h)
        )

    def forward(self, x):
        return self.net(x) + x     

In [3]:
class PixelCNN(nn.Module):
    def __init__(self,n_channel=3, h=128, discrete_channel=256):
        super(PixelCNN,self).__init__()
        self.discrete_channel = discrete_channel
        
        self.MaskAConv = MaskAConv2d(n_channel, 2 * h, k_size=7, stride=1, pad=3)
        MaskBConv = []
        for i in range(15):
            MaskBConv.append(MaskBConvBlock(h, k_size=3, stride=1, pad=1))
        self.MaskBConv = nn.Sequential(*MaskBConv)

        self.out = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(2 * h, 1024, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, n_channel * discrete_channel, kernel_size=1, stride=1, padding=0))

    def forward(self, x):
        batch_size, c_in, height, width = x.size()
        x = self.MaskAConv(x)
        x = self.MaskBConv(x)
        x = self.out(x)
        x = x.view(batch_size, c_in, self.discrete_channel, height, width)
        x = x.permute(0, 1, 3, 4, 2)
        return x
        

In [4]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
dataset = torchvision.datasets.CIFAR10(root='./cifar10',train=False,download=True,transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=20,shuffle=False, num_workers=0)

Files already downloaded and verified


In [5]:
net = PixelCNN()
net = net.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr=0.001)
n_epochs = 50

In [None]:
for epoch in range(n_epochs):
    batch_loss_history = []
    for batch , (image,label) in enumerate(dataloader):
        image = Variable(image).cuda()
        logit = net(image)
        logit = logit.contiguous()
        logit = logit.view(-1, 256)
        target = Variable(image.data.view(-1)*127 + 128).long()
        loss = criterion(logit,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss = float(loss.data)
        batch_loss_history.append(loss)
    epoch_loss = np.mean(batch_loss_history)
    print('Epoch {0} Loss: {1}'.format(epoch,epoch_loss))
net.save_state_dict('pixelcnn.pt')

In [26]:
dataiter = iter(dataloader)
images, labels = dataiter.next()
images[:,:,16:,:] =1
save_image(images, "./samples/input.png" , nrow=5)
image_path = "./samples/output.png"
sample = images
for i in range(32):
    for j in range(32):
        out = net(Variable(sample, volatile=True).cuda())
        probs = F.softmax(out[:, :, i, j], dim=2).data
        for k in range(3):
            pixel = torch.multinomial(probs[:, k], 1).float() / 255.
            sample[:, k, i, j] = pixel.squeeze(1)
#save_image(sample, image_path , nrow=5)

  if __name__ == '__main__':


RuntimeError: CUDA error: out of memory