In [None]:
"""
In this notebook, we recreate the simple linear Ornstein-Uhlenbeck example from,

Hua, M., Laurière, M., & Vanden-Eijnden, E. (2024). A Simulation-Free Deep Learning 
Approach to Stochastic Optimal Control. arXiv preprint arXiv:2410.05163.

In this reference, the "simulation-free" aspect originates from,
(1) taking derivatives of the original loss, which is expressed in an importance sampling / Girsanov theory setting,
(2) a clever choice of the reference measure (i.e., on-policy), essentially using an independent stop-grad copy of the policy being learned  

This example is similar to the linear Ornstein-Uhlenbeck example found in,

Nüsken, N., & Richter, L. (2021). Solving high-dimensional Hamilton–Jacobi–Bellman PDEs using neural networks: 
perspectives from the theory of controlled diffusions and measures on path space. Partial differential equations and applications, 2(4), 48.


Work in progress... 
(2/12/25): with some minor corrections and modifications, we are able to get some fairly strong performance, 
           at least at the level of loss values (e.g., compare to the rightmost plot in Figure 1 of the primary reference.)
           the loss variance can still be quite high.
(2/11/25): debugging and trying to reduce the training instability/variance. 
           the controlled case shows a notable reduction in cost compared to uncontrolled, which is a good sign.

These computations were run on a MacBook Pro with an M4 Pro and 24 GB of memory.
"""

In [14]:
from functools import partial
import time
from typing import Callable

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np

In [None]:
jax.devices()
key = jax.random.key(0)

In [16]:
# Process dynamics
d = 20
nu = 0.1

# state cost
# initial distribution - sqrt(1/2) N(0,I) 
init_stddev = jnp.sqrt(0.5)

key, subkey1, subkey2 = jax.random.split(key, 3)
A_sys = -jnp.eye(d) + nu * jax.random.normal(subkey1, shape=(d,d))
B_sys = jnp.eye(d) + nu * jax.random.normal(subkey2, shape=(d,d))

def process_step(prev_x: jnp.array, dw: jnp.array, dt: float):
    '''
    One step of the linear Ornstein-Uhlenbeck process,
    discretized via Euler-Maruyama.
    
    Outputs designed for use with jax.lax.scan.
    '''
    next_x = prev_x + dt * jnp.matmul(A_sys, prev_x) + jnp.sqrt(dt) * jnp.matmul(B_sys, dw)
    return next_x, prev_x

process_scan = lambda x, c : process_step(prev_x=x, dw=c[0:d], dt=c[d:])


In [None]:
num_tsteps = 1000
dt = 0.01*jnp.ones((num_tsteps,1))
key, subkey = jax.random.split(key, 2)
dw = jax.random.normal(subkey, shape=(num_tsteps, d))

x_init = jnp.zeros((d,))
x_final, x_walk = jax.lax.scan(f=process_scan, init=x_init, xs= jnp.concat((dw, dt), axis=1))

for i in range(d):
    plt.plot(jnp.cumsum(dt), x_walk[:, i])

In [None]:
x_walk.shape

In [19]:
## MODEL

def random_layer_params(
        key: jax.random.key,
        in_dim: int,
        out_dim: int,
        b_scale: float = 0.0,
) -> tuple[jnp.array]:
    
    w_key, b_key = jax.random.split(key, 2)
    stdv = 1/jnp.sqrt(in_dim) 
    return (
        jax.random.uniform(
            w_key,
            shape=(out_dim, in_dim),
            minval=-stdv,
            maxval=stdv,
        ),
        b_scale * jax.random.uniform(
            b_key,
            shape=(out_dim,),
            minval=-stdv,
            maxval=stdv,
        ),
    )

def init_network_params(key: jax.random.key, layer_sizes: list[int], b_scale: int = 0.0) -> list[int]:
    keys = jax.random.split(key, len(layer_sizes))
    return [
        random_layer_params(k, in_dim, out_dim, b_scale) 
        for in_dim, out_dim, k in zip(layer_sizes[:-1], layer_sizes[1:], keys)
    ]

