In [1]:
import numpy as np
import torch
import gymnasium as gym
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
from torch.utils import tensorboard
from tqdm import tqdm

In [2]:
def mish(input):
    return input * torch.tanh(F.softplus(input))

class Mish(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, input): return mish(input)

In [3]:
# helper function to convert numpy arrays to tensors
def t(x): return torch.from_numpy(x).float()

In [4]:
# Actor module, categorical actions only
class Actor(nn.Module):
    def __init__(self, state_dim, n_actions, activation=nn.Tanh):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 64),
            activation(),
            nn.Linear(64, 32),
            activation(),
            nn.Linear(32, n_actions),
            nn.Softmax()
        )
    
    def forward(self, X):
        return self.model(X)

In [5]:
# Critic module
class Critic(nn.Module):
    def __init__(self, state_dim, activation=nn.Tanh):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 64),
            activation(),
            nn.Linear(64, 32),
            activation(),
            nn.Linear(32, 1)
        )
    
    def forward(self, X):
        return self.model(X)

In [6]:
env = gym.make("CartPole-v1")

In [7]:
# config
state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
actor = Actor(state_dim, n_actions, activation=Mish)
critic = Critic(state_dim, activation=Mish)
adam_actor = torch.optim.Adam(actor.parameters(), lr=3e-4)
adam_critic = torch.optim.Adam(critic.parameters(), lr=1e-3)

torch.manual_seed(1)

<torch._C.Generator at 0x750a1a590610>

In [8]:
def clip_grad_norm_(module, max_grad_norm):
    nn.utils.clip_grad_norm_([p for g in module.param_groups for p in g["params"]], max_grad_norm)

def policy_loss(old_log_prob, log_prob, advantage, eps):
    ratio = (log_prob - old_log_prob).exp()
    clipped = torch.clamp(ratio, 1-eps, 1+eps)*advantage
    
    m = torch.min(ratio*advantage, clipped)
    return -m

In [9]:
episode_rewards = []
gamma = 0.98
eps = 0.2
log = False
if log:
    w = tensorboard.SummaryWriter('../runs/x')
s = 0
max_grad_norm = 0.5

for i in tqdm(range(800)):
    prev_prob_act = None
    done = False
    total_reward = 0
    state, _ = env.reset()


    while not done:
        s += 1
        probs = actor(t(state))
        dist = torch.distributions.Categorical(probs=probs)
        action = dist.sample()
        prob_act = dist.log_prob(action)
        
        next_state, reward, done, _, info = env.step(action.detach().numpy())
        advantage = reward + (1-done)*gamma*critic(t(next_state)) - critic(t(state))
        
        if log:
            w.add_scalar("loss/advantage", advantage, global_step=s)
            w.add_scalar("actions/action_0_prob", dist.probs[0], global_step=s)
            w.add_scalar("actions/action_1_prob", dist.probs[1], global_step=s)
        
        total_reward += reward
        state = next_state
        
        if prev_prob_act:
            actor_loss = policy_loss(prev_prob_act.detach(), prob_act, advantage.detach(), eps)
            if log:
                w.add_scalar("loss/actor_loss", actor_loss, global_step=s)
            adam_actor.zero_grad()
            actor_loss.backward()
            # clip_grad_norm_(adam_actor, max_grad_norm)
            if log:
                w.add_histogram("gradients/actor",
                                torch.cat([p.grad.view(-1) for p in actor.parameters()]), global_step=s)
            adam_actor.step()

            critic_loss = advantage.pow(2).mean()
            if log:
                w.add_scalar("loss/critic_loss", critic_loss, global_step=s)
            adam_critic.zero_grad()
            critic_loss.backward()
            # clip_grad_norm_(adam_critic, max_grad_norm)
            if log:
                w.add_histogram("gradients/critic",
                             torch.cat([p.data.view(-1) for p in critic.parameters()]), global_step=s)
            adam_critic.step()
        
        prev_prob_act = prob_act
    
    if log:
        w.add_scalar("reward/episode_reward", total_reward, global_step=i)
    episode_rewards.append(total_reward)

  return self._call_impl(*args, **kwargs)
  0%|          | 1/800 [00:00<03:06,  4.28it/s]

