In [0]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm

In [0]:
GPU = True

device = "cuda" if GPU else "cpu"

In [0]:
dataset = datasets.MNIST('./mnist', train=True, download=True, transform=transforms.ToTensor())
testset = datasets.MNIST('./mnist', train=False, download=True, transform=transforms.ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw
Processing...
Done!


In [0]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

In [0]:
class VAE(nn.Module):
  def __init__(self):
    super(VAE, self).__init__()

    self.fc1 = nn.Linear(784, 400)
    self.fc21 = nn.Linear(400, 20)
    self.fc22 = nn.Linear(400, 20)
    self.fc3 = nn.Linear(20, 400)
    self.fc4 = nn.Linear(400, 784)

  def encode(self, x):
    h1 = F.relu(self.fc1(x))
    return self.fc21(h1), self.fc22(h1)

  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

  def decode(self, z):
    h3 = F.relu(self.fc3(z))
    return torch.sigmoid(self.fc4(h3))

  def forward(self, x):
    mu, logvar = self.encode(x.view(-1, 784))
    z = self.reparameterize(mu, logvar)
    return self.decode(z), mu, logvar


In [0]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [0]:
def loss_function(recon_x, x, mu, logvar):
  BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum")

  KLD = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

  return BCE + KLD

In [0]:
def train(epoch):
  model.train()
  train_loss = 0
  for (data, _) in tqdm(dataloader):
    data = data.to(device)
    optimizer.zero_grad()
    recon_batch, mu, logvar = model(data)
    loss = loss_function(recon_batch, data, mu, logvar)
    loss.backward()
    train_loss += loss.item()
    optimizer.step()

  print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, (train_loss/len(dataloader.dataset)/100)))

In [0]:
def test(epoch):
  model.eval()
  test_loss = 0
  with torch.no_grad():
    for batch_idx, (data, _) in tqdm(enumerate(testloader)):
      data = data.to(device)
      recon_batch, mu, logvar = model(data)
      test_loss += loss_function(recon_batch, data, mu, logvar).item()
      if batch_idx == 0:
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n], recon_batch.view(4, 1, 28, 28)[:n]])
        save_image(comparison.cpu(), 'reconstruction_' + str(epoch) + ".png", nrow=n)
  
  test_loss /= len(testloader.dataset)
  print('====> Test set loss: {:.4f}'.format(test_loss/100))


In [0]:
def run(mode, epochs):
  if mode == "TRAIN":
    for epoch in range(epochs):
      train(epoch + 1)
  elif mode == "TEST":
    for epoch in range(epochs):
      test(epoch + 1)

In [0]:
EPOCH = 30
MODE = "TRAIN"

run(MODE, EPOCH)

