In [None]:
import gym
import numpy as np
import random
from collections import defaultdict
from tqdm import tqdm # type: ignore

In [2]:
env = gym.make("Blackjack-v1")

In [19]:
EPISODES = 500_0
GAMMA = 1.0
EPSILON = 0.1
ACTION_SPACE = [0, 1]

In [27]:
Q = {}
returns = {}

In [28]:
def epsilon_greedy_policy(state, epsilon=EPSILON):
    if random.random() < epsilon:
        return random.choice(ACTION_SPACE)
    else:
        return np.argmax(Q[state])

In [29]:
for episode in tqdm(range(EPISODES)):
    state = env.reset()[0]
    done = False
    episode_memory = []
    
    while not done:
        if state not in Q:
            Q[state] = np.zeros(len(ACTION_SPACE))
            returns[state] = {a: [] for a in ACTION_SPACE}
        action = epsilon_greedy_policy(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_memory.append((state, action, reward))
        state = next_state
        
    G = 0
    visited = set()
    for t in reversed(range(len(episode_memory))):
        s, a, r = episode_memory[t]
        G = GAMMA * G + r
        if (s, a) not in visited:
            visited.add((s, a))
            returns[s][a].append(G)
            Q[s][a] = np.mean(returns[s][a])

100%|██████████| 5000/5000 [00:00<00:00, 11279.55it/s]


In [30]:
print(Q)
print(returns)

{(17, 10, True): array([-0.52941176, -0.71428571]), (16, 6, False): array([-0.11111111, -1.        ]), (5, 1, False): array([-1., -1.]), (17, 5, False): array([ 0.        , -0.66666667]), (12, 8, False): array([-1.        , -0.36111111]), (13, 10, True): array([-0.58333333, -1.        ]), (19, 1, False): array([-0.1875, -1.    ]), (13, 2, False): array([-0.25      , -0.71428571]), (18, 10, False): array([-0.3253012 , -0.75675676]), (14, 8, False): array([-0.40540541, -1.        ]), (15, 10, False): array([-0.5862069 , -0.55284553]), (20, 10, False): array([ 0.42487047, -1.        ]), (12, 10, False): array([-0.62162162, -0.53465347]), (7, 10, False): array([-0.5       , -0.66666667]), (16, 8, False): array([-0.72972973, -1.        ]), (19, 2, False): array([ 0.36363636, -1.        ]), (15, 1, False): array([-0.88571429, -1.        ]), (11, 4, False): array([-0.33333333,  0.8       ]), (20, 9, False): array([ 0.86486486, -1.        ]), (12, 6, False): array([-0.28205128, -1.        ]), 