In [1]:
import gymnasium as gym
from gymnasium import Env
import torch
from torch import nn
from torch import optim
import random
import numpy as np
import time
from typing import Type

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EnvGenerator:
    def __init__(self, env_id="Pendulum-v1", render_mode=None, seed=None):
        """Store environment parameters."""
        self.env_id = env_id
        self.render_mode = render_mode
        self.seed = seed

    def __call__(self, render_mode=None):
        """Create and return a new environment instance when called."""
        if render_mode is not None:
            env = gym.make(self.env_id, render_mode=render_mode)
        else:
            env = gym.make(self.env_id, render_mode=self.render_mode)
        if self.seed is not None:
            env.reset(seed=self.seed)
        return env

class Actor(nn.Module): # Policy network (deterministic mapping s to a)
    def __init__(self):
        super().__init__()
        self.obs_in = nn.Linear(3, 200) # (x,y, ang_vel)
        self.hidden = nn.Linear(200, 200)
        self.a_out = nn.Linear(200, 1) # 1 action (torque)
        self.model = nn.Sequential(
            self.obs_in,
            nn.ReLU(),
            self.hidden,
            nn.ReLU(),
            self.a_out,
            nn.Tanh()
        )

    def forward(self, obs):
        return 2 * self.model(obs)

class Critic(nn.Module): # Value network (Q-function approximator)
    def __init__(self):
        super().__init__()
        self.obs_a_in = nn.Linear(3+1, 200) # (x,y, ang_vel, torque)
        self.hidden = nn.Linear(200, 200)
        self.v_out = nn.Linear(200, 1) # 1 value
        self.model = nn.Sequential(
            self.obs_a_in,
            nn.ReLU(),
            self.hidden,
            nn.ReLU(),
            self.v_out,
            #nn.Tanh()  # Restrict output between [-1, 1]
        )

    def forward(self, obs, a):
        return self.model(torch.cat([obs, a], dim=1))

class ReplayBuffer():
    def __init__(self, size):
        self.size = size
        self.buffer = []
        self.pointer = 0

    def store(self, transition):
        if len(self.buffer) < self.size:
            self.buffer.append(transition)
        else:
            self.buffer[self.pointer] = transition
        self.pointer = (self.pointer + 1) % self.size

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

