In [None]:
import gym
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [None]:
env = gym.make('Blackjack-v0')

In [None]:
n_episodes = 500000
Q = {((a, b, c), d):(0.,0.) for a in range(12, 22) for b in range(1, 11) for c in (True, False) for d in range(2)}
Pi = {(a, b, c):0 for a in range(12, 22) for b in range(1, 11) for c in (True, False)}

def do_act(state):
    return Pi[state]

for _ in range(n_episodes):
    exp_state = env.reset()   # player's sum, dealer's showing card, usable ace
    exp_action = np.random.choice(range(2))
    while exp_state[0] < 12:
        exp_state = env.reset()
        
    temp_sa = [(exp_state, exp_action)]
    s, r, done, info = env.step(exp_action)
    temp_history = [(exp_state, exp_action, r)]
    while not done:
        a = do_act(s)
        temp_sa.append((s, a))
        ns, r, done, info = env.step(a)
        temp_history.append((s, a, r))
        s = ns
        
    G = 0
    for i, ele in enumerate(reversed(temp_history)):
        s, a, r = ele
        G = G + r
        if (s, a) not in list(reversed(temp_sa))[i+1:]:
            count, incr = Q[s, a]
            count += 1
            incr = incr + (1./(count)*(G-incr))
            Q[s, a] = (count, incr)
            Pi[s] = np.argmax([Q[s, action][1] for action in (0, 1)])

In [None]:
for i in range(21, 11, -1):
    tmp = [i]
    for j in range(1, 11):
        x = 'S' if Pi[i, j, True]==0 else 'H'
        tmp.append(x)
    print('\t'.join([str(v) for v in tmp]))
print('\t'+'\t'.join([str(n) for n in range(1, 11)]))

In [None]:
data = []
data2 = []
for i in range(21, 11, -1):
    tmp = []
    tmp2 = []
    for j in range(1, 11):
        tmp.append(Pi[i, j, True]) # usable ace
        tmp2.append(Pi[i, j, False]) # no usable ace
    data.append(tmp)
    data2.append(tmp2)
        
fig, ax = plt.subplots()
ax.set_yticks(np.arange(len(data)))
ax.set_xticks(np.arange(len(data[0])))
ax.set_yticklabels(np.arange(21, 11, -1))
ax.set_xticklabels(np.arange(1, 11))
ax.set_xlabel('Dealer showing')
ax.set_ylabel('Player sum')
ax.set_title('Usable ace')
ax.imshow(data, cmap='gray')

fig, ax = plt.subplots()
ax.set_yticks(np.arange(len(data2)))
ax.set_xticks(np.arange(len(data2[0])))
ax.set_yticklabels(np.arange(21, 11, -1))
ax.set_xticklabels(np.arange(1, 11))
ax.set_xlabel('Dealer showing')
ax.set_ylabel('Player sum')
ax.set_title('No usable ace')
ax.imshow(data2, cmap='gray')

In [None]:
%matplotlib notebook

In [None]:
X, Y = np.meshgrid(range(12, 22), range(1, 11))
Z = np.zeros(X.shape)

for i in range(10):
    for j in range(10):
        state = X[i,j], Y[i,j], False   # if usable_ace: X[i,j], Y[i,j], True
        opt_act = Pi[state]
        Z[i,j] = Q[state, opt_act][1]   

fig = plt.figure()
ax = Axes3D(fig)
ax.set_zlim(-1, 1)
ax.invert_yaxis()
ax.plot_wireframe(X, Y, Z)
ax.set_ylabel('Dealer showing')
ax.set_xlabel('Player sum')