# Notebook 3: Numerical HJB Solver\n\n**Author:** Divyansh Atri\n\n## Overview\n\nImplementation of finite difference methods for solving the HJB equation.\n\n**Topics:**\n1. Finite difference discretization\n2. Backward time stepping\n3. Hamiltonian minimization\n4. Stability analysis\n5. Convergence studies

In [None]:
import numpy as np\nimport matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nimport sys\nsys.path.append('..')\nfrom utils import *\n\nplt.style.use('seaborn-v0_8-darkgrid')\nplt.rcParams['figure.figsize'] = (14, 6)\n\nprint('Numerical HJB Solver - Ready')

## 1. Finite Difference Discretization\n\n### Spatial Grid\n$$x_i = x_{\min} + i\Delta x, \quad i = 0, \ldots, N_x-1$$\n\n### Time Grid\n$$t_n = n\Delta t, \quad n = 0, \ldots, N_t-1$$\n\n### Derivatives\n- First: $V_x \approx \frac{V_{i+1} - V_{i-1}}{2\Delta x}$\n- Second: $V_{xx} \approx \frac{V_{i+1} - 2V_i + V_{i-1}}{(\Delta x)^2}$

In [None]:
# Setup: Controlled Brownian Motion\nmodel = ControlledBrownianMotion(sigma=0.5)\ncost_fn = QuadraticCost(q=1.0, r=1.0, q_terminal=10.0)\n\n# Grid\nx_min, x_max, nx = -3.0, 3.0, 101\nT, nt = 2.0, 201\n\ndx = (x_max - x_min) / (nx - 1)\ndt = T / (nt - 1)\n\nprint(f'Grid: x ∈ [{x_min}, {x_max}], {nx} points, dx = {dx:.4f}')\nprint(f'Time: t ∈ [0, {T}], {nt} points, dt = {dt:.4f}')\nprint(f'CFL number: {0.5 * model.sigma**2 * dt / dx**2:.4f}')

## 2. HJB Solver Implementation\n\nBackward time stepping:\n\n$$V^n_i = V^{n+1}_i - \Delta t \cdot \min_u H(x_i, u, V_x^{n+1}_i, V_{xx}^{n+1}_i)$$

In [None]:
# Create and solve\nsolver = HJBSolver(x_min, x_max, nx, T, nt, model, cost_fn)\n\nprint('Solving HJB equation...')\nV, u_opt = solver.solve_backward(u_bounds=(-5, 5), verbose=True)\n\nprint(f'\nSolution shape: V {V.shape}, u_opt {u_opt.shape}')

## 3. Visualization

In [None]:
# Plot results\nfig = plt.figure(figsize=(16, 5))\n\n# 3D surface\nax1 = fig.add_subplot(131, projection='3d')\nT_grid, X_grid = np.meshgrid(solver.t, solver.x, indexing='ij')\nsurf = ax1.plot_surface(T_grid, X_grid, V, cmap='viridis', alpha=0.9)\nax1.set_xlabel('Time $t$')\nax1.set_ylabel('State $x$')\nax1.set_zlabel('Value $V(t,x)$')\nax1.set_title('Value Function')\n\n# Value at different times\nax2 = fig.add_subplot(132)\ntimes = [0, nt//4, nt//2, 3*nt//4, nt-1]\nfor idx in times:\n    ax2.plot(solver.x, V[idx, :], label=f't={solver.t[idx]:.2f}', linewidth=2)\nax2.set_xlabel('State $x$')\nax2.set_ylabel('Value $V(t,x)$')\nax2.set_title('Value Function Slices')\nax2.legend()\nax2.grid(True, alpha=0.3)\n\n# Optimal control\nax3 = fig.add_subplot(133)\nfor idx in times:\n    ax3.plot(solver.x, u_opt[idx, :], label=f't={solver.t[idx]:.2f}', linewidth=2)\nax3.set_xlabel('State $x$')\nax3.set_ylabel('Control $u^*(t,x)$')\nax3.set_title('Optimal Control Policy')\nax3.legend()\nax3.grid(True, alpha=0.3)\nax3.axhline(0, color='k', linestyle='--', alpha=0.5)\n\nplt.tight_layout()\nplt.savefig('../plots/hjb_solution.png', dpi=150, bbox_inches='tight')\nplt.show()

## 4. Convergence Analysis

In [None]:
# Grid refinement study\nprint('Testing convergence with grid refinement...')\n\nnx_values = [51, 101, 201]\nV_solutions = []\n\nfor nx_test in nx_values:\n    solver_test = HJBSolver(x_min, x_max, nx_test, T, nt, model, cost_fn)\n    V_test, _ = solver_test.solve_backward(u_bounds=(-5, 5), verbose=False)\n    V_solutions.append((nx_test, V_test, solver_test.x))\n    print(f'  nx = {nx_test}: V(0, 0) = {V_test[0, nx_test//2]:.6f}')\n\n# Compute errors\nerrors = []\nfor i in range(len(V_solutions) - 1):\n    nx1, V1, x1 = V_solutions[i]\n    nx2, V2, x2 = V_solutions[i+1]\n    # Interpolate V1 to V2's grid\n    from scipy.interpolate import interp1d\n    V1_interp = np.zeros_like(V2)\n    for n in range(nt):\n        f = interp1d(x1, V1[n, :], kind='cubic', fill_value='extrapolate')\n        V1_interp[n, :] = f(x2)\n    error = np.max(np.abs(V2 - V1_interp))\n    errors.append(error)\n    print(f'Error (nx={nx1} vs {nx2}): {error:.6e}')\n\n# Plot convergence\nif len(errors) > 0:\n    plt.figure(figsize=(8, 6))\n    plt.semilogy(nx_values[1:], errors, 'bo-', linewidth=2, markersize=8)\n    plt.xlabel('Number of spatial points $N_x$')\n    plt.ylabel('Max error')\n    plt.title('Convergence with Grid Refinement')\n    plt.grid(True, alpha=0.3)\n    plt.tight_layout()\n    plt.savefig('../plots/hjb_convergence.png', dpi=150, bbox_inches='tight')\n    plt.show()

## Summary\n\nWe implemented a finite difference solver for the HJB equation with:\n- Backward time stepping\n- Pointwise Hamiltonian minimization\n- Convergence analysis\n\n**Next:** Validate against analytical LQR solution.