# Adaptive Tensor Product Grids - Applied to HANK

The **choice of grids** when solving an economic model is somewhat of an **art**. Choosing grids in a smart way can substantially reduce the computational burden of solving a model leaving results unchanged. 

In **general equilibrium** and **structural estimation** applications, the model needs to be solved repeatedly, and it is therefore particularily benefial to spend some time initially on choosing the grids wisely.

In this notebook, I propose a method for choosing **tensor product grids in an adaptive manner**. Algorithms exists for choosing adaptive sparse grids, but this requires that the solution method is changed to handle these more complex grids. This is not always possible, or might be cumbersome.

The main idea is to **remove grid points where the function of interest is linear**.

The **algorithm** in words is:

1. Solve the model on a **fine tensor product grid** of your arbitrary choice
2. **Initialize** the **adaptive tensor product grid** to the fine grid
3. For each *second* grid point (not including the last) in each dimension, compute the **maximum absolute relative errror leaving this grid point out and using linear interpolation instead**. The  error is computed using the result at the fine tensor product grid and interpolation across the values in the current adaptive tensor product grid. The maximum is computed across all combinations of the grid points in the remaining dimensions.
4. **Remove grids points** where the maximum absolute relative error is below some tolerance, and **update the adaptive tensor product grids** accordingly.
5. **Stop if** no grid points were removed
6. (*Optional*) **Re-solve the model** on the adaptive tensor product grid (else use the values from the solution fine grid).
7. Return to step 3

**Note I:** The above algorithm only requires solving the model once if step 6 is skipped.

**Note II:** Re-solving the model will not result in the same values. The optional step 6 solves this.

**Note III:** A termination based on a simulation result could be added.

# Setup

In [1]:
from consav.runtools import write_numba_config
write_numba_config(disable=0,threads=20,threading_layer='omp')

In [2]:
import os
import time
import numpy as np

import matplotlib.pyplot as plt
plt.style.use('seaborn-whitegrid')
colors = [x['color'] for x in plt.style.library['seaborn']['axes.prop_cycle']]
markers = ['s','P','D','v','^','*']

%load_ext autoreload
%autoreload 2

from consav.misc import elapsed
from consav import linear_interp
from TwoAssetModelCont import TwoAssetModelContClass

# Solve HANK for fine grid in $a$-dimension

In [3]:
t0 = time.time()
model = TwoAssetModelContClass(name='HANK',like_HANK=True)
print(f'model created in {elapsed(t0)}')

model created in 28.3 secs


In [4]:
# a. set fine grid in a dimension
model.par.Na = 200
model.par.KFEtol = 1e-8 # with larger (original) KFEtol, wont converge
# b. solve for fine grid
model.solve()
# c. save model solution for plotting later
model.calculate_moments(do_MPC=True)
model_fine = model.copy()

Grids created in 5.2 secs
Solution prepared in 5.8 secs
Solving HJB:
    1: 84.1700730791792466
    2: 38.2038971698338514
    3: 31.1125314480001975
    4: 30.4469082546748666


KeyboardInterrupt: 

# Adaptive tensor product grid algorithm

In [None]:
tol = 1e-4

par = model.par
sol = model.sol

# fine indices
grid_a = par.grid_a
grid_ia = np.arange(par.Na,dtype=np.int64)
grid_ia_approx = np.arange(par.Na,dtype=np.int64)
sol_true_s = sol.s.copy()
sol_approx_s = sol.s.copy()
sol_true_d = sol.d.copy()
sol_approx_d = sol.d.copy()
sol_true_d_adj = sol.d_adj.copy()
sol_approx_d_adj = sol.d_adj.copy()

