In [1]:
import awr_configs
import learning.awr_agent as awr_agent
import gym
import tensorflow as tf
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import torch
import util.rl_path as rl_path
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

In [2]:
configs = awr_configs.AWR_CONFIGS['LunarLanderContinuous-v2']
configs["action_std"] = 0.2

In [4]:
env = gym.make("LunarLanderContinuous-v2")
graph = tf.Graph()
sess = tf.Session(graph=graph)
agent = awr_agent.AWRAgent(env=env, sess=sess, **configs)

In [5]:
agent.load_model("output/model.ckpt")

INFO:tensorflow:Restoring parameters from output/model.ckpt
Model loaded from: output/model.ckpt


In [6]:
def sample_action(agent, s, action_std):
    n = len(s.shape)
    s = np.reshape(s, [-1, agent.get_state_size()])

    feed = {
        agent._s_tf : s
    }

    run_tfs = [agent._norm_a_pd_tf.parameters["loc"]]

    out = agent._sess.run(run_tfs, feed_dict=feed)
    loc = torch.tensor(out[0])
    
    a = np.array(torch.distributions.Normal(loc, scale=action_std).sample().tolist())
    
    if n == 1:
        a = a[0]
    
    return a

In [7]:
def rollout_path(agent, action_std):
    path = rl_path.RLPath()

    s = agent._env.reset()
    s = np.array(s)
    path.states.append(s)

    done = False
    while not done:
        a = sample_action(agent, s, action_std)
        s, r, done, info = agent._step_env(a)
        s = np.array(s)

        path.states.append(s)
        path.actions.append(a)
        path.rewards.append(r)

    path.terminate = agent._check_env_termination()

    return path

In [8]:
def gather_data(num_episodes, agent, action_std):
    episodes = []
    
    for _ in tqdm(range(num_episodes)):
        path = rollout_path(agent, action_std)
        I = np.hstack([np.array(path.states)[:-1], np.array(path.actions)])
        R = path.rewards
        S2 = np.array(path.states)[1:]
        episodes.append((I, R, S2))
        
    return episodes

In [12]:
data = gather_data(10000, agent, 0.2)

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))




In [None]:
# import pickle
# pickle.dump(data, open("data.pkl", "wb"))
# data = pickle.load(open("data.pkl", "rb"))

In [14]:
import gym
import numpy as np
import random
from sklearn.ensemble import ExtraTreesRegressor
import torch
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore")

In [15]:
class Policy(nn.Module):
    """Policy class with an epsilon-greedy dqn model"""
    def __init__(self, agent, action_std):
        super().__init__()
        self.agent = agent
        self.action_std = action_std

    def forward(self, states):
        return sample_action(self.agent, states, self.action_std)

In [71]:
import torch
import torch.nn as nn

class Q(nn.Module):
    """Q-network using a NN"""
    def __init__(self, state_dim, action_dim, lr):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        self.model = nn.Sequential(
            nn.Linear(self.state_dim + self.action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        
    def forward(self, state):
        """Forward"""
        state = torch.tensor(state).cuda().float()
        return self.model(state)
    
    def predict(self, state):
        """Forward without gradients (used for predictions)"""
        state = torch.tensor(state).cuda().float()
        with torch.no_grad():
            return self.model(state)
    
    def fit(self, state, true_value):
        """Fit NN with a single backward step"""
        state = torch.tensor(state).cuda().float()
        true_value = torch.tensor(true_value).cuda().float()
        self.optimizer.zero_grad()
        out = self(state).squeeze()
        loss = self.criterion(out, true_value)
        loss.backward()
        self.optimizer.step()

In [72]:
class FittedQEvaluation(object):
    def __init__(self, regressor=None):
        self.regressor = regressor or ExtraTreesRegressor()
        
    def Q(self, state_actions):
        """Return the Q function estimate of `states` for each action"""
        return self.regressor.predict(state_actions)

    def fit_Q(self, eval_policy, episodes, num_iters=100, discount=0.95):
        Is = []
        S2s = []
        Rs = []
        
        batches = []
        batch_len = len(episodes) // 10
        
        for i in range(10):
            Is = []
            S2s = []
            Rs = []

            for I,R,S2 in episodes[i * batch_len : (i + 1) * batch_len]:
                Is.append(I)
                Rs.append(R)
                S2s.append(S2)
            
            batches.append((np.concatenate(Is, 0), np.concatenate(Rs, 0), np.concatenate(S2s, 0)))
        
        for i in tqdm(range(num_iters)):
            for (Is, Rs, S2s) in batches:
                pi_S2s = eval_policy(S2s)
                S2pi_S2s = np.hstack([S2s, pi_S2s])
                Os = Rs + discount * self.Q(S2pi_S2s).cpu().numpy().reshape(-1)
                self.regressor.fit(Is, Os)    

In [73]:
qnn = Q(agent.get_state_size(), agent.get_action_size(), 0.001).cuda()

In [74]:
FQE = FittedQEvaluation(qnn)
policy = Policy(agent, 0.1)

In [75]:
FQE.fit_Q(policy, data, 200, agent._discount)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))




In [None]:
vals0 = []

for _ in tqdm(range(100)):
    path = rollout_path(agent, 0.1)
    true = sum([r * (agent._discount ** i) for i,r in enumerate(path.rewards)])
    pred = FQE.regressor.predict(np.hstack([path.states[0], path.actions[0]]).reshape(1,-1))[0].item()
    vals0.append([true, pred])

# vals1 = []
# for _ in tqdm(range(500)):
#     path = rollout_path(agent, 0.4)
#     true = sum([r * (agent._discount ** i) for i,r in enumerate(path.rewards)])
#     pred = FQE.Q(np.hstack([path.states[0], path.actions[0]]).reshape(1,-1))[0]
#     vals1.append([true, pred])

In [None]:
plt.scatter([val[0] for val in vals0], [val[1] for val in vals0], color="r")
# plt.hold()
# plt.scatter([val[0] for val in vals1], [val[1] for val in vals1], color="b")
plt.xlabel("True Value")
plt.ylabel("Prediction")