Minimal shocks retirement model
===========================

This notebook presents the step by step solution for a backward induction problem with a discrete and continuout choice variable for one period based on model_retirement.m

The comparison file for verification of the correct output of each step is retirement_minimal_shocks.m

In [1]:
import numpy as np
import math
import scipy.stats as scps
import matplotlib.pyplot as plt
import pickle

In [2]:
%load_ext autoreload

%autoreload 2

import numpy as np
from collections import namedtuple
from scipy.interpolate import InterpolatedUnivariateSpline
from scipy.stats import norm
from copy import *
from numpy.matlib import * 
from scipy.optimize import *
from dc_egm import interpolate, chop, upper_envelope,diff
from copy import *
import scipy.interpolate as scin


In [3]:
# There potentially is a pythonic solution that would make this function obsolete
def quadrature(n, lbnd, ubnd):
    
    x1 = lbnd
    x2 = ubnd
    x = np.zeros(n)
    w = x
    EPS = 3e-14
    m = int(round((n+EPS)/2)) # flor function in matlab, rounding to the lower integer
    xm = (x2+x1)/2
    xl = (x2-x1)/2
    z1 = 1e99

    x = np.full(n+1, np.nan)
    w = np.full(n+1, np.nan)

    i = 1

    while i <= m:

        z = math.cos(math.pi*(i - 0.25)/(n + 0.5))

        while abs(z - z1) > EPS:
            p1 = 1
            p2 = 0
            j = 1

            while j <= n:
                p3 = p2
                p2 = p1
                p1 = ((2*j -1)*z*p2 - (j-1)*p3)/j
                j += 1

            pp = n*(z*p1 - p2)/(z*z - 1)
            z1 = z
            z = z1 - p1/pp

        x[i] = xm - xl*z
        x[n + 1 - i] = xm + xl*z
        w[i] = 2*xl/((1-z*z)*pp*pp)
        w[n + 1 - i] = w[i]
        i += 1

    x = x[1:]
    w = w[1:]

    return x, w

In [81]:
# Model parameters (default)

Tbar = 25 # number of periods (fist period is t=1) 
ngridm = 500 # number of grid points over assets
mmax = 50 # maximum level of assets
expn = 5 # number of quadrature points used in calculation of expectations
nsims = 10 # number of simulations
init = [10, 30] # interval of the initial wealth
r = 0.05 # interest rate
df = 0.95 # discount factor
sigma = 0.25 # sigma parameter in income shocks
duw = 0.35 #disutility of work
theta = 1.95 # CRRA coefficient (log utility if ==1)
inc0 = 0.75 # income equation: constant
inc1 = 0.04 # income equation: age coef
inc2 = 0.0002 # income equation: age^2 coef
cfloor =0.001 # consumption floor (safety net in retirement)
lambda_ = 0.02 # scale of the EV taste shocks 

In [92]:
# Functions: utility and budget constraint

def util(consumption, working):
    """CRRA utility"""
    
    u = (consumption**(1-theta)-1)/(1-theta)
    u = u - duw*(working)
    
    return u

def mutil(consumption):
    """Marginal utility CRRA"""
    
    mu = consumption**(-theta)
    
    return mu

def imutil(mutil):
    """Inverse marginal utility CRRA
    Consumption as a function of marginal utility"""
    
    cons = mutil**(-1/theta)
    
    return cons


def income(it, shock):
    """Income in period it given normal shock"""
    
    age = it + 20 # matlab strats counting at 1, Python at zero
    print(inc0 + inc1*age - inc2*age**2)
    w = np.exp(inc0 + inc1*age - inc2*age**2 + shock)
    
    return w


def budget(it, savings, shocks, working):
    """Wealth, M_{t+1} in period t+1, where it == t
    
    Arguments
    ---------
        savings: np.array of savings with length ngridm
        shocks: np.array of shocks with length expn
    
    Returns
    -------
        w1: matrix with dimension (expn, ngridm) of all possible
    next period wealths
    """
    
    w1 = np.full((ngridm, expn), income(it, shocks)*working).T + np.full((expn, ngridm), savings*(1+r))
    
    return w1

def mbudget():
    """Marginal budget:
    Derivative of budget with respect to savings"""
    
    mw1 = np.full((expn, ngridm), (1+r))
    
    return mw1

