<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Reinforcement_Learning_with_Soft_Actor_Critic_(SAC).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np

# Define the actor and critic networks
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)
        self.log_std = nn.Linear(256, action_dim)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        mean = self.fc3(x)
        log_std = self.log_std(x).clamp(-20, 2)
        return mean, log_std

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# Replay buffer for storing experiences
class ReplayBuffer:
    def __init__(self, max_size):
        self.max_size = max_size
        self.buffer = []
        self.ptr = 0

    def add(self, transition):
        if len(self.buffer) < self.max_size:
            self.buffer.append(transition)
        else:
            self.buffer[self.ptr] = transition
        self.ptr = (self.ptr + 1) % self.max_size

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        batch = [self.buffer[i] for i in indices]
        return map(np.array, zip(*batch))

# Initialize environment, networks, and training
env = gym.make("Pendulum-v1", new_step_api=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

actor = Actor(state_dim, action_dim)
critic_1 = Critic(state_dim, action_dim)
critic_2 = Critic(state_dim, action_dim)

actor_optimizer = optim.Adam(actor.parameters(), lr=0.0003)
critic_1_optimizer = optim.Adam(critic_1.parameters(), lr=0.0003)
critic_2_optimizer = optim.Adam(critic_2.parameters(), lr=0.0003)

replay_buffer = ReplayBuffer(max_size=1000000)

# Training loop
def train_sac(num_epochs, batch_size):
    for epoch in range(num_epochs):
        state = env.reset()
        episode_reward = 0
        for step in range(200):
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            mean, log_std = actor(state_tensor)
            std = log_std.exp()
            action = torch.normal(mean, std).detach().numpy()[0]
            next_state, reward, done, truncated, _ = env.step(action)
            done = done or truncated
            replay_buffer.add((state, action, reward, next_state, done))

            if len(replay_buffer.buffer) > batch_size:
                states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
                states = torch.tensor(states, dtype=torch.float32)
                actions = torch.tensor(actions, dtype=torch.float32)
                rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)  # Ensure rewards shape is (batch_size, 1)
                next_states = torch.tensor(next_states, dtype=torch.float32)
                dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)  # Ensure dones shape is (batch_size, 1)

                # Critic updates
                with torch.no_grad():
                    next_mean, next_log_std = actor(next_states)
                    next_std = next_log_std.exp()
                    next_action = torch.normal(next_mean, next_std)
                    next_q1 = critic_1(next_states, next_action)
                    next_q2 = critic_2(next_states, next_action)
                    next_q = torch.min(next_q1, next_q2)
                    target_q = rewards + (1 - dones) * 0.99 * next_q

                q1 = critic_1(states, actions)
                q2 = critic_2(states, actions)
                critic_1_loss = nn.MSELoss()(q1, target_q)
                critic_2_loss = nn.MSELoss()(q2, target_q)

                critic_1_optimizer.zero_grad()
                critic_1_loss.backward()
                critic_1_optimizer.step()

                critic_2_optimizer.zero_grad()
                critic_2_loss.backward()
                critic_2_optimizer.step()

                # Actor update
                mean, log_std = actor(states)
                std = log_std.exp()
                actions = torch.normal(mean, std)
                q1 = critic_1(states, actions)
                q2 = critic_2(states, actions)
                q = torch.min(q1, q2)
                actor_loss = (log_std - q).mean()

                actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_optimizer.step()

            state = next_state
            episode_reward += reward
            if done:
                break

        print(f"Epoch {epoch + 1}, Reward: {episode_reward}")

train_sac(num_epochs=50, batch_size=64)