In [68]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [69]:
def langevin_step(sample, theta, eps, key):
    
    noise = jax.random.normal(key, shape=sample.shape)
    grad_V = sample - theta

    return sample - eps * grad_V + jnp.sqrt(2 * eps) * noise

In [None]:
def evolve_langevin(samples, theta, counter, eps=1e-2, n_evolution=1, seed=0):
    
    key = jax.random.PRNGKey(seed + counter)
    samples_evolved = samples

    for i in range(n_evolution):
        key, subkey = jax.random.split(key)
        samples_evolved = langevin_step(samples_evolved, theta, eps, subkey)

    return samples_evolved

In [75]:
from ott.tools import sinkhorn_divergence

def lm_loss(samples, theta, counter):

    samples = samples[:, None]

    evolved_samples = evolve_langevin(samples, theta, counter)

    loss = sinkhorn_divergence.sinkdiv(samples, evolved_samples)[0]

    return loss

In [76]:
def lm_optimize(samples, n_epochs=300, lr=1e-2):

    theta = jnp.array(0.0)

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(theta)
    counter = 0

    def loss_fn(theta):
        return lm_loss(samples, theta, counter)

    for t in range(n_epochs):
        loss_val, grad = jax.value_and_grad(loss_fn)(theta)
        updates, opt_state = optimizer.update(grad, opt_state)
        theta = optax.apply_updates(theta, updates)
        counter += 1

        if t % 5 == 0:
            print(f"step {t} | loss = {loss_val:.6f} | theta = {theta:.4f}")

    return theta

In [77]:
true_theta = 10.0
key = jax.random.PRNGKey(0)

samples = jax.random.normal(key, shape=(500,)) + true_theta

theta_hat = lm_optimize(samples, n_epochs=3000, lr=0.5)

print(f"\nfinal theta: {theta_hat:.4f}")

0
step 0 | loss = 0.013957 | theta = 0.5000
1
2
3
4
5
step 5 | loss = 0.006663 | theta = 2.9385
6
7
8
9
10
step 10 | loss = 0.002856 | theta = 5.2465
11
12
13
14
15
step 15 | loss = 0.001306 | theta = 7.3427
16
17
18
19
20
step 20 | loss = 0.000684 | theta = 9.0574
21
22
23
24
25
step 25 | loss = 0.000574 | theta = 10.2947
26
27
28
29
30
step 30 | loss = 0.000402 | theta = 11.0322
31
32
33
34
35
step 35 | loss = 0.000887 | theta = 11.3354
36
37
38
39
40
step 40 | loss = 0.001094 | theta = 11.2074
41
42
43
44
45
step 45 | loss = 0.000726 | theta = 10.8241
46
47
48
49
50
step 50 | loss = 0.000413 | theta = 10.3558
51
52
53
54
55
step 55 | loss = 0.000311 | theta = 10.0120
56
57
58
59
60
step 60 | loss = 0.000847 | theta = 9.8248
61
62
63
64
65
step 65 | loss = 0.001411 | theta = 9.7615
66
67
68
69
70
step 70 | loss = 0.000458 | theta = 9.9274
71
72
73
74
75
step 75 | loss = 0.000398 | theta = 10.0421
76
77
78
79
80
step 80 | loss = 0.000311 | theta = 10.0433
81
82
83
84
85
step 85 | loss

KeyboardInterrupt: 