In [1]:
import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
import os
import time
from tqdm.notebook import tqdm

BATCH_SIZE = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_set = datasets.MNIST(root='data', train=True, download=True, transform=transforms.ToTensor())
test_val_set = datasets.MNIST(root='data', train=False, download=True, transform=transforms.ToTensor())
# split test and validation set
test_size = 0.5
test_set_size = int(len(test_val_set) * test_size)
val_set_size = len(test_val_set) - test_set_size
test_set, val_set = torch.utils.data.random_split(test_val_set, [test_set_size, val_set_size])

train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=BATCH_SIZE, shuffle=False)


In [2]:
from VAE import VAE, loss
net = VAE((1, 28, 28), nhid = 4)
net.to(device)
save_name = "VAE.pt"

lr = 0.01
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay = 0.0001)

def adjust_lr(optimizer, decay_rate=0.95):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= decay_rate


if os.path.exists(save_name):
    print("Model parameters have already been trained before. Retrain ? (y/n)")
    ans = input()
    if not (ans == 'y'):
        checkpoint = torch.load(save_name, map_location = device)
        net.load_state_dict(checkpoint["net"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        for g in optimizer.param_groups:
            g['lr'] = lr

max_epochs = 1000
net = net.to(device)

print("training on ", device)


for epoch in range(max_epochs):

    train_loss, n, start = 0.0, 0, time.time()
    for X, _ in tqdm(train_loader, ncols = 50):
        X = X.to(device)
        X_hat, mean, logvar = net(X)

        l = loss(X, X_hat, mean, logvar).to(device)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()

        train_loss += l.cpu().item()
        n += X.shape[0]

    train_loss /= n
    print('epoch %d, train loss %.4f , time %.1f sec'
          % (epoch, train_loss, time.time() - start))
    
    adjust_lr(optimizer)

checkpoint = torch.load(early_stop.save_name)
net.load_state_dict(checkpoint["net"])




VAE(
  (encoder): Encoder(
    (encode): Sequential(
      (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace=True)
      (10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): ReLU(inplace=True)
      (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (14): Flatten()
      (15): MLP(
     

  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 0, train loss 140.3183 , time 11.6 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 1, train loss 128.0707 , time 10.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 2, train loss 125.1419 , time 10.7 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 3, train loss 123.3691 , time 11.1 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 4, train loss 122.0766 , time 11.5 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 5, train loss 120.8391 , time 12.3 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 6, train loss 120.1327 , time 11.3 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 7, train loss 119.2385 , time 11.0 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 8, train loss 118.5161 , time 10.6 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 9, train loss 117.7113 , time 11.0 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 10, train loss 117.4317 , time 10.7 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 11, train loss 116.8660 , time 11.5 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 12, train loss 116.4605 , time 11.8 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 13, train loss 116.0145 , time 11.0 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 14, train loss 115.3991 , time 11.2 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 15, train loss 115.2768 , time 10.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 16, train loss 114.8443 , time 10.7 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 17, train loss 114.6132 , time 10.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 18, train loss 114.3245 , time 11.4 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 19, train loss 113.9959 , time 11.2 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 20, train loss 113.7202 , time 11.0 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 21, train loss 113.3346 , time 10.7 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 22, train loss 113.3067 , time 10.8 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 23, train loss 112.9569 , time 11.1 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 24, train loss 112.8102 , time 10.3 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 25, train loss 112.6772 , time 9.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 26, train loss 112.4651 , time 9.7 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 27, train loss 112.3379 , time 10.2 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 28, train loss 112.1108 , time 10.7 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 29, train loss 111.8277 , time 10.0 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 30, train loss 111.7063 , time 12.3 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 31, train loss 111.5926 , time 12.0 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 32, train loss 111.4857 , time 11.0 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 33, train loss 111.3461 , time 10.1 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 34, train loss 111.2347 , time 10.0 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 35, train loss 111.0625 , time 9.7 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 36, train loss 110.9254 , time 9.6 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 37, train loss 110.8579 , time 9.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 38, train loss 110.7646 , time 10.1 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 39, train loss 110.6010 , time 10.5 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 40, train loss 110.5124 , time 9.8 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 41, train loss 110.5119 , time 9.8 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 42, train loss 110.3605 , time 10.3 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 43, train loss 110.2917 , time 9.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 44, train loss 110.1964 , time 10.2 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 45, train loss 110.1276 , time 10.2 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 46, train loss 110.0340 , time 11.3 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 47, train loss 109.9692 , time 11.8 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 48, train loss 109.8781 , time 10.8 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 49, train loss 109.8263 , time 9.8 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 50, train loss 109.8307 , time 9.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 51, train loss 109.7306 , time 9.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 52, train loss 109.6174 , time 9.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 53, train loss 109.6784 , time 9.8 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 54, train loss 109.5702 , time 10.2 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 55, train loss 109.5486 , time 9.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 56, train loss 109.4788 , time 10.2 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 57, train loss 109.4131 , time 10.3 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 58, train loss 109.4670 , time 9.8 sec


  0%|                     | 0/600 [00:00<?, ?it/s]

epoch 59, train loss 109.4109 , time 10.9 sec


  0%|                     | 0/600 [00:00<?, ?it/s]