Minimal shocks retirement model
===========================

This notebook presents the step by step solution for a backward induction problem with a discrete and continuout choice variable for one period based on model_retirement.m

The comparison file for verification of the correct output of each step is retirement_minimal_shocks.m

In [223]:
import numpy as np
import math
import scipy.stats as scps
import matplotlib.pyplot as plt
import pickle

In [224]:
%load_ext autoreload

%autoreload 2

import numpy as np
from collections import namedtuple
from scipy.interpolate import InterpolatedUnivariateSpline
from scipy.stats import norm
from copy import *
from numpy.matlib import * 
from scipy.optimize import *
from dc_egm import interpolate, chop, upper_envelope,diff
from copy import *
import scipy.interpolate as scin


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [225]:
# There potentially is a pythonic solution that would make this function obsolete
def quadrature(n, lbnd, ubnd):
    
    x1 = lbnd
    x2 = ubnd
    x = np.zeros(n)
    w = x
    EPS = 3e-14
    m = int(round((n+EPS)/2)) # flor function in matlab, rounding to the lower integer
    xm = (x2+x1)/2
    xl = (x2-x1)/2
    z1 = 1e99

    x = np.full(n+1, np.nan)
    w = np.full(n+1, np.nan)

    i = 1

    while i <= m:

        z = math.cos(math.pi*(i - 0.25)/(n + 0.5))

        while abs(z - z1) > EPS:
            p1 = 1
            p2 = 0
            j = 1

            while j <= n:
                p3 = p2
                p2 = p1
                p1 = ((2*j -1)*z*p2 - (j-1)*p3)/j
                j += 1

            pp = n*(z*p1 - p2)/(z*z - 1)
            z1 = z
            z = z1 - p1/pp

        x[i] = xm - xl*z
        x[n + 1 - i] = xm + xl*z
        w[i] = 2*xl/((1-z*z)*pp*pp)
        w[n + 1 - i] = w[i]
        i += 1

    x = x[1:]
    w = w[1:]

    return x, w

In [226]:
# Model parameters (default)

Tbar = 25 # number of periods (fist period is t=1) 
ngridm = 500 # number of grid points over assets
mmax = 50 # maximum level of assets
expn = 5 # number of quadrature points used in calculation of expectations
nsims = 10 # number of simulations
init = [10, 30] # interval of the initial wealth
r = 0.05 # interest rate
df = 0.95 # discount factor
sigma = 0.25 # sigma parameter in income shocks
duw = 0.35 #disutility of work
theta = 1.95 # CRRA coefficient (log utility if ==1)
inc0 = 0.75 # income equation: constant
inc1 = 0.04 # income equation: age coef
inc2 = 0.0002 # income equation: age^2 coef
cfloor =0.001 # consumption floor (safety net in retirement)
lambda_ = 0.02 # scale of the EV taste shocks 

In [227]:
# Functions: utility and budget constraint

def util(consumption, working):
    """CRRA utility"""
    
    u = (consumption**(1-theta)-1)/(1-theta)
    u = u - duw*(working)
    
    return u

def mutil(consumption):
    """Marginal utility CRRA"""
    
    mu = consumption**(-theta)
    
    return mu

def imutil(mutil):
    """Inverse marginal utility CRRA
    Consumption as a function of marginal utility"""
    
    cons = mutil**(-1/theta)
    
    return cons


def income(it, shock):
    """Income in period it given normal shock"""
    
    age = it + 20 # matlab strats counting at 1, Python at zero
    w = np.exp(inc0 + inc1*age - inc2*age**2 + shock)
    
    return w


def budget(it, savings, shocks, working):
    """Wealth, M_{t+1} in period t+1, where it == t
    
    Arguments
    ---------
        savings: np.array of savings with length ngridm
        shocks: np.array of shocks with length expn
    
    Returns
    -------
        w1: matrix with dimension (expn, ngridm) of all possible
    next period wealths
    """
    
    w1 = np.full((ngridm, expn), income(it + 1, shocks)*working).T + np.full((expn, ngridm), savings*(1+r))
    
    return w1

def mbudget():
    """Marginal budget:
    Derivative of budget with respect to savings"""
    
    mw1 = np.full((expn, ngridm), (1+r))
    
    return mw1

