## SAC

# Dependancies

In [38]:
import random
import numpy as np
import gymnasium as gym
import panda_gym
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm

# Actor Network

In [39]:
class StochasticActor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, action_bound):
        super(StochasticActor, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)
        self.action_bound = action_bound

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(-20, 2)  # Clamping for numerical stability
        std = torch.exp(log_std)
        return mean, std

    def sample(self, state):
        mean, std = self.forward(state)
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()  # Reparameterization trick
        action = torch.tanh(x_t) * self.action_bound
        log_prob = normal.log_prob(x_t) - torch.log(1 - action.pow(2) + 1e-6)
        return action, log_prob.sum(dim=-1)

# Critic Network

In [40]:
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.q_value = nn.Linear(hidden_dim, 1)  # Output single Q-value

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)  # Concatenate along the feature axis
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.q_value(x)  # Shape [batch_size, 1]


# Replay Buffer

In [41]:
class ReplayBuffer:
    def __init__(self, buffer_size, state_dim, action_dim, device="cpu"):
        self.buffer_size = buffer_size
        self.device = torch.device(device)
        self.ptr = 0
        self.size = 0

        self.states = np.zeros((buffer_size, state_dim), dtype=np.float32)
        self.actions = np.zeros((buffer_size, action_dim), dtype=np.float32)
        self.rewards = np.zeros((buffer_size, 1), dtype=np.float32)  # Store rewards with shape [batch_size, 1]
        self.next_states = np.zeros((buffer_size, state_dim), dtype=np.float32)
        self.dones = np.zeros((buffer_size, 1), dtype=np.float32)  # Store dones with shape [batch_size, 1]

    def add(self, state, action, reward, next_state, done):
        self.states[self.ptr] = state
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.next_states[self.ptr] = next_state
        self.dones[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.buffer_size
        self.size = min(self.size + 1, self.buffer_size)

    def sample_batch(self, batch_size):
        indices = np.random.randint(0, self.size, size=batch_size)
        return (
        torch.tensor(self.states[indices], device=self.device),
        torch.tensor(self.actions[indices], device=self.device),
        torch.tensor(self.rewards[indices].reshape(-1, 1), device=self.device),  # Shape [batch_size, 1]
        torch.tensor(self.next_states[indices], device=self.device),
        torch.tensor(self.dones[indices].reshape(-1, 1), device=self.device),    # Shape [batch_size, 1]
    )

    def __len__(self):
        return self.size
 

# SAC Agent

In [42]:
class SACAgent:
    def __init__(
        self, 
        state_dim, 
        action_dim, 
        action_bound, 
        actor, 
        critic_1, 
        critic_2, 
        buffer, 
        device, 
        gamma=0.99, 
        tau=0.005, 
        lr=3e-4, 
        alpha=0.2, 
        automatic_entropy_tuning=True, 
        batch_size=256, 
        updates_per_step=1
    ):
        self.actor = actor.to(device)
        self.critic_1 = critic_1.to(device)
        self.critic_2 = critic_2.to(device)
        self.target_critic_1 = critic_1.to(device)
        self.target_critic_2 = critic_2.to(device)

        self.replay_buffer = buffer
        self.device = device
        self.batch_size = batch_size
        self.updates_per_step = updates_per_step  # Added updates_per_step attribute

        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha
        self.automatic_entropy_tuning = automatic_entropy_tuning

        if self.automatic_entropy_tuning:
            self.target_entropy = -action_dim  # Target entropy is heuristic
            self.log_alpha = torch.tensor(np.log(alpha), requires_grad=True, device=device)
            self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=lr)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), lr=lr)
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), lr=lr)

        # Sync target critics
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())

    def select_action(self, state, deterministic=False):
        """Select an action given the current state.

        Args:
            state (np.ndarray): The current state.
            deterministic (bool): If True, select the mean action deterministically.

        Returns:
            np.ndarray: The selected action scaled to the environment's action range.
        """
        state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)  # Add batch dimension
        with torch.no_grad():
            mean, std = self.actor(state)  # Get the mean and standard deviation
            if deterministic:
                action = mean
            else:
                normal = torch.distributions.Normal(mean, std)
                action = normal.sample()  # Sample stochastically
            action = torch.tanh(action) * self.actor.action_bound  # Scale to the action range
        return action.squeeze(0).cpu().numpy()

    def update(self, batch_size=None):
        if batch_size is None:
            batch_size = self.batch_size

        states, actions, rewards, next_states, dones = self.replay_buffer.sample_batch(batch_size)

        with torch.no_grad():
            next_actions, next_log_probs = self.actor.sample(next_states)

            # Compute Q-values for the next state-action pairs
            q1_next = self.target_critic_1(next_states, next_actions)
            q2_next = self.target_critic_2(next_states, next_actions)

        #print("Before torch.min:")
        #print("q1_next shape:", q1_next.shape)  # Should be [batch_size, 1]
        #print("q2_next shape:", q2_next.shape)  # Should be [batch_size, 1]

        # Compute element-wise minimum
        q1_next = q1_next.view(-1, 1)  # Ensure shape [batch_size, 1]
        q2_next = q2_next.view(-1, 1)  # Ensure shape [batch_size, 1]
        min_q_next = torch.min(q1_next, q2_next)  # Element-wise minimum

        #print("After torch.min:")
        #print("min_q_next shape:", min_q_next.shape)  # Should be [batch_size, 1]

        # Compute target Q-values
        rewards = rewards.view(-1, 1)  # Ensure [batch_size, 1]
        dones = dones.view(-1, 1)      # Ensure [batch_size, 1]
        target_q = rewards + self.gamma * (1 - dones) * min_q_next


        # Update Q-functions
        current_q1 = self.critic_1(states, actions)
        current_q2 = self.critic_2(states, actions)

        critic_1_loss = F.mse_loss(current_q1, target_q)
        critic_2_loss = F.mse_loss(current_q2, target_q)

        self.critic_1_optimizer.zero_grad()
        critic_1_loss.backward()
        self.critic_1_optimizer.step()

        self.critic_2_optimizer.zero_grad()
        critic_2_loss.backward()
        self.critic_2_optimizer.step()

        # Update policy
        sampled_actions, log_probs = self.actor.sample(states)
        q1 = self.critic_1(states, sampled_actions)
        q2 = self.critic_2(states, sampled_actions)
        actor_loss = (self.alpha * log_probs - torch.min(q1, q2)).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Update alpha if using automatic entropy tuning
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()

            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

            self.alpha = self.log_alpha.exp()

        # Soft update target networks
        self.soft_update(self.target_critic_1, self.critic_1)
        self.soft_update(self.target_critic_2, self.critic_2)

        return critic_1_loss.item(), critic_2_loss.item(), actor_loss.item()

    def soft_update(self, target, source):
        """Perform a soft update for the target networks."""
        for target_param, source_param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(self.tau * source_param.data + (1 - self.tau) * target_param.data)


