# **TD3 Notebook (CleanRL)** #

Some changes in vairable names: 
- qf1_a_values -> q_est1
- qf1_loss -> q_loss1

---
## (A/B) Set UP ##

### (B) Import ###

In [None]:
import os
import random
import time
import datetime

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

### (B2) Args ###

In [None]:
class Args(dict):
        def __init__(self, *args, **kwargs):
            super(Args, self).__init__(*args, **kwargs)
            self.__dict__ = self

args = {
    "algo": "td3", 
    "env_id": "Hopper-v4", #"MountainCarContinuous-v0", #
    "seed": 1, 

    "total_timesteps": 800_000,
    "gamma": 0.99, 
    "learning_rate": 3e-4, # paper uses 1e-3, CleanRL uses 3e-4
    "buffer_size": int(1e6), # paper says buffer contains entire history, but code on Github has changed
    "batch_size": 256, 
    "tau": 0.005, # target smoothing coefficient 
    "policy_noise": 0.2, # noise in target policy smoothing
    "exploration_noise": 0.1, # multiplied to a normally distributed noise that is added to actor's action
    "learning_starts": int(25e3), 
    "policy_update_interval": 2, # ?really? 2 and 2 have so much of a difference
    "noise_clip": 0.5, 
    # my addition
    "num_envs": int(1), 

    "torch_deterministic": True, 
    "cuda": True, 

    "capture_video": True,
}

args = Args(args)

In [None]:
start_datetime = datetime.datetime.now().strftime("%m%d_%H%M")
run_name = f"{args.env_id}__{args.algo}__{args.seed}__{start_datetime}"

print(f"start_datetime = {start_datetime}")

### (B3) Hardware ###

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

print(f"device_name = {torch.cuda.get_device_name(device)}")

### (B4) Tensorboard ###

In [None]:
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
    "hyperparameters", 
    "|param|value|\n|-|-|\n%s" % "\n".join(f"|{key}|{val}" for key, val in args.items())
)

### (B5) Seeding ###

In [None]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False #

# result not reproducible! (bc stupid thing was using sb3 kernel?)

## (C) Implementation ##

### (C1) env and vec_env ###

In [None]:
def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        if idx == 0 and capture_video:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"runs/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        return env
    return thunk

In [None]:
envs = gym.vector.SyncVectorEnv(
    [make_env(args.env_id, args.seed, idx, args.capture_video, run_name) for idx in range(args.num_envs)]
)

### (C2) Agent ###

In [None]:
class Critic(nn.Module):
    def __init__(self, envs):
        super().__init__()

        hidden_size = 256
        self.fc1 = nn.Linear(
                np.prod(envs.single_observation_space.shape) + np.prod(envs.single_action_space.shape), 
                hidden_size
        )
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)

    def forward(self, s, a):
        q = torch.cat([s, a], 1)
        q = F.relu(self.fc1(q))
        q = F.relu(self.fc2(q))
        q = self.fc3(q)
        return q

class Actor(nn.Module):
    def __init__(self, envs):
        super().__init__()

        hidden_size = 256
        self.fc1 = nn.Linear(np.prod(envs.single_observation_space.shape), hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, np.prod(envs.single_action_space.shape))

        self.register_buffer(
            "action_scale", 
            torch.tensor((envs.action_space.high - envs.action_space.low) / 2.0, dtype=torch.float32)
        )
        self.register_buffer(
            "action_bias", 
            torch.tensor((envs.action_space.high + envs.action_space.low) / 2.0, dtype=torch.float32)
        )
    
    def forward(self, s):
        a = F.relu(self.fc1(s))
        a = F.relu(self.fc2(a))
        a = torch.tanh(self.fc3(a))
        return a * self.action_scale + self.action_bias

# I changed env in __init__ to envs; non-vector env does not have .single_observation_space

### (C3) Training ###

#### (C3a) Init ####

In [None]:
actor = Actor(envs).to(device)
actor_target = Actor(envs).to(device)
actor_target.load_state_dict(actor.state_dict())

critic1 = Critic(envs).to(device)
critic2 = Critic(envs).to(device)
critic1_target = Critic(envs).to(device)
critic2_target = Critic(envs).to(device)
critic1_target.load_state_dict(critic1.state_dict())
critic2_target.load_state_dict(critic2.state_dict())

optimizer_actor = optim.Adam(actor.parameters(), lr=args.learning_rate)
optimizer_critic = optim.Adam(list(critic1.parameters()) + list(critic2.parameters()), lr=args.learning_rate)

