In [None]:
import torch
from torch import nn, Tensor
import numpy as np
import matplotlib.pyplot as plt


# Choose device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('cuda available')
    # mlflow.log_param('device', torch.cuda.get_device_name(device))
else:
    device = torch.device('cpu')
    print('cuda not available')
    # mlflow.log_param('device', 'cpu')

# Define model

In [None]:
class MyModel(torch.nn.Module):
    def __init__(self, latent_dim):
        super(MyModel, self).__init__()
        self.latent_dim = latent_dim

        self.mlp1 = torch.nn.Sequential(
                        torch.nn.Linear(1, 64),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(64, 64),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(64, latent_dim*2)
                        )
        self.mlp2 = torch.nn.Sequential(
                        torch.nn.Linear(self.latent_dim, 64),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(64, 64),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(64, 1)
                        )

    def forward(self, x, return_latent=False):
        logsig_mu_i = self.mlp1(x)
        logsig_i, mu_i = logsig_mu_i[..., :self.latent_dim], logsig_mu_i[..., self.latent_dim:]

        # Sample
        eps = torch.randn_like(logsig_i)
        z_i = eps*torch.exp(logsig_i) + mu_i
        x = self.mlp2(z_i)

        if return_latent:
            return logsig_i, mu_i, x
        else:
            return (x, )

    def sample(self, n_samples=1):
        z = torch.randn(n_samples, self.latent_dim, device=device)
        x = self.mlp2(z)
        return x

In [None]:
# using latent_dim=2 here to be able to visualize the latent space, but a higher value would be more useful in practice
model = MyModel(latent_dim=16).to(device=device)
print(model)
n_params = sum(p.numel() for p in model.parameters())
print('Total nr of parameters:', n_params)

# Training

In [None]:
def KL_loss(logsig_i, mu_i, logsig_f=torch.tensor(0.0), mu_f=torch.tensor(0.0)):
    temp = 2*(logsig_f-logsig_i) - 1 + torch.exp(logsig_i)**2/torch.exp(logsig_f)**2 + (mu_f - mu_i)**2/torch.exp(logsig_f)**2
    temp = 0.5*torch.sum(temp, axis=-1)
    return torch.mean(temp)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
loss_fn = nn.MSELoss()

KL_increase_range = 50  # nr of epochs over which we increase the weight of the KL loss
for epoch in range(10000):
    KL_weight = np.clip(epoch/KL_increase_range, a_min=0.0, a_max=1.0)/2
    x = torch.randint(0, 2, (256, 1), dtype=torch.float32, device=device)*2-1
    logsig_i, mu_i, out = model(x, return_latent=True)

    optimizer.zero_grad()

    MSE_loss = loss_fn(out, x)
    KL_loss_value = KL_loss(logsig_i, mu_i)
    loss = (1.0-KL_weight)*MSE_loss + KL_weight*KL_loss_value
    loss.backward()

    optimizer.step()

    if epoch % 100 == 0:
        print(f'epoch {epoch}, total loss: {loss.item():.4f}, MSE loss: {MSE_loss.item():.4f}, KL loss: {KL_loss_value.item():.4f}, KL weight: {KL_weight:.4f}')

# Sampling

In [None]:
samples = model.sample(10000).cpu().detach().numpy()

plt.figure(figsize=(4, 3), dpi=200)
bins = np.linspace(-2, 2, 50)
plt.hist(samples, bins=bins, density=True)
# plt.yscale('log')
plt.xlabel('x')
plt.ylabel('Density')
plt.tight_layout()
plt.show()

# Wasserstein distance

In [None]:
from scipy.stats import wasserstein_distance

In [None]:
n_samples = 10000
samples = model.sample(10000).cpu().detach().numpy().flatten()
real = np.array([-1, 1])

print('samples.shape:', samples.shape)
print('real.shape:', real.shape)

print('Wasserstein distance:', wasserstein_distance(samples, real))

In [None]:
# Simple test to verify that the wasserstein_distance function works as expected
print('Wasserstein distance:', wasserstein_distance([0.0], [1.0, -1.0]))