In [None]:
# iterate
it = 0
while it < 100:
    
    # a. number of grid points
    Na_approx  = grid_ia_approx.size
    
    # b. vectors for including or not
    Ia_approx = np.ones(Na_approx,dtype=np.bool_)  

    # c. a dimension
    for ia_ in range(1,Na_approx,2):
        
        if ia_ == Na_approx-1: continue # interpolation not possible
        
        # a. solution at fine
        ia = grid_ia_approx[ia_]
        true_s = sol_true_s[:,ia,:]
        true_d = sol_true_d[:,ia,:]
        true_d_adj = sol_true_d_adj[:,ia,:]
        
        # b. neighbors in adaptive
        ia_u = grid_ia_approx[ia_+1]
        ia_d = grid_ia_approx[ia_-1]
        
        # c. interpolation
        w = (grid_a[ia_u]-grid_a[ia])/(grid_a[ia_u]-grid_a[ia_d])
        approx_s = w*sol_true_s[:,ia_d,:] + (1-w)*sol_true_s[:,ia_u,:]
        approx_d = w*sol_true_d[:,ia_d,:] + (1-w)*sol_true_d[:,ia_u,:]
        approx_d_adj = w*sol_true_d_adj[:,ia_d,:] + (1-w)*sol_true_d_adj[:,ia_u,:]
        
        # d. error
        max_abs_error_s = np.max(np.abs(true_s-approx_s)/(true_s or not true_s))
        max_abs_error_d = np.max(np.abs(true_d-approx_d)/(true_d or not true_d))
        max_abs_error_d_adj = np.max(np.abs(true_d_adj-approx_d_adj/(true_d_adj or not true_d_adj))
        max_abs_error = np.max(np.array([max_abs_error_s,max_abs_error_d,max_abs_error_d_adj]))
        Ia_approx[ia_] = True if max_abs_error > tol else False
    
    # e. update adaptive
    grid_ia_approx = grid_ia_approx[Ia_approx]
    
    sol_approx_s = sol_approx_s[:,Ia_approx,:]
    sol_approx_d = sol_approx_d[:,Ia_approx,:]
    sol_approx_d_adj = sol_approx_d_adj[:,Ia_approx,:]

    # f. check
    removed = np.sum(~Ia_approx)
    share = grid_ia_approx.size/grid_a.size
    print(f'{it}: {removed:6d} grids points removed, share of grid points remaining {share:.4f}')
    
    if removed == 0: break
    it += 1

# constructive adaptive
grid_a_approx = grid_a[grid_ia_approx]

**Effectiveness:** We see that approx. 15% of grid points are removed.

In [None]:
print(f'grids points for a-dimension is {grid_a_approx.size} [fine: {grid_a.size}]')

# Approximation error

In [None]:
from consav import linear_interp
max_rel_abs_error = -np.inf
for iz,z in enumerate(par.grid_z):
    for ia,a in enumerate(par.grid_a):
        for ib,b in enumerate(par.grid_b):
            interp_s = linear_interp.interp_3d(par.grid_z,grid_a_approx,par.grid_b,sol_approx_s,z,a,b) 
            interp_d = linear_interp.interp_3d(par.grid_z,grid_a_approx,par.grid_b,sol_approx_d,z,a,b) 
            interp_d_adj = linear_interp.interp_3d(par.grid_z,grid_a_approx,par.grid_b,sol_approx_d_adj,z,a,b) 
            true_s = sol_true_s[iz,ia,ib]
            true_d = sol_true_d[iz,ia,ib]
            true_d_adj = sol_true_d_adj[iz,ia,ib]
            max_abs_error_s = np.fmax(max_abs_error_s,np.abs(true_s-interp_s)/(true_s or not true_s))
            max_abs_error_d = np.fmax(max_abs_error_d,np.abs(true_d-interp_d)/(true_d or not true_d))
            max_abs_error_d_adj = np.fmax(max_abs_error_d_adj,np.abs(true_d_adj-interp_d_adj)/(true_d_adj or not true_d_adj))
            max_abs_error = np.max(np.array([max_abs_error_s,max_abs_error_d,max_abs_error_d_adj]))
        
print(f'max_rel_abs_error = {max_abs_error:.8f}')

**Precision:**  Somewhat better than the tolerance of 1e-4.

# Figures

**Re-solve model for small grid**

In [None]:
# a. set reduced grid in a dimension
model.par.Na = grid_a_approx.size
# b. solve for fine grid
model.solve(load_grid=True,grid_a=grid_a_approx)
# c. save for plotting
model.calculate_moments(do_MPC=True)
model_small = model.copy()

**Plot comparison**

In [None]:
fig = plt.figure(figsize=(12,6))

ax = fig.add_subplot(2,2,1)
ax.set_xscale('symlog')
ax.plot(model_fine.par.grid_a*100,model_fine.moms['a_margcum'])
ax.plot(model_small.par.grid_a*100,model_small.moms['a_margcum'])
ax.set_xlabel('$100 \cdot a_t$')
ax.set_ylabel(f'CDF')
ax.set_ylim([-0.01,1.01])
ax.legend(['Fine','Small'],frameon=True)

ax = fig.add_subplot(2,2,2)
ax.set_xscale('symlog')
ax.plot(model_fine.par.grid_b*100,model_fine.moms['b_margcum'])
ax.plot(model_small.par.grid_b*100,model_small.moms['b_margcum'])
ax.set_xlabel('$100 \cdot b_t$')
ax.set_ylabel(f'CDF')
ax.set_ylim([-0.01,1.01])
ax.legend(['Fine','Small'],frameon=True)

print(f"AY with a fine grid: {model_fine.moms['AY']:.8f}")
print(f"AY with a small grid: {model_small.moms['AY']:.8f}")

print(f"Avg. MPC with a fine grid: {model_fine.moms['MPC']:.8f}")
print(f"Avg. MPC with a small grid: {model_small.moms['MPC']:.8f}")