# Mujoco uses float32 action space and float 64 obs space
# but envs.observation_space.dtype still = np.float64
# ; however, if this line is before ReplayBuffer construction, rb.sample() yields float32 transitions (??)
envs.single_observation_space.dtype = np.float32 
# envs.observation_space.dtype = np.float32 # does not work, 
# e.g. envs.observation_space.sample() raises AttributeError: type object 'numpy.float32' has no attribute 'kind'

rb = ReplayBuffer(
    args.buffer_size, 
    envs.single_observation_space, 
    envs.single_action_space, 
    device, 
    handle_timeout_termination=False, # (I think) only usable with SB3's API
)


In [None]:
# learning sb3.ReplayBuffer
"""
obs, _ = envs.reset()
# step #1
actions = envs.action_space.sample()
obs1, rewards, terminations, truncations, infos = envs.step(actions)
rb.add(obs, obs1, actions, rewards, terminations, infos)
# step #2
actions = envs.action_space.sample()
obs1, rewards, terminations, truncations, infos = envs.step(actions)
rb.add(obs, obs1, actions, rewards, terminations, infos)

rb.sample(2) # uses np.random.randint, draw WITH replacement (2024/06)
rb.sample(2).observations # 2 tensors
rb.sample(2).observations[0].size() # torch.Size([11])
"""

#### (C3b) Training Loop ####

In [None]:
start_time = time.time()
global_episode = 0
global_step = 0
# can store reusing tensor here. clamp tensors for actions1

obs, _ = envs.reset()
while global_step < args.total_timesteps:
    # 1) pick action
    if global_step < args.learning_starts:
        actions = envs.action_space.sample()
    else:
        with torch.no_grad():
            actions = actor(torch.Tensor(obs).to(device))
            actions += torch.normal(0, actor.action_scale * args.exploration_noise)
            actions = actions.cpu().numpy().clip(envs.action_space.low, envs.action_space.high)

    # 2) step
    obs1, rewards, terminations, truncations, infos = envs.step(actions)

    # 3) store transition in replay buffer
    rb.add(obs, obs1, actions, rewards, terminations, infos)

    # 4) rollout to next step
    obs = obs1
    global_step += envs.num_envs
    if "final_info" in infos:
        written = False
        for info in infos["final_info"]:
            if info:
                global_episode += 1
                if not written:
                    writer.add_scalar("charts/episodic_reward", info["episode"]["r"], global_step)
                    writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                    written = True
        writer.add_scalar("charts/episode", global_episode, global_step)

    # 5) training
    # 5a) update critics
    if global_step > args.learning_starts:
        batch = rb.sample(args.batch_size)
        with torch.no_grad(): # update critics ignore actor / td_target as constant
            actions1 = actor_target(batch.next_observations)
            # CleanRL's noise: (N(0,1) * policy_noise).clamp(-c, c) <- using this (paper's code also use this)
            # paper's noise: N(0,σ).clamp(-c, c)
            actions1_noise = (torch.randn_like(batch.actions, device=device) * args.policy_noise)\
                    .clamp(-args.noise_clip, args.noise_clip) * actor_target.action_scale
            actions1 = (actions1 + actions1_noise).clamp(
                torch.tensor(envs.single_action_space.low).reshape((1,-1)).to(device), 
                torch.tensor(envs.single_action_space.high).reshape((1,-1)).to(device)
            ) # cleanRL uses aingle_action_space.low[0] ?why the sudden drop of generality?
            q1 = critic1_target(batch.next_observations, actions1)
            q2 = critic2_target(batch.next_observations, actions1)
            qmin = torch.min(q1, q2)
            td_target = (batch.rewards + (1. - batch.dones) * args.gamma * qmin).flatten() # my timeit shows flatten afterwards slightly faster
        
        q_est1 = critic1(batch.observations, batch.actions).flatten()
        q_est2 = critic2(batch.observations, batch.actions).flatten()
        q_loss1 = F.mse_loss(td_target, q_est1)
        q_loss2 = F.mse_loss(td_target, q_est2)
        q_loss = q_loss1 + q_loss2 # so critic1 and critic2 are only differed by initiation (?)
        # They are very close to one another on tensorboard (<10% I guess), but maybe that is sufficient to suppress the drift

        optimizer_critic.zero_grad()
        q_loss.backward()
        optimizer_critic.step()

        if global_step % args.policy_update_interval == 0:
            # 5b) update actor
            actor_loss = -critic1(batch.observations, actor(batch.observations)).mean()
            optimizer_actor.zero_grad()
            actor_loss.backward()
            optimizer_actor.step()

            # 5c) update targets
            for param, target_param in zip(actor.parameters(), actor_target.parameters()):
                target_param.data.copy_(
                    args.tau * param.data + (1-args.tau) * target_param.data
                )
            for param, target_param in zip(critic1.parameters(), critic1_target.parameters()):
                target_param.data.copy_(
                    args.tau * param.data + (1-args.tau) * target_param.data
                )
            for param, target_param in zip(critic2.parameters(), critic2_target.parameters()):
                target_param.data.copy_(
                    args.tau * param.data + (1-args.tau) * target_param.data
                )

        # 6) progress tracking
        if global_step % 100 == 0:
            writer.add_scalars("losses/q_est", {
                "q_est1": q_est1.mean().item(),
                "q_est2": q_est2.mean().item()
            }, global_step)
            writer.add_scalars("losses/q_loss", {
                "q_loss1": q_loss1.item(),
                "q_loss2": q_loss2.item()
            }, global_step)
            writer.add_scalar("losses/q_loss", q_loss.item(), global_step)
            writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)


