In [3]:
import gym
import itertools
import gym_tic_tac_toe
import plotting
from plotting import EpisodeStats
from collections import defaultdict 
import numpy as np 
# from windy_gridworld import WindyGridworldEnv 

In [55]:
env = gym.make('tic_tac_toe-v1')

In [56]:
def hash_state(state):
    board = state['board']
    move = state['on_move']
    return ''.join(str(b) for b in board)+str(move)
    
def hash_action(action):
    return action[1]

In [79]:
def create_policy(Q, epsilon):

    def get_action_probs(state, num_actions): 
        state_hash = hash_state(state)
        action_probs = np.ones(num_actions, 
                dtype = float) * epsilon / num_actions 
                  
        best_action = np.argmax(Q[state_hash]) 
        action_probs[best_action] += (1.0 - epsilon) 
        return action_probs 
   
    return get_action_probs 

In [80]:
def q_learn(env, num_episodes, discount_factor = 1.0, alpha = 0.6, epsilon = 0.1): 
   
    Q = {} 
    
    stats = plotting.EpisodeStats( 
        episode_lengths = np.zeros(num_episodes), 
        episode_rewards = np.zeros(num_episodes))
    
    policy = create_policy(Q, epsilon) 
       
    for ith_episode in range(num_episodes): 
           
        state = env.reset() 
        for t in itertools.count(): 
               
            moves = env.move_generator()
            action_probabilities = policy(state, len(moves)) 
            action_idx = np.random.choice(np.arange( 
                      len(action_probabilities)), 
                       p = action_probabilities)
            
            action = moves[action_idx]
               
            next_state, reward, done, _ = env.step(action) 
   
            stats.episode_rewards[ith_episode] += reward 
            stats.episode_lengths[ith_episode] = t 
            
            nsh = hash_state(next_state)
            bnah = np.argmax(Q[nsh])
            sh = hash_state(state)
            ah = hash_action(action)
            td_target = reward + discount_factor * Q[nsh][bnah] 
            td_delta = td_target - Q[sh][ah] 
            Q[sh][ah] += alpha * td_delta 
   
            if done: 
                break
                   
            state = next_state 
       
    return Q, stats 

In [81]:
print(q_learn(env, 1000))

KeyError: '0000000001'