In [8]:
import numpy as np
import scipy.misc
import scipy.io
import matplotlib.pyplot as plt
%matplotlib inline

In [9]:
def n_bib(balls,bins):
    return scipy.misc.comb(balls+bins-1,balls)

In [10]:
def n_st(n_balls,n_arms): #excluding terminal state
    return int(np.sum([n_bib(i,n_arms*2) for i in range(n_balls+1)]))

In [11]:
def check(T,b):
    a = np.where(np.all(T==b,axis=1))
    if len(a[0]) > 0:
        return a[0][0]
    else:
        return -1

In [12]:
n_arms = 2
n_balls = 2
S = np.ones((1,n_arms*2)) # All possible states
balls = 0
ipb = 0 # index of previous basis
count_added = 0 # number of states added with the current number of balls
while balls < n_balls:
    for j in range(count_added+1): # use each of the states added with balls-1 as bases
        for i in range(n_arms*2): # Distribute the ball to each bin
            new = S[ipb].copy()
            new[i] += 1 
            S = np.vstack((S,new))
            count_added += 1
        count_added -= 1
        ipb += 1
    balls += 1

In [120]:
def states(n_balls,n_arms,rewardCorrect=1.0,cost=-1, constant=0):
    
    if cost == -1: # Set cost to horizon-bounding value by default
        cost = rewardCorrect/n_balls
        
    S = np.ones((1,n_arms*2)) # Matrix of all possible states

    n_states = n_st(n_balls,n_arms) # Number of all possible states
    T1 = np.zeros((n_states+1,n_states+1,n_arms+1)) # Transition tensor

    balls = 0 # Counter for which observation is happening
    ipb = 0 # Index of previous basis
    count_added = 0 # Number of states added with the current number of balls
    state_count = 0

    while balls < n_balls:
        for j in range(count_added+1): # Use each of the states added with balls-1 observations as bases
            for i in range(n_arms*2): # Distribute the ball to each bin
                new = S[ipb].copy() # Copy the previous basis

                # Get the probability of making this observation
                if i%2 == 0: #ON or OFF observation
                    p = new[i]/(new[i] + new[i+1])
                else:
                    p = new[i]/(new[i] + new[i-1])

                # Prepare and add the new state generated by the observation
                new[i] += 1 
                k = check(S,new)

                if k == -1: #If it isn't already added
                    S = np.vstack((S,new))
                    count_added += 1
                    state_count += 1
                    T1[ipb,state_count,i//2] = p 
                else:
                    T1[ipb,k,i//2] = p
            count_added -= 1 # Remove the added one to balance after the initial case
            ipb += 1 # Move down the previously added states at observation balls-1
        balls += 1   

    S = np.vstack((S,-np.ones((1,n_arms*2)))) #Add the terminal state

    for i in range(n_states):
        if np.sum(S[i]) >= 2*n_arms+n_balls:
            T1[i,-1] = np.ones(n_arms+1)

    # The terminal state always goes back to itself
    t = np.zeros((n_states+1,n_arms+1))
    t[-1,:] = 1
    T1[-1] = t

    T2 = np.zeros((n_states+1,n_states+1,1))
    T2[:,-1] = 1

    T = np.concatenate((T1,T2),axis=2)

    R = -cost*np.ones((n_states+1,n_arms+2))
    p = np.max([[s[2*j]/(s[2*j+1]+s[2*j]) for j in range(n_arms)] for s in S],1)
    R[:,-1] = rewardCorrect*p
    R[:,-2] = constant
    R[-1,:] = 0
    
    T = np.rollaxis(T, 2, 0)    
    return S,T,R

In [139]:
T.shape

(5, 925, 925)

In [121]:
S,T,R = states(2, 2, cost=0.01)
print(*(x.shape for x in (S, T, R)), sep='\n')

(16, 4)
(4, 16, 16)
(16, 4)


In [97]:
print(*(x.shape for x in (S1, T1, R1)), sep='\n')

(925, 6)
(4, 925, 925)
(925, 4)


In [66]:
import mdptoolbox, mdptoolbox.example

def backward_induction(T, R, N):
    fh = mdptoolbox.mdp.FiniteHorizon(T, R, 1, N=N)
    fh.run()
    return fh.V[:, 0]


In [136]:
S,T,R = states(6, 3, cost=0.01)
backward_induction(T1, R1, 10)[0]



0.63861111111111102

In [15]:
from joblib import Memory
cache = Memory(cachedir='.joblib_cache', verbose=0).cache

In [17]:
def expected_values(s):
    hits = s[0:None:2]
    misses = s[1:None:2]
    return hits/(misses+hits)

In [99]:
s,t,r = states(100, 2)
# backward_inductiontop(t,r,100)

MemoryError: 

In [92]:
COST = 0.05
MAX_STEPS = int(1/(4*COST)) + 3
S1, T1, R1 = states(MAX_STEPS, 1, cost=COST)
R1_alt = R1.copy()

def s_idx(s):
    return int(sum(range(int(sum(s)-1))) + s[1] - 1)

@cache
def V1(constant):
    R1_alt[:, -1] = np.maximum(constant, R1[:, -1])
    return backward_induction(T1, R1_alt, N=MAX_STEPS)

def Q1(s, a, constant):
    if a == 1:  # bet
        return max(s[0] / sum(s), constant)
    else:  # observe
        V = V1(constant)
        print(V.max())
        return T1[a, s_idx(s)] @ V - COST
    
Q1([1,1], 0, 0.5)
# T1[0, 0]

4.4


4.0333333333333341

In [78]:
N_ARM = 2
S, T, R = states(MAX_STEPS, N_ARM, cost=COST)

def Q_blinker(s, a):
    if a == N_ARM:
        return max(expected_values(S[s]))

    # alternative is selecting current best
    mu = expected_values(S[s])
    mu[a] = -np.inf
    alternative = mu.max()
    idx = a * 2
    s_arm = to_int_s(S[s][[idx, idx+1]])
    return Q1(s_arm, 0, alternative)


Q_blinker(6,2)

0.5

In [75]:
S[6]

array([ 2.,  2.,  1.,  1.])

In [50]:
S

array([[ 1.,  1.,  1.,  1.],
       [ 2.,  1.,  1.,  1.],
       [ 1.,  2.,  1.,  1.],
       ..., 
       [ 1.,  1.,  2.,  8.],
       [ 1.,  1.,  1.,  9.],
       [-1., -1., -1., -1.]])

In [28]:
s, a = 0, 0
idx = a * 2
s_arm = to_int_s(S[s][[idx, idx+1]])

S[s][[idx, idx+1]]
# mu = expected_values(S[s])
# mu[s] = -np.inf
# alternative = mu.max()


array([ 1.,  1.])

In [None]:
def Q1(s, a, constant=0):
    if a == 1:  # bet on the one arm
        return max(s[0] / s.sum(), constant)
    else:  # observe
        alternative = s[1] / s.sum()  
        V = V1(alternative)
        return T[a, s] @ V - COST


In [287]:
def Q(s, a):
    return R[s, a] + T[a, s] @ V

In [None]:
def Q_one()

In [205]:


def Q(s, a):
    
    
    
# print('recursion', V(2))
print('induction', backward_induction(T, R))

induction 0.588333333333


In [None]:
from functools import lru_cache

n_actions = len(R[1])
actions = range(n_actions)
final_state = len(T) - 1

@lru_cache(None)
def V(s):
    if s == final_state:
        return 0
    return max(Q(s, a) for a in actions)

@lru_cache(None)
def Q(s, a):
    s1_options = T[s, :, a].nonzero()[0]
    s1_p = T[s, s1_options, a]
    return R[s, a] + sum(p * V(s1) for s1, p in zip(s1_options, s1_p))

In [193]:
s = states(4,2,cost=0.01)
a = {'states':s[0],'transition':s[1],'rewards':s[2]}
scipy.io.savemat('./file',a)