In [228]:
# Value function for worker
# interpolate and extrapolate are potentially substitutable by the interpolate function below

def value_function(working, it, x):
    """Value function calculation for the """
    
    x = x.flatten('F')
    
    res = np.full(x.shape, np.nan)
    
    # Mark constrained region
    mask = x < value[1, 0, working, it] # credit constraint between 1st (M_{t+1) = 0) and second point (A_{t+1} = 0)
    
    # Calculate t+1 value function in the constrained region
    res[mask] = util(x[mask], working) + df*value[0, 1, working, it]
    
    # Calculate t+1 value function in non-constrained region
    # interpolate
    res[~mask] = np.interp(x[~mask], value[:, 0, working, it], value[:, 1, working, it])
    # extrapolate
    slope = (value[-2, 1, working, it] - value[-1, 1, working, it])/(value[-2, 0, working, it] - value[-1, 0, working, it])
    intercept = value[-1, 1, working, it] - value[-1, 0, working, it]*slope
    res[res == np.max(value[:, 1, working, it])] = intercept + slope*x[res == np.max(value[:, 1, working, it])]

    return res

In [229]:
# Calculation of probability to choose work, if a worker today
def chpr(x):
    """Calculate the probability of choosing work in t+1
    for state worker given t+1 value functions"""
    
    mx = np.amax(x, axis = 0)
    mxx = x - mx
    res = np.exp(mxx[1, :]/lambda_)/np.sum(np.exp(mxx/lambda_), axis = 0)
    
    return res

In [230]:
# Expected value function calculation in state worker
def logsum(x):
    """Calculate expected value function"""
    
    mx = np.amax(x, axis = 0)
    mxx = x - mx
    res = mx + lambda_*np.log(np.sum(np.exp(mxx/lambda_), axis = 0))
    
    return res

m0 parametrisation - minimal shocks
--------------------------------------------

In [231]:
# Minimal shocks
sigma = 0
lambda_ = 2.2204e-16

In [232]:
# Initialize grids
quadp, quadw = quadrature(expn,0,1)
quadstnorm = scps.norm.ppf(quadp)
savingsgrid = np.linspace(0, mmax, ngridm)

In [233]:
# Initialize containers

# Container for endogenous gridpoints of (beginning-of-period) assets
# and corresponding consumption
policy = np.full((ngridm + 1, 2, 2, Tbar), np.nan)

# Value functions
value = np.full((ngridm + 1, 2, 2, Tbar), np.nan)

In [234]:
# Handling of last period and first elements
# policy
policy[1:, 0, 0, Tbar-1] = savingsgrid
policy[1:, 0, 1, Tbar-1] = savingsgrid
policy[1:, 1, :, Tbar-1] = policy[1:, 0, :, Tbar-1]
policy[0, :, :, :] = 0.00

In [235]:
# value
value[2:, 0, :, Tbar-1] = util(policy[2:, 0, :, Tbar-1], 0)
value[2:, 1, :, Tbar-1] = util(policy[2:, 0, :, Tbar-1], 1)
value[0:2, :, :, Tbar -1] = 0.00
value[0, 0, :, :] = 0.00