tensor(1.0271, grad_fn=<MeanBackward0>)
tensor(0.9866, grad_fn=<MeanBackward0>)
tensor(1.0155, grad_fn=<MeanBackward0>)
tensor(0.9889, grad_fn=<MeanBackward0>)
tensor(0.9775, grad_fn=<MeanBackward0>)
tensor(1.0272, grad_fn=<MeanBackward0>)
tensor(1.0151, grad_fn=<MeanBackward0>)
tensor(1.0017, grad_fn=<MeanBackward0>)
tensor(0.9879, grad_fn=<MeanBackward0>)
tensor(1.0148, grad_fn=<MeanBackward0>)
tensor(0.9789, grad_fn=<MeanBackward0>)
tensor(0.9660, grad_fn=<MeanBackward0>)
tensor(0.9528, grad_fn=<MeanBackward0>)
tensor(0.9392, grad_fn=<MeanBackward0>)
tensor(1.2088, grad_fn=<MeanBackward0>)
tensor(0.9803, grad_fn=<MeanBackward0>)
tensor(1.0132, grad_fn=<MeanBackward0>)
tensor(0.9826, grad_fn=<MeanBackward0>)
tensor(1.0093, grad_fn=<MeanBackward0>)
tensor(0.9844, grad_fn=<MeanBackward0>)
tensor(1.0064, grad_fn=<MeanBackward0>)
tensor(0.9927, grad_fn=<MeanBackward0>)
tensor(0.9788, grad_fn=<MeanBackward0>)
tensor(0.9652, grad_fn=<MeanBackward0>)
tensor(0.9520, grad_fn=<MeanBackward0>)


  0%|          | 2/800 [00:00<03:15,  4.08it/s]

tensor(1.1047, grad_fn=<MeanBackward0>)
tensor(1.0975, grad_fn=<MeanBackward0>)
tensor(1.0865, grad_fn=<MeanBackward0>)
tensor(1.0729, grad_fn=<MeanBackward0>)
tensor(0.9158, grad_fn=<MeanBackward0>)
tensor(1.0731, grad_fn=<MeanBackward0>)
tensor(1.0586, grad_fn=<MeanBackward0>)
tensor(1.0433, grad_fn=<MeanBackward0>)
tensor(0.9377, grad_fn=<MeanBackward0>)
tensor(0.9237, grad_fn=<MeanBackward0>)
tensor(0.8528, grad_fn=<MeanBackward0>)
tensor(1.0405, grad_fn=<MeanBackward0>)
tensor(1.0221, grad_fn=<MeanBackward0>)
tensor(0.9662, grad_fn=<MeanBackward0>)
tensor(0.9548, grad_fn=<MeanBackward0>)
tensor(1.0308, grad_fn=<MeanBackward0>)
tensor(0.9570, grad_fn=<MeanBackward0>)


  0%|          | 3/800 [00:00<03:09,  4.21it/s]

tensor(1.0277, grad_fn=<MeanBackward0>)
tensor(0.9586, grad_fn=<MeanBackward0>)
tensor(1.0252, grad_fn=<MeanBackward0>)
tensor(1.0093, grad_fn=<MeanBackward0>)
tensor(0.9743, grad_fn=<MeanBackward0>)
tensor(1.0066, grad_fn=<MeanBackward0>)
tensor(0.9764, grad_fn=<MeanBackward0>)
tensor(0.9641, grad_fn=<MeanBackward0>)
tensor(1.0186, grad_fn=<MeanBackward0>)
tensor(0.9640, grad_fn=<MeanBackward0>)
tensor(1.0180, grad_fn=<MeanBackward0>)
tensor(0.9636, grad_fn=<MeanBackward0>)
tensor(0.9513, grad_fn=<MeanBackward0>)
tensor(1.0324, grad_fn=<MeanBackward0>)
tensor(1.0189, grad_fn=<MeanBackward0>)
tensor(0.9606, grad_fn=<MeanBackward0>)
tensor(1.0188, grad_fn=<MeanBackward0>)
tensor(0.9596, grad_fn=<MeanBackward0>)
tensor(0.9466, grad_fn=<MeanBackward0>)
tensor(1.0337, grad_fn=<MeanBackward0>)
tensor(1.0210, grad_fn=<MeanBackward0>)
tensor(0.9543, grad_fn=<MeanBackward0>)
tensor(1.0217, grad_fn=<MeanBackward0>)
tensor(0.9519, grad_fn=<MeanBackward0>)
tensor(0.9382, grad_fn=<MeanBackward0>)


  1%|          | 6/800 [00:00<01:31,  8.64it/s]

