In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm
import imageio
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

In [2]:
GPU = True

device = "cuda" if GPU else "cpu"

In [3]:
dataset = datasets.MNIST('./mnist', train=True, download=True, transform=transforms.ToTensor())
testset = datasets.MNIST('./mnist', train=False, download=True, transform=transforms.ToTensor())

In [4]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

In [5]:
class VAE(nn.Module):
  def __init__(self):
    super(VAE, self).__init__()

    self.fc1 = nn.Linear(784, 400)
    self.fc21 = nn.Linear(400, 20)
    self.fc22 = nn.Linear(400, 20)
    self.fc3 = nn.Linear(20, 400)
    self.fc4 = nn.Linear(400, 784)

  def encode(self, x):
    h1 = F.relu(self.fc1(x))
    return self.fc21(h1), self.fc22(h1)

  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

  def decode(self, z):
    h3 = F.relu(self.fc3(z))
    return torch.sigmoid(self.fc4(h3))

  def forward(self, x):
    mu, logvar = self.encode(x.view(-1, 784))
    z = self.reparameterize(mu, logvar)
    return self.decode(z), mu, logvar


In [6]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [7]:
def loss_function(recon_x, x, mu, logvar):
  BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum")

  KLD = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

  return BCE + KLD

In [8]:
def train(epoch):
  model.train()
  train_loss = 0
  for (data, _) in tqdm(dataloader):
    data = data.to(device)
    optimizer.zero_grad()
    recon_batch, mu, logvar = model(data)
    loss = loss_function(recon_batch, data, mu, logvar)
    loss.backward()
    train_loss += loss.item()
    optimizer.step()

  print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, (train_loss/len(dataloader.dataset)/100)))

In [9]:
def test(epoch):
  model.eval()
  test_loss = 0
  with torch.no_grad():
    for batch_idx, (data, _) in tqdm(enumerate(testloader)):
      data = data.to(device)
      recon_batch, mu, logvar = model(data)
      test_loss += loss_function(recon_batch, data, mu, logvar).item()
      if batch_idx == 0:
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n], recon_batch.view(4, 1, 28, 28)[:n]])
        save_image(comparison.cpu(), 'result/reconstruction_' + str(epoch) + ".png", nrow=n)
  
  test_loss /= len(testloader.dataset)
  print('====> Test set loss: {:.4f}'.format(test_loss/100))


In [10]:
def run(mode, epochs):
  if mode == "TRAIN":
    for epoch in range(epochs):
      train(epoch + 1)
  elif mode == "TEST":
    for epoch in range(epochs):
      test(epoch + 1)

In [11]:
if not os.path.exists("result"):
    os.makedirs("result")

In [12]:
EPOCH = 50
MODE = "TRAIN"

run(MODE, EPOCH)