In [236]:
# Solve workers problem with EGM for period T-1, T-2 and T-3
# The EGM step already yields the same result as the matlab code for T-1 and T-2
# Difference in result for T-3 => DC step has to be performed after the EGM step
for period in [23, 22, 21]:
    
    for choice in[0, 1]:
        # M_{t+1}
        wk1 = budget(period, savingsgrid, quadstnorm*sigma, choice)
        wk1[wk1 < cfloor] = cfloor

        # Value function
        vl1 = np.full((2, ngridm * expn), np.nan)

        if period + 1 == Tbar - 1:
            vl1[0, :] = util(wk1, 0).flatten('F')
            vl1[1, :] = util(wk1, 1).flatten('F')
        else:
            vl1[1, :] = value_function(1, period + 1, wk1) # value function in t+1 if choice in t+1 is work
            vl1[0, :] = value_function(0, period + 1, wk1) # value function in t+1 if choice in t+1 is retiree

        # Probability of choosing work in t+1
        if choice == 0:
            # Probability of choosing work in t+1
            pr1 = np.full(2500, 0.00)
        else:
            pr1 = chpr(vl1)

        # Next period consumption based on interpolation and extrapolation
        # given grid points and associated consumption
        cons10 = np.interp(wk1, policy[:, 0, 0, period + 1], policy[:, 1, 0, period+1])
        # extrapolate linearly right of max grid point
        slope = (policy[-2, 1, 0, period + 1] - policy[-1, 1, 0, period + 1])/(policy[-2, 0, 0, period + 1] - policy[-1, 0, 0, period + 1])
        intercept = policy[-1, 1, 0, period + 1] - policy[-1, 0, 0, period + 1]*slope
        cons10[cons10 == np.max(policy[:, 1, 0, period+1])] = intercept + slope*wk1[cons10 == np.max(policy[:, 1, 0, period+1])]
        cons10_flat = cons10.flatten('F')

        cons11 = np.interp(wk1, policy[:, 0, 1, period + 1], policy[:, 1, 1, period+1])
        # extrapolate linearly right of max grid point
        slope = (policy[-2, 1, 1, period + 1] - policy[-1, 1, 1, period + 1])/(policy[-2, 0, 1, period + 1] - policy[-1, 0, 1, period + 1])
        intercept = policy[-1, 1, 1, period + 1] - policy[-1, 0, 1, period + 1]*slope
        cons11[cons11 == np.max(policy[:, 1, 1, period+1])] = intercept + slope*wk1[cons11 == np.max(policy[:, 1, 1, period+1])]
        cons11_flat = cons11.flatten('F')

        # Marginal utility of expected consumption next period
        mu1 = pr1*mutil(cons11_flat) + (1 - pr1)*mutil(cons10_flat)

        # Marginal budget
        # Note: Constant for this model formulation (1+r)
        mwk1 = mbudget()

        # RHS of Euler eq., p 337, integrate out error of y
        rhs = np.dot(quadw.T, np.multiply(mu1.reshape(wk1.shape, order = 'F'), mwk1))
        # Current period consumption from Euler equation
        
        cons0 = imutil(df*rhs)
        # Update containers related to consumption
        policy[1:, 1, choice, period] = cons0
        policy[1:, 0, choice, period] = savingsgrid + cons0


        if choice == 1:
            # Calculate continuation value
            ev = np.dot(quadw.T, logsum(vl1).reshape(wk1.shape, order = 'F'))
        else:
            ev = np.dot(quadw.T, vl1[0, :].reshape(wk1.shape, order = 'F'))

        # Update value function related containers
        value[1:, 1, choice, period] = util(cons0, choice) + df*ev
        value[1:, 0, choice, period] = savingsgrid + cons0
        value[0, 1, choice, period] = ev[0]

In [237]:
value[:, 1, 1, 23]

