In [36]:
from tqdm import tqdm
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from model import VAE

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
print(device)

cuda:0


## Sanity check

In [7]:
vae = VAE(input_dim=28).to(device)

In [8]:
x = torch.randn(4, 1, 28, 28).to(device)

In [9]:
y = vae(x)

In [14]:
print(type(y))
print(y[0].shape)
print(y[1].shape)
print(y[2].shape)

<class 'tuple'>
torch.Size([4, 1, 28, 28])
torch.Size([4, 1, 28, 200])
torch.Size([4, 1, 28, 200])


## Training

In [22]:
input_dim = 784
# input_dim = (1, 28, 28)
h_dim = 200
z_dim = 20
num_epochs = 10
batch_size = 32
learning_rate = 3e-4

In [17]:
dataset = datasets.MNIST(
    root="data/", train=True, transform=transforms.ToTensor(), download=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:07<00:00, 1262068.29it/s]


Extracting data/MNIST\raw\train-images-idx3-ubyte.gz to data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 160913.51it/s]


Extracting data/MNIST\raw\train-labels-idx1-ubyte.gz to data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 709222.25it/s]


Extracting data/MNIST\raw\t10k-images-idx3-ubyte.gz to data/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4507933.93it/s]

Extracting data/MNIST\raw\t10k-labels-idx1-ubyte.gz to data/MNIST\raw






In [19]:
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [23]:
model = VAE(input_dim=input_dim, h_dim=h_dim, z_dim=z_dim).to(device)

In [24]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [25]:
loss_fn = nn.BCELoss(reduction="sum")

### Actual training loop

In [28]:
for epoch in range(num_epochs):
    loop = tqdm(enumerate(train_loader))
    for i, (x, _) in loop:
        # forward pass
        x = x.to(device).view(x.shape[0], input_dim)
        x_reconstucted, mean, std = model(x)

        # loss
        reconstruction_loss = loss_fn(x_reconstucted, x)
        kl_div = -torch.sum(1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2))

        # backpropagation
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update tqdm loop
        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        loop.set_postfix(
            loss=loss.item(),
            reconstruction_loss=reconstruction_loss.item(),
            kl_div=kl_div.item(),
        )

Epoch [1/10]: : 1875it [00:25, 72.44it/s, kl_div=781, loss=4.69e+3, reconstruction_loss=3.91e+3] 
Epoch [2/10]: : 1875it [00:24, 77.57it/s, kl_div=935, loss=4.71e+3, reconstruction_loss=3.78e+3]    
Epoch [3/10]: : 1875it [00:27, 68.28it/s, kl_div=1.05e+3, loss=4.8e+3, reconstruction_loss=3.75e+3] 
Epoch [4/10]: : 1875it [00:26, 69.46it/s, kl_div=951, loss=4.38e+3, reconstruction_loss=3.43e+3]    
Epoch [5/10]: : 1875it [00:26, 70.93it/s, kl_div=1.06e+3, loss=4.78e+3, reconstruction_loss=3.72e+3]
Epoch [6/10]: : 1875it [00:27, 68.57it/s, kl_div=1.06e+3, loss=4.36e+3, reconstruction_loss=3.3e+3] 
Epoch [7/10]: : 1875it [00:27, 67.58it/s, kl_div=1.08e+3, loss=4.4e+3, reconstruction_loss=3.32e+3] 
Epoch [8/10]: : 1875it [00:26, 69.88it/s, kl_div=1.09e+3, loss=4.85e+3, reconstruction_loss=3.76e+3]
Epoch [9/10]: : 1875it [00:26, 70.95it/s, kl_div=1.16e+3, loss=4.72e+3, reconstruction_loss=3.56e+3]
Epoch [10/10]: : 1875it [00:25, 73.47it/s, kl_div=1.14e+3, loss=4.46e+3, reconstruction_loss=3

In [29]:
torch.save(model.state_dict(), "vae.pth")

In [34]:
# sanity check model load
test_model = torch.load("vae.pth")
print(type(test_model))
print(test_model.keys())

<class 'collections.OrderedDict'>
odict_keys(['img_2_hidden.weight', 'img_2_hidden.bias', 'hidden_2_mean.weight', 'hidden_2_mean.bias', 'hidden_2_std.weight', 'hidden_2_std.bias', 'z_2_hidden.weight', 'z_2_hidden.bias', 'hidden_2_img.weight', 'hidden_2_img.bias'])


## Evaluation

In [35]:
model.eval()

VAE(
  (relu): ReLU(inplace=True)
  (img_2_hidden): Linear(in_features=784, out_features=200, bias=True)
  (hidden_2_mean): Linear(in_features=200, out_features=20, bias=True)
  (hidden_2_std): Linear(in_features=200, out_features=20, bias=True)
  (z_2_hidden): Linear(in_features=20, out_features=200, bias=True)
  (hidden_2_img): Linear(in_features=200, out_features=784, bias=True)
)

In [39]:
def plot_reconstructed_image(model, x):
    with torch.no_grad():
        x_reconstructed, _, _ = model(x)
        x_reconstructed = x_reconstructed.view(-1, 1, 28, 28)
        save_image(x, "real_image.jpg")
        save_image(x_reconstructed, "reconstructed_image.jpg")

In [40]:
for x, _ in train_loader:
    x = x.to(device).view(x.shape[0], input_dim)
    plot_reconstructed_image(model, x)
    break