In [None]:
import collections
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
env = gym.make('Blackjack-v1', sab=True, ) # render_mode="human")

In [None]:
def play_env(env, agent):
    terminated = False
    observation, info = env.reset()

    while not terminated:
        action = agent.action(observation)

        new_observation, reward, terminated, truncated, info = env.step(action)

        agent.observe(observation, reward)

        observation = new_observation
    
    agent.estimating()

In [None]:
class MCFirstVisit():

    def __init__(self, gamma, policy):
        self.gamma  = gamma
        self.policy = policy

        self.state_value = collections.defaultdict(lambda: 0)
        self.returns = collections.defaultdict(lambda: [])

        self.states = []
        self.rewards = []

    def action(self, state):
        return self.policy(state)
    
    def observe(self, state, reward):
        self.states.append(state)
        self.rewards.append(reward)
    
    def estimating(self):
        g = self.rewards[-1]

        for t in range(len(self.states) - 2, -1, -1):
            g = self.gamma * g + self.rewards[t]

            self.returns[self.states[t]].append(g)
            self.state_value[self.states[t]] = sum(self.returns[self.states[t]]) / len(self.returns[self.states[t]])
        
        self.states = []
        self.rewards = []


In [None]:
# Create a random policy
def random_policy(state):
    return np.random.randint(low=0, high=1, size=(1))[0]

def stick_policy(state):
    player_score = state[0]
    if player_score in [20, 21]:
        return 0
    else:
        return 1 

agent = MCFirstVisit(gamma=1, policy=stick_policy)

play_env(env, agent)

In [None]:
agent.state_value

In [None]:
for i in range(500_000):
    play_env(env, agent)

In [None]:
len(agent.state_value)

In [None]:
agent.state_value

In [None]:
Z = np.zeros(shape=(24, 12)) * np.nan

for k in agent.state_value.keys():
    Z[k[0]][k[1]] = agent.state_value[k]

X, Y = np.meshgrid(np.arange(Z.shape[1]), np.arange(Z.shape[0]))
print(X)
print(Z.shape)

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

ax.plot_wireframe(X, Y, Z)

plt.show()