In [6]:
import numpy as np
import scipy.misc
import scipy.io
import matplotlib.pyplot as plt
%matplotlib inline
from functools import lru_cache

In [7]:
import generate_state_matrices
def states(*args, **kwargs):
    S, T, R = generate_state_matrices.states(*args, **kwargs)
    T = np.rollaxis(T, 2, 0)
    return S, T, R

import mdptoolbox, mdptoolbox.example
def backward_induction(T, R, N):
    fh = mdptoolbox.mdp.FiniteHorizon(T, R, 1, N=N)
    fh.run()
    return fh.V[:, 0]

def s_idx(s):
    """The index of a row in S"""
    return int(sum(range(int(sum(s)-1))) + s[1] - 1)

from joblib import Memory
cache = Memory(cachedir='.joblib_cache', verbose=0).cache

In [39]:
COST = 0.05
MAX_STEPS = int(1/(4*COST)) + 3
N_ARM = 2
S, T, R = states(MAX_STEPS, N_ARM, cost=COST)

# The one-bandit case is used to compute blinkered Q values
S1, T1, R1 = states(MAX_STEPS, 1, cost=COST)
R1_alt = R1.copy()  # we change this for each new V1

@cache
def V1(constant):
    # betting chooses the constant if it's better than the bandit's expected reward
    R1_alt[:, -1] = np.maximum(constant, R1[:, -1])
    R1_alt[-1, -1] = 0
    return backward_induction(T1, R1_alt, N=MAX_STEPS)

@lru_cache(None)
def Q1(s, a, constant):
    if a == 1:  # bet
        return max(R1[s,1], constant)
    else:  # observe
        V = V1(constant)
        return T1[a, s] @ V - COST

@lru_cache(None)
def expected_values(s):
    hits = S[s][0:None:2]
    misses = S[s][1:None:2]
    return hits/(misses+hits)

def Q_blinker(s, a):
    if a == N_ARM:
        return max(expected_values(s))

    # alternative is selecting current best
    mu = expected_values(s)
    mu[a] = -np.inf
    alternative = mu.max()
    
    idx = a * 2
    s_arm = s_idx(S[s][[idx, idx+1]])
    return Q1(s_arm, 0, alternative)


Q_blinker(0, 0)

0.53333333333333321

In [193]:
s = states(4,2,cost=0.01)
a = {'states':s[0],'transition':s[1],'rewards':s[2]}
scipy.io.savemat('./file',a)