tensor(0.9655, grad_fn=<MeanBackward0>)
tensor(0.9520, grad_fn=<MeanBackward0>)
tensor(0.9390, grad_fn=<MeanBackward0>)
tensor(0.9265, grad_fn=<MeanBackward0>)
tensor(0.9142, grad_fn=<MeanBackward0>)
tensor(1.0650, grad_fn=<MeanBackward0>)
tensor(1.0565, grad_fn=<MeanBackward0>)
tensor(1.0452, grad_fn=<MeanBackward0>)
tensor(1.0318, grad_fn=<MeanBackward0>)
tensor(0.9221, grad_fn=<MeanBackward0>)
tensor(0.3350, grad_fn=<MeanBackward0>)
tensor(0.9514, grad_fn=<MeanBackward0>)
tensor(1.0123, grad_fn=<MeanBackward0>)
tensor(0.9922, grad_fn=<MeanBackward0>)
tensor(0.9720, grad_fn=<MeanBackward0>)
tensor(0.9897, grad_fn=<MeanBackward0>)
tensor(0.9652, grad_fn=<MeanBackward0>)
tensor(0.9948, grad_fn=<MeanBackward0>)
tensor(0.9810, grad_fn=<MeanBackward0>)
tensor(0.9671, grad_fn=<MeanBackward0>)
tensor(0.9533, grad_fn=<MeanBackward0>)
tensor(1.0065, grad_fn=<MeanBackward0>)
tensor(0.9901, grad_fn=<MeanBackward0>)
tensor(0.9724, grad_fn=<MeanBackward0>)
tensor(0.9544, grad_fn=<MeanBackward0>)


  1%|▏         | 11/800 [00:01<00:55, 14.17it/s]

tensor(0.9051, grad_fn=<MeanBackward0>)
tensor(1.0288, grad_fn=<MeanBackward0>)
tensor(0.9029, grad_fn=<MeanBackward0>)
tensor(0.8848, grad_fn=<MeanBackward0>)
tensor(1.0534, grad_fn=<MeanBackward0>)
tensor(1.0329, grad_fn=<MeanBackward0>)
tensor(0.8963, grad_fn=<MeanBackward0>)
tensor(0.8780, grad_fn=<MeanBackward0>)
tensor(1.0560, grad_fn=<MeanBackward0>)
tensor(0.8728, grad_fn=<MeanBackward0>)
tensor(0.8555, grad_fn=<MeanBackward0>)
tensor(0.8393, grad_fn=<MeanBackward0>)
tensor(1.1008, grad_fn=<MeanBackward0>)
tensor(0.0536, grad_fn=<MeanBackward0>)
tensor(0.9597, grad_fn=<MeanBackward0>)
tensor(0.9368, grad_fn=<MeanBackward0>)
tensor(0.9814, grad_fn=<MeanBackward0>)
tensor(0.9295, grad_fn=<MeanBackward0>)
tensor(0.9847, grad_fn=<MeanBackward0>)
tensor(0.9218, grad_fn=<MeanBackward0>)
tensor(0.8984, grad_fn=<MeanBackward0>)
tensor(1.0087, grad_fn=<MeanBackward0>)
tensor(0.8853, grad_fn=<MeanBackward0>)
tensor(0.8620, grad_fn=<MeanBackward0>)
tensor(0.0023, grad_fn=<MeanBackward0>)


  2%|▏         | 14/800 [00:01<00:49, 15.99it/s]

tensor(0.9849, grad_fn=<MeanBackward0>)
tensor(0.8972, grad_fn=<MeanBackward0>)
tensor(0.9832, grad_fn=<MeanBackward0>)
tensor(0.9369, grad_fn=<MeanBackward0>)
tensor(0.9486, grad_fn=<MeanBackward0>)
tensor(0.9303, grad_fn=<MeanBackward0>)
tensor(0.8851, grad_fn=<MeanBackward0>)
tensor(1.0073, grad_fn=<MeanBackward0>)
tensor(0.8732, grad_fn=<MeanBackward0>)
tensor(0.8302, grad_fn=<MeanBackward0>)
tensor(1.0691, grad_fn=<MeanBackward0>)
tensor(1.0332, grad_fn=<MeanBackward0>)
tensor(0.9940, grad_fn=<MeanBackward0>)
tensor(0.8818, grad_fn=<MeanBackward0>)
tensor(0.8346, grad_fn=<MeanBackward0>)
tensor(1.0478, grad_fn=<MeanBackward0>)
tensor(0.4828, grad_fn=<MeanBackward0>)
tensor(0.9376, grad_fn=<MeanBackward0>)
tensor(0.8937, grad_fn=<MeanBackward0>)
tensor(0.9833, grad_fn=<MeanBackward0>)
tensor(0.8980, grad_fn=<MeanBackward0>)
tensor(0.9803, grad_fn=<MeanBackward0>)
tensor(0.9022, grad_fn=<MeanBackward0>)
tensor(0.9770, grad_fn=<MeanBackward0>)
tensor(0.9061, grad_fn=<MeanBackward0>)
