In [1]:
import gym
import gym_Snake
import time
import sys
from IPython.display import clear_output
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import os
from Sarsa import SARSA
from QLearning import QLearning

In [2]:
# Name of new env
env_name = 'Double_v'

# Name of the env in which the model was trained
train_env = 'Double_v'

# Algorithm used (QL or SARSA)
my_algo_name = 'QL'

# Training info (for retrive the file)
nb_iterations = 100000
eps_min_after = 70000

# Number of tests to run
nb_tests = 200

In [3]:
reward_eat = 1

custom_rewards = {
    "REWARD_TARGET": reward_eat,
    "REWARD_COLLISION": -1,
    "REWARD_TOWARD": 0,
    "REWARD_AWAY": 0
}



env = gym.make('Snake-v0', 
               player = 'computer', 
               shape = env_name, 
               state_mode = 'states', 
               reward_mode = 'normal', 
               width = 10, 
               height = 10, 
               solid_border = True,
               rewards = custom_rewards)

In [4]:
min_epsilon = 0.01

QL = QLearning(n_actions = env.action_space.n,
               n_states = env.observation_space.n, 
               discount = 0.9, 
               alpha = 0.2, 
               epsilon = min_epsilon,
               min_epsilon = min_epsilon)

SA = SARSA(n_actions = env.action_space.n, 
           n_states = env.observation_space.n,
           discount = 0.9,
           alpha = 0.2,
           epsilon = min_epsilon,
           min_epsilon = min_epsilon)


In [5]:
def play_epoch(algo, env, render = False, sleep_time = 0.5):
    
    # Reset env
    obs = env.reset()
    algo.reset(obs)

    done = False
    
    # Sum the rewards
    total_rew = 0
    i = 0
    eated = 0
    
    while not done:
        # Show
        if render: env.render()
        # Choose next action
        new_act = algo.act()
        # Act in the env
        obs, reward, done, info = env.step(new_act)
        # Store reward
        total_rew += reward
        if reward == reward_eat: eated += 1
        # Update algorithm
        algo.update(new_act, reward, obs)
        # Slow render
        if render: time.sleep(sleep_time)
        i += 1
            
    # Return total reward
    return total_rew, eated, i

In [6]:


# Ensure the current working directory is correct
filepath = globals()['_dh'][0]
os.chdir(filepath)

table = np.load(f'QL_results/{my_algo_name}_epochs_{nb_iterations}_batch_1000_x_5000_epsilon_{eps_min_after}_train_{train_env}.npy')

my_algo_test = SA if my_algo_name == 'SARSA' else QL
my_algo_test.Q = table
my_algo_test.epsilon = 0.01




times = []
steps = []
eated = []
rewards = []
 

for _ in tqdm(range(nb_tests)):
    start = time.time()
    r, e, i = play_epoch(algo = my_algo_test, env = env, render = False)
    stop = time.time()
    times.append(stop - start)
    steps.append(i)
    eated.append(e)
    rewards.append(r)
    
print('\n#########################')
print(f'After {nb_tests} tests with  {my_algo_name}:')
print(f'(trained on {train_env}, test on {env_name})')
print()
print(f'Average targets eated:       {np.mean(eated)}')
# print(f'Average reward:              {np.mean(rewards)}')
print(f'Max rewards eated:           {np.max(eated)}')
print(f'Median of rewards eated:     {round(np.median(eated))}')
print(f'Std of rewards eated:        {round(np.std(eated), 3)}')
print(f'Average time per simulation: {round(np.mean(times), 4)}')
print(f'Average time per step:       {round(np.sum(times) / np.sum(steps), 6)}')

100%|█████████████████████████████████████████| 200/200 [00:03<00:00, 61.42it/s]


#########################
After 200 tests with  QL:
(trained on Double_v, test on Double_v)

Average targets eated:       7.57
Max rewards eated:           23
Median of rewards eated:     7
Std of rewards eated:        4.978
Average time per simulation: 0.0162
Average time per step:       0.000109





In [7]:
# from IPython.display import HTML, Javascript, display
# def restart_kernel_and_run_all_cells():
#     display(HTML(
#         '''
#             <script>
#                 code_show = false;
#                 function restart_run_all(){
#                     IPython.notebook.kernel.restart();
#                     setTimeout(function(){
#                         IPython.notebook.execute_all_cells();
#                     }, 10000)
#                 }
#                 function code_toggle() {
#                     if (code_show) {
#                         $('div.input').hide(200);
#                     } else {
#                         $('div.input').show(200);
#                     }
#                     code_show = !code_show
#                 }
#                 code_toggle() 
#                 restart_run_all()
#             </script>

#         '''
#     ))
# restart_kernel_and_run_all_cells()