100%|██████████| 1875/1875 [00:07<00:00, 263.04it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 1 Average loss: 1.3565


100%|██████████| 1875/1875 [00:07<00:00, 265.66it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 2 Average loss: 1.1298


100%|██████████| 1875/1875 [00:07<00:00, 266.36it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 3 Average loss: 1.0988


100%|██████████| 1875/1875 [00:07<00:00, 267.49it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 4 Average loss: 1.0841


100%|██████████| 1875/1875 [00:07<00:00, 263.63it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 5 Average loss: 1.0756


100%|██████████| 1875/1875 [00:07<00:00, 255.02it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 6 Average loss: 1.0689


100%|██████████| 1875/1875 [00:07<00:00, 261.57it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 7 Average loss: 1.0640


100%|██████████| 1875/1875 [00:06<00:00, 268.52it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 8 Average loss: 1.0607


100%|██████████| 1875/1875 [00:07<00:00, 264.46it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 9 Average loss: 1.0569


100%|██████████| 1875/1875 [00:07<00:00, 264.87it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 10 Average loss: 1.0538


100%|██████████| 1875/1875 [00:07<00:00, 263.99it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 11 Average loss: 1.0513


100%|██████████| 1875/1875 [00:06<00:00, 268.52it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 12 Average loss: 1.0489


100%|██████████| 1875/1875 [00:07<00:00, 261.06it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 13 Average loss: 1.0469


100%|██████████| 1875/1875 [00:07<00:00, 264.70it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 14 Average loss: 1.0451


100%|██████████| 1875/1875 [00:07<00:00, 266.57it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 15 Average loss: 1.0433


100%|██████████| 1875/1875 [00:07<00:00, 262.71it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 16 Average loss: 1.0424


100%|██████████| 1875/1875 [00:07<00:00, 267.53it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 17 Average loss: 1.0408


100%|██████████| 1875/1875 [00:07<00:00, 267.77it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 18 Average loss: 1.0392


100%|██████████| 1875/1875 [00:07<00:00, 255.96it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 19 Average loss: 1.0379


100%|██████████| 1875/1875 [00:07<00:00, 266.17it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 20 Average loss: 1.0367


100%|██████████| 1875/1875 [00:07<00:00, 266.91it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 21 Average loss: 1.0359


100%|██████████| 1875/1875 [00:07<00:00, 267.83it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 22 Average loss: 1.0354


100%|██████████| 1875/1875 [00:07<00:00, 262.76it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 23 Average loss: 1.0341


100%|██████████| 1875/1875 [00:07<00:00, 267.70it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 24 Average loss: 1.0329


100%|██████████| 1875/1875 [00:07<00:00, 260.42it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 25 Average loss: 1.0318


100%|██████████| 1875/1875 [00:07<00:00, 264.92it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 26 Average loss: 1.0312


100%|██████████| 1875/1875 [00:07<00:00, 266.10it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 27 Average loss: 1.0310


100%|██████████| 1875/1875 [00:07<00:00, 266.88it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 28 Average loss: 1.0303


100%|██████████| 1875/1875 [00:07<00:00, 259.62it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 29 Average loss: 1.0297


100%|██████████| 1875/1875 [00:07<00:00, 264.25it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 30 Average loss: 1.0282


100%|██████████| 1875/1875 [00:07<00:00, 261.05it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 31 Average loss: 1.0278


100%|██████████| 1875/1875 [00:06<00:00, 275.27it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 32 Average loss: 1.0270


100%|██████████| 1875/1875 [00:07<00:00, 256.04it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 33 Average loss: 1.0265


100%|██████████| 1875/1875 [00:07<00:00, 265.78it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 34 Average loss: 1.0263


100%|██████████| 1875/1875 [00:07<00:00, 267.00it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 35 Average loss: 1.0257


100%|██████████| 1875/1875 [00:07<00:00, 265.89it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 36 Average loss: 1.0252


100%|██████████| 1875/1875 [00:07<00:00, 264.40it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 37 Average loss: 1.0249


100%|██████████| 1875/1875 [00:07<00:00, 261.14it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 38 Average loss: 1.0244


100%|██████████| 1875/1875 [00:07<00:00, 263.35it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 39 Average loss: 1.0243


100%|██████████| 1875/1875 [00:07<00:00, 262.50it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 40 Average loss: 1.0230


100%|██████████| 1875/1875 [00:07<00:00, 264.60it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 41 Average loss: 1.0230


100%|██████████| 1875/1875 [00:07<00:00, 261.41it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 42 Average loss: 1.0222


100%|██████████| 1875/1875 [00:07<00:00, 264.21it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 43 Average loss: 1.0220


100%|██████████| 1875/1875 [00:07<00:00, 266.47it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 44 Average loss: 1.0217


100%|██████████| 1875/1875 [00:07<00:00, 258.57it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 45 Average loss: 1.0213


100%|██████████| 1875/1875 [00:07<00:00, 266.94it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 46 Average loss: 1.0209


100%|██████████| 1875/1875 [00:07<00:00, 265.48it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 47 Average loss: 1.0207


100%|██████████| 1875/1875 [00:07<00:00, 256.83it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 48 Average loss: 1.0203


100%|██████████| 1875/1875 [00:07<00:00, 261.71it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]

====> Epoch: 49 Average loss: 1.0194


100%|██████████| 1875/1875 [00:06<00:00, 268.24it/s]

====> Epoch: 50 Average loss: 1.0197





In [13]:
EPOCH = 20
MODE = "TEST"

run(MODE, EPOCH)

2500it [00:03, 807.58it/s]

====> Test set loss: 1.0206



2500it [00:03, 821.35it/s]

====> Test set loss: 1.0206



2500it [00:03, 816.99it/s]

====> Test set loss: 1.0221



2500it [00:02, 836.96it/s]

====> Test set loss: 1.0207



2500it [00:03, 825.03it/s]

====> Test set loss: 1.0211



2500it [00:03, 805.60it/s]

====> Test set loss: 1.0206



2500it [00:03, 801.08it/s]

====> Test set loss: 1.0204



2500it [00:03, 821.97it/s]

====> Test set loss: 1.0206



2500it [00:03, 829.56it/s]

====> Test set loss: 1.0205



2500it [00:03, 808.86it/s]

====> Test set loss: 1.0214



2500it [00:02, 838.97it/s]

====> Test set loss: 1.0208



2500it [00:03, 817.23it/s]

====> Test set loss: 1.0216



2500it [00:03, 832.25it/s]

====> Test set loss: 1.0212



2500it [00:03, 823.45it/s]

====> Test set loss: 1.0210



2500it [00:03, 823.97it/s]

====> Test set loss: 1.0207



2500it [00:03, 815.90it/s]

====> Test set loss: 1.0208



2500it [00:03, 801.41it/s]

====> Test set loss: 1.0210



2500it [00:03, 817.30it/s]

====> Test set loss: 1.0203



2500it [00:03, 816.68it/s]

====> Test set loss: 1.0217



2500it [00:03, 821.73it/s]

====> Test set loss: 1.0212





In [14]:
def plot_reconstruct():
    images = [file for file in os.listdir('./result') if os.path.isfile(os.path.join('./result',file))]
    steps = []
    
    for idx in range(len(images)):
        steps.append(imageio.imread("result/reconstruction_" + str(idx+1) + ".png"))
    
    imageio.mimsave('./animation.gif', steps, fps=1)
        
plot_reconstruct()

![SegmentLocal](animation.gif "animation")