In [1]:
%matplotlib inline

In [2]:
import yaml
import datetime

from IPython.display import clear_output
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

In [3]:
import math
import os 
import random 
import numpy as np 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd 
import torch.nn.functional as F

In [4]:
import itertools
import pickle
np.set_printoptions(precision=3)

In [5]:
seed_value = 324267*2# sys.argv[1]

os.environ['PYTHONHASHSEED']=str(seed_value) 
random.seed(seed_value) 
np.random.seed(seed_value) 
torch.manual_seed(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [6]:
import gym
# CartPole-v0 Environment
env_id = "CartPole-v0"
env = gym.make(env_id)
env.seed(seed_value);

In [7]:
class DQN(nn.Module): #base model
    def __init__(self, num_inputs, num_actions, HIDDEN_LAYER_WIDTH):
        super(DQN, self).__init__()
        
        self.action_dim = num_actions
        
        self.layers = nn.Sequential(
            nn.Linear(num_inputs, HIDDEN_LAYER_WIDTH),
            nn.ReLU(),
            nn.Linear(HIDDEN_LAYER_WIDTH, HIDDEN_LAYER_WIDTH),
            nn.ReLU(),
            nn.Linear(HIDDEN_LAYER_WIDTH, num_actions)
        )

    def forward(self, x):
        return self.layers(x)
    
def get_action(policy_net, device, state, epsilon):
    with torch.no_grad():
        if random.random() > epsilon:
            state   = torch.FloatTensor(state).unsqueeze(dim=0).to(device)
            q_values = policy_net(state)
            action  = q_values.max(dim=1)[1].item()
        else:
            action = random.randrange(self.action_dim)
    return action

In [8]:
class DuelingDQN(nn.Module):
    def __init__(self, num_inputs, num_actions, HIDDEN_LAYER_WIDTH):
        super(DuelingDQN, self).__init__()
        self.action_dim = num_actions
        
        self.feature = nn.Sequential(
            nn.Linear(num_inputs, 128),
            nn.ReLU()
        )
        
        self.advantage = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, num_actions)
        )
        
        self.value = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        
    def forward(self, x):
        x = self.feature(x)
        advantage = self.advantage(x)
        value     = self.value(x)
        return value + advantage  - advantage.mean()
    
#     def act(self, state, epsilon):
#         with torch.no_grad():
#             if random.random() > epsilon:
#                 state   = torch.FloatTensor(state).unsqueeze(dim=0).to(self.device)
#                 q_values = self.forward(state)
#                 action  = q_values.max(dim=1)[1].item()
#             else:
#                 action = random.randrange(self.action_dim)
#         return action

In [9]:
exp_list = ['D1QN_D1QN_Naive_1000_freq_100_324267_04171400',
'D1QN-PER_D1QN_NaivePER_1000_freq_100_324267_04164957',
'DQN_DQN_Naive_1000_freq_100_324267_04171233',
'DQN-PER_DQN_NaivePER_1000_freq_100_324267_04164826',
'DQN-PER-original_DQN_NaivePER_100000_freq_1000_324267_04163756',

'D2QN_D2QN_Naive_1000_freq_100_324267_04170846',
'D2QN-PER_D2QN_NaivePER_1000_freq_100_324267_04165444',

'DuDQN_DuDQN_Naive_1000_freq_100_324267_04170649',
'DuDQN-PER_DuDQN_NaivePER_1000_freq_100_324267_04165300',

'DuD2QN_DuD2QN_Naive_1000_freq_100_324267_04170223',
'DuD2QN-PER_DuD2QN_NaivePER_1000_freq_100_324267_04165910']

In [10]:
exp_dict = {}
for exp in exp_list:
    exp_name = exp.split('_')[0]
    exp_dict[exp_name] = exp

In [11]:
experiment = 'DuDQN'
log_name = exp_dict[experiment]

In [12]:
# MEM_FILE = './memories/' + log_name + '.mpk'

# # Load Memories
# with open(MEM_FILE, 'rb') as fpr:
#     memories = np.array(list(pickle.load(fpr)))

# visited_states = np.stack(memories[:,0]).squeeze()
# actions = memories[:,1].astype(np.float32)
# rewards = memories[:,2].astype(np.float32)
# next_states = np.stack(memories[:,3]).squeeze()
# done = memories[:,4].astype(np.bool)


In [13]:
# FROM CONFIG FILE
config_path =  './configs/' + experiment + '.yaml' # sys.argv[2]
config = yaml.safe_load(open(config_path,'r'))

USE_GPU = config['USE_GPU']
# Use CUDA
USE_CUDA = torch.cuda.is_available() and USE_GPU

