In [None]:
import numpy as np
import matplotlib.pyplot as plt
from model import ConsumptionSavingModel
from algos import TISolver, EGMSolver, RolloutSolver
from plot import plot_policy_comparison, plot_policy_heat, plot_convergence

# Set random seed for reproducibility
np.random.seed(42)

# Model parameters (typical values from literature)
beta = 0.96  # discount factor
R = 1.04     # gross interest rate
sigma = 2.0  # risk aversion
b = 0.0      # borrowing constraint

# Create grids
k_min, k_max = 0.1, 20.0
M = 100  # number of capital grid points
k_grid = np.linspace(k_min, k_max, M)

# Income process parameters
w_mean = 1.0
w_std = 0.2
N = 5  # number of income states
w_grid = np.exp(np.linspace(-2*w_std, 2*w_std, N)) * w_mean

# Create Markov transition matrix for income (AR(1) approximation)
rho = 0.9  # persistence
P = np.zeros((N, N))
for i in range(N):
    for j in range(N):
        P[i,j] = np.exp(-(w_grid[j] - rho*w_grid[i])**2 / (2*w_std**2))
P = P / P.sum(axis=1, keepdims=True)

# Create model instance
model = ConsumptionSavingModel(beta, R, sigma, b, k_grid, w_grid, P)

# Create and run solvers
ti_solver = TISolver(model, tol=1e-6, max_iter=1000)
egm_solver = EGMSolver(model, tol=1e-6, max_iter=1000)
rollout_solver = RolloutSolver(model, tol=1e-6, max_iter=1000)

print("Solving using Time Iteration...")
c_ti = ti_solver.solve()
print("Solving using EGM...")
c_egm = egm_solver.solve()
print("Solving using Forward Rollout...")
c_rollout = rollout_solver.solve()

# Compare policies
plt.style.use('seaborn')
fig1 = plot_policy_comparison(model, 
                            [c_ti, c_egm, c_rollout],
                            ['Time Iteration', 'EGM', 'Forward Rollout'])
plt.show()

# Plot heatmaps
fig2 = plot_policy_heat(model, c_ti, 'Time Iteration')
plt.show()
fig3 = plot_policy_heat(model, c_egm, 'EGM')
plt.show()
fig4 = plot_policy_heat(model, c_rollout, 'Forward Rollout')
plt.show()

# Print max difference between methods
print(f"Max difference between TI and EGM: {np.max(np.abs(c_ti - c_egm)):.2e}")
print(f"Max difference between TI and Rollout: {np.max(np.abs(c_ti - c_rollout)):.2e}")
print(f"Max difference between EGM and Rollout: {np.max(np.abs(c_egm - c_rollout)):.2e}")

Solving using Time Iteration...
