# Solving-Hamilton-Jacobi-Bellman Equations (via FBSDEs)
#### Frederik Kelbel, Imperial College London

In [None]:
import torch
import plotly.graph_objects as go
import numpy as np
from operators import div, Δ, D, mdotb, bdotm, mdotm, bdotb
from FBSDEs import FBSDESolver
from pdes import FBSDE
from plotly.subplots import make_subplots
from configs import CONFIG_FBSDES as MODEL_CONFIG
from itertools import product
from torchsummary import summary

## Plotting

In [None]:
def plot_losses(losses, avg_over=10):
    avgs = np.convolve(losses, np.ones(avg_over), 'valid') / avg_over
    fig = make_subplots(rows=1, cols=1)
    fig.add_trace(go.Scatter(x=np.arange(len(avgs)), y=avgs, mode='lines', name="Error at x=0.1"), row=1, col=1)
    fig.update_layout(
        title="Loss",
        xaxis_title="Iterations",
        yaxis_title="Loss",
        font=dict(
            family="Courier New, monospace",
            size=14
        )
    )
    fig.show()

### The Merton Problem (Wealth Allocation Problem)

$$
\begin{cases}
dX_s = ((\mu -r)u_s + r)X_s ds + \sigma u_s X_s dW_s, \; s \in [0, T] \\
X_0 = x > 0
\end{cases},
$$

$$
\begin{cases}
\partial_t J(t, x) + \sup_{u} \Big\{ ((\mu-r)u + r)x \partial_x J(t, x) + \frac{1}{2} \sigma^2 u^2 x^2\partial_{xx} J(t, x) \Big\} = 0 \text{ on $[0, T] \times (0, \infty)$}
\\
J(T, x) = x^\gamma \text{ $\forall x > 0$}
\end{cases}
$$

Minimum at $u^* = \frac{(r-\mu)J_x}{\sigma^2 x J_{xx}}$. Equation becomes
$J_t + r x J_x - \frac{1}{2} \frac{(r-\mu)^2J^2_x}{J_{xx}} = 0$

In [None]:
LQR_MODEL_CONFIG = {
    "batch_size": 128,
    "num_discretisation_steps": 30,
    "hidden_dim": 128,
    "learning_rate": 5e-3,
    "lr_decay": 0.99,
    "network_type": "MINI",
    "optimiser": "Adam"
}
model = LQR_MODEL_CONFIG
class LQR(FBSDE):
    def __init__(self):
        super().__init__()
        sigma = 0.3*torch.ones((model["batch_size"], 1, 1))
        self.h = lambda X, Y, Z, t: C(X) + (1/2)*torch.einsum("bi, bij, bj -> b", Z, M, torch.einsum("bij, bj -> bi", (-inv_D @ M), Z)).unsqueeze(1)
        
        self.b = lambda X, t: 0.1*X
        self.sigma = lambda X, t: sigma
        
        self.terminal_condition = lambda X: 0.001*X**2
        
        self.var_dim = 1
        self.terminal_time = 1     
        self.init_sampling_func = lambda X: (X-0.5)*2
        self.control_noise = 0.2

## This would require 2FBSDE system. Meaning we would have to compute the Hessian for every batch entry and more (see papers). What if we did determine the optimal control function and do a policy iteration over the control as they do in PIADGM, so transform into FBSDE taking inf every iteration, i.e.

This
$$
\begin{cases}
\partial_t J(t, x) + \inf_{u} \Big\{ \frac{1}{2} tr(\mathcal{H_J \sigma \sigma^T}) + [H x + M u] \nabla_x^T J(t, x) + C(x) + \frac{1}{2} u^T D u \Big\} = 0 \text{ on $[0, T] \times (-\infty, \infty)$}
\\
J(T, x) = Rx^2 \text{ $\forall x \in \mathbb{R}$}
\end{cases}
$$ becomes

$$
\begin{align*}
    &\begin{cases}
        dX_t = [H(t, X_t)+Mu] dt + \sigma(t, X_t) dW_t, \quad t \in [0, T] \\
        X_0 = x
    \end{cases}, \\
    &\begin{cases}
        dY_t = C(X_t) dt + \frac{1}{2} D u^2 dt + \nabla J^{* \; T}(t, X_t) \sigma(t, X_t)  dW_t, \quad t \in [0, T] \\
        Y_T = g(X_T)
    \end{cases}.
\end{align*}
$$

And take $u_\theta$ such that the loss $\mathcal{L}(u) = (M u) (\nabla J^*)^T + \frac{1}{2} u^T D u$ is minimized.

### Linear-quadratic control problem 1-dimensional (Riccati Equation) 

Let $(\Omega, \mathcal{F}, \{\mathcal{F}_t\}_{t\in [0, T]}, \mathbb{P})$. We consider
$$
\begin{cases}
dX_s = [H_s(X_s) + M_s(X) u_s] ds + \sigma_s dW_s, \; s \in [0, T] \\
X_0 = x > 0
\end{cases},
$$