class DDPG():
    def __init__(self, env_gen_cls: Type, hyperparam):
        self.env_gen = env_gen_cls()
        self.env = self.env_gen()
        self.polyak = hyperparam['polyak']
        self.n_epoch = hyperparam['n_epoch']
        self.steps_per_epoch = hyperparam['steps_per_epoch']
        self.replay_buffer_size = hyperparam['replay_buffer_size']
        self.batch_size = hyperparam['batch_size']
        self.gamma = hyperparam['gamma'] # reward discount factor
        self.actor_lr = hyperparam['actor_lr']   # policy network
        self.critic_lr = hyperparam['critic_lr'] # value network
        self.update_after = hyperparam['update_after']
        self.update_every = hyperparam['update_every']
        # Uncorrelated mean zero Gaussian for exploration
        self.add_noise = hyperparam['add_noise']
        self.noise_decay = hyperparam['noise_decay'] # 1.0 for no noise decay, i.e. constant noise scale
        # start steps for exploration in the beginning of training
        self.start_steps = hyperparam['start_steps']

        # Actor and critic networks
        self.actor = Actor()
        self.target_actor = Actor()
        self.critic = Critic()
        self.target_critic = Critic()
        # Copy parameters in case parameters are not initialized the same
        self.target_actor.load_state_dict(self.actor.state_dict())
        self.target_critic.load_state_dict(self.critic.state_dict())
        # Freeze target networks with respect to optimizers
        for p in self.target_actor.parameters():
            p.requires_grad = False
        for p in self.target_critic.parameters():
            p.requires_grad = False

        self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=self.critic_lr)

        # Replay buffer
        self.replay_buffer = ReplayBuffer(self.replay_buffer_size)

    def train(self, save_model=False):
        # Get range of action space
        noise_std_dev = 0.1 # (self.env.action_space.high - self.env.action_space.low) / 6
        obs_, _ = self.env.reset()
        max_step_count = self.n_epoch * self.steps_per_epoch
        for step_count in range(max_step_count): # TBD: epoch?
            # 1. Get action
            if step_count < self.start_steps: # Encourage exploration in the beginning
                act = self.env.action_space.sample()
            else:
                with torch.no_grad():
                    if self.add_noise:
                        act = self.actor(torch.as_tensor(obs_, dtype=torch.float32)) + torch.normal(0, noise_std_dev, size=(1,))
                        act = act.cpu().detach().numpy()
                        act = np.clip(act, self.env.action_space.low, self.env.action_space.high)
                        if step_count % 5000 == 0:  # Decay every 5000 steps
                            noise_std_dev *= self.noise_decay
                    else:
                        act = self.actor(torch.as_tensor(obs_, dtype=torch.float32))
                        act = act.cpu().detach().numpy()
                        act = np.clip(act, self.env.action_space.low, self.env.action_space.high)

            # 2. Execute action
            next_obs_, rew, terminated, truncated, _ = self.env.step(act)
            done = terminated or truncated
            # 3. Store (s, a, r, s', done) in replay buffer
            self.replay_buffer.store((obs_, act, rew, next_obs_, 1 if done else 0))

            if done:
                obs_, _ = self.env.reset()
            else:
                obs_ = next_obs_

            # 4. Determine if it's time to update
            #if len(self.replay_buffer.buffer) > self.batch_size:
            if step_count >= self.update_after and step_count % self.update_every == 0:
                # Perform n updates
                for _ in range(self.update_every): # TBD: n updates
                    # 1. Sample a batch of data from replay buffer
                    batch = self.replay_buffer.sample(self.batch_size)
                    c_loss, a_loss = self.update(batch)

            if (step_count+1) % self.steps_per_epoch == 0:
                echo = (step_count+1) // self.steps_per_epoch
                # TODO: Test for 5 times without render and collect the average reward
                test_avg_rew = self.test()
                print(f"====== Epoch {echo}/{self.n_epoch}, avg reward = {test_avg_rew} ======")

        # Save model
        print("Training done.")
        if save_model:
            torch.save(self.actor.state_dict(), 'param/ddpg_actor_params2.pkl')
            torch.save(self.critic.state_dict(), 'param/ddpg_critic_params2.pkl')

    def update(self, batch_data):
        s = torch.tensor(np.array([t[0] for t in batch_data]), dtype=torch.float).to(device)
        a = torch.tensor(np.array([t[1] for t in batch_data]), dtype=torch.float).view(-1, 1).to(device)
        r = torch.tensor(np.array([t[2] for t in batch_data]), dtype=torch.float).view(-1, 1).to(device)
        s_prime = torch.tensor(np.array([t[3] for t in batch_data]), dtype=torch.float).to(device)
        done = torch.tensor(np.array([t[4] for t in batch_data]), dtype=torch.float).view(-1, 1).to(device)

        # 2. Compute Q-value from target critic network with action from target policy network
        self.optimizer_critic.zero_grad()
        q_eval = self.critic(s, a)
        with torch.no_grad(): # Not updating target networks with gradients
            q_target = r + self.gamma * (1-done) * self.target_critic(s_prime, self.target_actor(s_prime))

        # 3. Update critic
        msbe_loss = nn.MSELoss()
        c_loss = msbe_loss(q_eval, q_target)
        c_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0)
        self.optimizer_critic.step()

        # Freeze critic network for policy update
        for p in self.critic.parameters():
            p.requires_grad = False

        # 8. Update actor (gradient ascent)
        self.optimizer_actor.zero_grad()
        a_loss = -self.critic(s, self.actor(s)).mean()
        a_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=1.0)
        self.optimizer_actor.step()

        # Unfreeze critic network
        for p in self.critic.parameters():
            p.requires_grad = True

        # 9. Update target networks with polyak averaging
        with torch.no_grad():
            for target_param, eval_param in zip(self.target_actor.parameters(), self.actor.parameters()):
                target_param.data.copy_(self.polyak * target_param.data + (1 - self.polyak) * eval_param.data)
            for target_param, eval_param in zip(self.target_critic.parameters(), self.critic.parameters()):
                target_param.data.copy_(self.polyak * target_param.data + (1 - self.polyak) * eval_param.data)

        # Print loss
        #print(f"critic loss = {c_loss.item()}, actor loss = {a_loss.item()}")
        return c_loss.item(), a_loss.item()

    def load_model(self, actor_path, critic_path):
        self.actor.load_state_dict(torch.load(actor_path))
        self.critic.load_state_dict(torch.load(critic_path))

    def test(self, verbose=False, render=False):
        test_count = 5
        total_reward = 0
        test_env = self.env_gen(render_mode='human' if render else None)
        for i in range(test_count):
            obs_, _ = test_env.reset()
            done = False
            while not done:
                with torch.no_grad():
                    act = self.actor(torch.as_tensor(obs_, dtype=torch.float32))
                    act = act.cpu().detach().numpy()
                    act = np.clip(act, test_env.action_space.low, test_env.action_space.high)
                obs_, rew, terminated, truncated, _ = test_env.step(act)
                done = terminated or truncated
                total_reward += rew
                if render:
                    test_env.render()
                if verbose:
                    print(f"obs = {obs_}, act = {act}, rew = {rew}")
                    time.sleep(0.01)

        # Average reward
        print(f"Average reward = {total_reward / test_count}")
        return total_reward / test_count

    def print_model(self):
        print(self.actor)
        print(self.critic)



---

## Initialize the DDPG agent with hyperparameters

In [2]:
agent = DDPG(EnvGenerator, hyperparam={
    'n_epoch': 40,
    'steps_per_epoch': 2000,
    'update_after': 1000,
    'update_every': 50,
    'polyak': 0.995,
    'replay_buffer_size': 10000,
    'batch_size': 32,
    'gamma': 0.99,
    'actor_lr': 0.0001,
    'critic_lr': 0.0001,
    'add_noise': True,
    'noise_decay': 1.0,
    'start_steps': 1000
})


In [3]:
agent.print_model()

Actor(
  (obs_in): Linear(in_features=3, out_features=200, bias=True)
  (hidden): Linear(in_features=200, out_features=200, bias=True)
  (a_out): Linear(in_features=200, out_features=1, bias=True)
  (model): Sequential(
    (0): Linear(in_features=3, out_features=200, bias=True)
    (1): ReLU()
    (2): Linear(in_features=200, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=1, bias=True)
    (5): Tanh()
  )
)
Critic(
  (obs_a_in): Linear(in_features=4, out_features=200, bias=True)
  (hidden): Linear(in_features=200, out_features=200, bias=True)
  (v_out): Linear(in_features=200, out_features=1, bias=True)
  (model): Sequential(
    (0): Linear(in_features=4, out_features=200, bias=True)
    (1): ReLU()
    (2): Linear(in_features=200, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=1, bias=True)
  )
)




---


## Train the agent

In [None]:
agent.train(save_model=True)




---


## Test the agent

In [None]:
    agent.test(verbose=True, render=True)