In [93]:
# Value function for worker
# interpolate and extrapolate are potentially substitutable by the interpolate function below

def value_function(working, it, x):
    """Value function calculation for the """
    
    x = x.flatten('F')
    
    res = np.full(x.shape, np.nan)
    
    # Mark constrained region
    mask = x < value[1, 0, working, it] # credit constraint between 1st (M_{t+1) = 0) and second point (A_{t+1} = 0)
    
    # Calculate t+1 value function in the constrained region
    res[mask] = util(x[mask], working) + df*value[0, 1, working, it]
    
    # Calculate t+1 value function in non-constrained region
    # interpolate
    res[~mask] = np.interp(x[~mask], value[:, 0, working, it], value[:, 1, working, it])
    # extrapolate
    slope = (value[-2, 1, working, it] - value[-1, 1, working, it])/(value[-2, 0, working, it] - value[-1, 0, working, it])
    intercept = value[-1, 1, working, it] - value[-1, 0, working, it]*slope
    res[res == np.max(value[:, 1, working, it])] = intercept + slope*x[res == np.max(value[:, 1, working, it])]

    return res

In [94]:
# Calculation of probability to choose work, if a worker today
def chpr(x):
    """Calculate the probability of choosing work in t+1
    for state worker given t+1 value functions"""
    
    mx = np.amax(x, axis = 0)
    mxx = x - mx
    res = np.exp(mxx[1, :]/lambda_)/np.sum(np.exp(mxx/lambda_), axis = 0)
    
    return res

In [95]:
# Expected value function calculation in state worker
def logsum(x):
    """Calculate expected value function"""
    
    mx = np.amax(x, axis = 0)
    mxx = x - mx
    res = mx + lambda_*np.log(np.sum(np.exp(mxx/lambda_), axis = 0))
    
    return res

m0 parametrisation - minimal shocks
--------------------------------------------

In [96]:
# Minimal shocks
sigma = 0
lambda_ = 2.2204e-16

In [97]:
# Initialize grids
quadp, quadw = quadrature(expn,0,1)
quadstnorm = scps.norm.ppf(quadp)
savingsgrid = np.linspace(0, mmax, ngridm)

In [98]:
# Initialize containers

# Container for endogenous gridpoints of (beginning-of-period) assets
# and corresponding consumption
policy = np.full((ngridm + 1, 2, 2, Tbar), np.nan)

# Value functions
value = np.full((ngridm + 1, 2, 2, Tbar), np.nan)

In [99]:
# Handling of last period and first elements
# policy
policy[1:, 0, 0, Tbar-1] = savingsgrid
policy[1:, 0, 1, Tbar-1] = savingsgrid
policy[1:, 1, :, Tbar-1] = policy[1:, 0, :, Tbar-1]
policy[0, :, :, :] = 0.00

In [100]:
# value
value[2:, 0, :, Tbar-1] = util(policy[2:, 0, :, Tbar-1], 0)
value[2:, 1, :, Tbar-1] = util(policy[2:, 0, :, Tbar-1], 1)
value[0:2, :, :, Tbar -1] = 0.00
value[0, 0, :, :] = 0.00

