In [1]:
%matplotlib inline

import base64
import io
import glob
import os
import random

from IPython import display
from IPython.display import HTML
import matplotlib.pyplot as plt
import numpy as np

import torch

from gym.wrappers import Monitor
from pyvirtualdisplay import Display

import sys
sys.path.append('../src')

from common.atari_wrappers import make_atari
from nets import DQN
from utils import wrap_atari_dqn


vdisplay = Display(visible=0, size=(1400, 900))
display_params = vdisplay.start()

In [2]:
def get_state(obs, args=None):
    obs = np.asarray(obs)
    obs = obs.transpose(2, 0, 1)  # -> (c, h, w)
    obs = obs / 255.
    state = torch.from_numpy(obs).type(torch.float).unsqueeze(0)
    return state


def get_action(state, model, epsilon):
    r = random.random()
    if r > epsilon:
        action = model(state).max(dim=1)[1]
        action = action.view(1, 1)
    else:
        action = torch.tensor([[random.randrange(env.action_space.n)]], device=device, dtype=torch.long)
    return action


def validate(env, model, epsilon):
    model.eval()

    obs = env.reset()
    state = get_state(obs).to(device)
    G = 0
    while True:
        env.render(mode='rgb_array')
        action = get_action(state, policy_net, epsilon)
        obs, reward, done, info = env.step(action.item())
        G += reward
        next_state = get_state(obs).to(device)
        if done:
            break
        else:
            state = next_state
        
    print('EpisodeReward {}'.format(G))
    return G


def show_video(path):
    video = io.open(path, 'r+b').read()
    encoded = base64.b64encode(video)
    display.display(
        HTML(data='''<video alt="test" autoplay loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>'''.format(encoded.decode('ascii'))
        )
    )


"""
def show_video(logdir):
  mp4list = glob.glob(os.path.join(logdir, 'video/*.mp4'))
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    # mp4 = mp4list[-1]
    print(mp4)
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    display.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
"""

'\ndef show_video(logdir):\n  mp4list = glob.glob(os.path.join(logdir, \'video/*.mp4\'))\n  if len(mp4list) > 0:\n    mp4 = mp4list[0]\n    # mp4 = mp4list[-1]\n    print(mp4)\n    video = io.open(mp4, \'r+b\').read()\n    encoded = base64.b64encode(video)\n    display.display(HTML(data=\'\'\'<video alt="test" autoplay \n                loop controls style="height: 400px;">\n                <source src="data:video/mp4;base64,{0}" type="video/mp4" />\n             </video>\'\'\'.format(encoded.decode(\'ascii\'))))\n  else: \n    print("Could not find video")\n'

In [3]:
logdir = '../logs'

# expid = '20200724121347'
# # checkpoint_name = 'checkpoint-000027500-14.0.pt'
# checkpoint_name = 'checkpoint-000028400-13.0.pt'
# env_name = 'BreakoutNoFrameskip-v4'

env_name = 'PongNoFrameskip-v4'
# expid = '20200725044029'
# checkpoint_name = 'checkpoint-000000600--11.0.pt'
# checkpoint_name = 'checkpoint-000000900-15.0.pt'
expid = '20200725042931'
checkpoint_name = 'checkpoint-000002500-19.0.pt'


checkpoint_path = os.path.join(logdir, expid, checkpoint_name)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
epsilon = -1

In [4]:
env = make_atari(env_name)
env = wrap_atari_dqn(env)
_path = os.path.join(logdir, expid, 'video')

In [5]:
checkpoint = torch.load(checkpoint_path)
policy_net = DQN(84, 84, env.unwrapped.action_space.n)
policy_net.load_state_dict(checkpoint['policy_net'])
policy_net.to(device)
policy_net.eval()

DQN(
  (conv1): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
  (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=3136, out_features=512, bias=True)
  (out): Linear(in_features=512, out_features=6, bias=True)
)

In [6]:
env = Monitor(env, _path, force=True)
validate(env, policy_net, epsilon)
validate(env, policy_net, epsilon)
validate(env, policy_net, epsilon)
validate(env, policy_net, epsilon)

EpisodeReward 20.0
EpisodeReward 20.0
EpisodeReward 20.0
EpisodeReward 20.0


20.0

In [8]:
!ls ../logs/20200725042931/video

openaigym.video.0.6461.video000000.meta.json
openaigym.video.0.6461.video000000.mp4
openaigym.video.0.6461.video000001.meta.json
openaigym.video.0.6461.video000001.mp4


In [10]:
video_name = 'openaigym.video.0.6461.video000000.mp4'
show_video(os.path.join(logdir, expid, 'video', video_name))

In [9]:
raise 

RuntimeError: No active exception to reraise

In [7]:
!ls ../logs/20200725044029/video

openaigym.video.0.6061.video000000.meta.json
openaigym.video.0.6061.video000000.mp4
openaigym.video.0.6061.video000001.meta.json
openaigym.video.0.6061.video000001.mp4


In [9]:
video_name = 'openaigym.video.0.6061.video000001.mp4'
show_video(os.path.join(logdir, expid, 'video', video_name))

In [21]:
video_name = 'openaigym.video.1.5591.video000000.mp4'
show_video(os.path.join(logdir, expid, 'video', video_name))