In [1]:
from __future__ import print_function
import argparse
import torch
import numpy as np
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.distributions import Normal, Laplace, Independent, Bernoulli, Gamma, Uniform, Beta
from torch.distributions.kl import kl_divergence

In [19]:
class VAE(nn.Module):
    def __init__(self, encoder_layers, decoder_layers, p_z, q_z, loss_dist):
        super(VAE, self).__init__()
        #self.distribution1 = dist1
        self.encoder = nn.Sequential(*encoder_layers)
        self.decoder = nn.Sequential(*decoder_layers)
        self.p_z = p_z
        self.q_z = q_z
        self.loss_dist = loss_dist
        
    def encode(self, x):
        out = self.encoder(x)
        length_out = len(out[0]) // 2
        return out[:,:length_out], out[:,length_out:]

    def reparameterize(self, q_z_given_x):
        return q_z_given_x.rsample()

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        x = x.view(-1,784)
        mu, logvar = self.encode(x.view(-1, 784))
        
        q_z_given_x = self.q_z(mu, logvar) # for KL divergence
        q_z_given_x = Independent(q_z_given_x, 1)
        
        z = self.reparameterize(q_z_given_x)
        x_hat = self.decode(z)

        p_x_given_z = self.loss_dist(x_hat) # loss function/ distribution
        p_x_given_z = Independent(p_x_given_z, 1)
        
        loss = self.loss_function(x_hat, x, q_z_given_x, p_x_given_z, z)
        return x_hat, loss
    
    def loss_function(self, x_hat, x,q_z_given_x, p_x_given_z, z):
        BCE = torch.sum(-p_x_given_z.log_prob(x))
        #KLD = q_z_given_x.log_prob(z) - self.p_z.log_prob(z)
        #print(KLD)
        KLD = kl_divergence(q_z_given_x.base_dist, self.p_z.base_dist) # vervangen en werkend krijgen 
        KLD = torch.sum(KLD.sum(len(p_z.event_shape)-1))
        return BCE + KLD

In [3]:
def normal_dist(mu, var):
    return Normal(loc=mu, scale=var)

def laplace_dist(mu, var):
    return Laplace(loc=mu, scale=var)

def gamma_dist(mu, var):
    return Gamma(mu, var)

def beta_dist(mu, var):
    return Beta(mu, var)

def bernoulli_loss(x_hat):
    return Bernoulli(x_hat)

def laplace_loss(x_hat):
    return Laplace(loc=x_hat, scale=1e-2)

In [4]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data
        optimizer.zero_grad()
        x_hat, loss = model(data)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data
            x_hat, loss = model(data)
            test_loss += loss.item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      x_hat.view(128, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [5]:
def load_data(batch_size):
    train_data = datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size, shuffle=True, **{})

    test_data = datasets.MNIST('../data', train=False,
                       transform=transforms.ToTensor())
    test_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=batch_size, shuffle=True, **{})
    return train_data, train_loader, test_data, test_loader

In [33]:
x_dim = 784
z_dim = 2

encoder_layers = [
    nn.Linear(x_dim, 400),
    nn.ReLU(True),
    nn.Linear(400, 40),
    nn.ReLU(True),
    nn.Linear(40, z_dim*2),
    nn.Softplus()
    ]

decoder_layers = [
    nn.Linear(z_dim, 40),
    nn.ReLU(True),
    nn.Linear(40, 400),
    nn.ReLU(True),
    nn.Linear(400, x_dim),
    nn.Sigmoid()
    ]

lr = 1e-3
batch_size = 128
epochs = 20

# prior
#p_z = Normal(loc=torch.zeros(1,z_dim), scale=1)
p_z = Beta(torch.tensor([0.3, 0.3]), torch.tensor([0.3, 0.3]))
p_z = Independent(p_z,1)
# target distribution
q_z = beta_dist

# loss function
loss_dist = bernoulli_loss

train_data, train_loader, test_data, test_loader = load_data(batch_size)
model = VAE(encoder_layers, decoder_layers, p_z, q_z, loss_dist)
optimizer = optim.Adam(model.parameters(), lr=lr)
if __name__ == "__main__":
    for epoch in range(1, epochs + 1):
        train(epoch)
        test(epoch)
        #with torch.no_grad():
        #    sample = torch.randn(64, 20)
        #    sample = model.decode(sample)
        #    save_image(sample.view(64, 1, 28, 28),
        #               'results/sample_' + str(epoch) + '.png')




KeyboardInterrupt: 

In [24]:
xv = np.arange(0, 4, .05)
yv = np.arange(0, 4, .05)
sample = np.zeros([len(yv)*len(xv), 2])
counter = 0
for i in xv:
    for j in yv:
        sample[counter] = [i, j]
        counter += 1

images = model.decode(torch.tensor(sample, dtype=torch.float)).detach().numpy()
image = np.zeros([len(xv)*28, len(yv)*28])
counter = 0
for i in range(len(xv)):
    for j in range(len(yv)):
        image[i*28:i*28+28,j*28:j*28+28] = images[counter].reshape((28,28))
        counter += 1
        
save_image(torch.tensor(image),'results/sample_norm2laplace' + str(epoch) + '.png')

In [71]:
x_dim = 784
z_dim = 2
encoder_layers = [
    nn.Linear(x_dim, 128),
    nn.ReLU(True),
    nn.Linear(128, z_dim*2)
    ]

decoder_layers = [
    nn.Linear(2, 128),
    nn.ReLU(True),
    nn.Linear(128, x_dim)
    ]
print(encoder_layers)

[Linear(in_features=784, out_features=128, bias=True), ReLU(inplace), Linear(in_features=128, out_features=4, bias=True)]