In [101]:
# Solve workers problem with EGM for period T-1, T-2 and T-3
# The EGM step already yields the same result as the matlab code for T-1 and T-2
# Difference in result for T-3 => DC step has to be performed after the EGM step
for period in [23]:
    
    for choice in[0, 1]:
        # M_{t+1}
        print(quadstnorm*sigma)
        wk1 = budget(period, savingsgrid, quadstnorm*sigma, choice)
        wk1[wk1 < cfloor] = cfloor
        print(wk1)
        # Value function
        vl1 = np.full((2, ngridm * expn), np.nan)

        if period + 1 == Tbar - 1:
            vl1[0, :] = util(wk1, 0).flatten('F')
            vl1[1, :] = util(wk1, 1).flatten('F')
        else:
            vl1[1, :] = value_function(1, period + 1, wk1) # value function in t+1 if choice in t+1 is work
            vl1[0, :] = value_function(0, period + 1, wk1) # value function in t+1 if choice in t+1 is retiree

        # Probability of choosing work in t+1
        if choice == 0:
            # Probability of choosing work in t+1
            pr1 = np.full(2500, 0.00)
        else:
            pr1 = chpr(vl1)

        # Next period consumption based on interpolation and extrapolation
        # given grid points and associated consumption
        cons10 = np.interp(wk1, policy[:, 0, 0, period + 1], policy[:, 1, 0, period+1])
        # extrapolate linearly right of max grid point
        slope = (policy[-2, 1, 0, period + 1] - policy[-1, 1, 0, period + 1])/(policy[-2, 0, 0, period + 1] - policy[-1, 0, 0, period + 1])
        intercept = policy[-1, 1, 0, period + 1] - policy[-1, 0, 0, period + 1]*slope
        cons10[cons10 == np.max(policy[:, 1, 0, period+1])] = intercept + slope*wk1[cons10 == np.max(policy[:, 1, 0, period+1])]
        cons10_flat = cons10.flatten('F')

        cons11 = np.interp(wk1, policy[:, 0, 1, period + 1], policy[:, 1, 1, period+1])
        # extrapolate linearly right of max grid point
        slope = (policy[-2, 1, 1, period + 1] - policy[-1, 1, 1, period + 1])/(policy[-2, 0, 1, period + 1] - policy[-1, 0, 1, period + 1])
        intercept = policy[-1, 1, 1, period + 1] - policy[-1, 0, 1, period + 1]*slope
        cons11[cons11 == np.max(policy[:, 1, 1, period+1])] = intercept + slope*wk1[cons11 == np.max(policy[:, 1, 1, period+1])]
        cons11_flat = cons11.flatten('F')

        # Marginal utility of expected consumption next period
        mu1 = pr1*mutil(cons11_flat) + (1 - pr1)*mutil(cons10_flat)

        # Marginal budget
        # Note: Constant for this model formulation (1+r)
        mwk1 = mbudget()

        # RHS of Euler eq., p 337, integrate out error of y
        rhs = np.dot(quadw.T, np.multiply(mu1.reshape(wk1.shape, order = 'F'), mwk1))
        # Current period consumption from Euler equation
        
        cons0 = imutil(df*rhs)
        # Update containers related to consumption
        policy[1:, 1, choice, period] = cons0
        policy[1:, 0, choice, period] = savingsgrid + cons0


        if choice == 1:
            # Calculate continuation value
            ev = np.dot(quadw.T, logsum(vl1).reshape(wk1.shape, order = 'F'))
        else:
            ev = np.dot(quadw.T, vl1[0, :].reshape(wk1.shape, order = 'F'))

        # Update value function related containers
        value[1:, 1, choice, period] = util(cons0, choice) + df*ev
        value[1:, 0, choice, period] = savingsgrid + cons0
        value[0, 1, choice, period] = ev[0]

[-0. -0.  0.  0.  0.]
2.1001999999999996
[[1.00000000e-03 1.05210421e-01 2.10420842e-01 ... 5.22895792e+01
  5.23947896e+01 5.25000000e+01]
 [1.00000000e-03 1.05210421e-01 2.10420842e-01 ... 5.22895792e+01
  5.23947896e+01 5.25000000e+01]
 [1.00000000e-03 1.05210421e-01 2.10420842e-01 ... 5.22895792e+01
  5.23947896e+01 5.25000000e+01]
 [1.00000000e-03 1.05210421e-01 2.10420842e-01 ... 5.22895792e+01
  5.23947896e+01 5.25000000e+01]
 [1.00000000e-03 1.05210421e-01 2.10420842e-01 ... 5.22895792e+01
  5.23947896e+01 5.25000000e+01]]
[-0. -0.  0.  0.  0.]
2.1001999999999996
[[ 8.16780331  8.27301373  8.37822415 ... 60.45738247 60.56259289
  60.66780331]
 [ 8.16780331  8.27301373  8.37822415 ... 60.45738247 60.56259289
  60.66780331]
 [ 8.16780331  8.27301373  8.37822415 ... 60.45738247 60.56259289
  60.66780331]
 [ 8.16780331  8.27301373  8.37822415 ... 60.45738247 60.56259289
  60.66780331]
 [ 8.16780331  8.27301373  8.37822415 ... 60.45738247 60.56259289
  60.66780331]]


In [None]:
with open('m0_value.pkl', 'rb') as file : 
    m0_value = pickle.load(file)

with open('m0_policy.pkl', 'rb') as file : 
    m0_policy = pickle.load(file)

