In [1]:
import gym
from collections import deque, namedtuple
import random
import numpy as np
from tqdm.notebook import tqdm 
import time
import altair as alt
import pandas as pd
import torch
from torch import nn
from torch.optim import SGD, Adam
from torch.distributions.categorical import Categorical
import toolz
import sys

if "./" not in sys.path:
    sys.path.insert(0, "./")

from core import create_model, reward_to_go, discount_cumsum

In [2]:
class ValueFunction():
    def __init__(self, input_size):
        self.model = create_model(input_size, 1)
        self.optim = Adam(self.model.parameters())
        self.loss_fn = nn.MSELoss()
    
    def __call__(self, x):
        return self.model(x)
    
    def fit(states, rewards):
        self.optim.zero_grad()
        loss = self.loss_fn(self(states), rewards)
        loss.backward()
        self.optim.step()
        return loss.item()

## Vanilla Policy Gradient

In [3]:
class Agent():
    def __init__(self, env, lr=1e-2, gamma=0.999, lam=0.95, batch_size=5000):
        self.env = env
        self.memory = self._create_memory()
        self.model = create_model(self.env.observation_space.shape[0], self.env.action_space.n)
        self.value_model = ValueFunction(self.env.observation_space.shape[0])
        self.optim = Adam(self.model.parameters(), lr=lr)
        self.batch_size = batch_size
        self.df = pd.DataFrame()
        self.epoch = 0
        self.gamma = gamma
        self.lam = lam
        
    def _reset_memory(self):
        self.memory = self._create_memory()
    
    @staticmethod
    def _create_memory():
        m = namedtuple("Memory", ["states", "actions", "rewards", "advantages"])
        for key in m._fields:
            setattr(m, key, [])
        return m
    
    def _print_memory_lens(self):
        [print(f"{key}: {len(getattr(self.memory, key))}") for key in self.memory._fields]
    
    def loss_fn(self, states, actions, rewards):
        log_ps = self.get_policy(states).log_prob(actions)
        return -(log_ps * rewards).mean()
    
    def get_policy(self, state):
        return Categorical(logits=self.model(state))
    
    def get_action(self, state):
        return self.get_policy(state).sample().item()
    
    def train(self, epochs, show_every=0):
        pbar = tqdm(range(epochs))
        data = []
        for epoch in pbar:
            if show_every and epoch and not (epoch % show_every): self.play()
            loss, returns, lens = self.train_step()
            row = pd.Series({
                "epoch": self.epoch,
                "loss": loss,
                "min_return": min(returns),
                "min_len": min(lens),
                "max_return": max(returns),
                "max_len": max(lens),
                "avg_return": sum(returns) / len(returns),
                "avg_len": sum(lens) / len(lens),
            })
            self.df = self.df.append(row, ignore_index=True)
            pbar.set_postfix(row)
            self.epoch += 1
        return self.df
            
    def play(self, fps=60):
        state = self.env.reset()
        done = False
        ep_len = 0
        total_reward = 0
        while not done:
            if fps: self.env.render()
            action = self.get_action(torch.as_tensor(state, dtype=torch.float32))
            state, reward, done, _ = self.env.step(action)
            if fps: time.sleep(1/fps)
            
            total_reward += reward
            ep_len += 1
            
        if fps: self.env.close()
        return ep_len, total_reward
    
    def fit(self, states, actions, weights):
        self.optim.zero_grad()
        loss = self.loss_fn(states, actions, weights)
        loss.backward()
        self.optim.step()
        return loss.item()
    
    def _calculate_advantage(self, rewards, values, last_value):
        deltas = rewards[:-1] + (self.gamma * values[1:]) - values[:-1]
        result = discount_cumsum(deltas, self.gamma * self.lam)
        return result
    
    def _calculate_discounted_rewards(self, rewards, last_value):
        rewards = np.append(rewards, last_value)
        result = discount_cumsum(rewards, self.gamma)[:-1]
        return result
    
    def train_step(self):
        ep_lens = []
        ep_returns = []
        while len(self.memory.states) < self.batch_size:
            state = self.env.reset()
            done = False
            ep_len = 0
            ep_return = 0
            ep_rewards = []
            states = []
            actions = []
            values = []
            while not done:
                states.append(state.copy())
                action = self.get_action(torch.as_tensor(state, dtype=torch.float32))
                state, reward, done, _ = self.env.step(action)
                
                values.append(self.value_model(torch.as_tensor(state, dtype=torch.float32)).item())
                actions.append(action)
                ep_rewards.append(reward)
                ep_return += reward
                
            last_value = values[-1] if self.env._elapsed_steps == self.env._max_episode_steps else 0
            
            self.memory.states += list(states)
            self.memory.actions += list(actions)
            
            self.memory.rewards += list(self._calculate_discounted_rewards(np.asarray(ep_rewards), last_value).tolist())
            
            self.memory.advantages += list(self._calculate_advantage(
                np.asarray(ep_rewards), 
                np.asarray(values),
                last_value
            ).tolist())
            
            ep_lens.append(self.env._elapsed_steps)
            ep_returns.append(ep_return)
        loss = self.fit(
            states=torch.as_tensor(self.memory.states, dtype=torch.float32),
            actions=torch.as_tensor(self.memory.actions, dtype=torch.int32),
            weights=torch.as_tensor(self.memory.rewards, dtype=torch.float32)
        )
        
        self._reset_memory()
        return loss, ep_returns, ep_lens
    
a = Agent(gym.make("CartPole-v0"))



## TODO ##
* Fit Model by advantage function
* Fit Value Function

In [4]:
results = a.train(90, show_every=15)

HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




In [5]:
base = alt.Chart(results.melt(["epoch", "min_return", "max_return", "min_len", "max_len"], value_vars=None, var_name="type"))
(
    base.mark_line().encode(x="epoch:Q", y="value:Q", color="type:N") +
    base.mark_area(color="blue", opacity=0.3).encode(x="epoch:Q", y="min_return:Q", y2="max_return:Q")
)

In [7]:
a.play()

(200, 200.0)