<a href="https://colab.research.google.com/github/Nicolasalan/td3/blob/main/TD3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Setup**

In [45]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


Add script in browser console: `inspect` => `console` => add script.

```javascript
function ConnectButton(){
    console.log("Conectado");
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click()
}
setInterval(ConnectButton,60000);
```

## **Install**

In [46]:
!sudo apt-get install swig

!pip install gymnasium
!pip install gymnasium[box2d]

# !pip install torch
# !pip install matplotlib
# !pip install numpy
!pip install wandb

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
swig is already the newest version (4.0.2-1ubuntu1).
0 upgraded, 0 newly installed, 0 to remove and 45 not upgraded.


In [47]:
!pip install lightning



## **Import**

In [48]:
try:
    import torch
    assert int(torch.__version__.split(".")[1]) >= 12 or int(torch.__version__.split(".")[0]) == 2, "torch version should be 1.12+"
    print(f"torch version: {torch.__version__}")
except:
    print(f"[INFO] torch versions not as required, installing nightly versions.")

torch version: 2.3.0+cu121


In [49]:
# Make sure we're using a NVIDIA GPU
if torch.cuda.is_available():
  gpu_info = !nvidia-smi
  gpu_info = '\n'.join(gpu_info)
  if gpu_info.find("failed") >= 0:
    print("Not connected to a GPU, to leverage the best of PyTorch 2.0, you should connect to a GPU.")

  # Get GPU name
  gpu_name = !nvidia-smi --query-gpu=gpu_name --format=csv
  gpu_name = gpu_name[1]
  GPU_NAME = gpu_name.replace(" ", "_") # remove underscores for easier saving
  print(f'GPU name: {GPU_NAME}')

  # Get GPU capability score
  GPU_SCORE = torch.cuda.get_device_capability()
  print(f"GPU capability score: {GPU_SCORE}")
  if GPU_SCORE >= (8, 0):
    print(f"GPU score higher than or equal to (8, 0), PyTorch 2.x speedup features available.")
  else:
    print(f"GPU score lower than (8, 0), PyTorch 2.x speedup features will be limited (PyTorch 2.x speedups happen most on newer GPUs).")

  # Print GPU info
  print(f"GPU information:\n{gpu_info}")

else:
  print("PyTorch couldn't find a GPU, to leverage the best of PyTorch 2.0, you should connect to a GPU.")

PyTorch couldn't find a GPU, to leverage the best of PyTorch 2.0, you should connect to a GPU.


In [50]:
# Check available GPU memory and total GPU memory
try:
  total_free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
  print(f"Total free GPU memory: {round(total_free_gpu_memory * 1e-9, 3)} GB")
  print(f"Total GPU memory: {round(total_gpu_memory * 1e-9, 3)} GB")
except:
  print("Please check that you have an NVIDIA GPU and installed a driver from ")

Please check that you have an NVIDIA GPU and installed a driver from 


In [51]:
# Set batch size depending on amount of GPU memory
try:
  total_free_gpu_memory_gb = round(total_free_gpu_memory * 1e-9, 3)
  if total_free_gpu_memory_gb >= 16:
    BATCH_SIZE = 128 # Note: you could experiment with higher values here if you like.
    print(f"GPU memory available is {total_free_gpu_memory_gb} GB, using batch size of {BATCH_SIZE}")
  else:
    BATCH_SIZE = 32
    print(f"GPU memory available is {total_free_gpu_memory_gb} GB, using batch size of {BATCH_SIZE}")
except:
  BATCH_SIZE = 32

In [52]:
BUFFER_SIZE = int(1e5)  # replay buffer size
BATCH_SIZE = 100        # minibatch size
GAMMA = 0.99            # discount factor
TAU = 1e-3              # for soft update of target parameters
LR_ACTOR = 1e-3         # learning rate of the actor
LR_CRITIC = 1e-3        # learning rate of the critic
UPDATE_EVERY_STEP = 2   # how often to update the target and actor networks

## **Replay**

