Run the below cell to download the MuJoCo dependencies

In [None]:
!pip install -U gymnasium[mujoco] mujoco

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

# **IMPORTING LIBRARIES**

In [None]:
import torch
import random
import os
import torch.nn as nn
import numpy as np
import time
from collections import deque
import wandb

import gymnasium as gym
# from gymnasium.vector import SyncVectorEnv
from gymnasium.wrappers import RecordVideo

In [None]:
os.environ["MUJOCO_GL"] = "egl"

In [None]:
def create_environment(cfgs, eval = False):
  env = gym.make( id=cfgs.id , render_mode="rgb_array", max_episode_steps=cfg.max_steps)
  return env

# **WANDB RUN**

In [None]:
def wandb_runs(cfg):

  wandb.login(key = "")
  run = wandb.init(
    entity="ajheshbasnet-kpriet",
    project="ddpg",
    name = "DDPG",
    config=vars(cfg),
  )

  return run

# **CONFIGURATIONS**

In [None]:
from dataclasses import dataclass

@dataclass
class configuration:
  id = "Ant-v5"
  n_rollouts = 100_000
  max_steps = 1000
  eval_steps = 10_000
  global_steps = 0
  buffer_size = 800_000
  eval_loops = 3
  batch_size = 512
  wandb_log_steps = 50
  start_training = 50_000
  training_step = 2
  actor_freq = 2
  critic_lr = 2.5e-4
  actor_lr = 2.5e-4
  record_video = 500_000
  eval_max_steps = 800
  device = "cuda" if torch.cuda.is_available() else "cpu"

cfg = configuration()

In [None]:
envs = create_environment(cfg)

# **Actor and Critic Netowrk**

In [None]:
class Actor(nn.Module):

  def __init__(self, input_dim, action_dim):
    super().__init__()
    self.sequential = nn.Sequential(
        nn.Linear(input_dim, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, action_dim),
        nn.Tanh()
    )

  def forward(self, x):
    x = self.sequential(x)
    return x

In [None]:
class Critic(nn.Module):

  def __init__(self, input_dim):
    super().__init__()

    self.sequential = nn.Sequential(
        nn.Linear(input_dim, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 1)
    )

  def forward(self, state, action):
    x = torch.cat([state, action], dim = 1)
    x = self.sequential(x)
    return x

In [None]:
print(envs.observation_space,"\t", envs.action_space,)

In [None]:
actornet = Actor(envs.observation_space.shape[0], envs.action_space.shape[0]).to(cfg.device)  #type: ignore
criticnet1 = Critic(envs.observation_space.shape[0]+envs.action_space.shape[0]).to(cfg.device)  #type: ignore
criticnet2 = Critic(envs.observation_space.shape[0]+envs.action_space.shape[0]).to(cfg.device)  #type: ignore

TargetActor = Actor(envs.observation_space.shape[0], envs.action_space.shape[0]).to(cfg.device) #type: ignore
TargetCritic1 = Critic(envs.observation_space.shape[0]+envs.action_space.shape[0]).to(cfg.device) #type: ignore
TargetCritic2 = Critic(envs.observation_space.shape[0]+envs.action_space.shape[0]).to(cfg.device) #type: ignore

TargetActor.load_state_dict(actornet.state_dict())
TargetCritic1.load_state_dict(criticnet1.state_dict())
TargetCritic2.load_state_dict(criticnet2.state_dict())

In [None]:
print(f'''Parameters:
=================================================================
actor-network     : {sum(p.numel() for p in actornet.parameters())/1e3} k
critic-network(s) : {sum(p.numel() for p in criticnet1.parameters())/ 1e3} k + {sum(p.numel() for p in criticnet2.parameters())/ 1e3} k
=================================================================
      ''')

**Evaluation Loop**

In [None]:
def evaluation(actornet, record_video = False):

  eval_env = gym.make(id = cfg.id, render_mode = 'rgb_array' ,max_episode_steps=cfg.eval_max_steps)
  if record_video:
    video_dir = f"videos/{int(time.time())}"
    eval_env = RecordVideo(eval_env,  video_folder=video_dir, episode_trigger=lambda ep: True)

  net_reward = 0
  net_step = 0

  with torch.no_grad():

    for _ in range(cfg.eval_loops):

      done = False

      episodic_reward = 0
      episodic_step = 0
      state = eval_env.reset()[0]

      while not done:

        stateT = torch.tensor(state, dtype=torch.float32, device=cfg.device)
        action = np.array(actornet(stateT).cpu().numpy())
        nxt_state, reward, terminated, truncated, _ = eval_env.step(action)
        done = terminated or truncated
        state = nxt_state

        episodic_reward += float(reward)
        episodic_step += 1

      print(episodic_reward)
      net_reward += episodic_reward
      net_step  += episodic_step

  net_reward = net_reward / cfg.eval_loops
  net_step = net_step / cfg.eval_loops

  eval_env.close()

  return net_reward, net_step

In [None]:
evaluation(actornet, True)

**To sample the batches**

In [None]:
def get_batches(memory, batch_size):
    batches = random.sample(memory, batch_size)
    state, action, reward, next_state, done = zip(*batches)

    state = torch.stack(state).float().to(cfg.device)
    action = torch.stack(action).float().to(cfg.device)
    reward = torch.stack(reward).float().to(cfg.device)
    next_state = torch.stack(next_state).float().to(cfg.device)
    done = torch.stack(done).float().to(cfg.device)  # float for TD computation

    return state, action, reward.view(-1, 1), next_state, done.view(-1, 1)

