In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import gym
from gym import wrappers
from IPython import display
%matplotlib inline

from ASSETS.network_policy import Policy

In [2]:
def load_environment_and_network(folder_path):
    path_info = os.path.join(folder_path, "info.txt")
    path_policy = os.path.join(folder_path, "policy.pt")
    
    with open(path_info, 'r') as f:
        cont = f.read()
    
    info_dict = {}
    for line in cont.split("\n"):
        parts = line.split(" = ")
        if(len(parts) == 2):
            info_dict[parts[0]] = parts[1]
    env = gym.make(info_dict["env"])
    n_inputs = env.observation_space.shape[0]
    n_actions = env.action_space.n
    hiddensize = int(info_dict["hiddensize"])
    
    policy = Policy(n_inputs, n_actions, hiddensize)
    policy.load_state_dict(torch.load(path_policy))
    return env, policy

def show_replay():
    """
    Not-so-elegant way to display the MP4 file generated by the Monitor wrapper inside a notebook.
    The Monitor wrapper dumps the replay to a local file that we then display as a HTML video object.
    """
    import io
    import base64
    from IPython.display import HTML
    video = io.open('./gym-results/openaigym.video.%s.video000000.mp4' % env.file_infix, 'r+b').read()
    encoded = base64.b64encode(video)
    return HTML(data='''
        <video width="360" height="auto" alt="test" controls><source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>'''
    .format(encoded.decode('ascii')))


In [3]:
rollout_limit = 500
env, policy = load_environment_and_network(r"C:\Source\DeepLearningProject\SHs\TESTO")
env = wrappers.Monitor(env, "./gym-results", force=True)
state = env.reset()

for t in range(rollout_limit):
    state = torch.from_numpy(np.atleast_2d(state)).float()
    policy_distribution = policy(state)
    action = policy_distribution.detach().numpy().argmax()
    state, reward, done, _ = env.step(action)
    env.render()
    if(done): break

show_replay()