In [13]:
%load_ext autoreload
%autoreload 2

import gym
from matplotlib import pyplot as plt
from collections import deque
import numpy as np

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [15]:
def mish(input):
    return input * torch.tanh(F.softplus(input))

class Mish(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, input): return mish(input)

In [32]:
class Actor(nn.Module):
    def __init__(self, state_dim, n_actions, activation=nn.Tanh):
        super().__init__()
        self.n_actions = n_actions
        self.model = nn.Sequential(
            nn.Linear(state_dim, 512),
            activation(),
            nn.Linear(512, 128),
            activation(),
            nn.Linear(128, n_actions)
        )
        
        logstds_param = nn.Parameter(torch.full((n_actions,), 0.1))
        self.register_parameter("logstds", logstds_param)
    
    def forward(self, X):
        means = self.model(X)
        stds = torch.clamp(self.logstds.exp(), 1e-3, 50)
        
        return torch.distributions.Normal(means, stds)
    
## Critic module
class Critic(nn.Module):
    def __init__(self, state_dim, activation=nn.Tanh):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 1024),
            activation(),
            nn.Linear(1024, 512),
            activation(),
            nn.Linear(512, 1),
        )
    
    def forward(self, X):
        return self.model(X)

class Memory:
    def __init__(self):
        self.log_probs = []
        self.values = []
        self.dones = []
        self.rewards = []
        self.actions = []
        self.states = []
    
    def add(self, log_prob, value, done, reward, action, states):
        self.log_probs.append(log_prob)
        self.values.append(value)
        self.dones.append(done)
        self.rewards.append(reward)
        self.actions.append(action)
        self.states.append(state)
    
    def clear(self):
        self.log_probs = []
        self.values = []
        self.dones = []
        self.rewards = []
        self.actions = []
        self.states = []
    
    def __len__(self):
        return len(self.log_probs)

In [33]:
def clip_grad_norm_(module, max_grad_norm=0.5):
    nn.utils.clip_grad_norm_([p for g in module.param_groups for p in g["params"]], max_grad_norm)

In [34]:
def update(mem, next_state):
    q_values = []
    
    # Do not add this one as it's for the next state
    running_reward = critic(torch.tensor(next_state).float().to(device))
    
    for reward, done in zip(reversed(mem.rewards), reversed(mem.dones)):
        running_reward = (reward + gamma * running_reward) * (1-done)
        
        q_values.append(running_reward)
    
    q_values = torch.stack(q_values).to(device)
    values = torch.stack(mem.values).to(device)
    log_probs = torch.stack(mem.log_probs).to(device)
    
    advantage = q_values - values
    
    critic_loss = torch.pow(advantage, 2).mean()
    
    clip_grad_norm_(critic_optim)
    
    critic_optim.zero_grad()
    critic_loss.backward(retain_graph=True)
    critic_optim.step()
    
    # Actor
    states = torch.FloatTensor(mem.states)
    normal_dists = actor(states)
    entropy = normal_dists.entropy().mean()

    actor_loss = (-log_probs * advantage).mean() - entropy*0.5
    
    clip_grad_norm_(actor_optim)
    actor_optim.zero_grad()
    actor_loss.backward(retain_graph=True)
    actor_optim.step()
    


In [35]:
# env = gym.make("CartPole-v1")
# env = gym.make("LunarLander-v2")
env = gym.make("Pendulum-v0")


state_space = env.observation_space.shape[0]
action_space = env.action_space.shape[0]

In [36]:
win_condition = 0

num_episodes = 500
bootstrap_len = 16

gamma = 0.99

In [37]:
mem = Memory()

actor = Actor(state_space, action_space, activation=Mish).to(device)
critic = Critic(state_space, activation=Mish).to(device)

actor_optim = optim.Adam(actor.parameters(), lr=1e-3)
critic_optim = optim.Adam(critic.parameters(), lr=1e-3)

In [38]:
running_scores = deque(maxlen=100)
score_log = []
average_score_log = []

In [None]:
t = 0
for i in range(num_episodes):
    state = env.reset()
    steps = 0
    score = 0
    
    while True:
        steps += 1
            
        dists = actor(torch.tensor(state).float().to(device))
        actions = dists.sample()
        actions_clipped = np.clip(actions.cpu().numpy(), env.action_space.low.min(), env.action_space.high.max())
        
        log_probs = dists.log_prob(actions)

        next_state, reward, done, _ = env.step(actions_clipped)

        
        value = critic(torch.tensor(state).float().to(device))

        mem.add(log_probs, value, done, reward, actions, state)

        if (done or steps % bootstrap_len == 0):
            update(mem, next_state)
            mem.clear()
        
        if done:
            break
        
        score += reward
        state = next_state

    score_log.append(score)
    running_scores.append(score)
    average_score_log.append(np.mean(running_scores))
    
    print("\rEpisode: {:.4f}\taverage: {:.4f}\tReward: {:.4f}".format(i, np.mean(running_scores), score), end="")
    
    if i % 100 == 0:
        print("\rEpisode: {:.4f}\taverage: {:.4f}\tReward: {:.4f}".format(i, np.mean(running_scores), score))
    
    if np.mean(running_scores) > win_condition:
        print("\rEnvironment Solved!")
        break

Episode: 0.0000	average: -1254.5170	Reward: -1254.5170
Episode: 48.0000	average: -1555.1978	Reward: -1795.3246

In [None]:
plt.plot(score_log)
plt.plot(average_score_log)