In [1]:
import gym
import gym_snake
import random
import numpy as np
from stable_baselines import PPO2
from gym.envs.registration import register

In [2]:
def inference(env, model):
    reward = 0
    length = 0

    obs = env.reset()
    done = False
    while not done:
        action, _ = model.predict(obs)
        obs, step_reward, done, _ = env.step(action)
        length += 1
        reward += step_reward
    return reward, length

In [14]:
def test_model(model_path, episodes=1000, kwargs={}):
    # Create env and load model
    registered_name = f'Snake-inference_{random.randint(0, 1000000)}-v0'
    register(id=registered_name,
             entry_point='gym_snake.envs:SnakeEnv',
             kwargs=kwargs)
    env = gym.make(registered_name)
    model = PPO2.load(model_path)
    
    # Perform inference
    rewards = []
    lengths = []
    for episode in range(episodes):
        print(f'Episode {episode+1}/{episodes}')
        reward, length = inference(env, model)
        rewards.append(reward)
        lengths.append(length)
    return rewards, lengths

In [17]:
rewards, lengths = test_model(
    model_path='/home/oskar/kth/kex/simple-env/experiments/models/empty_map_4e_1M.pkl',
    #model_path='/home/oskar/kth/kex/simple-env/experiments/models/increasing_complexity_4e_1M.pkl',
    #model_path='/home/oskar/kth/kex/simple-env/experiments/models/no_increased_complexity_4e_1M.pkl',
    episodes=10,
    kwargs = {
        'sticky': False,
        #'obstacle_rate': 0.08,
        'level': 'empty'
    })

print('\nMin reward:', min(rewards))
print('Max reward:', max(rewards))
print('Avg reward:', np.average(rewards))

print('\nMin length:', min(lengths))
print('Max length:', max(lengths))
print('Avg length:', np.average(lengths))

Loading a model without an environment, this model cannot be trained until it has a valid environment.
Episode 1/10
Episode 2/10
Episode 3/10
Episode 4/10
Episode 5/10
Episode 6/10
Episode 7/10
Episode 8/10
Episode 9/10
Episode 10/10

Min reward: 106.09999999999982
Max reward: 265.39999999999975
Avg reward: 212.0399999999996

Min length: 163
Max length: 1224
Avg length: 598.5
