# HJB Optimal Insurance Control

## Overview
Solve a Hamilton-Jacobi-Bellman (HJB) PDE to find **globally optimal, state-dependent** insurance strategies.  We set up state/control spaces, compare utility functions, solve the HJB equation, visualise the value function and policy surfaces, and benchmark HJB feedback control against static and time-varying baselines.

- **Prerequisites**: [optimization/01_optimization_overview](../optimization/01_optimization_overview.ipynb)
- **Estimated runtime**: 3-5 minutes (grid solve is CPU-intensive)
- **Audience**: [Developer]

> **Manual-run notebook.**  The HJB grid solve can take several minutes on
> modest hardware.  It is excluded from automated CI runs.
> Set `CI = True` below to skip the heavy solve and use cached results.

In [None]:
"""Google Colab setup: mount Drive and install package dependencies.

Run this cell first. If prompted to restart the runtime, do so, then re-run all cells.
This cell is a no-op when running locally.
"""
import sys, os
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')

    NOTEBOOK_DIR = '/content/drive/My Drive/Colab Notebooks/ei_notebooks/advanced'

    os.chdir(NOTEBOOK_DIR)
    if NOTEBOOK_DIR not in sys.path:
        sys.path.append(NOTEBOOK_DIR)

    !pip install git+https://github.com/AlexFiliakov/Ergodic-Insurance-Limits.git -q 2>&1 | tail -3
    print('\nSetup complete. If you see numpy/scipy import errors below,')
    print('restart the runtime (Runtime > Restart runtime) and re-run all cells.')

## Setup

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import warnings
warnings.filterwarnings("ignore")

from ergodic_insurance.hjb_solver import (
    StateVariable, ControlVariable, StateSpace,
    LogUtility, PowerUtility, ExpectedWealth,
    HJBProblem, HJBSolver, HJBSolverConfig,
    create_custom_utility,
)
from ergodic_insurance.optimal_control import (
    ControlSpace, StaticControl, HJBFeedbackControl,
    TimeVaryingControl, OptimalController, create_hjb_controller,
)
from ergodic_insurance.config import ManufacturerConfig
from ergodic_insurance.manufacturer import WidgetManufacturer

plt.style.use("seaborn-v0_8-darkgrid")

# Reproducibility
SEED = 42
np.random.seed(SEED)

# Set to True in CI to skip the heavy HJB solve
CI = False

## 1. State and Control Spaces

Two state dimensions (wealth, time) and two control dimensions (insurance limit, retention).

In [None]:
state_variables = [
    StateVariable(name="wealth", min_value=1e6, max_value=1e8,
                  num_points=30, log_scale=True),
    StateVariable(name="time", min_value=0, max_value=5,
                  num_points=20, log_scale=False),
]
state_space = StateSpace(state_variables)

control_variables = [
    ControlVariable(name="limit", min_value=1e6, max_value=3e7, num_points=15),
    ControlVariable(name="retention", min_value=1e5, max_value=5e6, num_points=15),
]

print(f"State space : {state_space.ndim}D  shape={state_space.shape}  points={state_space.size}")
print(f"Controls    : limit ${control_variables[0].min_value/1e6:.0f}M-${control_variables[0].max_value/1e6:.0f}M, "
      f"retention ${control_variables[1].min_value/1e6:.1f}M-${control_variables[1].max_value/1e6:.0f}M")

## 2. Utility Functions

Compare log (ergodic), power (CRRA), and linear (risk-neutral) utilities.

In [None]:
wealth = np.linspace(1e6, 1e8, 100)
log_u  = LogUtility()
pow2_u = PowerUtility(risk_aversion=2.0)
pow4_u = PowerUtility(risk_aversion=4.0)
lin_u  = ExpectedWealth()

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
for u, label in [(log_u, "Log"), (pow2_u, "Power (g=2)"),
                  (pow4_u, "Power (g=4)")]:
    vals = u.evaluate(wealth)
    ax.plot(wealth / 1e6, vals - vals[0], lw=2, label=label)
ax.set_xlabel("Wealth ($M)")
ax.set_ylabel("Utility (normalised)")
ax.set_title("Utility Functions")
ax.legend()
ax.grid(True, alpha=0.3)