def predict_single(params: jnp.array, x: jnp.array, t: float) -> jnp.array:
    def relu(z: jnp.array):
        return jnp.maximum(0, z)
    def leaky_relu(z: jnp.array, alpha: float = 0.1):
        return jnp.maximum(alpha*z, z)
    def tanh(z: jnp.array):
        return jnp.tanh(z)
    def sigmoid(z: jnp.array):
        return 1/(1 + jnp.exp(-z))
    def silu(z: jnp.array):
        return z * sigmoid(z)
    
    h = jnp.append(x, t)
    readin_w, readin_b = params[0]
    u = jnp.dot(readin_w, h) + readin_b
    h = leaky_relu(u)

    for (w,b) in params[1:-1]:
        u = jnp.dot(w, h) + b
        #u = (u - u.mean()) / u.std()
        h = relu(u) 
        # h = tanh(u)
    
    readout_w, readout_b = params[-1]
    return jnp.dot(readout_w, h) + readout_b

batch_predict = jax.vmap(predict_single, in_axes=(None, 0, 0))

In [20]:
def generate_dt_xinits_and_dw(
        key: jax.random.key,
        n_timesteps: float,
        final_time: float,
        n_walkers: int,
) -> tuple:
    key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)

    # time grid
    time_grid = jnp.sort(jax.random.uniform(subkey1, shape=(n_timesteps-2,), minval=0, maxval=final_time))
    time_grid = jnp.insert(time_grid, jnp.array((0, n_timesteps-2)), jnp.array((0, final_time)))
    dt = time_grid[1:] - time_grid[:-1]

    # x_inits
    x_inits = init_stddev * jax.random.normal(key=subkey2, shape=(n_walkers, d))

    # noise
    dw = jax.random.normal(key=subkey3, shape=(n_walkers, n_timesteps-1, d))

    return dt, x_inits, dw

def controlled_step(policy_params: list[jnp.array], prev_x: jnp.array, dw: jnp.array, dt: float, t: float):
    # next_x @ time=t+dt, prev_x @ time=t
    next_x = prev_x + dt * jnp.matmul(A_sys, prev_x) + dt * jnp.matmul(B_sys, predict_single(policy_params, prev_x, t)) + jnp.sqrt(dt)*jnp.matmul(B_sys, dw)
    return next_x, prev_x

def simulate_walker(params: list[jnp.array], x_init: jnp.array, dt: jnp.array, dw: jnp.array):

    controlled_scan = lambda x, c : controlled_step(policy_params=params, prev_x=x, dw=c[0:d], dt=c[d], t=c[d+1])
    
    # time
    t = jnp.insert(jnp.cumsum(dt[:-1]), 0, 0) # t[1:] = t[:-1]+dt[:-1]
    # forward simulation with policy (but no gradients needed)
    x_final, x_path = map(
            jax.lax.stop_gradient,
            jax.lax.scan(f=controlled_scan, init=x_init, xs=jnp.concat((dw, dt[:, None], t[:, None]), axis=1))
    )

    # policy evaluations on on states
    policy_eval = batch_predict(params, x_path, t) 

    A = 0.5 * (dt * (policy_eval**2).sum(axis=-1)).sum() 
    Abar = jax.lax.stop_gradient(A)
    C = (jnp.sqrt(dt) * (policy_eval * dw).sum(axis=-1)).sum()
    Bbar = 0.0 # 0.5 * (dt * (x_path**2).sum(axis=-1)).sum()  # state cost 
    g = x_final.sum() # final state cost

    return A + (Abar + Bbar + g) * C

batch_walkers = jax.vmap(simulate_walker, in_axes=(None, 0, None, 0))

def simulation_free_loss(
    params: list[jnp.array],
    x_inits: jnp.array,
    dt: jnp.array,
    dw: jnp.array,
):
    # simulate a batch of walkers
    return batch_walkers(params, x_inits, dt, dw).mean()