# Training Loop

In [43]:
def train_sac_agent(env, agent, episodes, max_steps):
    """
    Training loop for Soft Actor-Critic (SAC) agent.
    
    Args:
        env: Gym environment.
        agent: Instance of SACAgent class.
        episodes: Number of episodes to train the agent.
        max_steps: Maximum number of steps per episode.
    
    Returns:
        rewards_log: List of cumulative rewards per episode.
        critic_1_losses, critic_2_losses, actor_losses, timesteps: Lists of losses and timesteps.
    """
    rewards_log = []
    critic_1_losses = []
    critic_2_losses = []
    actor_losses = []
    timesteps = []
    total_steps = 0  # Count global timesteps

    for episode in tqdm(range(1, episodes + 1), desc="Training Progress", unit="episode"):
        # Reset environment and retrieve initial state
        raw_state, _ = env.reset()
        if isinstance(raw_state, dict):
            state = raw_state.get("observation", raw_state)
        else:
            state = raw_state
        if not isinstance(state, np.ndarray):
            state = np.array(state, dtype=np.float32)
        
        episode_reward = 0

        for step in range(max_steps):
            # Select action from policy (with exploration)
            action = agent.select_action(state, deterministic=False)

            # Step environment
            raw_next_state, reward, done, _, info = env.step(action)
            if isinstance(raw_next_state, dict):
                next_state = raw_next_state.get("observation", raw_next_state)
            else:
                next_state = raw_next_state
            if not isinstance(next_state, np.ndarray):
                next_state = np.array(next_state, dtype=np.float32)

            # Store transition in replay buffer
            agent.replay_buffer.add(state, action, reward, next_state, done)

            # Perform updates (if replay buffer has enough samples)
            if len(agent.replay_buffer) >= agent.batch_size:
                for _ in range(agent.updates_per_step):  # Multiple updates per environment step
                    critic_loss_1, critic_loss_2, actor_loss = agent.update()
                    if critic_loss_1 is not None and critic_loss_2 is not None and actor_loss is not None:
                        critic_1_losses.append(critic_loss_1)
                        critic_2_losses.append(critic_loss_2)
                        actor_losses.append(actor_loss)
                        timesteps.append(total_steps)

            state = next_state
            episode_reward += reward
            total_steps += 1

            if done:
                break

        rewards_log.append(episode_reward)
        print(f"Episode {episode}/{episodes}, Reward: {episode_reward}")

    return rewards_log, critic_1_losses, critic_2_losses, actor_losses, timesteps