We aim to maximise
$$
J^u(t, x) := \mathbb{E}^{t, x} \Big[ \int_t^T X_s^T C_s X_s + \frac{1}{2}u_s^T D_s u_s ds + X_T^T R X_T\Big],
$$
with $C(t) = C \leq 0, R \leq 0$, and $D=D(t) < -\delta < 0$ given and deterministic ($\delta > 0$ some constant).

We write down the problem in its primal form as
$$
\begin{cases}
\partial_t J(t, x) + \inf_{u} \Big\{ \frac{1}{2} \sigma^2 \partial_{xx} J(t, x) + [H x + M u] \partial_x J(t, x) + C x^2 + \frac{1}{2}D u^2 \Big\} = 0 \text{ on $[0, T] \times (-\infty, \infty)$}
\\
J(T, x) = Rx^2 \text{ $\forall x \in \mathbb{R}$}
\end{cases}
$$

$$
\begin{align*}
    &\begin{cases}
        dX_t = H(t, X_t) dt + \sigma(t, X_t) dW_t, \quad t \in [0, T] \\
        X_0 = x
    \end{cases}, \\
    &\begin{cases}
        dY_t = C(X_t) dt - \frac{1}{2}(\nabla J^{* \; T} M D^{-1} M^T \nabla J^*)(t, X_t) dt + \nabla J^{* \; T}(t, X_t) \sigma(t, X_t)  dW_t, \quad t \in [0, T] \\
        Y_T = g(X_T)
    \end{cases}.
\end{align*}
$$

In [None]:
LQR_MODEL_CONFIG = {
    "batch_size": 128,
    "num_discretisation_steps": 30,
    "hidden_dim": 128,
    "learning_rate": 5e-3,
    "lr_decay": 0.99,
    "network_type": "MINI",
    "optimiser": "Adam"
}
model = LQR_MODEL_CONFIG
class LQR(FBSDE):
    def __init__(self):
        super().__init__()
        sigma = 0.3*torch.ones((model["batch_size"], 1, 1))
        M = 2.0*torch.ones((model["batch_size"], 1, 1))
        C = lambda X: 2.0*X**2
        inv_D = torch.inverse(torch.tensor([[0.2]]))
        self.h = lambda X, Y, Z, t: C(X) + (1/2)*torch.einsum("bi, bij, bj -> b", Z, M, torch.einsum("bij, bj -> bi", (-inv_D @ M), Z)).unsqueeze(1)
        
        self.b = lambda X, t: 0.1*X
        self.sigma = lambda X, t: sigma
        
        self.terminal_condition = lambda X: 0.001*X**2
        
        self.var_dim = 1
        self.terminal_time = 1     
        self.init_sampling_func = lambda X: (X-0.5)*2
        self.control_noise = 0.2

In [None]:
eq = LQR()
solver = FBSDESolver(model, eq)
summary(solver.Y_net, (1, 2))

In [None]:
loss = np.array(list(solver.train(400)))
plot_losses(loss)

In [None]:
fig = make_subplots(rows=1, cols=1, specs=[[{'type': 'surface'}]])
xs = np.linspace(-1, 1, 100)
ts = np.linspace(0, 1, 100)
us_pred = np.array([[solver.J(x, t).item() for x in xs] for t in ts])
fig.add_trace(go.Surface(x=xs, y=ts, z=us_pred), row=1, col=1)
fig.update_layout(title='Solution | Approximation',
                  scene = dict(
                    xaxis_title="x",
                    yaxis_title="t",
                    zaxis_title="J(x, t"),
                  scene2 = dict(
                    xaxis_title="x",
                    yaxis_title="t",
                    zaxis_title="J(x, t)"),
                  margin=dict(l=50, r=50, b=50, t=50))
fig.show()

In [None]:
M=2.0
sigma= 0.3

In [None]:
n = model["num_discretisation_steps"]
c_xs = np.zeros(n)
c_xs[0] = -0.5
uc_xs = np.zeros(n)
uc_xs[0] = c_xs[0]
dt = 1/n
ts = [t for t in np.linspace(0, 1, n)]
c_cum_cost = np.zeros(n)
uc_cum_cost = np.zeros(n)
for i in range(n-1):
    dW = np.sqrt(dt)*np.random.randn()
    c = solver.u(c_xs[i], i*dt).item()
    uc = 0
    c_xs[i+1] = c_xs[i] + (eq.H(c_xs[i], i*dt) + M*c)*dt + sigma*dW
    uc_xs[i+1] = uc_xs[i] + (eq.H(uc_xs[i], i*dt) + M*uc)*dt + sigma*dW
    c_cum_cost[i+1] = c_cum_cost[i] + eq.C(c_xs[i]) + eq.D*c**2
    uc_cum_cost[i+1] = uc_cum_cost[i] + eq.C(uc_xs[i]) + eq.D*uc**2

c_cum_cost[-1] += eq.terminal_condition(c_xs[-1])
uc_cum_cost[-1] += eq.terminal_condition(uc_xs[-1])
    
