In [1]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
import matplotlib.pyplot as plt
import time
import random
import torch
from torch import nn
from common.network import DuelingNetwork
from common.hparameter import *

In [2]:
"speed"
speed = "slow" # "slow", "equal" or "fast"

In [3]:
""" seed """
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

''' divice '''
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

""" Network """
net_p = DuelingNetwork(10, 13).to(device)
net_e = DuelingNetwork(10, 13).to(device)

""" Environment """
env_dir = os.path.join(os.pardir, 'c1ae')
sys.path.append(env_dir)
from chase1_and_escape import Chase1AndEscape

if speed == "fast":
    speed_p = 3.6
elif speed == "equal":
    speed_p = 3.0
elif speed == "slow":
    speed_p = 2.4

speed_e = 3
max_step_episode = 300
env = Chase1AndEscape(speed_pursuer=speed_p, speed_evader=speed_e, max_step=max_step_episode)

""" Load """
net_p.load_state_dict(torch.load("../model/c1ae/p_" + str (speed_p) + ".pth"))
net_e.load_state_dict(torch.load("../model/c1ae/e_" + str (speed_p) + ".pth"))

""" No. of episodes """
num_episodes_test = 10

cpu


In [4]:
pursuer_pos_list = []
evader_pos_list = []

for episode in range(num_episodes_test):
    
    '''reset'''    
    pursuer_pos_episode = []
    evader_pos_episode = []
    step_episode = 0
    obs_p, obs_e = env.reset()
    obs_p, obs_e = torch.Tensor(obs_p), torch.Tensor(obs_e)
    done = False
    total_reward_p = 0
    total_reward_e = 0
    
    while not done:
        
        action_p = net_p.act(obs_p.float().to(device), 0)
        action_e = net_e.act(obs_e.float().to(device), 0)        
        
        next_obs_p, next_obs_e, reward_p, reward_e, done = env.step(action_p, action_e, step_episode)        
        next_obs_p, next_obs_e = torch.Tensor(next_obs_p), torch.Tensor(next_obs_e)
        
        total_reward_p += reward_p
        total_reward_e += reward_e

        obs_p = next_obs_p
        obs_e = next_obs_e
        
        step_episode += 1        
        pos_pursuer = env.pos_p 
        pos_evader = env.pos_e 
        pursuer_pos_episode.append(pos_pursuer)
        evader_pos_episode.append(pos_evader)
    
    pursuer_pos_list.append(pursuer_pos_episode)
    evader_pos_list.append(evader_pos_episode)


In [5]:
%matplotlib notebook
from matplotlib import animation

red = [177/255, 24/255, 42/255]
darkblue = [4/255, 44/255, 88/255]
blue = [31/255, 100/255, 169/255]
lightblue = [65/255, 144/255, 194/255]

fig = plt.figure(figsize=(8,3))
fig.subplots_adjust(bottom = 0.2)
ax = fig.add_subplot(111)
ax.set_aspect('equal')

pursuer_pos_video = []
evader_pos_video = []
for j in range(num_episodes_test):
    pursuer_pos_video.extend(pursuer_pos_list[j])
    evader_pos_video.extend(evader_pos_list[j])

def update_func(i):

    ax.clear()
    
    ax.plot([-1, -1], [1, -1], color="black")
    ax.plot([1, 1], [1, -1], color="black")
    ax.plot([1, -1], [-1, -1], color="black")
    ax.plot([1, -1], [1, 1], color="black")

    offsets = [4, 3, 2, 1]
    sizes = [0.5, 1, 2, 3]
    alphas = [0.1, 0.2, 0.3, 0.4]
    for offset, size, alpha in zip(offsets, sizes, alphas):
        if i >= offset:
            ax.plot(evader_pos_video[i-offset][0], evader_pos_video[i-offset][1], 'o', markersize=size, color=red, alpha=alpha)
            ax.plot(pursuer_pos_video[i-offset][0], pursuer_pos_video[i-offset][1], 'o', markersize=size, color=darkblue, alpha=alpha)

    ax.plot(evader_pos_video[i][0], evader_pos_video[i][1], 'o', markersize=4, color=red)
    ax.plot(pursuer_pos_video[i][0], pursuer_pos_video[i][1], 'o', markersize=4, color=darkblue)
    
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False, bottom=False, left=False, right=False, top=False)
    ax.set_xlabel("X-position", fontsize=14)
    ax.set_ylabel("Y-position", fontsize=14)
    ax.set_aspect('equal')
    
ani = animation.FuncAnimation(fig, update_func, frames=len(evader_pos_video), interval=100, repeat=False)
plt.show()


<IPython.core.display.Javascript object>