In [None]:
envs.close()
writer.close()

In [None]:
os.mkdir(f"upload")

## (D) Evaluation ##

In [None]:
def evaluate_agent(env, n_eval_episodes, policy, hyperparameters=args):
    # (1) evaluate
    episode_rewards = []
    for episode in range(n_eval_episodes):
        state, _ = env.reset()
        total_rewards_ep = 0

        with torch.no_grad():
            while True:
                state = torch.Tensor(state).to(device)
                action = policy(state)
                new_state, reward, terminated, truncated, info = env.step(action.flatten().cpu().numpy())
                total_rewards_ep += reward
                if terminated or truncated:
                    break
                state = new_state
            episode_rewards.append(total_rewards_ep)
            print(f"episode {episode:2}: reward={total_rewards_ep:5.1f}")
    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)

    # (2) metadata
    eval_datetime = datetime.datetime.now()
    eval_form_datetime = eval_datetime.isoformat()
    evaluate_data = {
        "env_id": hyperparameters.env_id,
        "mean_reward": mean_reward,
        "std_reward": std_reward,
        "n_evaluation_episodes": n_eval_episodes,
        "eval_datetime": eval_form_datetime,
    }

    return mean_reward, std_reward, evaluate_data

eval_env = gym.make(args.env_id, render_mode="rgb_array")
eval_env = gym.wrappers.RecordVideo(eval_env, video_folder="upload", video_length=99999)

eval_mean_reward, eval_std_reward, evaluate_data = evaluate_agent(eval_env, 10, actor)
eval_env.close()
print(f"mean_reward={eval_mean_reward:.2f}")
print(f"std_reward={eval_std_reward:.2f}")


## (E) Save ##

In [None]:
os.mkdir(f"runs/{run_name}/models")

In [None]:
torch.save(actor.state_dict(), f"runs/{run_name}/models/actor.pt")
torch.save(critic1.state_dict(), f"runs/{run_name}/models/critic1.pt")
torch.save(critic2.state_dict(), f"runs/{run_name}/models/critic2.pt")
f""

In [None]:
"""
t1 = torch.tensor([[.1,.2,.3,.4]], dtype=torch.float32, device=device)
t2 = torch.tensor([[.4,.3,.2,.1]], dtype=torch.float32, device=device)
t3 = torch.tensor([[.5,.6,.7,.8]], dtype=torch.float32, device=device)

%timeit -n 100_000 -r 5 t1 + t2 + t3 * 2.0
# 31.9 µs ± 2.13 µs per loop
# 28.5 µs ± 945 ns per loop

%timeit -n 100_000 -r 5 t1.flatten() + t2.flatten() + t3.flatten() * 2.0
# 38.5 µs ± 969 ns per loop
# 39.2 µs ± 1.05 µs per loop

%timeit -n 100_000 -r 5 (t1 + t2 + t3 * 2.0).flatten()
# 36.1 µs ± 2.16 µs per loop
# 36.3 µs ± 1.59 µs per loop

%timeit -n 100_000 -r 5 (t1 + t2 + t3 * 2.0).view(-1)
# 35.9 µs ± 1.23 µs per loop
# 34.8 µs ± 1.6 µs per loop
"""