fig = make_subplots(rows=1, cols=2)
fig.add_trace(go.Scatter(x=ts, y=c_xs, mode='lines', name="Controlled", line=dict(color="#00e476")), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=uc_xs, mode='lines', name="Uncontrolled", line=dict(color="#FFe476")), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=c_cum_cost, mode='lines', showlegend=False, line=dict(color="#00e476")), row=1, col=2)
fig.add_trace(go.Scatter(x=ts, y=uc_cum_cost, mode='lines', showlegend=False, line=dict(color="#FFe476")), row=1, col=2)
fig.update_layout(
    title="Minimise amount of X | Minimise the costs (hold both close to zero)",
    xaxis_title="t",
    yaxis_title="X",
    font=dict(
        family="Courier New, monospace",
        size=14
    )
)
fig.show()

### Black-Scholes-Barenblatt Equation N-dimensional

In [None]:
BSB_MODEL_CONFIG = {
    "batch_size": 128,
    "num_discretisation_steps": 30,
    "hidden_dim": 64,
    "learning_rate": 5e-3,
    "lr_decay": 0.99,
    "network_type": "MINI",
    "optimiser": "Adam"
}
model = BSB_MODEL_CONFIG
class BSB(FBSDE):
    def __init__(self):
        super().__init__()
        r = 0.05
        self.h = lambda X, Y, Z, t: r*(Y-torch.einsum("bi, bi -> b", Z, X).unsqueeze(1))
        
        self.b = lambda X, t: 0.0*X
        self.sigma = lambda X, t: 0.3*torch.diag_embed(X)
        
        self.terminal_condition = lambda X: torch.einsum("bi, bi-> b", X, X).unsqueeze(1)
        
        self.var_dim = 2
        self.terminal_time = 1     
        self.init_sampling_func = lambda X: (X-0.5)*2
        self.control_noise = 0.0

In [None]:
eq = BSB()
solver = FBSDESolver(model, eq)
loss = np.array(list(solver.train(600)))
plot_losses(loss)

In [None]:
num_samples = 2
def J_sol(X, t):
    r = 0.05
    sigma = 0.3
    return np.exp((r + sigma**2)*(1 - t))*np.sum(X**2, axis=-1, keepdims=True)
Xs, Y_preds, ts = solver.simulate_processes(num_samples)
Y_sol = J_sol(Xs, ts)

In [None]:
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Scatter(x=ts[:, 0].flatten(), y=Y_preds[:, 0].flatten(), mode='lines', name="Prediction", line=dict(color="#FFe476")), row=1, col=1)
fig.add_trace(go.Scatter(x=ts[:, 0].flatten(), y=Y_sol[:, 0].flatten(), mode='lines', name="Ground truth", line=dict(color="#00e476")), row=1, col=1)
for i in range(1, num_samples):
    fig.add_trace(go.Scatter(x=ts[:, i].flatten(), y=Y_preds[:, i].flatten(), mode='lines', showlegend=False, line=dict(color="#FFe476")), row=1, col=1)
    fig.add_trace(go.Scatter(x=ts[:, i].flatten(), y=Y_sol[:, i].flatten(), mode='lines', showlegend=False, line=dict(color="#00e476")), row=1, col=1)
fig.update_layout(
    title="Loss",
    xaxis_title="t",
    yaxis_title="J",
    font=dict(
        family="Courier New, monospace",
        size=14
    )
)
fig.show()

### Allen-Cahn

In [None]:
AC_MODEL_CONFIG = {
    "batch_size": 128,
    "num_discretisation_steps": 15,
    "hidden_dim": 128,
    "learning_rate": 5e-3,
    "lr_decay": 0.99,
    "network_type": "MINI",
    "optimiser": "Adam"
}
model = AC_MODEL_CONFIG
class AC(FBSDE):
    def __init__(self):
        super().__init__()
        self.h = lambda X, Y, Z, t: Y - Y**3
        
        self.b = lambda X, t: 0.0*X
        self.sigma = lambda X, t: torch.diag_embed(X*0+1)
        
        self.terminal_condition = lambda X: 1/(2+0.4*torch.sum(X**2, dim=-1, keepdims=True))
        
        self.var_dim = 20
        self.terminal_time = 0.3   
        self.init_sampling_func = lambda X: X*0
        self.control_noise = 0.5

In [None]:
eq = AC()
solver = FBSDESolver(model, eq)
loss = np.array(list(solver.train(500)))
plot_losses(loss)

In [None]:
num_samples=5
Xs, Y_preds, ts = solver.simulate_processes(num_samples)
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Scatter(x=ts[:, 0].flatten(), y=Y_preds[:, 0].flatten(), mode='lines', name="Prediction", line=dict(color="#FFe476")), row=1, col=1)
for i in range(1, num_samples):
    fig.add_trace(go.Scatter(x=ts[:, i].flatten(), y=Y_preds[:, i].flatten(), mode='lines', showlegend=False, line=dict(color="#FFe476")), row=1, col=1)
fig.update_layout(
    title="Allen-Cahn",
    xaxis_title="t",
    yaxis_title="J",
    font=dict(
        family="Courier New, monospace",
        size=14
    )
)
fig.show()

In [None]:
solver.J(*([0]*3), 0.1)