In [1]:
import numpy as np
import torch
import gymnasium as gym
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
from torch.utils import tensorboard
from tqdm import tqdm

In [2]:
def mish(input):
    return input * torch.tanh(F.softplus(input))

class Mish(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, input): return mish(input)

In [3]:
# helper function to convert numpy arrays to tensors
def t(x): return torch.from_numpy(x).float()

In [4]:
# Actor module, categorical actions only
class Actor(nn.Module):
    def __init__(self, state_dim, n_actions, activation=nn.Tanh):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 64),
            activation(),
            nn.Linear(64, 32),
            activation(),
            nn.Linear(32, n_actions),
            nn.Softmax()
        )
    
    def forward(self, X):
        return self.model(X)

In [5]:
# Critic module
class Critic(nn.Module):
    def __init__(self, state_dim, activation=nn.Tanh):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 64),
            activation(),
            nn.Linear(64, 32),
            activation(),
            nn.Linear(32, 1)
        )
    
    def forward(self, X):
        return self.model(X)

In [6]:
env = gym.make("CartPole-v1")

In [7]:
# config
state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
actor = Actor(state_dim, n_actions, activation=Mish)
critic = Critic(state_dim, activation=Mish)
adam_actor = torch.optim.Adam(actor.parameters(), lr=3e-4)
adam_critic = torch.optim.Adam(critic.parameters(), lr=1e-3)
np.random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True
env.action_space.seed(1)
env.observation_space.seed(1)

1

In [8]:
def clip_grad_norm_(module, max_grad_norm):
    nn.utils.clip_grad_norm_([p for g in module.param_groups for p in g["params"]], max_grad_norm)

def policy_loss(old_log_prob, log_prob, advantage, eps):
    ratio = (log_prob - old_log_prob).exp()
    clipped = torch.clamp(ratio, 1-eps, 1+eps)*advantage
    
    m = torch.min(ratio*advantage, clipped)
    return -m

In [9]:
episode_rewards = []
gamma = 0.98
eps = 0.2
log = True
if log:
    w = tensorboard.SummaryWriter('../runs/rollout_dones_fixed')
s = 0
max_grad_norm = 0.5
total_timesteps = 8000
episode_length = 2048
for i in tqdm(range(int(total_timesteps))):
    prev_prob_act = None
    done = False
    total_reward = 0
    state, _ = env.reset()

    observations, actions, logprobs, rewards, dones, values = torch.zeros((episode_length,)+env.observation_space.shape, dtype=torch.float32), torch.zeros((episode_length,)+env.action_space.shape, dtype=torch.float32), torch.zeros((episode_length,), dtype=torch.float32), torch.zeros((episode_length,), dtype=torch.float32), torch.zeros((episode_length,), dtype=torch.float32), torch.zeros((episode_length,), dtype=torch.float32)
    j = 0
    while not done and j < episode_length:
        # gathering rollout data
        s += 1
        with torch.no_grad():
            probs = actor(torch.from_numpy(state).float())
        value = critic(torch.from_numpy(state).float())
        dist = torch.distributions.Categorical(probs=probs)
        action = dist.sample()
        prob_act = dist.log_prob(action)
        
        observations[j] = torch.from_numpy(state).float()
        actions[j] = action
        logprobs[j] = prob_act
        next_state, reward, done, _, info = env.step(action.detach().numpy())
        rewards[j] = reward
        dones[j] = done
        values[j] = value
        state = next_state
        j += 1
    
    # if (episode_length - j)/episode_length < 0.05: episode_length *=2 
    # print(f'Rollout ended with done={done}, reward={rewards.sum().numpy()}')
    # with torch.no_grad():
    # advantage calculation
    advantages = torch.zeros((episode_length,), dtype=torch.float32)
    # advantage = reward + (1-done)*gamma*critic(t(next_state)) - critic(t(state))
    done_index = dones.nonzero().max().item() if dones.any() else episode_length
    for t in reversed(range(done_index)):
        if t == episode_length - 1:
            advantages[t] = rewards[t] + (1-dones[t])*gamma*values[t]
        else:
            advantages[t] = rewards[t] + (1-dones[t+1])*gamma*(values[t+1] - values[t])
    
    if log:
        w.add_scalar("loss/advantage", advantages.detach().numpy().mean(), global_step=i)
        # w.add_scalar("actions/action_0_prob", dist.probs[0], global_step=s)
        # w.add_scalar("actions/action_1_prob", dist.probs[1], global_step=s)
    
    total_reward = rewards.sum().numpy()
    if log:
        w.add_scalar("reward/episode_reward", total_reward, global_step=i)

    probs = actor(observations)
    dist = torch.distributions.Categorical(probs=probs)
    new_logprobs = dist.log_prob(actions)
    actor_loss = policy_loss(logprobs, new_logprobs, advantages.detach(), eps).mean()
    adam_actor.zero_grad()
    actor_loss.backward()
    if log:
        w.add_scalar("loss/actor_loss", actor_loss.detach(), global_step=i)
        w.add_histogram("gradients/actor",
                        torch.cat([p.grad.view(-1) for p in actor.parameters()]), global_step=i)
    adam_actor.step()
    state = next_state

    critic_loss = advantages.pow(2).mean()
    adam_critic.zero_grad()
    critic_loss.backward()
    if log:
        w.add_scalar("loss/critic_loss", critic_loss.detach(), global_step=j)
        w.add_histogram("gradients/critic",
                        torch.cat([p.grad.view(-1) for p in critic.parameters()]), global_step=i)
    adam_critic.step()

    episode_rewards.append(total_reward)

  return self._call_impl(*args, **kwargs)
 33%|███▎      | 2621/8000 [18:04<2:24:57,  1.62s/it]