In [1]:
import gym
import itertools
import gym_tic_tac_toe
import plotting
from plotting import EpisodeStats
from collections import defaultdict 
from copy import deepcopy
import numpy as np 
import operator

In [4]:
# Example of Q dictionary used in q_learn function
#
# Q = {
#     "0000000001": {
#         0: 0,
#         1: 0,
#         2: 0,
#         # ...
#         8: 0
#     },
#     # ...
#     "1-11-1-11-11-10-1": {
#         8: 0
#     },
# }

# ile zajmuje wyuczenie sie
# procent klasyfikacji, dokładność, 
# jak zależy od rozmiaru zbioru uczacego

In [5]:
def hash_state(state):
    board = state['board']
    move = state['on_move']
    return ''.join(str(b) for b in board)+str(move)
    
def hash_action(action):
    return action[1]

In [6]:
def get_best_action_idx(Q, state_hash, action_hashes):
    if state_hash in Q:
        best_action_idx = np.argmax(Q[state_hash]) 
    else:
        Q[state_hash] = dict((ah, 0) for ah in action_hashes)
        best_action_idx = np.random.choice(len(action_hashes))
        
    return best_action_idx

In [7]:
def create_policy(Q, epsilon):

    def get_action_probs(state_hash, action_hashes):
        num_actions = len(action_hashes)
        action_probs = np.ones(num_actions, dtype = float) * epsilon / num_actions 
        best_action_idx = get_best_action_idx(Q, state_hash, action_hashes)
        action_probs[best_action_idx] += (1.0 - epsilon)
        return action_probs
   
    return get_action_probs 

In [17]:
def q_learn(num_episodes, discount_factor = 1.0, alpha = 0.6, epsilon = 0.1, print_log = False):  
    Q = {}
    env = gym.make('tic_tac_toe-v1')
    
    stats = plotting.EpisodeStats( 
        episode_lengths = np.zeros(num_episodes), 
        episode_rewards = np.zeros(num_episodes))
    
    policy = create_policy(Q, epsilon) 
       
    for ith_episode in range(num_episodes): 
        if (ith_episode%1000==0):
            print(ith_episode)
        env.reset()
        state = deepcopy(env.state)
        for t in itertools.count(): 
               
            state_hash = hash_state(state)
            actions = env.move_generator()
            action_hashes = [hash_action(act) for act in actions]
            action_probabilities = policy(state_hash, action_hashes)
            
            action_idx = np.random.choice(np.arange(len(actions)), p=action_probabilities)
            action = actions[action_idx]
            action_hash = hash_action(action)
                           
            next_state, reward, done, _ = env.step(action)
   
            stats.episode_rewards[ith_episode] += reward 
            stats.episode_lengths[ith_episode] = t 
            
            next_state_hash = hash_state(next_state)
            next_action_hashes = [hash_action(act) for act in env.move_generator()]
            
            if len(next_action_hashes) == 0:
                next_action_score = 0
            else:   
                best_next_action_idx = get_best_action_idx(Q, next_state_hash, next_action_hashes)
                best_next_action_hash = next_action_hashes[best_next_action_idx]
                next_max = Q[next_state_hash][best_next_action_hash]
                
            old_value = Q[state_hash][action_hash]

            new_value = (1 - alpha) * old_value + alpha * (reward + discount_factor * next_max)
            Q[state_hash][action_hash] = new_value
            
            if print_log:
                print('\n\n----------- STATE -------------')
                env.render()
                print(state_hash)
                print(actions)
                print(action_hashes)
                print(action_probabilities)
                print(action_idx)
                print(action)
                print(action_hash)
                print(reward)
                print(td_target)
                print(td_delta)
                print(done)
            
            if done: 
                break
                   
            state = deepcopy(next_state)
       
    return Q, stats 

In [18]:

def play_game(Q, player = -1):
    env = gym.make('tic_tac_toe-v1')
    state = env.reset()
    env.render()
    
    on_move = state['on_move']
    reward = 0
    done = False
    
    while not done:
        on_move = state['on_move']
        
        if player == on_move:
            print('Pick a move index')
            moves = env.move_generator()
            print(list(enumerate(moves)))
            idx = int(input())
            action = moves[idx]
        else:
            actions = Q[hash_state(state)].items()
            print(actions)
            best_action_hash = max(actions, key=operator.itemgetter(1))
            print(best_action_hash)
            best_action_hash = best_action_hash[0]
            action = [on_move, best_action_hash]
        
        state, reward, done, _ = env.step(action) 

        env.render()
    
    if reward == 0:
        print("Draw!")
    elif on_move == player:
        print('You won!')
    else:
        print('AI won!')
        
    return env

In [19]:
(Q, stats) = q_learn(30000, print_log=False)


0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000


In [None]:
play_game(Q)

on move:  X
      
      
      
dict_items([(0, 1.0), (1, 1.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 0.0), (6, 0.0), (7, -1.0), (8, -1.0)])
(0, 1.0)
on move:  O
X     
      
      
Pick a move index
[(0, [-1, 1]), (1, [-1, 2]), (2, [-1, 3]), (3, [-1, 4]), (4, [-1, 5]), (5, [-1, 6]), (6, [-1, 7]), (7, [-1, 8])]
on move:  X
X     
  O   
      
dict_items([(1, 1.0), (2, 0.9984246610803995), (3, 0.9188485652062178), (5, -0.7784503039561823), (6, 0.0), (7, 43.621953019776925), (8, 0.9061484927998593)])
(7, 43.621953019776925)
on move:  O
X     
  O   
  X   
Pick a move index
[(0, [-1, 1]), (1, [-1, 2]), (2, [-1, 3]), (3, [-1, 5]), (4, [-1, 6]), (5, [-1, 8])]
on move:  X
X   O 
  O   
  X   
dict_items([(8, 0), (1, 0.0), (3, 0), (5, 0), (6, 0)])
(8, 0)
on move:  O
X   O 
  O   
  X X 
Pick a move index
[(0, [-1, 1]), (1, [-1, 3]), (2, [-1, 5]), (3, [-1, 6])]
