In [None]:
import tensorflow as tf

In [None]:
def tf_loss(x, y, logvar, mu, kl_tolerance, z_size):
    r_loss = tf.reduce_sum(
      tf.square(x - y),
      reduction_indices = [1,2,3]
    )
    r_loss = tf.reduce_mean(r_loss)

    # augmented kl loss per dim
    kl_loss = - 0.5 * tf.reduce_sum(
      (1 + logvar - tf.square(mu) - tf.exp(logvar)),
      reduction_indices = 1
    )
    kl_loss = tf.maximum(kl_loss, kl_tolerance * z_size)
    kl_loss = tf.reduce_mean(kl_loss)

    loss = r_loss + kl_loss

In [None]:
kl_tolerance = 0.5
z_size = 32

In [None]:
x = tf.random.normal([1, 1, 3, 5], dtype=tf.float32)
y = tf.random.normal([1, 1, 3, 5], dtype=tf.float32)

In [None]:
def r_loss_tf(x, y):
    return tf.reduce_mean(tf.reduce_sum(tf.square(x - y), reduction_indices = [1,2,3]))

In [None]:
def kld_tf(logvar, mu, kl_tolerance=0.5, z_size=32):
    kl_loss = - 0.5 * tf.reduce_sum(
      (1 + logvar - tf.square(mu) - tf.exp(logvar)),
      reduction_indices = 1
    )
    kl_loss = tf.maximum(kl_loss, kl_tolerance * z_size)
    kl_loss = tf.reduce_mean(kl_loss)
    return kl_loss

In [None]:
tf.constant(2.0, shape=[1, 1, 2, 3], dtype=tf.float32)

In [None]:
tf.constant(1.0, shape=[1, 1, 2, 3])

In [None]:
# x = tf.random.normal([1, 1, 3, 5], dtype=tf.float32)
# y = tf.random.normal([1, 1, 3, 5], dtype=tf.float32)
x = tf.constant(1.0, shape=[2, 1, 2, 3], dtype=tf.float32)
y = tf.constant(2.0, shape=[2, 1, 2, 3], dtype=tf.float32)
with tf.Session() as sess:
    x = sess.run(x)
    y = sess.run(y)
    print(sess.run(r_loss_tf(x,y)))

In [None]:
x = tf.constant(1.0, shape=[2, 32], dtype=tf.float32)
y = tf.constant(2.0, shape=[2, 32], dtype=tf.float32)
with tf.Session() as sess:
    x = sess.run(x)
    y = sess.run(y)
    print(sess.run(kld_tf(x,y)))

In [None]:
import torch.nn.functional as F

In [None]:
def r_loss_pytorch(x, y):
    return F.mse_loss(x, y, reduction='none')

In [None]:
def kld_loss_pytorch(logvar, mu):
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    print(kld_loss)
    return torch.mean(kld_loss)

In [None]:
x = torch.ones(2,1,2,3)
y = torch.ones(2,1,2,3) * 2

In [None]:
r_loss_pytorch(x, y)

In [None]:
x = torch.ones(2, 32)
y = torch.ones(2, 32) * 2

In [None]:
kld_loss_pytorch(x, y)

In [None]:
import torch
import torch.nn as nn
from utils import initialize_weights
from utils.gelu import GELU


class InferenceNetwork(nn.Module):
    def __init__(self, params):
        super(InferenceNetwork, self).__init__()
        self.params = params
        self.fc = nn.Linear(in_features=self.params.input_dim, out_features=self.params.hidden_dim)
        self.fc_mu = nn.Linear(in_features=self.params.hidden_dim, out_features=self.params.latent_dim)
        self.fc_logvar = nn.Linear(in_features=self.params.hidden_dim, out_features=self.params.latent_dim)
        self.activation_fn = GELU()
        initialize_weights(self)

    def forward(self, x):
        x = x.view(-1, self.params.input_dim)
        h1 = self.activation_fn(self.fc(x))
        mu = self.activation_fn(self.fc_mu(h1))
        logvar = self.activation_fn(self.fc_logvar(h1))
        return mu, logvar


class GenerativeNetwork(nn.Module):
    def __init__(self, params):
        super(GenerativeNetwork, self).__init__()
        self.params = params
        self.fc1 = nn.Linear(in_features=self.params.latent_dim, out_features=self.params.hidden_dim)
        self.fc2 = nn.Linear(in_features=self.params.hidden_dim, out_features=self.params.input_dim)
        self.activation_fn = GELU()
        initialize_weights(self)

    def forward(self, z):
        h3 = self.activation_fn(self.fc1(z))
        out = torch.sigmoid(self.fc2(h3))
        return out


class VariationalAutoencoder(nn.Module):
    def __init__(self, params):
        super(VariationalAutoencoder, self).__init__()
        self.params = params
        self.inference_network = InferenceNetwork(params=params)
        self.generative_network = GenerativeNetwork(params=params)

    def sample(self, eps=None):
        if eps is None:
            eps = torch.randn(torch.Size([1, self.params.hidden_dim]))
        return self.decode(eps).view(self.params.num_examples_to_generate, 1, 28, 28)

    @staticmethod
    def reparameterization(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def encode(self, x):
        mu, logvar = self.inference_network(x)
        return mu, logvar

    def decode(self, z):
        return self.generative_network(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterization(mu, logvar)
        x_reconstructed = self.decode(z)
        return x_reconstructed, mu, logvar, z

In [None]:
json_path = '/home/aktersnurra/Documents/Projects/variational-autoencoders/experiments/vae/params.json'

In [None]:
from utils.misc import create_dir, create_log_dir, load_checkpoint, save_checkpoint, tab_printer, Params, set_logger

In [None]:
params = Params(json_path)

In [None]:
vae = VariationalAutoencoder(params)

In [None]:
img = torch.randn(2, 1, 28, 28)

In [None]:
X_reconstructed, mu, logvar, z = vae(img)

In [None]:
logvar.shape