In [70]:
from collections import namedtuple, deque
from typing import Tuple
import random
import numpy as np
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ReplayBuffer:

    def __init__(self, buffer_size: int, batch_size: int):

        self.memory = deque(maxlen=buffer_size)  # internal memory (deque)
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])

    def add(self, state: np.ndarray, action: np.ndarray, reward: np.float64, next_state: np.float32, done: bool) -> None:
        """Add experiences to the buffer

        Params
        ======
            state (np.ndarray): agent states
            action (np.ndarray): agent action
            reward (np.float64): agent reward
            next_state (np.ndarray): agent next_state
        """

        assert isinstance(state, np.ndarray), "State is not of data structure (np.ndarray) in REPLAY BUFFER -> state: {}.".format(type(state))
        assert isinstance(action, np.ndarray), "Action is not of data structure (np.ndarray) in REPLAY BUFFER -> action: {}.".format(type(action))
        assert isinstance(next_state, np.ndarray), "Next State is not of data structure (np.ndarray) in REPLAY BUFFER -> next state: {}.".format(type(next_state))

        assert isinstance(state[0], np.float32), "State is not of type (np.float32) in REPLAY BUFFER -> state type: {}.".format(type(state))
        assert isinstance(action[0], np.float32), "Action is not of type (np.float32) in REPLAY BUFFER -> action type: {}.".format(type(action))
        assert isinstance(reward, (int, np.float64)), "Reward is not of type (np.float64 / int) in REPLAY BUFFER -> reward: {}.".format(type(reward))
        assert isinstance(next_state[0], np.float32), "Next State is not of type (np.float32) in REPLAY BUFFER -> next state type: {}.".format(type(next_state))
        assert isinstance(done, bool), "Done is not of type (bool) in REPLAY BUFFER -> done type: {}.".format(type(done))

        assert state.shape[0] == 24, "The size of the state is not (24) in REPLAY BUFFER -> state size: {}.".format(state.shape[0])
        assert action.shape[0] == 4, "The size of the action is not (4) in REPLAY BUFFER -> action size: {}.".format(state.shape[0])
        if isinstance(reward, np.float64):
          assert reward.size == 1, "The size of the reward is not (1) in REPLAY BUFFER -> reward size: {}.".format(reward.size)
        assert next_state.shape[0] == 24, "The size of the next_state is not (24) in REPLAY BUFFER -> next_state size: {}.".format(next_state.shape[0])

        assert state.ndim == 1, "The ndim of the state is not (1) in REPLAY BUFFER -> state ndim: {}.".format(state.ndim)
        assert action.ndim == 1, "The ndim of the action is not (1) in REPLAY BUFFER -> action ndim: {}.".format(state.ndim)
        if isinstance(reward, np.float64):
          assert reward.ndim == 0, "The ndim of the reward is not (0) in REPLAY BUFFER -> reward ndim: {}.".format(reward.ndim)
        assert next_state.ndim == 1, "The ndim of the next_state is not (1) in REPLAY BUFFER -> next_state ndim: {}.".format(next_state.ndim)

        """Add a new experience to memory."""
        e = self.experience(state, action, reward, next_state, done)
        self.memory.append(e)

    def sample(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Randomly sample a batch of experiences from memory."""
        experiences = random.sample(self.memory, k=self.batch_size)

        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).float().to(device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None])).int().to(device)

        assert isinstance(states, torch.Tensor), "State is not of type torch.Tensor in REPLAY BUFFER."
        assert isinstance(actions, torch.Tensor), "Actions is not of type torch.Tensor in REPLAY BUFFER."
        assert isinstance(rewards, torch.Tensor), "Rewards is not of type torch.Tensor in REPLAY BUFFER."
        assert isinstance(next_states, torch.Tensor), "Next states is not of type torch.Tensor in REPLAY BUFFER."
        assert isinstance(dones, torch.Tensor), "Dones is not of type torch.Tensor in REPLAY BUFFER."

        assert states.dtype == torch.float32, "The (state) tensor elements are not of type torch.float32 in the REPLAY BUFFER -> {}.".format(states.dtype)
        assert actions.dtype == torch.float32,"The (actions) tensor elements are not of type torch.float32 in the REPLAY BUFFER -> {}.".format(actions.dtype)
        assert rewards.dtype == torch.float32, "The (rewards) tensor elements are not of type torch.float32 in the REPLAY BUFFER -> {}.".format(rewards.dtype)
        assert next_states.dtype == torch.float32, "The (next_states) tensor elements are not of type torch.float32 in the REPLAY BUFFER -> {}.".format(next_states.dtype)
        assert dones.dtype == torch.int, "The (dones) tensor elements are not of type torch.float32 in the REPLAY BUFFER -> {}.".format(dones.dtype)

        # TODO
        # assert all(tensor.device.type == DEVICE for tensor in [states, actions, rewards, next_states, dones]), "Each tensor must be on the same device in REPLAY BUFFER"

        return (
            states, actions, rewards, next_states, dones
        )

    def __len__(self) -> None:
        """Return the current size of internal memory."""
        return len(self.memory)

In [71]:
from torch.utils.data.dataset import IterableDataset
from typing import Iterator

class RLDataset(IterableDataset):
    """Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training.

    >>> RLDataset(ReplayBuffer(5))  # doctest: +ELLIPSIS
    <...reinforce_learn_Qnet.RLDataset object at ...>

    """

    def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
        """
        Args:
            buffer: replay buffer
            sample_size: number of experiences to sample at a time
        """
        self.buffer = buffer
        self.sample_size = sample_size

    def __iter__(self) -> Iterator:
        states, actions, rewards, dones, new_states = self.buffer.sample()
        for i in range(len(dones)):
            yield states[i], actions[i], rewards[i], dones[i], new_states[i]


## **Model**

In [55]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

def hidden_init(layer):
    fan_in = layer.weight.data.size()[0]
    lim = 1. / np.sqrt(fan_in)
    return (-lim, lim)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Actor(nn.Module):

    def __init__(self, state_size: int, action_size: int, max_action: float, l1=400, l2=300):
        super(Actor, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(state_size, l1),
            nn.ReLU(),
            nn.Linear(l1, l2),
            nn.ReLU(),
            nn.Linear(l2, action_size),
            nn.Tanh()
        )
        self.reset_parameters()
        self.max_action = max_action

    def reset_parameters(self):
        self.net[0].weight.data.uniform_(*hidden_init(self.net[0]))
        self.net[2].weight.data.uniform_(*hidden_init(self.net[2]))
        self.net[4].weight.data.uniform_(-3e-3, 3e-3)

    def forward(self, state) -> torch.Tensor:
        assert isinstance(state, torch.Tensor), "State is not of type torch.Tensor in ACTOR."
        assert state.dtype == torch.float32, "Tensor elements are not of type torch.float32 in ACTOR."
        assert state.shape[0] <= 24 or state.shape[0] >= BATCH_SIZE, "The tensor shape is not torch.Size([24]) in ACTOR."
        assert str(state.device.type) == str(DEVICE), "The state must be on the same device in ACTOR."

        # x = self.net(state)
        # action = self.max_action * x
        return self.net(state)


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, l1=400, l2=300):
        super(Critic, self).__init__()

        # Critic Q1
        self.net1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, l1),
            nn.ReLU(),
            nn.Linear(l1, l2),
            nn.ReLU(),
            nn.Linear(l2, 1)
        )

        # Critic Q2
        self.net2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, l1),
            nn.ReLU(),
            nn.Linear(l1, l2),
            nn.ReLU(),
            nn.Linear(l2, 1)
        )

        self.reset_parameters()

    def reset_parameters(self):
        self.net1[0].weight.data.uniform_(*hidden_init(self.net1[0]))
        self.net1[2].weight.data.uniform_(*hidden_init(self.net1[2]))
        self.net1[4].weight.data.uniform_(-3e-3, 3e-3)

        self.net2[0].weight.data.uniform_(*hidden_init(self.net2[0]))
        self.net2[2].weight.data.uniform_(*hidden_init(self.net2[2]))
        self.net2[4].weight.data.uniform_(-3e-3, 3e-3)

    def forward(self, state, action) -> Tuple[torch.Tensor, torch.Tensor]:
        assert isinstance(state, torch.Tensor), "State is not of type torch.Tensor in CRITIC."
        assert state.dtype == torch.float32, "Tensor elements are not of type torch.float32 in CRITIC."
        assert state.shape[0] == BATCH_SIZE, "The tensor shape is not torch.Size([100]) in CRITIC."
        assert str(state.device.type) == str(DEVICE), "The state must be on the same device in CRITIC."

        assert isinstance(action, torch.Tensor), "Action is not of type torch.Tensor in CRITIC."
        assert action.dtype == torch.float32, "Tensor elements are not of type torch.float32 in CRITIC."
        assert action.shape[0] == BATCH_SIZE, "The action shape is not torch.Size([100]) in CRITIC."
        assert str(action.device.type) == str(DEVICE), "The action must be on the same device in CRITIC."

        sa = torch.cat([state, action], dim=1)

        return self.net1(sa), self.net2(sa)

    def Q1(self, state, action) -> torch.Tensor:
        assert isinstance(state, torch.Tensor), "State is not of type torch.Tensor in CRITIC."
        assert state.dtype == torch.float32, "Tensor elements are not of type torch.float32 in CRITIC."
        assert state.shape[0] == BATCH_SIZE, "The tensor shape is not torch.Size([100]) in CRITIC."
        assert str(state.device.type) == str(DEVICE), "The state must be on the same device in CRITIC."

        assert isinstance(action, torch.Tensor), "Action is not of type torch.Tensor in CRITIC."
        assert action.dtype == torch.float32, "Tensor elements are not of type torch.float32 in CRITIC."
        assert action.shape[0] == BATCH_SIZE, "The action shape is not torch.Size([100]) in CRITIC."
        assert str(action.device.type) == str(DEVICE), "The action must be on the same device in CRITIC."

        sa = torch.cat([state, action], dim=1)

        return self.net1(sa)

## **Wrapper**

In [56]:
import gym
import wandb
from typing import Union
from gym import spaces
from gym.spaces import Box

class CustomWrapper(gym.Wrapper):

    def __init__(
        self,
        env: gym.Env,
        min_action: Union[float, int, np.ndarray],
        max_action: Union[float, int, np.ndarray],
    ):
        """Initializes the :class:`RescaleAction` wrapper.
        Args:
            env (Env): The environment to apply the wrapper
            min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
            max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
        """
        assert isinstance(
            env.action_space, spaces.Box
        ), f"expected Box action space, got {type(env.action_space)}"
        assert np.less_equal(min_action, max_action).all(), (min_action, max_action)

        super().__init__(env)
        self.min_action = (
            np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
        )
        self.max_action = (
            np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action
        )
        self.action_space = spaces.Box(
            low=min_action,
            high=max_action,
            shape=env.action_space.shape,
            dtype=env.action_space.dtype,
        )
        low = self.observation_space.low[:24]
        high = self.observation_space.high[:24]
        self.observation_space = Box(low, high, dtype=np.float32)

    def step(self, action):
        obs, reward, terminated, info = self.env.step(action)
        obs = obs[:24]
        return obs, reward, terminated, info

    def reset(self):
        obs = self.env.reset()
        obs = obs[:24]
        return obs

    def action(self, action):
        """Rescales the action affinely from  [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`.
        Args:
            action: The action to rescale
        Returns:
            The rescaled action
        """
        assert np.all(np.greater_equal(action, self.min_action)), (
            action,
            self.min_action,
        )
        assert np.all(np.less_equal(action, self.max_action)), (action, self.max_action)
        low = self.env.action_space.low
        high = self.env.action_space.high
        action = low + (high - low) * (
            (action - self.min_action) / (self.max_action - self.min_action)
        )
        action = np.clip(action, low, high)
        return action

    def seed(self, seed):
        torch.manual_seed(seed)
        np.random.seed(seed)

# gym.logger.set_level(40)
# env = CustomWrapper(gym.make("BipedalWalker-v3"),  min_action = -1.0,  max_action = 1.0)
# env.seed(0)

# agent = TD3Agent(state_size=env.observation_space.shape[0], \
#                  action_size=env.action_space.shape[0], \
#                  max_action=env.action_space.high, \
#                  min_action=env.action_space.low, continue_training=False)


## **Agent**

In [72]:
from typing import Tuple

import torch
import copy
import numpy as np
import torch.optim as optim

from numpy import inf

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Agent():
    """Interacts with and learns from the environment."""

    def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer, max_action, min_action) -> None:
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            max_action (ndarray): the maximum valid value for each action vector
            min_action (ndarray): the minimum valid value for each action vector
            noise (float): the range to generate random noise while learning
            noise_std (float): the range to generate random noise while performing action
            noise_clip (float): to clip random noise into this range
        """
        self.env = env
        self.reset()
        self.state = self.env.reset()

        self.max_action = max_action
        self.min_action = min_action

        # Replay memory
        self.memory = replay_buffer

    def action(self, actor, device) -> np.ndarray:

        """Returns actions for given state as per current policy."""
        state = torch.tensor([self.state])

        if device not in ["cpu"]:
          state = state.cuda(device).cpu().data.numpy()

        action = actor(state).cpu().data.numpy()

        action = action.clip(self.min_action[0], self.max_action[0])

        return action

    def reset(self) -> None:
        """Resets the environment and updates the state."""
        self.state = self.env.reset()

    @torch.no_grad()
    def step(self, actor: nn.Module, device: str = "cpu") -> Tuple[float, bool]:
        action = self.action(actor, device)

        next_state, reward, done, _ = self.env.step(action[0])

        self.memory.add(self.state, action[0], reward, next_state, done)

        self.state = next_state

        if done:
            self.reset()
        return reward, done

env = CustomWrapper(gym.make("BipedalWalker-v3"),  min_action = -1.0,  max_action = 1.0)
memory = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE)
agent = Agent(env=env, replay_buffer=memory, max_action=env.action_space.high, min_action=env.action_space.low)
actor = Actor(env.observation_space.shape[0], env.action_space.shape[0], float(env.action_space.high[0])).to(device)

for i in range(10):
  agent.step(actor)


In [79]:
from typing import Tuple, List, Tuple
from lightning.pytorch import LightningModule
from collections import OrderedDict, deque, namedtuple

import torch
import copy
import numpy as np
import torch.optim as optim
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger

from numpy import inf

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NOISE = 0.2
NOISE_STD = 0.1
NOISE_CLIP = 0.5

class TD3Lightning(LightningModule):
    """Interacts with and learns from the environment."""

    def __init__(self,
                 env,
                 state_size: int,
                 action_size: int,
                 max_action: int,
                 min_action: int,
                 sync_rate: int = 10,
                 lr: float = 1e-2,
                 batch_size: int = 100,
                 episode_length: int = 50,
                 warm_start_steps: int = 200):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            max_action (ndarray): the maximum valid value for each action vector
            min_action (ndarray): the minimum valid value for each action vector
            noise (float): the range to generate random noise while learning
            noise_std (float): the range to generate random noise while performing action
            noise_clip (float): to clip random noise into this range
        """
        super(TD3Lightning, self).__init__()
        self.warm_start_steps = warm_start_steps

        self.state_size = state_size
        self.action_size = action_size
        self.max_action = max_action
        self.min_action = min_action
        self.env = env
        self.nb_optim_iters = 4,
        self.memory = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE)
        self.agent = Agent(self.env, self.memory, self.max_action, self.min_action)

        self.total_reward = 0
        self.episode_reward = 0
        self.lr = lr
        self.sync_rate = sync_rate
        self.batch_size = batch_size
        self.episode_length = episode_length

        # Set the device globally
        torch.set_default_device(device)

        # Transfer Learning
        self.actor = Actor(state_size, action_size, float(max_action[0])).to(device)
        self.actor_target = Actor(state_size, action_size, float(max_action[0])).to(device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=LR_ACTOR)

        self.critic = Critic(state_size, action_size).to(device)
        self.critic_target = Critic(state_size, action_size).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=LR_CRITIC)

        self.populate(self.warm_start_steps)

        self.automatic_optimization = False


    def populate(self, steps: int = 1000) -> None:
        """Carries out several random steps through the environment to initially fill up the replay buffer with
        experiences.

        Args:
            steps: number of random steps to populate the buffer with

        """
        for i in range(steps):
            self.agent.step(self.actor)

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Passes in a state x through the network and returns the policy and a sampled action
        Args:
            x: environment state
        Returns:
            Tuple of policy and action
        """
        action = self.actor(state)
        action = action * self.max_action
        Q1_expected, Q2_expected = self.critic(state, action)

        return action, Q1_expected, Q2_expected

    def actor_loss(self, state, action) -> torch.Tensor:
        actor_loss = -self.critic.Q1(state, self.actor(state)).mean()

        return actor_loss

    def critic_loss(self, Q1_expected, Q2_expected, Q_targets) -> torch.Tensor:

        critic_loss = F.mse_loss(Q1_expected, Q_targets) + F.mse_loss(Q2_expected, Q_targets)

        return critic_loss

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx):
        """ Update policy and value parameters using given batch of experience tuples.

        Params
        ======
            n_iteraion (int): the number of iterations to train network
            gamma (float): discount factor
        """

        device = self.get_device(batch)

        if len(self.memory) > BATCH_SIZE:
            average_critic_loss = 0
            average_actor_loss = 0

            state, action, reward, next_state, done = batch

            action_ = action.cpu().numpy()

            # ---------------------------- update critic ---------------------------- #
            # Get predicted next-state actions and Q values from target models

            with torch.no_grad():

                # Generate a random noise
                noise = torch.FloatTensor(action_).data.normal_(0, NOISE).to(device)
                noise = noise.clamp(-NOISE_CLIP, NOISE_CLIP)
                actions_next = (self.actor_target(next_state) + noise).clamp(self.min_action[0].astype(float), self.max_action[0].astype(float))

                Q1_targets_next, Q2_targets_next = self.critic_target(next_state, actions_next)

                Q_targets_next = torch.min(Q1_targets_next, Q2_targets_next)

                # Compute Q targets for current states (y_i)
                Q_targets = reward + (GAMMA * Q_targets_next * (1 - done)).detach()

            # Compute critic loss
            Q1_expected, Q2_expected = self.critic(state, action)

            self.critic_optimizer.zero_grad()
            critic_loss = self.critic_loss(Q1_expected, Q2_expected, Q_targets)

            self.manual_backward(critic_loss)
            self.critic_optimizer.step()

            if i % UPDATE_EVERY_STEP == 0:
                # ---------------------------- update actor ---------------------------- #
                # Compute actor loss
                self.actor_optimizer.zero_grad()
                actor_loss = -self.critic.Q1(state, self.actor(state)).mean()

                self.manual_backward(actor_loss)
                self.actor_optimizer.step()

                # ----------------------- update target networks ----------------------- #
                self.soft_update(self.critic, self.critic_target, TAU)
                self.soft_update(self.actor, self.actor_target, TAU)


    def soft_update(self, local_model, target_model, tau) -> None:
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)

    def configure_optimizers(self) -> List[Optimizer]:
        """ Initialize Adam optimizer"""
        optimizer_actor = optim.Adam(self.actor.parameters(), lr=self.lr)
        optimizer_critic = optim.Adam(self.critic.parameters(), lr=self.lr)

        return optimizer_actor, optimizer_critic


    def optimizer_step(self, *args, **kwargs):
        """
        Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic
        for each data sample.
        """
        for i in range(self.nb_optim_iters):
            super().optimizer_step(*args, **kwargs)

    def save(self, filename, version) -> None:
          """ Save the model """
          torch.save(self.critic.state_dict(), filename + "_critic_" + version + ".pth")
          torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer_" + version + ".pth")

          torch.save(self.actor.state_dict(), filename + "_actor_" + version + ".pth")
          torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer_" + version + ".pth")

    def load(self, filename) -> None:
          """ Load the model """
          self.critic.load_state_dict(torch.load(filename + "_critic.pth"))
          self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer.pth"))
          self.critic_target = copy.deepcopy(self.critic)

          self.actor.load_state_dict(torch.load(filename + "_actor.pth"))
          self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer.pth"))
          self.actor_target = copy.deepcopy(self.actor)

    def __dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences."""
        dataset = RLDataset(self.memory, self.episode_length)
        return DataLoader(dataset=dataset, batch_size=self.batch_size, sampler=None)

    def train_dataloader(self) -> DataLoader:
        """Get train loader."""
        return self.__dataloader()

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch."""
        return batch[0].device.index if self.on_gpu else "cpu"


from lightning.pytorch import Trainer, cli_lightning_logo, seed_everything

def main() -> None:
    env = CustomWrapper(gym.make("BipedalWalker-v3"),  min_action = -1.0,  max_action = 1.0)
    model = TD3Lightning(env=env,
                         state_size=env.observation_space.shape[0],
                         action_size=env.action_space.shape[0],
                         max_action=env.action_space.high,
                         min_action=env.action_space.low,
                         sync_rate=10, lr=1e-2,
                         episode_length=200,
                         batch_size=100,
                         warm_start_steps=1000)
    trainer = Trainer(accelerator="cpu", devices=1, val_check_interval=100, max_epochs=1000, logger=WandbLogger(log_model="all"))
    trainer.fit(model)

main()

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
  return LooseVersion(v) >= LooseVersion(check)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


INFO: 
  | Name          | Type   | Params
-----------------------------------------
0 | actor         | Actor  | 131 K 
1 | actor_target  | Actor  | 131 K 
2 | critic        | Critic | 264 K 
3 | critic_target | Critic | 264 K 
-----------------------------------------
791 K     Trainable params
0         Non-trainable params
791 K     Total params
3.167     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name          | Type   | Params
-----------------------------------------
0 | actor         | Actor  | 131 K 
1 | actor_target  | Actor  | 131 K 
2 | critic        | Critic | 264 K 
3 | critic_target | Critic | 264 K 
-----------------------------------------
791 K     Trainable params
0         Non-trainable params
791 K     Total params
3.167     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1000` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1000` reached.


## **Train**

In [None]:
wandb.init(project="td3") # fb372890f5180a16a9cd2df5b9558e55493cd16c

In [None]:
import time

def td3(n_episodes=2000, max_t=2000):
    scores_deque = deque(maxlen=100)
    times_deque = deque(maxlen=100)
    scores = []
    solved = False
    for i_episode in range(1, n_episodes+1):
        state = env.reset()
        score = 0
        start_time = time.time()
        for t in range(max_t):
            action = agent.predict(state)
            next_state, reward, done, _ = env.step(action)
            agent.step(state, action, reward, next_state, done)
            state = next_state
            score += reward

            if done or t==(max_t-1):
                critic_loss, actor_loss, q, max = agent.learn(t)
                break

        duration = time.time() - start_time

        scores_deque.append(score)
        times_deque.append(duration)
        scores.append(score)
        mean_score = np.mean(scores_deque)
        mean_times = np.mean(times_deque)

        #wandb.log({'Score': mean_score, 'Critic loss': critic_loss, 'Actor loss': actor_loss, 'Average Q': q, 'Max. Q': max, "Duration ": mean_times}, step=i_episode)

        print('\rEpisode {}\tAverage Score: {:.2f}\tScore: {:.2f}'.format(i_episode, mean_score, score), end="")
        if i_episode % 100 == 0:
            print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, mean_score))
        if i_episode % 500 == 0:
            agent.save("checkpoint", str(i_episode))
        if mean_score >= 300 and solved == False:
            print('\rSolved at Episode {} !\tAverage Score: {:.2f}'.format(i_episode, mean_score))
            agent.save("checkpoint")
            solved = True

    return scores

scores = td3()

## **Result**

In [None]:
!apt-get install python-opengl -y
!apt install xvfb -y
!pip install pyvirtualdisplay
!pip install piglet

In [None]:
from pyvirtualdisplay import Display
Display().start()

import gym
from IPython import display
import matplotlib.pyplot as plt
%matplotlib inline

agent.actor.load_state_dict(torch.load('/content/checkpoint_actor_300.pth'))
agent.critic.load_state_dict(torch.load('/content/checkpoint_critic_300.pth'))
agent.actor_optimizer.load_state_dict(torch.load('/content/checkpoint_actor_optimizer_300.pth'))
agent.critic_optimizer.load_state_dict(torch.load('/content/checkpoint_critic_optimizer_300.pth'))

env = gym.make('BipedalWalker-v3')
state = env.reset()
score = 0
img = plt.imshow(env.render('rgb_array'))
while True:
    img.set_data(env.render('rgb_array'))
    display.display(plt.gcf())
    display.clear_output(wait=True)
    action = agent.predict(state)
    next_state, reward, done, _ = env.step(action)
    state = next_state
    score += reward
    if np.any(done):
        break

print("Score: {}".format(score))

## **Test**

In [None]:
class TestModel(unittest.TestCase):

    def setUp(self):
        self.env = gym.make("BipedalWalker-v3")

        # param model and buffer
        self.batch_size = 100
        self.buffer_size = int(1e5)
        self.random_seed = 0
        self.error = 0

        # size action / state
        self.action_size = self.env.action_space.shape[0]
        self.state_size = self.env.observation_space.shape[0]

        # min / max action
        self.min_action = self.env.action_space.low
        self.max_action = self.env.action_space.high

        # min / max state
        self.min_state = 0
        self.max_state = 1

        # min / max reward
        self.min_reward = -300
        self.max_reward = 300

        self.model = TD3Agent(state_size=self.state_size, action_size=self.action_size, \
                         max_action=self.max_action, min_action=self.min_action, random_seed=self.random_seed)

        self.memory = ReplayBufferPer(self.buffer_size)

        # param number tests
        self.num_attempts = 150

    def _randomStates(self):
        states = np.array([random.uniform(self.min_state, self.max_state) for _ in range(24)], dtype=np.float32)
        return states

    def _randomAction(self):
        action = np.random.uniform(self.min_action, self.max_action, self.action_size)
        return action

    def _randomDone(self):
        done = random.choice([True, False])
        return done

    def _randomReward(self):
        reward = random.randint(self.min_reward, self.max_reward)
        return reward

    def test_predict_act(self):
        """ Teste para verificar se o estado de saida da rede contem os valores minimos e maximos de ação que o ambiente exigi

            Input: numpy.ndarray [24]
            output: numpy.ndarray [4]

        """

        for _ in range(self.num_attempts):
            states = self._randomStates()
            action = self.model.predict(states)
            is_valid = (isinstance(action, np.ndarray) and np.all(action >= self.min_action) and np.all(action <= self.max_action))

            if not is_valid:
                self.fail("Teste falhou na tentativa {}".format(_ + 1))

    def test_buffer_type(self):
        while True:
          next_state, reward, done, action, state = self._randomStates(), self._randomReward(), self._randomDone(), self._randomAction(), self._randomStates()

          self.memory.add((state, action, reward, next_state, done), reward)

          if len(self.memory) > self.batch_size:
              break


        (states, actions, rewards, next_states, dones), idxs, is_weights = self.memory.sample(self.batch_size)

        self.assertIsInstance(states, torch.Tensor)
        self.assertIsInstance(actions, torch.Tensor)
        self.assertIsInstance(rewards, torch.Tensor)
        self.assertIsInstance(next_states, torch.Tensor)
        self.assertIsInstance(dones, torch.Tensor)

    def test_buffer_size(self):
        while True:
          next_state, reward, done, action, state = self._randomStates(), self._randomReward(), self._randomDone(), self._randomAction(), self._randomStates()
          self.memory.add((state, action, reward, next_state, done), reward)

          if len(self.memory) > self.batch_size:
            break

        (states, actions, rewards, next_states, dones), idxs, is_weights = self.memory.sample(self.batch_size)

        expected_batch_size = self.batch_size
        self.assertEqual(states.size(0), expected_batch_size)
        self.assertEqual(actions.size(0), expected_batch_size)
        self.assertEqual(rewards.size(0), expected_batch_size)
        self.assertEqual(next_states.size(0), expected_batch_size)
        self.assertEqual(dones.size(0), expected_batch_size)

    def test_buffer_range(self):
        while True:
          next_state, reward, done, action, state = self._randomStates(), self._randomReward(), self._randomDone(), self._randomAction(), self._randomStates()
          self.memory.add((state, action, reward, next_state, done), reward)

          if len(self.memory) > self.batch_size:
            break

        (states, actions, rewards, next_states, dones), idxs, weights = self.memory.sample(self.batch_size)

        self.assertTrue(np.all(states[1].cpu().data.numpy() >= self.min_state) and np.all(states[1].cpu().data.numpy() <= self.max_state))
        self.assertTrue(np.all(actions[1].cpu().data.numpy() >= self.min_action) and np.all(actions[1].cpu().data.numpy() <= self.max_action))
        self.assertTrue(np.all(rewards[1].cpu().data.numpy() >= self.min_reward) and np.all(rewards[1].cpu().data.numpy() <= self.max_reward))
        self.assertTrue(np.all(next_states[1].cpu().data.numpy() >= self.min_state) and np.all(next_states[1].cpu().data.numpy() <= self.max_state))
        self.assertTrue(np.all(dones.cpu().data.numpy() >= 0.) and np.all(dones.cpu().data.numpy() <= 1.))

if __name__ == "__main__":
    unittest.main(argv=['first-arg-is-ignored'], exit=False)