In [1]:
import torch
import geomloss

import matplotlib.pyplot as plt

In [2]:
def generate_gaussian_params(d, sigma_mu=0.1, sigma_cov=0.1, seed=0):
    
    torch.manual_seed(seed)
    
    mu = sigma_mu * torch.randn(d)

    A = sigma_cov * torch.randn(d, d)
    cov = A @ A.T + 1e-2 * torch.eye(d)

    return mu, cov

def generate_gaussian_data(mu, cov, n_samples, seed=0):

    torch.manual_seed(seed)

    L = torch.linalg.cholesky(cov)
    d = mu.shape[0]

    z = torch.randn(n_samples, d)
    samples = mu + z @ L.T

    return samples

In [3]:
def potential_grad(samples, mu, cov):

    diff = samples - mu
    precision = torch.linalg.inv(cov)

    grad = torch.linalg.solve(cov, diff.T).T
    
    return grad

In [4]:
def evolve_langevin(samples, mu, cov, eps=1e-2, n_evolution=1, seed=0):

    torch.manual_seed(seed)
    evolved_samples = samples.clone()

    for i in range(n_evolution):
        grad = potential_grad(evolved_samples, mu, cov)
        noise = torch.randn_like(evolved_samples)
        evolved_samples = evolved_samples - eps * grad + torch.sqrt(torch.tensor(2 * eps)) * noise

    return evolved_samples

In [5]:
sinkhorn_loss = geomloss.SamplesLoss(loss="sinkhorn", p=2, blur=0.1)
energy_loss = geomloss.SamplesLoss(loss='energy')

def lm_loss(samples, mu, cov):

    evolved_samples = evolve_langevin(samples, mu, cov)

    return energy_loss(samples, evolved_samples)

In [6]:
def lm_optimize(samples, n_epochs=1000, lr=1e-2, seed=10, eps=1e-2, plot_every=500):

    d = samples.shape[1]
    torch.manual_seed(seed)

    mu = torch.randn(d, requires_grad=True)
    a = torch.randn(d, d, requires_grad=True)

    optimizer = torch.optim.Adam([mu, a], lr=lr)

    for epoch in range(n_epochs):
        optimizer.zero_grad()
        cov = a.T @ a
        loss = lm_loss(samples, mu, cov)
        loss.backward()
        optimizer.step()

        if epoch % plot_every == 0 or epoch == n_epochs - 1:
            print(f"epoch {epoch} | loss = {loss.item():.6f} | mean = {mu} | cov = {cov}")

    return mu.detach(), cov.detach()

In [7]:
d = 1
mu_true, cov_true = generate_gaussian_params(d, sigma_mu=10, sigma_cov=2, seed=0)
samples = generate_gaussian_data(mu_true, cov_true, n_samples=500, seed=0)
print(mu_true, "\n\n", cov_true, "\n\n\n\n")

mu_hat, cov_hat = lm_optimize(samples, n_epochs=40000, lr=1e-2)

print("\n--------- results: ---------")
print("mu true:     ", mu_true)
print("mu hat:  ", mu_hat)
print("cov true:\n", cov_true)
print("cov hat:\n", cov_hat)

tensor([15.4100]) 

 tensor([[0.3544]]) 




epoch 0 | loss = 0.013927 | mean = tensor([-0.5914], requires_grad=True) | cov = tensor([[1.0246]], grad_fn=<MmBackward0>)
epoch 500 | loss = 0.004637 | mean = tensor([1.1505], requires_grad=True) | cov = tensor([[4.8351]], grad_fn=<MmBackward0>)
epoch 1000 | loss = 0.004493 | mean = tensor([2.0515], requires_grad=True) | cov = tensor([[6.5231]], grad_fn=<MmBackward0>)
epoch 1500 | loss = 0.004450 | mean = tensor([2.7948], requires_grad=True) | cov = tensor([[7.7272]], grad_fn=<MmBackward0>)
epoch 2000 | loss = 0.004430 | mean = tensor([3.4765], requires_grad=True) | cov = tensor([[8.6654]], grad_fn=<MmBackward0>)
epoch 2500 | loss = 0.004420 | mean = tensor([4.1342], requires_grad=True) | cov = tensor([[9.4064]], grad_fn=<MmBackward0>)
epoch 3000 | loss = 0.004413 | mean = tensor([4.7882], requires_grad=True) | cov = tensor([[9.9671]], grad_fn=<MmBackward0>)
epoch 3500 | loss = 0.004410 | mean = tensor([5.4601], requires_grad=True) | cov = 

  """Creates a criterion that computes distances between sampled measures on a vector space.
  """Implements kernel ("gaussian", "laplacian", "energy") norms between sampled measures.


KeyboardInterrupt: 