In [1]:
import os
import re
import pickle
import gym
import gym_snake
import numpy as np
from stable_baselines import PPO2
from gym.envs.registration import register

In [8]:
def inference(model, env):
    reward = 0
    length = 0
    since_last_reward = 0

    obs = env.reset()
    done = False
    while not done:
        action, _ = model.predict(obs)
        obs, step_reward, done, _ = env.step(action)
        reward += step_reward
        length += 1
        since_last_reward = 0 if step_reward > 5 else since_last_reward + 1
        if since_last_reward > 500:
            done = True
            reward -= 10
    return reward, length

In [9]:
def test_model(model_path, env, episodes=100):
    model = PPO2.load(model_path)
    rewards = []
    lengths = []
    for episode in range(1, episodes+1):
        if episode % 10 == 0: print(f'Episode {episode}/{episodes}')
        reward, length = inference(model, env)
        rewards.append(reward)
        lengths.append(length)
    return rewards, lengths

In [10]:
def get_filenames(directory):
    filenames = []
    for path, _, files in os.walk(directory):
        for name in files:
            if name.endswith('.pkl'):
                filename = os.path.join(path, name)
                matches = re.findall('timestep_(.+)_complexity', filename)
                assert(len(matches) == 1)
                filenames.append((int(matches[0]), filename))
    filenames.sort()
    return filenames

In [11]:
def create_env(name, kwargs={}):
    registered_name = f'Snake-inference-{name}-v0'
    register(id=registered_name,
             entry_point='gym_snake.envs:SnakeEnv',
             kwargs=kwargs)
    env = gym.make(registered_name)
    return env

In [12]:
try:
    env = create_env(name='empty',
                     kwargs={
                        'sticky': False,
                        #'obstacle_rate': 0.08,
                        'level': 'empty'
                     })
except:
    print('Environment already registered\n')

for timesteps, filename in get_filenames('/home/oskar/kth/kex/simple-env/experiments/10M/models/empty'):
    print('Evaluating model, timesteps:', timesteps)
    rewards, lengths = test_model(filename, env)
    print()
    with open(f'inference/data_{timesteps}.pkl', 'wb') as file:
        pickle.dump((rewards, lengths), file)

Environment already registered

Evaluating model, timesteps: 50000
Loading a model without an environment, this model cannot be trained until it has a valid environment.
Episode 10/100
Episode 20/100
Episode 30/100
Episode 40/100
Episode 50/100
Episode 60/100
Episode 70/100
Episode 80/100
Episode 90/100
Episode 100/100

Evaluating model, timesteps: 100000
Loading a model without an environment, this model cannot be trained until it has a valid environment.
Episode 10/100
Episode 20/100
Episode 30/100
Episode 40/100
Episode 50/100
Episode 60/100
Episode 70/100
Episode 80/100
Episode 90/100
Episode 100/100

Evaluating model, timesteps: 150000
Loading a model without an environment, this model cannot be trained until it has a valid environment.
Episode 10/100
Episode 20/100
Episode 30/100
Episode 40/100
Episode 50/100
Episode 60/100
Episode 70/100
Episode 80/100
Episode 90/100
Episode 100/100

Evaluating model, timesteps: 200000
Loading a model without an environment, this model cannot be

KeyboardInterrupt: 