In [21]:
## OPTIMIZATION

def moving_average_step(
    past_state: list[jnp.array],
    new_state: list[jnp.array],
    beta: float = 0.9,
) -> list[jnp.array]:
    return [
        (beta * w_ps + (1-beta) * w_ns, beta * b_ps +(1-beta) * b_ns)
        for (w_ps, b_ps), (w_ns, b_ns) in zip(past_state, new_state)
    ]
    
def clip_grads(grads: list[jnp.array], clip_threshold: float = 1.0) -> list[jnp.array]:
     return [(_clip(w, clip_threshold), _clip(b, clip_threshold)) for (w, b) in grads]
    
def _clip(g: jnp.array, clip_threshold: float, eps: float = 1e-7) -> jnp.array:
    g_norm = jnp.max(jnp.array([jnp.linalg.norm(g), eps]))
    return jnp.min(jnp.array([clip_threshold/g_norm, 1.0])) * g

def gradient_step(
    params: list[jnp.array],
    grads: list[jnp.array],
    lr: float,
    weight_decay: float = 0,
) -> list[jnp.array]:
    
    return [
        ( (1 - lr * weight_decay) * w - lr * dw, (1 - lr * weight_decay) * b - lr * db) for (w, b), (dw, db) in zip(params, grads)
    ] 

def adamw_step(
    params: list[jnp.array],
    momentum: list[jnp.array],
    velocity: list[jnp.array],
    lr: float,
    weight_decay: float = 0,
    eps: float = 1e-8,
) -> list[jnp.array]:
    
    adjusted_grads = [(w_g / (jnp.sqrt(w_v) + eps), b_g / (jnp.sqrt(b_v) + eps)) for (w_g, b_g), (w_v, b_v) in zip(momentum, velocity)]
    return [
        ( (1 - weight_decay) * w - lr * dw, (1 - weight_decay) * b - lr * db) for (w, b), (dw, db) in zip(params, adjusted_grads)
    ] 

def cosine_lr_scheduler(min_lr: float, max_lr: float, current_epoch: int, epochs_per_cycle: int, decay_rate: float = 1.0) -> float:
     adjusted_max_lr = decay_rate**current_epoch * max_lr
     return min_lr + 0.5 * (adjusted_max_lr - min_lr) * (1 + jnp.cos( (current_epoch % epochs_per_cycle) / epochs_per_cycle * jnp.pi))

def decay_lr_scheduler(lr: float, decay_rate: float = 0.95) -> float:
     return lr * decay_rate

In [22]:
## training settings

# network
layers = [d+1, 128, 128, d]
b_scale = 1.0
key, subkey = jax.random.split(key, 2)
params = init_network_params(subkey, layers, b_scale)

# loss

n_timesteps = 500
final_time = 2
n_walkers = 2048

# optimization
n_epochs = 1000
# momentum = None
# velocity = None
momentum = [(jnp.zeros_like(w), jnp.zeros_like(b)) for w,b in params]
velocity = [(jnp.zeros_like(w), jnp.zeros_like(b)) for w,b in params]
beta1 = 0.9
beta2 = 0.999
min_lr = 1e-4
max_lr = 3e-4
lr = max_lr
epochs_per_cycle = 100
weight_decay = 1e-5

lr_scheduler = partial(
    cosine_lr_scheduler,
    min_lr=min_lr,
    max_lr=max_lr,
    epochs_per_cycle=epochs_per_cycle,
    decay_rate = 1.0,
)


value_and_grad_fn = jax.value_and_grad(simulation_free_loss, argnums=0)

