In [1]:
import numpy as np
rng = np.random.default_rng(42)

In [22]:
class State():
  def __init__(self, p_g, reward, quiter):
    self.p_g = p_g
    self.p_b = 1 - self.p_g
    self.reward = reward
    self.quiter = quiter
    assert self.p_g + self.p_b == 1
  
  def get_succes(self):
    res = rng.random()
    if res <= self.p_g:
      return True
    else:
      return False

  def get_prob(self):
    return self.p_g
    
  def get_reward(self):
    return self.reward
  
  def get_quit(self):
    return self.quiter

In [48]:
class MC():
  def __init__(self, states):
    self.chain = states

  def calc_reward(self, choices):
    sum = 0
    for i in range(len(choices)):
      if choices[i] == False:
        sum += self.chain[i].get_quit()
        break
      else:
        if self.chain[i].get_succes():
          sum += self.chain[i].get_reward()
        else:
          break
    return sum

In [49]:
prob = [0.9, 0.75, 0.5, 0.1]
rewards = [100, 1000, 10000, 50000]
quiter = [0, 100, 1100, 11100]
states = []
for i in range(len(prob)):
  states.append(State(prob[i], rewards[i], quiter[i]))

In [65]:
q = 0.5
choices = np.random.choice(a=[True, False], size=len(prob), p=[q, 1-q]) 

In [68]:
chain = MC(states)

Test playability of chain

In [69]:
choices

array([ True,  True, False,  True])

In [70]:
chain.calc_reward(choices)

100

Now we want to calculate the q values for all potential policies. Therefore we will calculate a matrix for every state-action pair and update the q values using the known function based on their expected value.

In [15]:
options = [True, False]
Q = np.zeros((len(rewards)-1, len(options))) # Amount of states by play or quit

In [71]:
ite = 3
Q[0, 0] = states[0].get_reward()
g = 0.99
for n in range(ite):
  Qn = Q.copy()
  for i in range(len(states)-2):
    p = states[i+1].get_prob()
    r0 = states[i].get_reward()
    r1 = states[i].get_quit() 
    qp = max(Qn[i+1])
    Qn[i] = p * (r0+g * qp) + (1-p) * (r1+g*qp)
  print(Qn)

[[ 1065.  1065.]
 [10450. 10450.]
 [10000.  1100.]
 [50000. 11100.]]
[[ 1065.  1065.]
 [10450. 10450.]
 [10000.  1100.]
 [50000. 11100.]]
[[ 1065.  1065.]
 [10450. 10450.]
 [10000.  1100.]
 [50000. 11100.]]


In [None]:
# Calculate q values
for n in range(3):
  Qn = Q.copy()
  Qn[0, 0] = rewards[0] # prob is 1
  for i in range(1,len(Q)):
    Qn[i,0] = states[i].get_reward()
    Qn[i,1] = states[i].get_quit()
  print(Q)