In [2]:
import torch
import torch.nn as nn
from torch.distributions.multivariate_normal import MultivariateNormal

from torchvision import datasets, transforms
from torchvision.utils import save_image
import torch.optim as optim

In [19]:
class VAE(nn.Module):
    def __init__(self, D_in, H, D_out):
        super().__init__()

        self.D_in = D_in
        self.H = H
        self.D_out = D_out

        self.inputer_layer = nn.Linear(D_in, H)
        self.hidden_layer_mean = nn.Linear(H, D_out)
        self.hidden_layer_var = nn.Linear(H, D_out)

        self.recon_layer = nn.Linear(D_out, H)
        self.recon_output = nn.Linear(H, D_in)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def encode(self, inp):
        h = self.inputer_layer(inp)
        h = self.sigmoid(h)
        means = self.hidden_layer_mean(h)
        log_vars = self.hidden_layer_var(h)
        return means, log_vars
    
    def decoder(self, means, log_vars):
        std_devs = torch.pow(2, log_vars) ** 0.5
        aux = MultivariateNormal(torch.zeros(self.D_out), torch.eye(self.D_out)).sample()
        sample = means + aux * std_devs

        h_vec = self.recon_layer(sample)
        h_vec = self.tanh(h_vec)
        output = self.sigmoid(self.recon_output(h_vec))
        return output
    
    def forward(self, inp):
        means, log_vars = self.encode(inp)
        output = self.decoder(means, log_vars)
        return output, means, log_vars
    
    def reconstruct(self, sample):
        h_vec = self.recon_layer(sample)
        h_vec = self.tanh(h_vec)
        output = self.sigmoid(self.recon_output(h_vec))
        return output

In [None]:
def compute_loss(inp, recon_inp, means, log_vars):
    
    kl_loss = -0.5 * torch.sum(1 + log_vars - means ** 2 - torch.pow(2, log_vars))

    loss = nn.BCELoss(reduce='sum')
    recon_loss = loss(recon_inp, inp)
    return kl_loss + recon_loss

D_in = 28*28
H = 200
D_out = 20

vae = VAE(D_in, H, D_out)
vae.to("cpu")

optimizer = optim.Adam(vae.parameters(), lr=1e-4)

trainloader = torch.utils.data.DataLoader(
    datasets.MNIST("./ds", True, transforms.ToTensor()),
    100,
    True
)

for i in range(1, 10 + 1):
    running_loss = 0.0
    for imgs, _ in trainloader:

        imgs = imgs.view(-1, 28*28)

        output, means, log_vars = vae(imgs)

        loss = compute_loss(imgs, output, means, log_vars)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss

    print("Epoch {}, Loss {:5.3f}".format(i, running_loss / len(trainloader)))

Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 2



Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 28])
After:  torch.Size([100, 784])
Before:  torch.Size([100, 1, 28, 2

KeyboardInterrupt: 

In [None]:
len(trainloader)

60000

In [33]:
x = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 3])]
len(x), torch.stack(x)

(2,
 tensor([[1, 2, 3],
         [4, 5, 3]]))

In [45]:
torch.stack(x, dim=0)

tensor([[1, 2, 3],
        [4, 5, 3]])

In [55]:
cated = torch.cat(x)
cated, torch.all(cated.view(2, 3) == torch.stack(x, dim=0))

(tensor([1, 2, 3, 4, 5, 3]), tensor(True))