In [None]:
# Verify that EGM already yields correct solution for t-1 and t-2 in value
np.testing.assert_almost_equal(m0_value[0:501, 1, 1, 22], value[:, 1, 1, 22])

In [None]:
np.testing.assert_almost_equal(m0_value[0:501, 1, 0, 22], value[:, 1, 0, 22])

In [None]:
# Verify that EGM already yields correct solution for t-1 and t-2 in policy
np.testing.assert_almost_equal(m0_policy[0:501, 1, 1, 22], policy[:, 1, 1, 22])

In [None]:
np.testing.assert_almost_equal(m0_policy[0:501, 1, 0, 22], policy[:, 1, 0, 22])

In [None]:
# Difference in t-3
value[0:100, 1, 0, 21]

In [None]:
m0_value[:100, 1, 0, 21]

In [None]:
policy[:, 0, 1, period]

Start secondary evelope
-----------------------------

To Do: Secondary envelope as a function and not line-by-line
To Do: Handling of discontinuity in the credit constrained region, retirement_model.m 137-148

In [None]:
def secondary_envelope(obj):
    result = []
    newdots = []
    index_removed = []

    for k in range(obj.shape[0]):
        sect = []
        cur = deepcopy(obj[k])
        # Find discontinutiy
        ii = cur[0][1:]>cur[0][:-1]
        # Substitute for matlab while true loop
        i=1
        while_operator = True
        while while_operator:
            j = np.where([ii[counter] != ii[0] for counter in range(len(ii))])[0]
            if len(j) == 0:
                if i >1:
                    sect += [cur]
                while_operator=False
            else:
                j = min(j)

                sect_container, cur = chop(cur, j, True)
                sect += [sect_container]
                ii = ii[j:]
                i += 1
        # yes we can use np.sort instead of the pre-specified function from the upper envelope notebook
        if len(sect) > 1:
            sect = [np.sort(i) for i in sect]
            result_container, newdots_container = upper_envelope(sect, True, True)
            index_removed_container = diff(obj[k], result_container, 10)
        else:
            result_container = obj[k]
            index_removed_container = np.array([])
            newdots_container = np.stack([np.array([]), np.array([])])
        
        result += [result_container]
        newdots += [newdots_container]
        index_removed += [index_removed_container]
        
    return result, newdots, index_removed


In [None]:
for point in [21]:
    obj = np.stack([value[1:, :, 1, point].T, value[1:, :, 0, point].T])
    r1, new, rem = secondary_envelope(obj)
    # Values are now equal 
    np.testing.assert_almost_equal(m0_value[0:501, 0, 1, point].T[~np.isnan(m0_value[0:501, :, 1, point].T[0])][1:], r1[0][0])
    np.testing.assert_almost_equal(m0_value[0:501, 1, 1, point].T[~np.isnan(m0_value[0:501, :, 1, point].T[0])][1:], r1[0][1])
    np.testing.assert_almost_equal(m0_value[0:501, 0, 0, point].T[~np.isnan(m0_value[0:501, :, 0, point].T[0])][1:], r1[1][0])
    np.testing.assert_almost_equal(m0_value[0:501, 1, 0, point].T[~np.isnan(m0_value[0:501, :, 0, point].T[0])][1:], r1[1][1])

In [None]:
value[:, 0, 1, 21] = np.append(np.append(np.full(18, np.nan),np.array([0.0])) , r1[0][0])
value[:, 1, 1, 21] = np.append(np.append(np.full(18, np.nan),value[0, 1, 1, 21]) , r1[0][1])


