In [None]:
import gym
from collections import deque, namedtuple
import random
from tqdm.notebook import tqdm 
import time
import pandas as pd
import torch
from torch import nn
from torch.optim import SGD, Adam
from torch.distributions.categorical import Categorical

## Vanilla Policy Gradient

In [7]:
def create_model(input_size, output_size):
    return nn.Sequential(
        nn.Linear(input_size, 32),
        nn.Tanh(),
        nn.Linear(32, output_size),
        nn.Identity()
    )

class Agent():
    def __init__(self, env, lr=1e-2, batch_size=5000):
        self.env = env
        self.memory = self._create_memory()
        self.model = create_model(self.env.observation_space.shape[0], self.env.action_space.n)
        self.optim = Adam(self.model.parameters(), lr=lr)
        self.batch_size = batch_size
        self.df = pd.DataFrame()
        self.epoch = 0
        
    def _reset_memory(self):
        self.memory = self._create_memory()
    
    @staticmethod
    def _create_memory():
        m = namedtuple("Memory", ["states", "actions", "rewards", "ep_lens", "ep_rewards", "weights"])
        for key in m._fields:
            setattr(m, key, [])
        return m
    
    def loss_fn(self, states, actions, rewards):
        log_ps = self.get_policy(states).log_prob(actions)
        return -(log_ps * rewards).mean()
    
    def get_policy(self, state):
        return Categorical(logits=self.model(state))
    
    def get_action(self, state):
        return self.get_policy(state).sample().item()
    
    def train(self, epochs, show_every=0):
        pbar = tqdm(range(epochs))
        data = []
        for epoch in pbar:
            if show_every and epoch and not (epoch % show_every): self.play()
            loss, returns, lens = self.train_step()
            row = pd.Series({
                "epoch": self.epoch,
                "loss": loss.item(),
                "max_return": max(returns),
                "max_len": max(lens),
                "avg_return": sum(returns) / len(returns),
                "avg_len": sum(lens) / len(lens),
            })
            self.df = self.df.append(row, ignore_index=True)
            pbar.set_postfix(row)
            self.epoch += 1
        return self.df
            
    def play(self, fps=0):
        state = self.env.reset()
        done = False
        ep_len = 0
        total_reward = 0
        while not done:
            if fps: self.env.render()
            action = self.get_action(torch.as_tensor(state, dtype=torch.float32))
            state, reward, done, _ = self.env.step(action)
            if fps: time.sleep(1/fps)
            
            total_reward += reward
            ep_len += 1
            
        if fps: self.env.close()
        return ep_len, total_reward
    
    def train_step(self):  
        while len(self.memory.states) < self.batch_size:
            state = self.env.reset()
            done = False
            total_reward = 0
            ep_len = 0
            while not done:
                self.memory.states.append(state.copy())
                action = self.get_action(torch.as_tensor(state, dtype=torch.float32))
                state, reward, done, _ = self.env.step(action)
                
                self.memory.actions.append(action)
                self.memory.rewards.append(reward)
                ep_len += 1
                total_reward += reward
                
            self.memory.ep_lens.append(ep_len)
            self.memory.ep_rewards.append(total_reward)
            self.memory.weights += [total_reward] * ep_len
        
        self.optim.zero_grad()
        loss = self.loss_fn(
            torch.as_tensor(self.memory.states, dtype=torch.float32),
            torch.as_tensor(self.memory.actions, dtype=torch.int32),
            torch.as_tensor(self.memory.weights, dtype=torch.float32)
        )
        loss.backward()
        self.optim.step()

        ep_returns = self.memory.ep_rewards
        ep_lens = self.memory.ep_lens
        self._reset_memory()
        return loss, ep_returns, ep_lens
a = Agent(gym.make("CartPole-v0"))

In [9]:
a.train(15)

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))




Unnamed: 0,avg_len,avg_return,epoch,loss,max_len,max_return
0,21.5,21.5,0.0,18.721119,72.0,72.0
1,24.105769,24.105769,1.0,20.568045,93.0,93.0
2,29.25,29.25,2.0,26.438005,156.0,156.0
3,32.160256,32.160256,3.0,28.966616,117.0,117.0
4,31.936306,31.936306,4.0,25.274162,87.0,87.0
5,35.744681,35.744681,5.0,30.278904,160.0,160.0
6,35.85,35.85,6.0,28.335466,122.0,122.0
7,42.075,42.075,7.0,34.235592,136.0,136.0
8,42.091667,42.091667,8.0,31.296282,109.0,109.0
9,48.153846,48.153846,9.0,36.680809,123.0,123.0


In [103]:
a.play(fps=0)

(127, 127.0)