In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE, gym.__version__, torch.__version__, np.__version__

(device(type='cuda'), '1.1.1', '2.6.0+cu118', '1.24.4')

In [1]:
env = gym.make("Hopper-v4", render_mode="human")  # Change to 'rgb_array' if headless
observation, information = env.reset(seed=42)
print("Environment initialized successfully!")
        
env.close()

  logger.deprecation(


Environment initialized successfully!


In [2]:
observation

array([ 1.24938878e+00,  3.58597920e-03,  1.97368029e-03, -4.05822652e-03,
        4.75622352e-03,  2.61139702e-03,  2.86064305e-03, -3.71886367e-03,
       -4.96140621e-04, -1.29201976e-03,  4.26764989e-03])

In [3]:
information

{}

### Actor Net

In [6]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, action_dim)
        self.max_action = max_action

    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        return self.max_action * torch.tanh(self.l3(a))

### Critic Net

In [10]:
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_units_l1=400, hidden_units_l2=300):
        super(Critic, self).__init__()

        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, hidden_units_l1)
        self.l2 = nn.Linear(hidden_units_l1, hidden_units_l2)
        self.l3 = nn.Linear(hidden_units_l2, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, hidden_units_l1)
        self.l5 = nn.Linear(hidden_units_l1, hidden_units_l2)
        self.l6 = nn.Linear(hidden_units_l2, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)

        return q1, q2

    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        return self.l3(q1)

### Replay Buffer

In [13]:
class ReplayBuffer:
    def __init__(self, max_size=1_000_000):
        self.buffer = deque(maxlen=max_size)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return (
            torch.FloatTensor(state).to(DEVICE),
            torch.FloatTensor(action).to(DEVICE),
            torch.FloatTensor(reward).unsqueeze(1).to(DEVICE),
            torch.FloatTensor(next_state).to(DEVICE),
            torch.FloatTensor(done).unsqueeze(1).to(DEVICE),
        )

    def __len__(self):
        return len(self.buffer)

### TD3 
- Target nets to be updated every once in a while to ensure convergence
- Actor - Takes in the state, spits out action
- Critic - Takes in state and action taken for that state, spits out value
- Polyak Averaging -
- 

In [15]:
"""
For documentation:

"""
class TD3:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action).to(DEVICE)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(DEVICE) 
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-3)

        self.critic = Critic(state_dim, action_dim).to(DEVICE)
        self.critic_target = Critic(state_dim, action_dim).to(DEVICE)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)

        self.max_action = max_action
        self.total_iterations = 0

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(DEVICE)
        return self.actor(state).cpu().data.numpy().flatten()

    def train(self, replay_buffer, batch_size=256, gamma=0.99, tau=0.005, policy_delay=2):
        self.total_iterations += 1

        state, action, reward, next_state, done = replay_buffer.sample(batch_size)

        # Target actions (no smoothing as per your instruction)
        target_action = self.actor_target(next_state)

        # Target Q values
        target_Q1, target_Q2 = self.critic_target(next_state, target_action)
        target_Q = torch.min(target_Q1, target_Q2)
        target_Q = reward + ((1 - done) * gamma * target_Q).detach()

        # Current Q estimates
        current_Q1, current_Q2 = self.critic(state, action)

        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Delayed policy updates
        if self.total_iterations % policy_delay == 0:
            actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Polyak averaging
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)