In [1]:
import gym
from gym import wrappers
from gym.spaces.utils import flatdim

import torch
from torch import nn
from torch.functional import F
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import cv2
from tqdm import tqdm

from copy import deepcopy
from moviepy.editor import ImageSequenceClip
import collections

# Comment out for debugging
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm
ALSA lib confmisc.c:855:(parse_card) cannot find card '0'
ALSA lib conf.c:5178:(_snd_config_evaluate) function snd_func_card_inum returned error: No such file or directory
ALSA lib confmisc.c:422:(snd_func_concat) error evaluating strings
ALSA lib conf.c:5178:(_snd_config_evaluate) function snd_func_concat returned error: No such file or directory
ALSA lib confmisc.c:1334:(snd_func_refer) error evaluating name
ALSA lib conf.c:5178:(_snd_config_evaluate) function snd_func_refer returned error: No such file or directory
ALSA lib conf.c:5701:(snd_config_expand) Evaluate error: No such file or directory
ALSA lib pcm.c:2664:(snd_pcm_open_noupdate) Unknown PCM default
ALSA lib confmisc.c:855:(parse_card) cannot find card '0'
ALSA lib conf.c:5178:(_snd_config_evaluate) function snd_func_card_inum returned error: No such file or directory
ALSA lib confmisc.c:422:(snd_func_concat) error evaluating strings
ALSA lib conf.c:5178:(_snd_config_evalua

In [20]:
# Simple linear annealing schedule for random exploration
def linear_schedule(start_e, end_e, total_steps, step):
    slope = (end_e - start_e) / total_steps
    return max(slope * step + start_e, end_e)

# Simple Dense network for environments that have small, vector-based states
class ActorCriticNet(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.actor = nn.Sequential(
            nn.Linear(np.array(env.observation_space.shape).prod(), 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, env.action_space.n)
        )

        self.critic = nn.Sequential(
            nn.Linear(np.array(env.observation_space.shape).prod(), 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def value(self, x):
        return self.critic(x)
    
    def action_and_value(self, x, action=None):
        logits = self.actor(x)
        probs = torch.distributions.Categorical(logits=logits)
        if action == None:
            action = probs.sample()
        
        return action, probs.log_prob(action), probs.entropy(), self.critic(x)


# Hparams that seem to work reasonably well, can definitely be tuned for quicker optimization
DEFAULT_HPARAMS = {
    "training_steps": 500_000,
    "rollout_length": 128,
    "mini_batch_size": 4,
    "gamma": 0.9,
    "gae_lambda": 0.95,
    "clip_coef": 0.2,
    "learning_rate": 2.5e-4,
    "epochs": 4
}

def train_model(env, agent: ActorCriticNet, hparams, path="./results"):
    # Pseudocode from DQN Paper
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cpu"
    writer = SummaryWriter(path)

    # Hyperparams
    training_steps = hparams["training_steps"]
    rollout_length = hparams["rollout_length"]
    mini_batch_size = hparams["mini_batch_size"]
    gamma = hparams["gamma"]
    gae_lambda = hparams["gae_lambda"]
    clip_coef = hparams["clip_coef"]
    lr = hparams["learning_rate"]
    epochs = hparams["epochs"]

    batch_size = rollout_length
    training_steps = int(training_steps / batch_size)

    # Intialize optimizer
    optimizer = torch.optim.Adam(agent.parameters(), lr=lr)

    # Initialize Storage vectors
    obs = torch.zeros((rollout_length,  *env.observation_space.shape)).to(device)
    actions = torch.zeros((rollout_length, 1)).to(device)
    logprobs = torch.zeros((rollout_length, 1)).to(device)
    rewards = torch.zeros((rollout_length, 1)).to(device)
    dones = torch.zeros((rollout_length, 1)).to(device)
    values = torch.zeros((rollout_length, 1)).to(device)



    obs_n = torch.Tensor(env.reset()).to(device)
    done = torch.zeros(1).to(device)
    ep_return = 0
    ep_length = 0
    global_step = 0
    for step in tqdm(range(training_steps)):
        epsilon = linear_schedule(1.0, 0.05, 0.1 * training_steps ,step)

        # Env rollouts
        #########################################################
        for env_step in range(rollout_length):
            global_step += 1
            obs[env_step] = obs_n
            dones[env_step] = done

            # Select action 
            with torch.no_grad():
                action, logprob, _, value = agent.action_and_value(obs_n)

                # Log this here so we dont track grad
                values[env_step] = value.flatten()
            
            # print("Action taken: ", action)
            actions[env_step] = action
            logprobs[env_step] = logprob

            obs_n, r, done, info = env.step(action.cpu().numpy())
            obs_n = torch.Tensor(obs_n).to(device)
            done = torch.Tensor([done]).to(device)

            rewards[env_step] = torch.tensor(r).to(device).view(-1)

            # Record statistics seems to be broken for some envs (TODO: find out why borked)
            ep_return += r
            ep_length += 1
            if "episode" in info.keys():
                writer.add_scalar("charts/episodic_return", ep_return, global_step)
                writer.add_scalar("charts/episodic_length", ep_length, global_step)

            if done:
                obs_n = torch.Tensor(env.reset()).to(device)
                ep_return = 0
                ep_length = 0

        #######################################################

        # Compute advantages/returns
        with torch.no_grad():
            next_value = agent.value(obs[-1]).reshape(1, -1)

            # Essentially Q - V, + good action/ - bad action
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0

            for t in reversed(range(rollout_length)):
                if t == rollout_length - 1:
                    nextnonterminal = 1.0 - done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]

                bellman_error = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                lastgaelam = bellman_error + gamma * gae_lambda * nextnonterminal * lastgaelam
                advantages[t] = lastgaelam
            returns = advantages + values


        # Flatten rollout batch
        batch_obs = obs.reshape((-1,) + env.observation_space.shape)
        batch_logprobs = logprobs.reshape(-1)
        batch_actions = actions.reshape((-1,) + env.action_space.shape)
        batch_advantages = advantages.reshape(-1)
        batch_returns = returns.reshape(-1)
        batch_values = values.reshape(-1)

        # print("Batch: ")
        # print(batch_obs.shape)
        # print(batch_logprobs.shape)
        # print(batch_actions.shape)
        # print(batch_advantages.shape)
        # print(batch_returns.shape)
        # print(batch_values.shape)

        # Optimize Agent
        #######################################################
        batch_inds = np.arange(batch_size)
        clipfracs = []
        for epoch in range(epochs):
            np.random.shuffle(batch_inds)
            for ind_start in range(0, batch_size, mini_batch_size):
                ind_end = ind_start + mini_batch_size
                batch = batch_inds[ind_start:ind_end]

                # Compute approx KL divergence of batch, use for clip factor
                _, new_logprobs, entropy, new_values = agent.action_and_value(batch_obs[batch], batch_actions.long()[batch])
                logratio = new_logprobs - batch_logprobs[batch] # Difference in policy probs
                ratio = logratio.exp()


                # https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
                with torch.no_grad():
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > clip_coef).float().mean().item()]


                # Policy loss
                policy_loss_a = - batch_advantages[batch] * ratio
                policy_loss_b = - batch_advantages[batch] * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
                policy_loss = torch.max(policy_loss_a, policy_loss_b).mean()

                # Value Loss
                new_values = new_values.view(-1)
                value_loss = 0.5 * F.mse_loss(new_values, batch_returns[batch])
                entropy_loss = entropy.mean()

                # Weights from link above
                loss = policy_loss - 0.01 * entropy_loss + 0.5 * value_loss
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
                optimizer.step()
                
        y_pred, y_true = batch_values.cpu().numpy(), batch_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        writer.add_scalar("losses/value_loss", value_loss.item(), global_step)
        writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        writer.add_scalar("losses/explained_variance", explained_var, global_step)

        #######################################################

    env.close()
    writer.close()
    torch.save(agent, path + f"/q_{training_steps}.pt")
    return agent
    
def generate_video(net, env, path="test"):
    env = gym.wrappers.RecordVideo(env, path)
    obs= env.reset()
    done = False

    while not done:
        q_vals = net(torch.Tensor(obs).unsqueeze(0))
        action =torch.argmax(q_vals).cpu().numpy()
        obs, _, done, info = env.step(action)
    
    env.close()
    return


In [21]:
# Test on acrobot environment: https://gymnasium.farama.org/environments/classic_control/acrobot/
env = gym.make("Acrobot-v1")
env = gym.wrappers.RecordEpisodeStatistics(env)

hparams = DEFAULT_HPARAMS
agent = train_model(env, ActorCriticNet(env), hparams, path="./acrobot")

# env = gym.make("Acrobot-v1")
# generate_video(q, env, "./acrobot/eval")
(hparams["training_steps"], 1) + env.observation_space.shape

 29%|██▉       | 1148/3906 [10:20<24:51,  1.85it/s]


KeyboardInterrupt: 

In [None]:
# Test on acrobot environment: https://gymnasium.farama.org/environments/classic_control/acrobot/
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordEpisodeStatistics(env)

hparams = DEFAULT_HPARAMS
agent = train_model(env, ActorCriticNet(env), hparams, path="./cartpole")

# env = gym.make("Acrobot-v1")
# generate_video(q, env, "./acrobot/eval")
(hparams["training_steps"], 1) + env.observation_space.shape

In [None]:
env = make_env("ALE/Pong-v5")
env = gym.wrappers.RecordEpisodeStatistics(env)
q = ConvQNet(env)

params = DEFAULT_HPARAMS
params["buffer_size"] = 50_000
params["training_steps"] = 1_000_000

q = train_atari_model(env, q, params)

generate_video(q, env, "./pong/eval")