In [1]:
import gym
import time
import argparse
import numpy as np
import torch

from lib import wrappers
from lib import dqn_model

In [2]:
DEFAULT_MODEL = "PongNoFrameskip-v4-best.dat"
DEFAULT_ENV_NAME = "PongNoFrameskip-v4"
DEFAULT_RECORD = "video"

#FPS parameter is to show the game with this FPS
FPS = 25

In [3]:
if __name__ == "__main__":
    #program accept model file name as parameter, environment to load, directory for storing recordings.
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", default=DEFAULT_MODEL, help="Model file to load")
    parser.add_argument("-e", "--env", default=DEFAULT_ENV_NAME,help="Environment name to use, default=" + DEFAULT_ENV_NAME)
    parser.add_argument("-r", "--record", default=DEFAULT_RECORD, help="Directory to store video recording")
    args, unknown = parser.parse_known_args()
    
    #we set up environment and our model, and fill in parameters
    env = wrappers.make_env(args.env)
    if args.record:
        env = gym.wrappers.Monitor(env, args.record)
    net = dqn_model.DQN(env.observation_space.shape, env.action_space.n)
    net.load_state_dict(torch.load(args.model))
    
    #this similart to the play_step() in Agent when training, but don't have epsilon greedy method, we just send the 
    #observation result to the agent, and let the agent to choose the action with the maximum value
    #we use render() to render the game to see.
    state = env.reset()
    total_reward = 0.0
    while True:
        start_ts = time.time()
        env.render()
        state_v = torch.tensor(np.array([state], copy=False))
        q_vals = net(state_v).data.numpy()[0]
        action = np.argmax(q_vals)
        
        #we send action to environment, calculate totoal reward, stop loop when episode ended.
        state, reward, done, _ = env.step(action)
        total_reward += reward
        if done:
            break
        delta = 1/FPS - (time.time() - start_ts)
        if delta > 0:
            time.sleep(delta)
    print("Total reward : %.2f" %total_reward)

Total reward : 20.00