@jax.jit
def update(
    params: list[jnp.array],
    dt: jnp.array,
    x_inits: jnp.array,
    dw: jnp.array,
    lr: float,
    momentum: list[jnp.array],
    velocity: list[jnp.array],
    epoch: int,
)->tuple:
    
    loss_val, grads = value_and_grad_fn(params, x_inits, dt, dw)
    clipped_grads = clip_grads(grads, 1.0)

    # adamW, based on:
    # Loshchilov, I., & Hutter, F. (2017). Fixing weight decay regularization in adam.
    # arXiv preprint arXiv:1711.05101, 5.

    momentum = moving_average_step(momentum, clipped_grads, beta1)
    velocity = moving_average_step(velocity, [(w**2, b**2) for w, b in clipped_grads], beta2)
    
    # bias corrections for moving average initializations
    mhat = [(w / (1-beta1**(epoch+1)),  b / (1-beta1**(epoch+1))) for w, b in momentum]
    vhat = [(w / (1-beta2**(epoch+1)),  b / (1-beta2**(epoch+1))) for w, b in velocity]

    return loss_val, adamw_step(params, mhat, vhat, lr, weight_decay, eps=1e-4), momentum, velocity

In [None]:
losses = []

for epoch in range(n_epochs):
    key, subkey = jax.random.split(key, 2)

    t_start = time.time()

    dt, x_inits, dw = generate_dt_xinits_and_dw(
        key=subkey,
        n_timesteps=n_timesteps,
        final_time=final_time,
        n_walkers=n_walkers,
    )

    loss_val, params, momentum, velocity = update(
        params=params,
        dt=dt,
        x_inits=x_inits,
        dw=dw,
        lr=lr,
        momentum=momentum,
        velocity=velocity,
        epoch=epoch,
    )

    t_epoch = time.time() - t_start
    print(f'epoch {epoch} | lr {lr:0.3e} | loss {loss_val:0.5e} | epoch time: {t_epoch:0.3e}s')
    lr = lr_scheduler(current_epoch=epoch)
    losses.append(loss_val)



In [None]:
plt.plot(losses)

In [None]:
def controlled_step(policy_params: list[jnp.array], prev_x: jnp.array, dw: jnp.array, dt: float, t: float):
    next_x = prev_x + dt * jnp.matmul(A_sys, prev_x) + dt * jnp.matmul(B_sys, predict_single(policy_params, prev_x, t)) + jnp.sqrt(dt)*jnp.matmul(B_sys, dw)
    return next_x, prev_x

controlled_scan = lambda x, c : controlled_step(policy_params=params, prev_x=x, dw=c[0:d], dt=c[d], t=c[d+1])  

def uncontrolled_step(prev_x: jnp.array, dw: jnp.array, dt: float):
    next_x = prev_x + dt * jnp.matmul(A_sys, prev_x) + jnp.sqrt(dt)*jnp.matmul(B_sys, dw)
    return next_x, prev_x

uncontrolled_scan = lambda x, c : uncontrolled_step(prev_x=x, dw=c[0:d], dt=c[d])   
# forward simulation with policy (but no gradients needed)
num_tsteps = 200
dt = 0.01*jnp.ones((num_tsteps,1))
t = jnp.insert(jnp.cumsum(dt[:-1]), 0, 0)

key, subkey1, subkey2 = jax.random.split(key, 3)
dw = jax.random.normal(subkey1, shape=(num_tsteps, d))
x_init = init_stddev * jax.random.normal(subkey2, shape=(d,))

xc_final, xc_path = jax.lax.scan(f=controlled_scan, init=x_init, xs=jnp.concat((dw, dt, t[:, None]), axis=1))
xuc_final, xuc_path = jax.lax.scan(f=uncontrolled_scan, init=x_init, xs=jnp.concat((dw, dt), axis=1))

for i in range(d):
    plt.plot(t, xc_path[:,i])

for i in range(d):
    plt.plot(t, xuc_path[:,i], c='k', linestyle='--', alpha=0.3)

# compare final costs:
print('Cost objective being minimized')
print(f'Uncontrolled: {xuc_final.sum()}')
print(f'Controlled: {xc_final.sum()}')