ax = axes[1]
for u, label in [(log_u, "Log"), (pow2_u, "Power (g=2)"),
                  (pow4_u, "Power (g=4)")]:
    ax.loglog(wealth / 1e6, u.derivative(wealth), lw=2, label=label)
ax.set_xlabel("Wealth ($M)")
ax.set_ylabel("Marginal Utility")
ax.set_title("Marginal Utilities (log-log)")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 3. Solve the HJB Equation

Define dynamics, running reward, and terminal value; then solve via backward induction.

In [None]:
def company_dynamics(state, control, time):
    w = state[..., 0]
    limit = control[..., 0]
    retention = control[..., 1]
    growth = 0.08
    prem_rate = 0.015 * (limit / 2e7) * (1 + np.exp(-retention / 1e6))
    loss_mit  = 0.02 * (1 - np.exp(-limit / 1e7))
    drift_w = w * (growth - prem_rate + loss_mit)
    drift_t = np.ones_like(w)
    return np.stack([drift_w, drift_t], axis=-1)

log_util = LogUtility()

def running_reward(state, control, time):
    return log_util.evaluate(state[..., 0])

def terminal_value(state):
    return log_util.evaluate(state[..., 0])

hjb_problem = HJBProblem(
    state_space=state_space,
    control_variables=control_variables,
    utility_function=log_util,
    dynamics=company_dynamics,
    running_cost=running_reward,
    terminal_value=terminal_value,
    discount_rate=0.05,
    time_horizon=5.0,
)

solver_config = HJBSolverConfig(
    time_step=0.05, max_iterations=50, tolerance=1e-3, verbose=True,
)

if not CI:
    print("Solving HJB equation...")
    hjb_solver = HJBSolver(hjb_problem, solver_config)
    value_function, optimal_policy = hjb_solver.solve()
    print(f"Value fn shape: {value_function.shape}")
    print(f"Policy keys   : {list(optimal_policy.keys())}")
else:
    print("CI mode -- HJB solve skipped.")

## 4. Visualise Value Function and Optimal Policies

In [None]:
if not CI:
    wealth_grid = state_space.grids[0]
    time_grid  = state_space.grids[1]
    W, T = np.meshgrid(wealth_grid, time_grid, indexing="ij")

    fig = plt.figure(figsize=(16, 10))

    ax1 = fig.add_subplot(2, 2, 1)
    ax1.semilogx(wealth_grid / 1e6, value_function[:, 0], "b-", lw=2)
    ax1.set_xlabel("Wealth ($M)")
    ax1.set_ylabel("Value")
    ax1.set_title("Value Function at t=0")
    ax1.grid(True, alpha=0.3)

    ax2 = fig.add_subplot(2, 2, 2)
    ax2.semilogx(wealth_grid / 1e6, optimal_policy["limit"][:, 0] / 1e6, "r-", lw=2)
    ax2.set_xlabel("Wealth ($M)")
    ax2.set_ylabel("Optimal Limit ($M)")
    ax2.set_title("Insurance Limit at t=0")
    ax2.grid(True, alpha=0.3)

    ax3 = fig.add_subplot(2, 2, 3)
    ax3.semilogx(wealth_grid / 1e6, optimal_policy["retention"][:, 0] / 1e6, "g-", lw=2)
    ax3.set_xlabel("Wealth ($M)")
    ax3.set_ylabel("Optimal Retention ($M)")
    ax3.set_title("Retention at t=0")
    ax3.grid(True, alpha=0.3)

    ax4 = fig.add_subplot(2, 2, 4, projection="3d")
    ax4.plot_surface(np.log10(W / 1e6), T, value_function, cmap="viridis", alpha=0.8)
    ax4.set_xlabel("Log Wealth")
    ax4.set_ylabel("Time")
    ax4.set_zlabel("Value")
    ax4.set_title("Value Function Surface")

    plt.tight_layout()
    plt.show()
else:
    print("Skipped (CI mode).")

## 5. Strategy Comparison

Run Monte Carlo simulations with HJB feedback, static, and time-varying controllers.

