In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim, cuda, backends
from torch.autograd import Variable
from torch.utils import data
from torchvision import datasets, transforms, utils

In [2]:
class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask_type, *args, **kwargs):
        super(MaskedConv2d, self).__init__(*args, **kwargs)
        assert mask_type in {'A', 'B'}
        self.register_buffer('mask', self.weight.data.clone())
        _, _, kH, kW = self.weight.size()
        self.mask.fill_(1)
        self.mask[:, :, kH // 2, kW // 2 + (mask_type == 'B'):] = 0
        self.mask[:, :, kH // 2 + 1:] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)

In [3]:
fm = 64
net = nn.Sequential(
    MaskedConv2d('A', 1,  fm, 7, 1, 3, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
    MaskedConv2d('B', fm, fm, 7, 1, 3, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
    MaskedConv2d('B', fm, fm, 7, 1, 3, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
    MaskedConv2d('B', fm, fm, 7, 1, 3, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
    MaskedConv2d('B', fm, fm, 7, 1, 3, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
    MaskedConv2d('B', fm, fm, 7, 1, 3, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
    MaskedConv2d('B', fm, fm, 7, 1, 3, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
    MaskedConv2d('B', fm, fm, 7, 1, 3, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
    nn.Conv2d(fm, 256, 1))
net.cuda()

Sequential(
  (0): MaskedConv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
  (3): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace)
  (6): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
  (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): ReLU(inplace)
  (9): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
  (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ReLU(inplace)
  (12): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
  (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): ReL

In [4]:
tr = data.DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor()),
                     batch_size=128, shuffle=True, num_workers=1, pin_memory=True)
sample = torch.Tensor(144, 1, 28, 28).cuda()
print(len(tr))

469


In [17]:
optimizer = optim.Adam(net.parameters(),lr=0.001)
for epoch in range(45):
    err_tr = []
    net.train(True)
    for bs , (image,_) in enumerate(tr):
        image = Variable(image).cuda()
        target = Variable((image.data[:,0] * 255).long())
        out = net(image)
        loss = F.cross_entropy(out, target)
        err_tr.append(loss.data[0])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(np.mean(err_tr))
    torch.save(net,"./pixelCnn.pt")

  # Remove the CWD from sys.path while we load stuff.


0.6020797


  "type " + obj.__name__ + ". It won't be checked "


0.60016525
0.5983626
0.59798384
0.59683716
0.5964019
0.5940644
0.5937077
0.592965
0.592362
0.59102774
0.5901135
0.5889411
0.5894492
0.5877539
0.5887813
0.5871245
0.587147
0.5859145
0.58436733
0.5846595
0.58380866
0.5846182
0.5830016
0.58235675
0.5830938
0.5817681
0.5810481
0.5812944
0.5825168
0.5799343
0.5790081
0.5790843
0.5788196
0.579169
0.57815367
0.57738197
0.5771584
0.57826227
0.5765786
0.5758998
0.5766243
0.5762317
0.57582337
0.5750248


In [18]:
sample.fill_(0)
net.train(False)
for i in range(28):
    for j in range(28):
        out = net(Variable(sample, volatile=True))
        probs = F.softmax(out[:, :, i, j]).data
        sample[:, :, i, j] = torch.multinomial(probs, 1).float() / 255.
utils.save_image(sample, './samples/sample.png'.format(epoch), nrow=12, padding=0)

  """
  
