In [None]:
import numpy as np
from typing import List, Tuple
from dataclasses import dataclass
import matplotlib.pyplot as plt

In [None]:
@dataclass
class Environment:
  values: np.ndarray

  money: int
  cars: List[int]

  pay: int
  move_cost: int
  max_cars: int
  max_move: int

  expected_rent: List[int]
  expected_return: List[int]

  def get_rental_requests(self) -> List[int]:
    return [np.random.poisson(lam) for lam in self.expected_rent]

In [None]:
@dataclass
class Agent:
  policy: np.ndarray
  discont: float

  

In [1]:
num_branches = 2
max_cars = 20

initial_values = np.zeros([max_cars for _ in range(num_branches)])
initial_cars = [0 for _ in range(num_branches)]

expected_rent = [3, 4]
expected_return = [3, 2]

env = Environment(values=initial_values, money=0, cars=initial_cars, pay=10, move_cost=2, max_cars=max_cars, max_move=5, expected_rent=expected_rent, expected_return=expected_return)
agent = Agent(policy=np.zeros([max_cars for _ in range(num_branches)]), discont=0.9)

delta = 0
theta = 1e-6

# Policy evaluation
while delta > theta:
  delta = 0
  for i in range(max_cars):
    for j in range(max_cars):
      v = env.values[i, j]
      # i -> j
      env.values[i, j] = 0
      for k in range(-env.max_move, env.max_move + 1):
        if i - k < 0 or j + k < 0:
          continue
        env.values[i, j] += 1 / (2 * env.max_move + 1) * (env.pay * min(i, env.expected_rent[0]) + env.pay * min(j, env.expected_rent[1]) - env.move_cost * abs(k) + agent.discont * env.values[i - k, j + k])
      delta = max(delta, abs(v - env.values[i, j]))

# Policy improvement
policy_stable = True
for i in range(max_cars):
  for j in range(max_cars):
    old_action = agent.policy[i, j]
    action_values = []
    for k in range(-env.max_move, env.max_move + 1):
      if i - k < 0 or j + k < 0:
        action_values.append(-np.inf)
        continue
      action_values.append(env.pay * min(i, env.expected_rent[0]) + env.pay * min(j, env.expected_rent[1]) - env.move_cost * abs(k) + agent.discont * env.values[i - k, j + k])
    agent.policy[i, j] = np.argmax(action_values)
    if old_action != agent.policy[i, j]:
      policy_stable = False

print(agent.policy)

NameError: name 'np' is not defined