In [1]:
import numpy as np
from tqdm import tqdm
import random
from game_env import ACTION_SPACE

In [2]:
def initialize_q_table():
    Qtable = {}
    return Qtable

def get_q_value(Qtable, state, action):
    return Qtable.get((state, action), 0.0)

def get_max(Qtable, state):
    best_action = None
    max_q_value = float('-inf')
    for action in ACTION_SPACE:
        q_value = get_q_value(Qtable, state, action)
        if q_value > max_q_value:
            max_q_value = q_value
            best_action = action
    return max_q_value, best_action

def hash_state(state): # make state tuple hashable
    hashed_grid = tuple(tuple(row) for row in state[1])
    return (state[0], hashed_grid)

In [3]:
def greedy_policy(Qtable, state):
    # Exploitation: take the action with the highest state, action value
    return get_max(Qtable, state)[1]

In [4]:
def epsilon_greedy_policy(Qtable, state, epsilon):
    # Randomly generate a number between 0 and 1
    random_num = random.uniform(0, 1)
    # if random_num > greater than epsilon --> exploitation
    if random_num > epsilon:
        # Take the action with the highest value given a state
        # np.argmax can be useful here
        action = greedy_policy(Qtable, state)
    # else --> exploration
    else:
        action = random.choice(ACTION_SPACE)

    return action

In [25]:
# Training parameters
n_training_episodes = 20  # Total training episodes
learning_rate = 0.7  # Learning rate

# Evaluation parameters
n_eval_episodes = 10000  # Total number of test episodes

# Environment parameters
max_steps = 30  # Max steps per episode
gamma = 0.9  # Discounting rate
eval_seed = []  # The evaluation seed of the environment

# Exploration parameters
max_epsilon = 1.0  # Exploration probability at start
min_epsilon = 0.01  # Minimum exploration probability
decay_rate = 0.005  # Exponential decay rate for exploration prob

In [26]:
from game_env import Game_Env
env = Game_Env()

In [27]:
def train(n_training_episodes, min_epsilon, max_epsilon, decay_rate, env, max_steps, Qtable):
    for episode in tqdm(range(n_training_episodes)):
        # Reduce epsilon (because we need less and less exploration)
        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * episode)
        # Reset the environment
        state, info = env.reset()
        state = hash_state(state)
        step = 0
        terminated = False
        truncated = False

        # repeat
        for step in range(max_steps):
            # Choose the action At using epsilon greedy policy
            action = epsilon_greedy_policy(Qtable, state, epsilon)

            # Take action At and observe Rt+1 and St+1
            # Take the action (a) and observe the outcome state(s') and reward (r)
            new_state, reward, terminated, result = env.step(action)
            new_state = hash_state(new_state)
            # Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]
            Qtable[(state, action)] = get_q_value(Qtable, state, action) + learning_rate * (
                reward + gamma * get_max(Qtable, new_state)[0] - get_q_value(Qtable, state, action)
            )

            #Qtable[(state, action)] = 5

            # If terminated or truncated finish the episode
            if terminated or truncated:
                break

            # Our next state is the new state
            state = new_state
    return Qtable

In [28]:
Qtable = initialize_q_table()

In [29]:
Qtable = train(n_training_episodes, min_epsilon, max_epsilon, decay_rate, env, max_steps, Qtable)

100%|██████████| 20/20 [00:00<00:00, 2530.65it/s]


In [32]:
for state, _ in Qtable:
    state_explored = False
    for action in ACTION_SPACE:
        if get_q_value(Qtable, state, action) > 0:
            print(get_q_value(Qtable, state, action), end=' ')
            state_explored = True
    if state_explored:
        print()
print('qtable printed done!')

7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
7.0 
qtable printed done!


In [31]:
print(len(Qtable.keys()))

423


In [33]:
import pickle

In [34]:
with open('qtable.pkl', 'wb') as file:
    pickle.dump(Qtable, file)