In [349]:
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [350]:
MAX_CAPITAL = 100
MIN_CAPITAL = 0
θ = .0001 # Convergence threshold

In [351]:
def initialisation():
  S = np.arange(1, 100) # {1, 2, ..., 99}
  V = np.random.rand(101)
  V[MIN_CAPITAL] = 0
  V[MAX_CAPITAL] = 1
  return S, V

In [352]:
def value_iteration(S, V, θ, p_h):
  while True:
    Δ = 0
    for s in S:
      A = np.arange(1, min(s, 100 - s)+1) # Possibles actions
      v = V[s]
      V[s] = max(expected_rewards(V, A, s, p_h))
      Δ = max(Δ, abs(v - V[s])) # Difference between old and new policy for state s
    if Δ < θ: # Check for convergence
      break
  return optimal_π(V, A, p_h)

In [353]:
def expected_rewards(V, A, s, p_h):
  rewards = []
  """
    there is two possibles transitions for 'a' fixed
    - head : with probability p_h
    - tail : with probability 1 - p_h
  """
  for a in A:
    if s + a >= MAX_CAPITAL:
      rewards.append(p_h + (1 - p_h) * V[s-a])
    elif s - a <= MIN_CAPITAL:
      rewards.append(p_h * V[s+a])
    else:
      rewards.append(p_h * V[s+a] + (1 - p_h) * V[s-a])
  return rewards

In [354]:
def optimal_π(V, A, p_h):
  π = np.zeros(99)
  for s in range(1, len(π) + 1):
    A = np.arange(1, min(s, 100 - s)+1)
    # Action where the expected reward is optimal
    π[s - 1] = A[np.argmax(expected_rewards(V, A, s, p_h))]
  return π

In [357]:
def run(p_h):
  S, V = initialisation()
  π = value_iteration(S, V, θ, p_h)

  # plot the result
  fig = make_subplots(rows=1, cols=2)

  fig.add_trace(
      go.Scatter(x = np.arange(0,101), y = V, name = 'Value Function'),
      row = 1, col = 1
  )

  fig.add_trace(
      go.Bar(x = np.arange(1, 100), y = π, name = 'Policy'),
      row = 1, col = 2
  )

  fig.update_xaxes(title_text='Capital', row=1, col=1)
  fig.update_xaxes(title_text='Capital', row=1, col=2)
  fig.update_yaxes(title_text='Probability', row=1, col=1)
  fig.update_yaxes(title_text='Stake', row=1, col=2)
  fig.update_layout(height=450, width=1200, title_text= ('Pₕ = ' + str(p_h)))
  fig.show()

In [358]:
run(.25) # Pₕ = 0.25
run(.55) # Pₕ = 0.55