if USE_CUDA:
    torch.cuda.manual_seed(seed_value)
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# MODEL
if (config['MODEL_NAME']=='D1QN' or config['MODEL_NAME']=='DQN' or config['MODEL_NAME']=='D2QN'):
    # only one NN for estimating Q-values
    policy_net = DQN(env.observation_space.shape[0], 
                 env.action_space.n,
                 config['HIDDEN_LAYER_WIDTH'])
    policy_net = policy_net.to(device)


elif (config['MODEL_NAME']=='DuDQN' or config['MODEL_NAME']=='DuD2QN'):
    # one policy_net and one target_net
    policy_net = DuelingDQN(env.observation_space.shape[0], 
                 env.action_space.n,
                 config['HIDDEN_LAYER_WIDTH'])

    policy_net = policy_net.to(device)

else: #default policy_net is D1QN
    # only one NN for estimating Q-values
    policy_net = DQN(env.observation_space.shape[0], 
                 env.action_space.n,
                 config['HIDDEN_LAYER_WIDTH'])

    policy_net = policy_net.to(device)

# # Load Learned Model Parameters
MODEL_FILE = './models/'+ log_name + '.pth'
policy_net.load_state_dict(torch.load(MODEL_FILE));
policy_net.eval();

In [14]:
import ipywidgets as widgets
from ipywidgets import interact, interact_manual

In [15]:
cart_pos_threshold = 2.4
theta_threshold = 12 * 2 * np.pi / 360 # ~ 0.21

In [16]:
def get_qvalue(x_pos=0.0, x_vel=0.0, a_pos=0.0, a_vel=0.0):
    state = [x_pos, x_vel, a_pos, a_vel]
    with torch.no_grad():
        state  = torch.FloatTensor(state).unsqueeze(dim=0).to(device)
        q_values = policy_net(state)
    
    rects=plt.bar([0,0.35], 
            q_values.cpu().numpy().squeeze(0),
            width = 0.35,
            color=['r','g'])
    plt.ylim([-200,200])
    plt.xticks([0,0.35],['left','right'])
    
    for rect in rects:
        height = 50#rect.get_height()
        plt.text(rect.get_x() + rect.get_width()/2., 1.05*height,
                '%.2f' % rect.get_height(),
                ha='center', va='bottom')

In [17]:
# Q-values for each action for each of the state space
interact(get_qvalue, x_pos=widgets.FloatSlider(min=-cart_pos_threshold, 
                                                max=cart_pos_threshold, 
                                                step=0.1, 
                                                value=0.0),
                     x_vel=widgets.FloatSlider(min=-5, 
                                                max=5, 
                                                step=0.1, 
                                                value=0.0),
                     a_pos=widgets.FloatSlider(min=-theta_threshold, 
                                                max=theta_threshold, 
                                                step=0.001, 
                                                value=0.0),
                     a_vel=widgets.FloatSlider(min=-5, 
                                                max=5, 
                                                step=0.1, 
                                                value=0.0) );

interactive(children=(FloatSlider(value=0.0, description='x_pos', max=2.4, min=-2.4), FloatSlider(value=0.0, d…

In [18]:
# Heatmap of 2d state space q-value function
def get_heatmap(xvel=0.0, avel=0.0):
    xpos = np.linspace(-cart_pos_threshold,cart_pos_threshold,50)
    apos = np.linspace(-theta_threshold,theta_threshold,50)

    XX,YY = np.meshgrid(xpos,apos)

    q = np.zeros([len(xpos),len(apos),2])
    ga = np.zeros([len(xpos),len(apos)])


    for i in range(len(xpos)):
        for j in range(len(apos)):
            state = [xpos[i],xvel,apos[i],avel]
            with torch.no_grad():
                state  = torch.FloatTensor(state).unsqueeze(dim=0).to(device)
                q_values = policy_net(state).cpu().numpy().squeeze(0)
            q[i,j]=q_values
            ga[i,j]=np.argmax(q_values)

    q_left = q[:,:,0]
    q_right = q[:,:,1]

    fig,ax = plt.subplots(1,3,figsize=(10,3))
    sns.heatmap(q_left, ax=ax[0], cbar=True, xticklabels=False, yticklabels=False,vmin=-200, vmax=200, cmap="Reds")
    sns.heatmap(q_right,ax=ax[1], cbar=True, xticklabels=False, yticklabels=False,vmin=-200, vmax=200, cmap="Greens")
    sns.heatmap(ga,ax=ax[2], cbar=True, xticklabels=False, yticklabels=False,vmin=0, vmax=1,cmap=['r','g'])

    ax[0].set_title('Left')
    ax[1].set_title('Right')    


In [19]:
interact(get_heatmap, xvel=widgets.FloatSlider(min=-5, 
                                                max=5, 
                                                step=0.1, 
                                                value=0.0),
                     avel=widgets.FloatSlider(min=-5, 
                                                max=5, 
                                                step=0.1, 
                                                value=0.0) );

interactive(children=(FloatSlider(value=0.0, description='xvel', max=5.0, min=-5.0), FloatSlider(value=0.0, de…