In [None]:
# Solve workers problem with EGM for period T-1, T-2 and T-3
# The EGM step already yields the same result as the matlab code for T-1 and T-2
# Difference in result for T-3 => DC step has to be performed after the EGM step
for period in [20]:
    
    for choice in[0, 1]:
        # M_{t+1}
        wk1 = budget(period, savingsgrid, quadstnorm*sigma, choice)
        wk1[wk1 < cfloor] = cfloor

        # Value function
        vl1 = np.full((2, ngridm * expn), np.nan)

        if period + 1 == Tbar - 1:
            vl1[0, :] = util(wk1, 0).flatten('F')
            vl1[1, :] = util(wk1, 1).flatten('F')
        else:
            vl1[1, :] = value_function(1, period + 1, wk1) # value function in t+1 if choice in t+1 is work
            vl1[0, :] = value_function(0, period + 1, wk1) # value function in t+1 if choice in t+1 is retiree

        # Probability of choosing work in t+1
        if choice == 0:
            # Probability of choosing work in t+1
            pr1 = np.full(2500, 0.00)
        else:
            pr1 = chpr(vl1)

        # Next period consumption based on interpolation and extrapolation
        # given grid points and associated consumption
        cons10 = np.interp(wk1, policy[:, 0, 0, period + 1], policy[:, 1, 0, period+1])
        # extrapolate linearly right of max grid point
        slope = (policy[-2, 1, 0, period + 1] - policy[-1, 1, 0, period + 1])/(policy[-2, 0, 0, period + 1] - policy[-1, 0, 0, period + 1])
        intercept = policy[-1, 1, 0, period + 1] - policy[-1, 0, 0, period + 1]*slope
        cons10[cons10 == np.max(policy[:, 1, 0, period+1])] = intercept + slope*wk1[cons10 == np.max(policy[:, 1, 0, period+1])]
        cons10_flat = cons10.flatten('F')

        cons11 = np.interp(wk1, policy[:, 0, 1, period + 1], policy[:, 1, 1, period+1])
        # extrapolate linearly right of max grid point
        slope = (policy[-2, 1, 1, period + 1] - policy[-1, 1, 1, period + 1])/(policy[-2, 0, 1, period + 1] - policy[-1, 0, 1, period + 1])
        intercept = policy[-1, 1, 1, period + 1] - policy[-1, 0, 1, period + 1]*slope
        cons11[cons11 == np.max(policy[:, 1, 1, period+1])] = intercept + slope*wk1[cons11 == np.max(policy[:, 1, 1, period+1])]
        cons11_flat = cons11.flatten('F')

        # Marginal utility of expected consumption next period
        mu1 = pr1*mutil(cons11_flat) + (1 - pr1)*mutil(cons10_flat)

        # Marginal budget
        # Note: Constant for this model formulation (1+r)
        mwk1 = mbudget()

        # RHS of Euler eq., p 337, integrate out error of y
        rhs = np.dot(quadw.T, np.multiply(mu1.reshape(wk1.shape, order = 'F'), mwk1))
        # Current period consumption from Euler equation
        
        cons0 = imutil(df*rhs)
        # Update containers related to consumption
        policy[1:, 1, choice, period] = cons0
        policy[1:, 0, choice, period] = savingsgrid + cons0


        if choice == 1:
            # Calculate continuation value
            ev = np.dot(quadw.T, logsum(vl1).reshape(wk1.shape, order = 'F'))
        else:
            ev = np.dot(quadw.T, vl1[0, :].reshape(wk1.shape, order = 'F'))

        # Update value function related containers
        value[1:, 1, choice, period] = util(cons0, choice) + df*ev
        value[1:, 0, choice, period] = savingsgrid + cons0
        value[0, 1, choice, period] = ev[0]

In [None]:
for point in [20]:
    obj = np.stack([value[1:, :, 1, point].T, value[1:, :, 0, point].T])
    r1, new, rem = secondary_envelope(obj)


In [None]:
r1[0][0][:10], m0_value[0:501, 0, 1, 21].T[~np.isnan(m0_value[0:501, :, 1, 21].T[0])][1:11]

In [None]:
import scipy.interpolate as scin
interpolation = scin.interp1d(r1[0][0], r1[0][1], bounds_error=False, fill_value='extrapolate')
interpolation(0.0)

In [None]:
len(r1[0][0])

In [None]:
np.append(np.full(18, np.nan), r1[0][0])

In [None]:
# Discontinuity in values
obj[0, 0, 24:50]

In [None]:
range(len(ii))

In [None]:
i = 1
sect = []

In [None]:
len(ii)

In [None]:
# Explore wether one can simply use numpy for sorting insted of the translated matlab function
sect[1] = np.sort(sect[1])
sect[1]

Upper envelope

In [None]:
# aux_function used in upper_evelope below
def aux_function(x, obj1, obj2):
    x = [x]
    value, extr = np.subtract(interpolate(x,obj1), interpolate(x, obj2))
    return value

