In [1]:
import numpy as np
import plotter as pltr
pltr.set_backend(pltr.MatplotlibBackend)

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import torch.utils.data as dutils

In [3]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

In [4]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 2)
        self.fc22 = nn.Linear(400, 2)
        self.fc3 = nn.Linear(2, 400)
        self.fc4 = nn.Linear(400, 784)
        
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
    
    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
    
model = VAE().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [11]:
trainset = datasets.MNIST('/data/pytorch/mnist', train=True, download=False, transform=transforms.ToTensor())
train_loader = dutils.DataLoader(trainset, batch_size=2, shuffle=True, num_workers=4)

testset = datasets.MNIST('/data/pytorch/mnist', train=False, download=False, transform=transforms.ToTensor())
test_loader = dutils.DataLoader(testset, batch_size=32, num_workers=4)

In [12]:
def loss_fn(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [13]:
# Train a single batch
model.train()
x, _ = next(iter(train_loader))
x = x.to(DEVICE)
optimizer.zero_grad()
recon_x, mu, logvar = model(x)



In [14]:
print(f'images.size() = {images.size()}')
print(f'recon_images.size() = {recon_images.size()}')
print(f'mu.size() = {mu.size()}')
print(f'logvar.size() = {logvar.size()}')

images.size() = torch.Size([2, 1, 28, 28])
recon_images.size() = torch.Size([2, 784])
mu.size() = torch.Size([2, 2])
logvar.size() = torch.Size([2, 2])


In [15]:
mu

tensor([[-0.0141,  0.1251],
        [ 0.0532,  0.0810]], device='cuda:0', grad_fn=<ThAddmmBackward>)

In [16]:
logvar

tensor([[0.1715, 0.0826],
        [0.1893, 0.0991]], device='cuda:0', grad_fn=<ThAddmmBackward>)

In [None]:
logvar + mu

In [None]:
0.0948 + 0.1682

In [None]:
0.2821 + 0.263

In [None]:
torch.sum(mu)