array([0.91252708, 1.42959855, 1.4328254 , 1.43597492, 1.4390499 ,
       1.44205296, 1.44498662, 1.44785329, 1.45065524, 1.45339466,
       1.45607365, 1.45869419, 1.46125819, 1.46376747, 1.46622379,
       1.46862882, 1.47098416, 1.47329135, 1.47555186, 1.4777671 ,
       1.47993843, 1.48206717, 1.48415454, 1.48620178, 1.48821001,
       1.49018037, 1.49211392, 1.49401169, 1.49587468, 1.49770383,
       1.49950008, 1.5012643 , 1.50299736, 1.50470007, 1.50637324,
       1.50801764, 1.509634  , 1.51122304, 1.51278546, 1.51432192,
       1.51583308, 1.51731955, 1.51878194, 1.52022084, 1.52163681,
       1.52303041, 1.52440215, 1.52575256, 1.52708214, 1.52839136,
       1.5296807 , 1.5309506 , 1.53220152, 1.53343386, 1.53464805,
       1.5358445 , 1.53702358, 1.53818567, 1.53933116, 1.54046038,
       1.54157369, 1.54267142, 1.54375391, 1.54482147, 1.5458744 ,
       1.54691302, 1.54793761, 1.54894845, 1.54994583, 1.55093002,
       1.55190127, 1.55285984, 1.55380598, 1.55473994, 1.55566

In [82]:
with open('m0_value.pkl', 'rb') as file : 
    m0_value = pickle.load(file)

with open('m0_policy.pkl', 'rb') as file : 
    m0_policy = pickle.load(file)

In [60]:
# Verify that EGM already yields correct solution for t-1 and t-2 in value
np.testing.assert_almost_equal(m0_value[0:501, 1, 1, 22], value[:, 1, 1, 22])

In [61]:
np.testing.assert_almost_equal(m0_value[0:501, 1, 0, 22], value[:, 1, 0, 22])

In [62]:
# Verify that EGM already yields correct solution for t-1 and t-2 in policy
np.testing.assert_almost_equal(m0_policy[0:501, 1, 1, 22], policy[:, 1, 1, 22])

In [63]:
np.testing.assert_almost_equal(m0_policy[0:501, 1, 0, 22], policy[:, 1, 0, 22])

In [148]:
# Difference in t-3
value[0:100, 1, 0, 21]

array([-2.12269774e+03, -2.75980810e+03, -1.12058718e+03, -4.41607496e+02,
       -2.81369826e+01, -2.16432559e+01, -1.65814677e+01, -1.26072883e+01,
       -1.05387511e+01, -8.78154880e+00, -7.30430272e+00, -6.28485664e+00,
       -5.39026599e+00, -4.61280653e+00, -4.00421184e+00, -3.46156104e+00,
       -2.97982042e+00, -2.57460113e+00, -2.21011855e+00, -1.88137302e+00,
       -1.59178817e+00, -1.33002698e+00, -1.09084792e+00, -8.73615204e-01,
       -6.76243625e-01, -4.94092139e-01, -3.25598043e-01, -1.70843491e-01,
       -2.72854700e-02,  1.07209644e-01,  2.31876542e-01,  3.48076254e-01,
        4.57907344e-01,  5.60531714e-01,  6.56616505e-01,  7.47984688e-01,
        8.33971184e-01,  9.14824022e-01,  9.92010735e-01,  1.06512608e+00,
        1.13416048e+00,  1.20021748e+00,  1.26316811e+00,  1.32284343e+00,
        1.38000500e+00,  1.43472125e+00,  1.48691996e+00,  1.53685962e+00,
        1.58482405e+00,  1.63094007e+00,  1.67493625e+00,  1.71734840e+00,
        1.75839370e+00,  

In [173]:
m0_value[:100, 1, 0, 21]

array([-2.12269774e+03, -2.75980810e+03, -1.12058718e+03, -4.41607496e+02,
       -2.81369826e+01, -2.16432559e+01, -1.65814677e+01, -1.26072883e+01,
       -1.05387511e+01, -8.78154880e+00, -7.30430272e+00, -6.28485664e+00,
       -5.39026599e+00, -4.61280653e+00, -4.00421184e+00, -3.46156104e+00,
       -2.97982042e+00, -2.57460113e+00, -2.21011855e+00, -1.88137302e+00,
       -1.59178817e+00, -1.33002698e+00, -1.09084792e+00, -8.73615204e-01,
       -6.76243625e-01, -4.94092139e-01, -3.25598043e-01, -1.70843491e-01,
       -2.72854700e-02,  1.07209644e-01,  2.31876542e-01,  3.48076254e-01,
        4.57907344e-01,  5.60531714e-01,  6.56616505e-01,  7.47984688e-01,
        8.33971184e-01,  9.14824022e-01,  9.92010735e-01,  1.06512608e+00,
        1.13416048e+00,  1.20021748e+00,  1.26316811e+00,  1.32284343e+00,
        1.38000500e+00,  1.43472125e+00,  1.48691996e+00,  1.53685962e+00,
        1.58482405e+00,  1.63094007e+00,  1.67493625e+00,  1.71734840e+00,
        1.75839370e+00,  

In [22]:
policy[:, 0, 1, period]

array([ 0.        ,  5.52589268,  5.66298055,  5.80006842,  5.93715629,
        6.07424416,  6.21133203,  6.3484199 ,  6.48550777,  6.62259564,
        6.75968351,  6.89677138,  7.03385925,  7.17094712,  7.308035  ,
        7.44512287,  7.58221074,  7.71929861,  7.85638648,  7.99347435,
        8.13056222,  8.26765009,  8.40473796,  8.54182583,  8.6789137 ,
        8.81600157,  8.95308944,  6.36285742,  6.49994529,  6.63703316,
        6.77412103,  6.9112089 ,  7.04829677,  7.18538464,  7.32247251,
        7.45956038,  7.59664825,  7.73373612,  7.87082399,  8.00791186,
        8.14499974,  8.28208761,  8.41917548,  8.55626335,  8.69335122,
        8.83043909,  8.96752696,  9.10461483,  9.2417027 ,  9.37879057,
        9.51587844,  9.65296631,  9.79005418,  9.92714205, 10.06422992,
       10.20131779, 10.33840566, 10.47549353, 10.61258141, 10.74966928,
       10.88675715, 11.02384502, 11.16093289, 11.29802076, 11.43510863,
       11.5721965 , 11.70928437, 11.84637224, 11.98346011, 12.12

Start secondary evelope
-----------------------------

To Do: Secondary envelope as a function and not line-by-line
To Do: Handling of discontinuity in the credit constrained region, retirement_model.m 137-148

In [204]:
def secondary_envelope(obj):
    result = []
    newdots = []
    index_removed = []

    for k in range(obj.shape[0]):
        sect = []
        cur = deepcopy(obj[k])
        # Find discontinutiy
        ii = cur[0][1:]>cur[0][:-1]
        # Substitute for matlab while true loop
        i=1
        while_operator = True
        while while_operator:
            j = np.where([ii[counter] != ii[0] for counter in range(len(ii))])[0]
            if len(j) == 0:
                if i >1:
                    sect += [cur]
                while_operator=False
            else:
                j = min(j)

                sect_container, cur = chop(cur, j, True)
                sect += [sect_container]
                ii = ii[j:]
                i += 1
        # yes we can use np.sort instead of the pre-specified function from the upper envelope notebook
        if len(sect) > 1:
            sect = [np.sort(i) for i in sect]
            result_container, newdots_container = upper_envelope(sect, True, True)
            index_removed_container = diff(obj[k], result_container, 10)
        else:
            result_container = obj[k]
            index_removed_container = np.array([])
            newdots_container = np.stack([np.array([]), np.array([])])
        
        result += [result_container]
        newdots += [newdots_container]
        index_removed += [index_removed_container]
        
    return result, newdots, index_removed


In [205]:
for point in [21]:
    obj = np.stack([value[1:, :, 1, point].T, value[1:, :, 0, point].T])
    r1, new, rem = secondary_envelope(obj)
    # Values are now equal 
    np.testing.assert_almost_equal(m0_value[0:501, 0, 1, point].T[~np.isnan(m0_value[0:501, :, 1, point].T[0])][1:], r1[0][0])
    np.testing.assert_almost_equal(m0_value[0:501, 1, 1, point].T[~np.isnan(m0_value[0:501, :, 1, point].T[0])][1:], r1[0][1])
    np.testing.assert_almost_equal(m0_value[0:501, 0, 0, point].T[~np.isnan(m0_value[0:501, :, 0, point].T[0])][1:], r1[1][0])
    np.testing.assert_almost_equal(m0_value[0:501, 1, 0, point].T[~np.isnan(m0_value[0:501, :, 0, point].T[0])][1:], r1[1][1])

In [219]:
value[:, 0, 1, 21] = np.append(np.append(np.full(18, np.nan),np.array([0.0])) , r1[0][0])
value[:, 1, 1, 21] = np.append(np.append(np.full(18, np.nan),value[0, 1, 1, 21]) , r1[0][1])


ValueError: could not broadcast input array from shape (493) into shape (501)

array([       nan,        nan,        nan,        nan,        nan,
              nan,        nan,        nan,        nan,        nan,
              nan,        nan,        nan,        nan,        nan,
              nan,        nan,        nan, 2.05931703, 2.45149424,
       2.45635048, 2.461171  , 2.46588366, 2.47055756, 2.47519777,
       2.47973316, 2.4842344 , 2.48869924, 2.49307278, 2.49741036,
       2.50170953, 2.50593035, 2.51011269, 2.51034068, 2.51080362,
       2.51983117, 2.5287087 , 2.53743052, 2.5459949 , 2.55442686,
       2.56270043, 2.57083315, 2.57885725, 2.58671531, 2.59444954,
       2.60209492, 2.60956709, 2.61693268, 2.62422575, 2.63133893,
       2.63836114, 2.6453203 , 2.65210264, 2.65880848, 2.66545529,
       2.67193087, 2.67834133, 2.68469649, 2.6908866 , 2.69702015,
       2.70309779, 2.70902657, 2.71490019, 2.72071481, 2.7264027 ,
       2.73203201, 2.73760106, 2.74306257, 2.74846197, 2.75380139,
       2.75904988, 2.76423268, 2.76935705, 2.77440485, 2.77938

In [211]:
# Solve workers problem with EGM for period T-1, T-2 and T-3
# The EGM step already yields the same result as the matlab code for T-1 and T-2
# Difference in result for T-3 => DC step has to be performed after the EGM step
for period in [20]:
    
    for choice in[0, 1]:
        # M_{t+1}
        wk1 = budget(period, savingsgrid, quadstnorm*sigma, choice)
        wk1[wk1 < cfloor] = cfloor

        # Value function
        vl1 = np.full((2, ngridm * expn), np.nan)

        if period + 1 == Tbar - 1:
            vl1[0, :] = util(wk1, 0).flatten('F')
            vl1[1, :] = util(wk1, 1).flatten('F')
        else:
            vl1[1, :] = value_function(1, period + 1, wk1) # value function in t+1 if choice in t+1 is work
            vl1[0, :] = value_function(0, period + 1, wk1) # value function in t+1 if choice in t+1 is retiree

        # Probability of choosing work in t+1
        if choice == 0:
            # Probability of choosing work in t+1
            pr1 = np.full(2500, 0.00)
        else:
            pr1 = chpr(vl1)

        # Next period consumption based on interpolation and extrapolation
        # given grid points and associated consumption
        cons10 = np.interp(wk1, policy[:, 0, 0, period + 1], policy[:, 1, 0, period+1])
        # extrapolate linearly right of max grid point
        slope = (policy[-2, 1, 0, period + 1] - policy[-1, 1, 0, period + 1])/(policy[-2, 0, 0, period + 1] - policy[-1, 0, 0, period + 1])
        intercept = policy[-1, 1, 0, period + 1] - policy[-1, 0, 0, period + 1]*slope
        cons10[cons10 == np.max(policy[:, 1, 0, period+1])] = intercept + slope*wk1[cons10 == np.max(policy[:, 1, 0, period+1])]
        cons10_flat = cons10.flatten('F')

        cons11 = np.interp(wk1, policy[:, 0, 1, period + 1], policy[:, 1, 1, period+1])
        # extrapolate linearly right of max grid point
        slope = (policy[-2, 1, 1, period + 1] - policy[-1, 1, 1, period + 1])/(policy[-2, 0, 1, period + 1] - policy[-1, 0, 1, period + 1])
        intercept = policy[-1, 1, 1, period + 1] - policy[-1, 0, 1, period + 1]*slope
        cons11[cons11 == np.max(policy[:, 1, 1, period+1])] = intercept + slope*wk1[cons11 == np.max(policy[:, 1, 1, period+1])]
        cons11_flat = cons11.flatten('F')

        # Marginal utility of expected consumption next period
        mu1 = pr1*mutil(cons11_flat) + (1 - pr1)*mutil(cons10_flat)

        # Marginal budget
        # Note: Constant for this model formulation (1+r)
        mwk1 = mbudget()

        # RHS of Euler eq., p 337, integrate out error of y
        rhs = np.dot(quadw.T, np.multiply(mu1.reshape(wk1.shape, order = 'F'), mwk1))
        # Current period consumption from Euler equation
        
        cons0 = imutil(df*rhs)
        # Update containers related to consumption
        policy[1:, 1, choice, period] = cons0
        policy[1:, 0, choice, period] = savingsgrid + cons0


        if choice == 1:
            # Calculate continuation value
            ev = np.dot(quadw.T, logsum(vl1).reshape(wk1.shape, order = 'F'))
        else:
            ev = np.dot(quadw.T, vl1[0, :].reshape(wk1.shape, order = 'F'))

        # Update value function related containers
        value[1:, 1, choice, period] = util(cons0, choice) + df*ev
        value[1:, 0, choice, period] = savingsgrid + cons0
        value[0, 1, choice, period] = ev[0]

  if sys.path[0] == '':


In [212]:
for point in [20]:
    obj = np.stack([value[1:, :, 1, point].T, value[1:, :, 0, point].T])
    r1, new, rem = secondary_envelope(obj)


In [218]:
r1[0][0][:10], m0_value[0:501, 0, 1, 21].T[~np.isnan(m0_value[0:501, :, 1, 21].T[0])][1:11]

(array([5.56300853, 5.69155521, 5.82010189, 5.94864857, 6.07719525,
        6.20574193, 6.33428861, 6.4628353 , 6.59138198, 6.71992866]),
 array([5.52589268, 5.66298055, 5.80006842, 5.93715629, 6.07424416,
        6.21133203, 6.3484199 , 6.48550777, 6.62259564, 6.75968351]))

In [164]:
import scipy.interpolate as scin
interpolation = scin.interp1d(r1[0][0], r1[0][1], bounds_error=False, fill_value='extrapolate')
interpolation(0.0)

array(2.25574365)

In [175]:
len(r1[0][0])

482

In [184]:
np.append(np.full(18, np.nan), r1[0][0])

array([        nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,  5.52589268,  5.66298055,
        5.80006842,  5.93715629,  6.07424416,  6.21133203,  6.3484199 ,
        6.48550777,  6.62259564,  6.75968351,  6.89677138,  7.03385925,
        7.17094712,  7.308035  ,  7.31557897,  7.32247251,  7.45956038,
        7.59664825,  7.73373612,  7.87082399,  8.00791186,  8.14499974,
        8.28208761,  8.41917548,  8.55626335,  8.69335122,  8.83043909,
        8.96752696,  9.10461483,  9.2417027 ,  9.37879057,  9.51587844,
        9.65296631,  9.79005418,  9.92714205, 10.06422992, 10.20131779,
       10.33840566, 10.47549353, 10.61258141, 10.74966928, 10.88675715,
       11.02384502, 11.16093289, 11.29802076, 11.43510863, 11.5721965 ,
       11.70928437, 11.84637224, 11.98346011, 12.12054798, 12.25

In [398]:
# Discontinuity in values
obj[0, 0, 24:50]

array([8.81600157, 8.95308944, 6.36285742, 6.49994529, 6.63703316,
       6.77412103, 6.9112089 , 7.04829677, 7.18538464, 7.32247251,
       7.45956038, 7.59664825, 7.73373612, 7.87082399, 8.00791186,
       8.14499974, 8.28208761, 8.41917548, 8.55626335, 8.69335122,
       8.83043909, 8.96752696, 9.10461483, 9.2417027 , 9.37879057,
       9.51587844])

In [117]:
range(len(ii))

range(0, 2)

In [118]:
i = 1
sect = []

In [119]:
len(ii)

2

In [112]:
# Explore wether one can simply use numpy for sorting insted of the translated matlab function
sect[1] = np.sort(sect[1])
sect[1]

array([[6.36285742, 8.95308944],
       [2.44271185, 2.55666561]])

Upper envelope

In [34]:
# aux_function used in upper_evelope below
def aux_function(x, obj1, obj2):
    x = [x]
    value, extr = np.subtract(interpolate(x,obj1), interpolate(x, obj2))
    return value

In [35]:
# interpolate used in upper_evelope below
def interpolate(xx, obj, one=False):    
    if not one:
        interpolation = InterpolatedUnivariateSpline(obj[0], obj[1], k=1)
        container = interpolation(xx)
        extrapolate = [True if (i>max(obj[0])) |(i<min(obj[0])) else False for i in xx]
    else:
        container = []
        extrapolate = []
        
        for poly  in obj:
            interpolation = InterpolatedUnivariateSpline(poly[0], poly[1], k=1)
            container += [interpolation(xx)]
            extrapolate += [np.array([True if (i>max(poly[0])) |(i<min(poly[0])) else False for i in xx])]
    return container, extrapolate

In [37]:
# Perform upper envelope calculation
result_container, newdots_container = upper_envelope(sect, True, True)

In [38]:
# Array of same length as MatLab code
# Same number of points removed
result_container[0].shape

(482,)

In [39]:
# Verify result for x values of value after upper envelope
# Same points removed as in MatLab code
np.testing.assert_array_almost_equal(m0_value[19:501, 1, 0, 21], result_container[0])

In [40]:
# Verify result for y values of value after upper envelope
# Same points removed as in MatLab code
np.testing.assert_array_almost_equal(m0_value[19:501, 0, 0, 21], result_container[1])

In [41]:
# Verify output of result_inter
# One new point added, one intersection, same as MatLab code
newdots_container

array([[7.31557897],
       [2.51034068]])

Finish up secondary envelope

diff(obj, result_container)

!!! Function not working with this input

In [42]:
# Find indexes of missing elements
missing_elements = np.setdiff1d(obj[0], result_container[0])
missing_elements

array([6.36285742, 6.49994529, 6.63703316, 6.77412103, 6.9112089 ,
       7.04829677, 7.18538464, 7.44512287, 7.58221074, 7.71929861,
       7.85638648, 7.99347435, 8.13056222, 8.26765009, 8.40473796,
       8.54182583, 8.6789137 , 8.81600157, 8.95308944])

In [43]:
indexremoved = []

for value in missing_elements:
    indexremoved.append(obj[0].tolist().index(value))

indexremoved = np.array(sort(indexremoved))

In [44]:
# Same indexes of removed points as in MatLab code
indexremoved

array([14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
       31, 32])

Back to retirement_model.m solve_dsegm line 149ff.

In [45]:
len(indexremoved) > 0

True

In [46]:
# All points below
# Note what MatLab function find is doing
# If one simply does "<" in Python, result would be wrong
# Current workaround might not be robust to all cases - find a better wat to pythonise MatLabs find function
j = arange(0, (np.where(policy[:, 0, 1, 21] > newdots_container[0][0]))[0][0])
j

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])

In [47]:
# Points that were not deleted
j_new = np.setdiff1d(j, indexremoved)
j_new

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13])

In [48]:
j = max(j_new)
j

13

In [49]:
# Potentially a better way of setting j exists, such that indexes here are simplified
new_left = np.interp(newdots_container[0][0], policy[j+1:j+3, 0, 1, 21], policy[j+1:j+3, 1, 1, 21])
new_left

6.007459709059365

In [50]:
# Perform similar operation for the upper/right side
# All comments from above apply here
j = np.arange(np.where(policy[:, 0, 1, 21] < newdots_container[0][0])[0][-1], policy.shape[0]+1)
j

array([ 33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,
        46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,
        59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
        72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
        85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
        98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
       111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
       124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
       137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149,
       150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162,
       163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
       176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188,
       189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201,
       202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 21

In [51]:
j_new = np.setdiff1d(j, indexremoved)
j_new

array([ 33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,
        46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,
        59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
        72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
        85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
        98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
       111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
       124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
       137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149,
       150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162,
       163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
       176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188,
       189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201,
       202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 21

In [52]:
j = min(j_new)
j

33

In [53]:
new_right = np.interp(newdots_container[0][0], policy[j:j+2, 0, 1, period], policy[j:j+2, 1, 1, period])
new_right

4.014004377221639

In [62]:
# Remove inferior points from policy
# Means: Remove all points with indexes in indexremoved
policy_thinout_x = policy[:, 0, 1, period].tolist()
policy_thinout_y = policy[:, 1, 1, period].tolist()

In [64]:
del policy_thinout_x[indexremoved[0]+1 : indexremoved[-1]+2]
del policy_thinout_y[indexremoved[0]+1 : indexremoved[-1]+2]

In [65]:
# Add new point twice
policy_thinout_x.append(newdots_container[0][0] - 1e3*2.2204e-16)
policy_thinout_x.append(newdots_container[0][0])

In [66]:
# Add new point twice
policy_thinout_y.append(new_left)
policy_thinout_y.append(new_right)

In [67]:
len(policy_thinout_x)

484

In [69]:
new_policy = np.full((2, ngridm + 1), np.nan)
new_policy[0, - len(policy_thinout_x):] = np.array(policy_thinout_x)
new_policy[1, - len(policy_thinout_y):] = np.array(policy_thinout_y)

In [70]:
# Ensure that points are in the right position
new_policy = np.sort(new_policy)

In [75]:
# Verify correct solution for t-3 in policy
np.testing.assert_almost_equal(m0_policy[17:501, 1, 0, 21], new_policy[0, 0:484])

In [76]:
# Finish up
policy[:, :, 1, period] = new_policy.T