In [None]:
# interpolate used in upper_evelope below
def interpolate(xx, obj, one=False):    
    if not one:
        interpolation = InterpolatedUnivariateSpline(obj[0], obj[1], k=1)
        container = interpolation(xx)
        extrapolate = [True if (i>max(obj[0])) |(i<min(obj[0])) else False for i in xx]
    else:
        container = []
        extrapolate = []
        
        for poly  in obj:
            interpolation = InterpolatedUnivariateSpline(poly[0], poly[1], k=1)
            container += [interpolation(xx)]
            extrapolate += [np.array([True if (i>max(poly[0])) |(i<min(poly[0])) else False for i in xx])]
    return container, extrapolate

In [None]:
# Perform upper envelope calculation
result_container, newdots_container = upper_envelope(sect, True, True)

In [None]:
# Array of same length as MatLab code
# Same number of points removed
result_container[0].shape

In [None]:
# Verify result for x values of value after upper envelope
# Same points removed as in MatLab code
np.testing.assert_array_almost_equal(m0_value[19:501, 1, 0, 21], result_container[0])

In [None]:
# Verify result for y values of value after upper envelope
# Same points removed as in MatLab code
np.testing.assert_array_almost_equal(m0_value[19:501, 0, 0, 21], result_container[1])

In [None]:
# Verify output of result_inter
# One new point added, one intersection, same as MatLab code
newdots_container

Finish up secondary envelope

diff(obj, result_container)

!!! Function not working with this input

In [None]:
# Find indexes of missing elements
missing_elements = np.setdiff1d(obj[0], result_container[0])
missing_elements

In [None]:
indexremoved = []

for value in missing_elements:
    indexremoved.append(obj[0].tolist().index(value))

indexremoved = np.array(sort(indexremoved))

In [None]:
# Same indexes of removed points as in MatLab code
indexremoved

Back to retirement_model.m solve_dsegm line 149ff.

In [None]:
len(indexremoved) > 0

In [None]:
# All points below
# Note what MatLab function find is doing
# If one simply does "<" in Python, result would be wrong
# Current workaround might not be robust to all cases - find a better wat to pythonise MatLabs find function
j = arange(0, (np.where(policy[:, 0, 1, 21] > newdots_container[0][0]))[0][0])
j

In [None]:
# Points that were not deleted
j_new = np.setdiff1d(j, indexremoved)
j_new

In [None]:
j = max(j_new)
j

In [None]:
# Potentially a better way of setting j exists, such that indexes here are simplified
new_left = np.interp(newdots_container[0][0], policy[j+1:j+3, 0, 1, 21], policy[j+1:j+3, 1, 1, 21])
new_left

In [None]:
# Perform similar operation for the upper/right side
# All comments from above apply here
j = np.arange(np.where(policy[:, 0, 1, 21] < newdots_container[0][0])[0][-1], policy.shape[0]+1)
j

In [None]:
j_new = np.setdiff1d(j, indexremoved)
j_new

In [None]:
j = min(j_new)
j

In [None]:
new_right = np.interp(newdots_container[0][0], policy[j:j+2, 0, 1, period], policy[j:j+2, 1, 1, period])
new_right

In [None]:
# Remove inferior points from policy
# Means: Remove all points with indexes in indexremoved
policy_thinout_x = policy[:, 0, 1, period].tolist()
policy_thinout_y = policy[:, 1, 1, period].tolist()

In [None]:
del policy_thinout_x[indexremoved[0]+1 : indexremoved[-1]+2]
del policy_thinout_y[indexremoved[0]+1 : indexremoved[-1]+2]

In [None]:
# Add new point twice
policy_thinout_x.append(newdots_container[0][0] - 1e3*2.2204e-16)
policy_thinout_x.append(newdots_container[0][0])

In [None]:
# Add new point twice
policy_thinout_y.append(new_left)
policy_thinout_y.append(new_right)

In [None]:
len(policy_thinout_x)

In [None]:
new_policy = np.full((2, ngridm + 1), np.nan)
new_policy[0, - len(policy_thinout_x):] = np.array(policy_thinout_x)
new_policy[1, - len(policy_thinout_y):] = np.array(policy_thinout_y)

In [None]:
# Ensure that points are in the right position
new_policy = np.sort(new_policy)

In [None]:
# Verify correct solution for t-3 in policy
np.testing.assert_almost_equal(m0_policy[17:501, 1, 0, 21], new_policy[0, 0:484])

In [None]:
# Finish up
policy[:, :, 1, period] = new_policy.T