In [8]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg

import torch
import geomloss

# we start with the trivial_1 case

In [9]:
def langevin_step(sample, theta, eps):

    noise = torch.randn_like(sample)

    grad_V = sample - theta

    return sample - eps * grad_V + torch.sqrt(torch.tensor(2 * eps)) * noise

In [10]:
def evolve_langevin(samples, theta, eps=1e-2, n_evolution=10, seed=0):

    torch.manual_seed(seed)
    samples_evolved = samples.clone()

    for i in range(n_evolution):
        samples_evolved = langevin_step(samples_evolved, theta, eps)

    return samples_evolved

In [11]:
sinkhorn_loss = geomloss.SamplesLoss(loss="sinkhorn", p=2, blur=0.05)

def lm_loss(samples, theta):

    evolved_samples = evolve_langevin(samples, theta)
    if samples.ndim == 1:
        samples = samples.unsqueeze(1)
    if evolved_samples.ndim == 1:
        evolved_samples = evolved_samples.unsqueeze(1)

    return sinkhorn_loss(samples, evolved_samples)

In [12]:
def lm_optimize(samples, n_epochs=300, lr=1e-2):
    
    theta = torch.tensor(0.0, requires_grad=True)

    optimizer = torch.optim.Adam([theta], lr=lr)

    for t in range(n_epochs):
        optimizer.zero_grad()
        loss = lm_loss(samples, theta)
        loss.backward()
        optimizer.step()

        if t % 50 == 0:
            print(f"step {t} | loss = {loss.item():.6f} | theta = {theta.item():.4f}")

    return theta.detach()

In [13]:
true_theta = 10.0
samples = torch.randn(500) + true_theta

theta_hat = lm_optimize(samples, n_epochs=3000, lr=1e-2)

print(f"\nfinal theta: {theta_hat.item():.4f}")

step 0 | loss = 0.458554 | theta = 0.0100
step 50 | loss = 0.414184 | theta = 0.5068
step 100 | loss = 0.372991 | theta = 0.9924
step 150 | loss = 0.334942 | theta = 1.4654
step 200 | loss = 0.299885 | theta = 1.9257
step 250 | loss = 0.267665 | theta = 2.3730
step 300 | loss = 0.238133 | theta = 2.8074
step 350 | loss = 0.211141 | theta = 3.2287
step 400 | loss = 0.186544 | theta = 3.6368
step 450 | loss = 0.164200 | theta = 4.0315
step 500 | loss = 0.143971 | theta = 4.4128
step 550 | loss = 0.125720 | theta = 4.7806
step 600 | loss = 0.109314 | theta = 5.1348
step 650 | loss = 0.094625 | theta = 5.4753
step 700 | loss = 0.081526 | theta = 5.8021
step 750 | loss = 0.069897 | theta = 6.1150
step 800 | loss = 0.059618 | theta = 6.4142
step 850 | loss = 0.050578 | theta = 6.6996
step 900 | loss = 0.042666 | theta = 6.9712
step 950 | loss = 0.035778 | theta = 7.2291
step 1000 | loss = 0.029817 | theta = 7.4733
step 1050 | loss = 0.024687 | theta = 7.7041
step 1100 | loss = 0.020300 | the