100%|██████████| 1875/1875 [00:12<00:00, 146.60it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 1 Average loss: 1.3622


100%|██████████| 1875/1875 [00:12<00:00, 147.82it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 2 Average loss: 1.1326


100%|██████████| 1875/1875 [00:12<00:00, 158.96it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 3 Average loss: 1.1017


100%|██████████| 1875/1875 [00:13<00:00, 143.23it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 4 Average loss: 1.0869


100%|██████████| 1875/1875 [00:12<00:00, 147.90it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 5 Average loss: 1.0779


100%|██████████| 1875/1875 [00:12<00:00, 161.07it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 6 Average loss: 1.0710


100%|██████████| 1875/1875 [00:12<00:00, 149.84it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 7 Average loss: 1.0664


100%|██████████| 1875/1875 [00:12<00:00, 150.54it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 8 Average loss: 1.0623


100%|██████████| 1875/1875 [00:12<00:00, 149.79it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 9 Average loss: 1.0583


100%|██████████| 1875/1875 [00:12<00:00, 149.05it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 10 Average loss: 1.0554


100%|██████████| 1875/1875 [00:12<00:00, 150.51it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 11 Average loss: 1.0524


100%|██████████| 1875/1875 [00:12<00:00, 148.90it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 12 Average loss: 1.0507


100%|██████████| 1875/1875 [00:12<00:00, 151.43it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 13 Average loss: 1.0482


100%|██████████| 1875/1875 [00:12<00:00, 151.18it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 14 Average loss: 1.0462


100%|██████████| 1875/1875 [00:12<00:00, 151.86it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 15 Average loss: 1.0450


100%|██████████| 1875/1875 [00:12<00:00, 149.43it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 16 Average loss: 1.0434


100%|██████████| 1875/1875 [00:12<00:00, 148.16it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 17 Average loss: 1.0416


100%|██████████| 1875/1875 [00:12<00:00, 150.13it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 18 Average loss: 1.0406


100%|██████████| 1875/1875 [00:12<00:00, 147.60it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 19 Average loss: 1.0387


100%|██████████| 1875/1875 [00:12<00:00, 160.19it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 20 Average loss: 1.0378


100%|██████████| 1875/1875 [00:12<00:00, 147.31it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 21 Average loss: 1.0367


100%|██████████| 1875/1875 [00:12<00:00, 149.11it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 22 Average loss: 1.0357


100%|██████████| 1875/1875 [00:12<00:00, 148.16it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 23 Average loss: 1.0348


100%|██████████| 1875/1875 [00:12<00:00, 149.44it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 24 Average loss: 1.0340


100%|██████████| 1875/1875 [00:12<00:00, 150.01it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 25 Average loss: 1.0331


100%|██████████| 1875/1875 [00:12<00:00, 149.61it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 26 Average loss: 1.0321


100%|██████████| 1875/1875 [00:12<00:00, 148.20it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 27 Average loss: 1.0312


100%|██████████| 1875/1875 [00:12<00:00, 149.88it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 28 Average loss: 1.0306


100%|██████████| 1875/1875 [00:12<00:00, 145.11it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 29 Average loss: 1.0294


100%|██████████| 1875/1875 [00:12<00:00, 149.72it/s]

====> Epoch: 30 Average loss: 1.0291





In [0]:
EPOCH = 30
MODE = "TEST"

run(MODE, EPOCH)

2500it [00:05, 423.06it/s]

====> Test set loss: 1.0297



2500it [00:05, 425.65it/s]

====> Test set loss: 1.0303



2500it [00:05, 433.59it/s]

====> Test set loss: 1.0298



2500it [00:05, 438.22it/s]

====> Test set loss: 1.0298



2500it [00:05, 429.32it/s]

====> Test set loss: 1.0296



2500it [00:05, 428.80it/s]

====> Test set loss: 1.0298



2500it [00:05, 421.22it/s]

====> Test set loss: 1.0301



2500it [00:05, 421.75it/s]


====> Test set loss: 1.0304


2500it [00:05, 444.43it/s]

====> Test set loss: 1.0298



2500it [00:05, 421.65it/s]

====> Test set loss: 1.0300



2500it [00:05, 431.88it/s]

====> Test set loss: 1.0297



2500it [00:05, 440.67it/s]

====> Test set loss: 1.0298



2500it [00:05, 424.72it/s]

====> Test set loss: 1.0302



2500it [00:05, 426.04it/s]

====> Test set loss: 1.0302



2500it [00:05, 432.93it/s]

====> Test set loss: 1.0297



2500it [00:05, 438.57it/s]


====> Test set loss: 1.0295


2500it [00:06, 410.58it/s]

====> Test set loss: 1.0298



2500it [00:06, 411.09it/s]

====> Test set loss: 1.0292



2500it [00:05, 431.16it/s]

====> Test set loss: 1.0303



2500it [00:06, 412.72it/s]

====> Test set loss: 1.0297



2500it [00:05, 425.92it/s]

====> Test set loss: 1.0305



2500it [00:05, 424.01it/s]

====> Test set loss: 1.0304



2500it [00:05, 418.06it/s]

====> Test set loss: 1.0302



2500it [00:06, 414.20it/s]

====> Test set loss: 1.0302



2500it [00:06, 415.87it/s]

====> Test set loss: 1.0297



2500it [00:05, 433.89it/s]

====> Test set loss: 1.0293



2500it [00:06, 408.14it/s]

====> Test set loss: 1.0299



2500it [00:06, 412.30it/s]

====> Test set loss: 1.0290



2500it [00:05, 422.37it/s]

====> Test set loss: 1.0304



2500it [00:06, 401.09it/s]

====> Test set loss: 1.0302



