In [1]:
%matplotlib inline
import sys
from collections import defaultdict

if "../" not in sys.path:
    sys.path.append("../")
    
import numpy as np
from lib.envs.blackjack import BlackjackEnv
from lib.utils.draw import show_value_function
import matplotlib

matplotlib.style.use('ggplot')
env = BlackjackEnv()

In [2]:
def sample_policy(observation):
    score, dealer_score, usable_ace = observation
    return 0 if score >= 20 else 1

In [3]:
def mc_prediction(policy, env, num_episodes, gamma=1.0):
    # defaultdict(float)
    # float -> 确实的键自动赋值 0.0
    # 一种字典子类，它允许指定一个默认的工厂函数来为字典的缺失键提供默认值。
    returns_num = defaultdict(float)
    returns_count = defaultdict(float)
    
    V = defaultdict(float)
    
    for i_episode in range(num_episodes):
        
        if i_episode % 1000 == 0:
            # \r 实现动态更新
            print(f"\rEpisode :{i_episode}/{num_episodes}",end = "")
            sys.stdout.flush()
        
        stack_state = []
        stack_reward = []
        state = env.reset()
        #
        while True:
            action = policy(state)
            next_state, reward, done, _ = env.step(action)
            stack_state.append(state)
            stack_reward.append(reward)
            if done:
                break
            state = next_state
        
        size = len(stack_state)
        G = 0
        for i in range(size-1,-1,-1):
            key = tuple(stack_state[i])
            G = gamma * G + stack_reward[i]
            
            if stack_state[i] not in stack_state[0:i]:
                returns_num[key] += G
                returns_count[key] += 1.0
                V[key] = returns_num[key] / returns_count[key]
        
    return V
        
        

In [4]:
V_10k = mc_prediction(sample_policy, env, num_episodes=10000)

show_value_function(V_10k, title="10,000 Steps")

In [5]:
test = [1,0,1,1,1,0,1]

idx = next( i for i,x in enumerate(test) if x == 1 )
print(idx)

list1 = [1, 2, 3, 4, 5]
list2 = ['a', 'b', 'c', 'd', 'e']