# **REPLAY MEMORY**

In [None]:
replay_buffer = deque(maxlen = cfg.buffer_size)
action_sigma = 0.1
tau = 0.001
gamma = 0.99
noise_clip = 0.5
policy_noise = 0.2
global_step = cfg.global_steps

In [None]:
critic_optimizer1 = torch.optim.AdamW(criticnet1.parameters(), lr = cfg.critic_lr)
critic_optimizer2 = torch.optim.AdamW(criticnet2.parameters(), lr = cfg.critic_lr)
actor_optimizer = torch.optim.AdamW(actornet.parameters(), lr = cfg.actor_lr)

**W&B RUNS TO LOG THE METRICS**

In [None]:
runs = wandb_runs(cfg)

# **Heart & Core of the notebook: DDPG Algorithm's Training Loop**

In [None]:
from tqdm import tqdm

for _ in tqdm(range(cfg.n_rollouts)):

  states = envs.reset()[0]

  statesT= torch.tensor(states, dtype=torch.float32, device = cfg.device)

  training_rewards = 0

  done = False

  while not done:

    with torch.no_grad():
      action = actornet(statesT).view(-1)

    action_noise = torch.randn_like(action)

    action = (action + action_sigma * action_noise)

    action = torch.clamp(action, -1.0, 1.0).cpu().numpy()

    next_states, rewards, terminated, truncated, _ =  envs.step(action)

    done = terminated | truncated

    next_statesT = torch.tensor(next_states, dtype=torch.float32, device = cfg.device)

    actionT = torch.tensor(action, dtype=torch.float32, device = cfg.device)

    rewardsT = torch.tensor(rewards, dtype=torch.float32, device=cfg.device)

    training_rewards += float(rewards)

    doneT = torch.tensor(done, dtype=torch.bool, device = cfg.device)

    replay_buffer.append((statesT.detach(), actionT.detach(), rewardsT.detach(), next_statesT.detach(), doneT.detach()))

    statesT = next_statesT

    if cfg.global_steps% cfg.training_step == 0 and len(replay_buffer)>cfg.start_training:
        # Sample batch
        states_b, action_b, reward_b, next_states_b, dones_b = get_batches(replay_buffer, cfg.batch_size)

        # Target Q
        with torch.no_grad():

          next_action_ = TargetActor(next_states_b)
          noise_next_action = torch.clamp(torch.randn_like(next_action_) * policy_noise, -noise_clip, +noise_clip)
          next_action = next_action_ + noise_next_action
          next_action = torch.clamp(next_action, -1.0, 1.0)

          target_next_q1 = TargetCritic1(next_states_b, next_action)
          target_next_q2 = TargetCritic2(next_states_b, next_action)
          target_q = reward_b + gamma * torch.min(target_next_q1, target_next_q2) * (1 - dones_b.float())

        # Current critic Q
        current_q1 = criticnet1(states_b, action_b)
        current_q2 = criticnet2(states_b, action_b)

        # Critic loss
        critic_loss1 = torch.nn.functional.mse_loss(current_q1, target_q)
        critic_loss2 = torch.nn.functional.mse_loss(current_q2, target_q)

        # Optimize critic1
        critic_optimizer1.zero_grad()
        critic_loss1.backward()
        critic_grad_norm1 = torch.nn.utils.clip_grad_norm_(criticnet1.parameters(), max_norm=1.0)
        critic_optimizer1.step()

        # Optimize critic2
        critic_optimizer2.zero_grad()
        critic_loss2.backward()
        critic_grad_norm2 = torch.nn.utils.clip_grad_norm_(criticnet2.parameters(), max_norm=1.0)
        critic_optimizer2.step()

        # Actor loss (use current actor)
        # Clone states_b to create an independent computational graph for actor update
        states_b_actor = states_b.clone()
        actor_actions = actornet(states_b_actor) # Renamed to avoid shadowing action_b from get_batches

        min_q = torch.min(criticnet1(states_b_actor, actor_actions), criticnet2(states_b_actor, actor_actions))

        actor_loss = -min_q.mean()

        if cfg.global_steps%cfg.actor_freq==0:

          # Optimize actor
          actor_optimizer.zero_grad()
          actor_loss.backward()
          actor_grad_norm = torch.nn.utils.clip_grad_norm_(actornet.parameters(), max_norm=1.0)
          actor_optimizer.step()

          for target_param, param in zip(TargetActor.parameters(), actornet.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        advantages = (target_q - torch.min(current_q1, current_q2)).detach().mean()

        runs.log({ "actor-loss": actor_loss.item(), "critic-loss1": critic_loss1.item(), "critic-loss2": critic_loss2.item(),"advantages": advantages.item()})

        # Soft update targets
        for target_param, param in zip(TargetCritic1.parameters(), criticnet1.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        for target_param, param in zip(TargetCritic2.parameters(), criticnet2.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        if cfg.global_steps%cfg.eval_steps==0 and cfg.global_steps>0:

            rec = True if cfg.global_steps%cfg.record_video==0 else False
            eval_reward, eval_steps = evaluation(actornet, rec)
            runs.log(
                {
                    "eval-reward": eval_reward,
                }
            )

    runs.log({"training-reward": training_rewards, "global-steps": cfg.global_steps, "memory": len(replay_buffer)})

    cfg.global_steps += 1

envs.close()
wandb.finish()