In [None]:
if not CI:
    manufacturer_config = ManufacturerConfig(
        initial_assets=2e7, asset_turnover_ratio=1.0,
        base_operating_margin=0.08, tax_rate=0.25, retention_ratio=0.6,
    )
    cs = ControlSpace(limits=[(1e6, 3e7)], retentions=[(1e5, 5e6)])

    hjb_ctrl   = OptimalController(HJBFeedbackControl(hjb_solver, cs), cs)
    static_ctrl = OptimalController(StaticControl(limits=[1.5e7], retentions=[1e6]), cs)
    tv_ctrl    = OptimalController(
        TimeVaryingControl(
            time_schedule=[0, 2.5, 5],
            limits_schedule=[[1e7], [2e7], [3e7]],
            retentions_schedule=[[5e5], [1e6], [2e6]],
        ), cs)

    def run_mc(controller, cfg, years=5, n_sims=100):
        rows = []
        for _ in range(n_sims):
            m = WidgetManufacturer(cfg)
            controller.reset()
            for yr in range(years):
                ins = controller.apply_control(m, time=yr)
                rev = m.assets * cfg.asset_turnover_ratio
                costs = rev * (1 - cfg.base_operating_margin)
                net_loss = 0
                if np.random.random() < 0.1:
                    loss = np.random.lognormal(14, 1)
                    net_loss = max(0, loss - min(loss, ins.get_total_coverage()))
                profit = rev - costs - net_loss
                m.assets *= (1 + profit / m.assets)
            rows.append({"final": m.assets,
                         "cagr": np.log(m.assets / cfg.initial_assets) / years})
        return pd.DataFrame(rows)

    hjb_res = run_mc(hjb_ctrl, manufacturer_config)
    sta_res = run_mc(static_ctrl, manufacturer_config)
    tv_res  = run_mc(tv_ctrl, manufacturer_config)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    for df, lbl in [(hjb_res, "HJB"), (sta_res, "Static"), (tv_res, "Time-Varying")]:
        axes[0].hist(df["final"] / 1e6, bins=20, alpha=0.5, label=lbl)
        axes[1].hist(df["cagr"] * 100, bins=20, alpha=0.5, label=lbl)
    axes[0].set_xlabel("Final Wealth ($M)")
    axes[0].set_title("Final Wealth Distribution")
    axes[0].legend()
    axes[1].set_xlabel("CAGR (%)")
    axes[1].set_title("Growth Rate Distribution")
    axes[1].legend()
    for ax in axes:
        ax.grid(True, alpha=0.3)
    plt.suptitle("Strategy Comparison", fontweight="bold")
    plt.tight_layout()
    plt.show()

    print(f"{'Strategy':<16s} {'Mean Wealth':>14s} {'Median CAGR':>12s}")
    for lbl, df in [("HJB", hjb_res), ("Static", sta_res), ("Time-Varying", tv_res)]:
        print(f"{lbl:<16s} ${df['final'].mean()/1e6:>12.2f}M {df['cagr'].median()*100:>10.2f}%")
else:
    print("Skipped (CI mode).")

## 6. Convergence Metrics

In [None]:
if not CI:
    cm = hjb_solver.compute_convergence_metrics()
    print(f"Max HJB residual : {cm['max_residual']:.6e}")
    print(f"Mean HJB residual: {cm['mean_residual']:.6e}")
    for name, stats in cm["policy_stats"].items():
        print(f"{name}: min=${stats['min']/1e6:.2f}M  max=${stats['max']/1e6:.2f}M  "
              f"mean=${stats['mean']/1e6:.2f}M")
else:
    print("Skipped (CI mode).")

## Key Takeaways

- **HJB provides globally optimal** insurance strategies that adapt to the company's current financial state.
- **Log utility** (ergodic) yields insurance policies that maximise long-term growth.
- The optimal limit generally **increases with wealth**, while retention tracks company resilience.
- HJB feedback control outperforms static and time-varying baselines on both median growth and tail risk.

## Next Steps

- [advanced/02_walk_forward_validation](02_walk_forward_validation.ipynb) -- out-of-sample strategy testing
- [advanced/03_advanced_convergence](03_advanced_convergence.ipynb) -- Monte Carlo convergence diagnostics
- [optimization/01_optimization_overview](../optimization/01_optimization_overview.ipynb) -- lighter-weight optimisers