## Heavily adapted from 
[here](https://github.com/pytorch/examples/blob/master/vae/main.py)

In [4]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
batch_size=128
epochs = 1000
log_interval = 100

In [7]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.ToTensor()
                  ),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

In [14]:
im_dim = 28*28

In [81]:
??nn.MaxUnpool2d()

In [118]:
latent_space_size = 20

In [141]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        #self.conv1 = nn.Conv2d(1, 5, 3, stride=1)
        #self.maxpool1= nn.MaxPool2d(kernel_size=2, return_indices=False)
        #self.fc21 = nn.Linear(845, 20)
        #self.fc22 = nn.Linear(845, 20)
        #self.fc3 = nn.Linear(20, 400)
        #self.unconv1 = nn.ConvTranspose2d(1, 3, 3, stride=1)
        #self.fc4 = nn.Linear(1452, 784)
        
        self.fc1 = nn.Conv2d(1,32, kernel_size=(28,28), stride=1, padding=0)
        self.fc21 = nn.Conv2d(32,latent_space_size, kernel_size=(1,1), stride=1, padding=0)
        self.fc22 = nn.Conv2d(32,latent_space_size, kernel_size=(1,1), stride=1, padding=0)
        
        self.fc3 = nn.ConvTranspose2d(latent_space_size,118, kernel_size=(1,1),  stride=1, padding=0)
        self.fc4 = nn.ConvTranspose2d(118,1, kernel_size=(28,28),stride=1, padding=0)

    #def encode(self, x):
    #    h1 = F.sigmoid(self.maxpool1(self.conv1(x)))
    #    h1 = h1.view(-1, 845)
    #    return self.fc21(h1), self.fc22(h1)

    #def reparameterize(self, mu, logvar):
    #    std = torch.exp(0.5*logvar)
    #    eps = torch.randn_like(std)
    #    return mu + eps*std

    #def decode(self, z):
    #    h3 = F.sigmoid(self.fc3(z)).view(-1,20,20)
    #    h3 = h3.view(-1,1,20,20)
    #    h3 = self.unconv1(h3)
    #    return torch.sigmoid(self.fc4(h3.view(-1,1452)))

    #def forward(self, x):
    #    mu, logvar = self.encode(x)
    #    z = self.reparameterize(mu, logvar)
    #    return self.decode(z), mu, logvar
    
    
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [142]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [143]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, beta=1):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + beta * KLD

In [144]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar, beta=2)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [145]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [149]:
## Main
for epoch in range(1, epochs + 1):
    print(epoch)
    train(epoch)
    print("epoch")
    test(epoch)
    print("after")
    #if epochs % 10 == 0:
    #    with torch.no_grad():
    #        sample = torch.randn(64, 20).to(device)
    #        sample = model.decode(sample).cpu()
    #        save_image(sample.view(64, 1, 28, 28),'results/sample_' + str(epoch) + '.png')

1


  "Please ensure they have the same size.".format(target.size(), input.size()))




  "Please ensure they have the same size.".format(target.size(), input.size()))


====> Epoch: 1 Average loss: 151.6794
epoch


  "Please ensure they have the same size.".format(target.size(), input.size()))


====> Test set loss: 129.7117
after
2
====> Epoch: 2 Average loss: 147.1863
epoch
====> Test set loss: 125.7795
after
3
====> Epoch: 3 Average loss: 144.8802
epoch
====> Test set loss: 124.4922
after
4
====> Epoch: 4 Average loss: 143.4092
epoch
====> Test set loss: 123.1160
after
5
====> Epoch: 5 Average loss: 142.0382
epoch
====> Test set loss: 121.7854
after
6
====> Epoch: 6 Average loss: 140.8234
epoch
====> Test set loss: 121.0852
after
7
====> Epoch: 7 Average loss: 140.0654
epoch
====> Test set loss: 120.6847
after
8
====> Epoch: 8 Average loss: 139.5356
epoch
====> Test set loss: 119.9585
after
9
====> Epoch: 9 Average loss: 139.0575
epoch
====> Test set loss: 119.3949
after
10
====> Epoch: 10 Average loss: 138.5762
epoch
====> Test set loss: 119.1256
after
11
====> Epoch: 11 Average loss: 138.2068
epoch
====> Test set loss: 118.7414
after
12
====> Epoch: 12 Average loss: 137.8554
epoch
====> Test set loss: 118.2708
after
13
====> Epoch: 13 Average loss: 137.6021
epoch
====> Te

====> Test set loss: 115.9900
after
26
====> Epoch: 26 Average loss: 135.3937
epoch
====> Test set loss: 116.1227
after
27
====> Epoch: 27 Average loss: 135.3669
epoch
====> Test set loss: 115.7214
after
28
====> Epoch: 28 Average loss: 135.2675
epoch
====> Test set loss: 115.6508
after
29
====> Epoch: 29 Average loss: 135.1792
epoch
====> Test set loss: 115.6736
after
30
====> Epoch: 30 Average loss: 135.1389
epoch
====> Test set loss: 115.5053
after
31
====> Epoch: 31 Average loss: 135.0002
epoch
====> Test set loss: 115.7975
after
32
====> Epoch: 32 Average loss: 134.9570
epoch
====> Test set loss: 115.5605
after
33
====> Epoch: 33 Average loss: 134.8614
epoch
====> Test set loss: 115.4061
after
34
====> Epoch: 34 Average loss: 134.7850
epoch
====> Test set loss: 115.2468
after
35
====> Epoch: 35 Average loss: 134.6803
epoch
====> Test set loss: 115.0159
after
36
====> Epoch: 36 Average loss: 134.6707
epoch
====> Test set loss: 115.1891
after
37
====> Epoch: 37 Average loss: 134.656

====> Epoch: 49 Average loss: 134.1010
epoch
====> Test set loss: 114.2990
after
50
====> Epoch: 50 Average loss: 134.0529
epoch
====> Test set loss: 114.5496
after
51
====> Epoch: 51 Average loss: 134.0569
epoch
====> Test set loss: 114.5586
after
52
====> Epoch: 52 Average loss: 133.9944
epoch
====> Test set loss: 114.5029
after
53
====> Epoch: 53 Average loss: 133.9722
epoch
====> Test set loss: 114.3539
after
54
====> Epoch: 54 Average loss: 133.9264
epoch
====> Test set loss: 114.3696
after
55
====> Epoch: 55 Average loss: 133.8684
epoch
====> Test set loss: 114.2540
after
56
====> Epoch: 56 Average loss: 133.8427
epoch
====> Test set loss: 114.4515
after
57
====> Epoch: 57 Average loss: 133.8638
epoch
====> Test set loss: 114.2408
after
58
====> Epoch: 58 Average loss: 133.8182
epoch
====> Test set loss: 114.4384
after
59
====> Epoch: 59 Average loss: 133.7540
epoch
====> Test set loss: 114.0838
after
60
====> Epoch: 60 Average loss: 133.7625
epoch
====> Test set loss: 114.1253
a

====> Epoch: 73 Average loss: 133.4961
epoch
====> Test set loss: 113.9801
after
74
====> Epoch: 74 Average loss: 133.4821
epoch
====> Test set loss: 113.7848
after
75
====> Epoch: 75 Average loss: 133.4587
epoch
====> Test set loss: 113.8615
after
76
====> Epoch: 76 Average loss: 133.4327
epoch


KeyboardInterrupt: 

In [96]:
8448 * 22 / 128

1452.0

In [140]:
2560 / 28

91.42857142857143