# Experimentation

In [44]:
if __name__ == '__main__':
    # Set seeds
    np.random.seed(1234)
    random.seed(1234)
    torch.manual_seed(1234)

    # Environment
    env = gym.make("PandaReach-v3")
    if isinstance(env.observation_space, gym.spaces.Dict):
        state_dim = env.observation_space["observation"].shape[0]
    else:
        state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    action_bound = env.action_space.high[0]

    # Replay buffer
    replay_buffer = ReplayBuffer(buffer_size=100_000,state_dim=state_dim,action_dim=action_dim)

    # Instantiate the actor and critics
    actor = StochasticActor(
    state_dim=state_dim,
    action_dim=action_dim,
    hidden_dim=256,  # Example hidden layer size
    action_bound=action_bound,
    )   

    critic_1 = Critic(
    state_dim=state_dim,
    action_dim=action_dim,
    hidden_dim=256,  # Example hidden layer size
    )

    critic_2 = Critic(
    state_dim=state_dim,
    action_dim=action_dim,
    hidden_dim=256,  # Example hidden layer size
    )

    # SAC Agent
    sac_agent = SACAgent(
    state_dim=state_dim,
    action_dim=action_dim,
    action_bound=action_bound,
    actor=actor,
    critic_1=critic_1,
    critic_2=critic_2,
    buffer=replay_buffer,
    device=torch.device("cpu"),
    gamma=0.99,
    tau=0.005,
    lr=5e-4,
    alpha=0.2,
    automatic_entropy_tuning=True,
    batch_size=256,
    updates_per_step=1,  # Perform one update per environment step
)

# Train the SAC Agent
train_rewards, critic_1_losses, critic_2_losses, actor_losses, timesteps = train_sac_agent(
    env, sac_agent, episodes=100, max_steps=1000
)

# Plot the losses
plt.figure()
plt.plot(timesteps, critic_1_losses, label="Critic 1 Loss")
plt.plot(timesteps, critic_2_losses, label="Critic 2 Loss")
plt.plot(timesteps, actor_losses, label="Actor Loss")
plt.xlabel("Timestep")
plt.ylabel("Loss")
plt.title("Losses vs Timestep")
plt.legend()
plt.show()

# Save the rewards and losses
np.save("sac_train_rewards.npy", train_rewards)
np.save("sac_critic_1_losses.npy", critic_1_losses)
np.save("sac_critic_2_losses.npy", critic_2_losses)
np.save("sac_actor_losses.npy", actor_losses)
#np.save("sac_timesteps.npy", timesteps)


Training Progress:   1%|          | 1/100 [00:16<26:31, 16.07s/episode]

Episode 1/100, Reward: -1000.0


Training Progress:   2%|▏         | 2/100 [00:42<36:34, 22.39s/episode]

Episode 2/100, Reward: -1000.0


Training Progress:   3%|▎         | 3/100 [01:09<39:25, 24.38s/episode]

Episode 3/100, Reward: -1000.0


Training Progress:   3%|▎         | 3/100 [01:36<52:05, 32.22s/episode]


KeyboardInterrupt: 