<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Soft_Actor_Critic_(SAC).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal

# Define the policy network
class PolicyNet(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        self.mu = nn.Linear(256, action_dim)
        self.log_std = nn.Linear(256, action_dim)

    def forward(self, x):
        x = self.fc(x)
        mu = self.mu(x)
        log_std = self.log_std(x)
        log_std = torch.clamp(log_std, -20, 2)  # Clip log_std to prevent numerical issues
        std = torch.exp(log_std)
        return mu, std

    def sample(self, state):
        mu, std = self.forward(state)
        print(f"mu: {mu}, std: {std}")  # Debug statement to check for NaN values
        normal = Normal(mu, std)
        z = normal.rsample()
        action = torch.tanh(z)
        return action, normal.log_prob(z)

# Define the value network
class ValueNet(nn.Module):
    def __init__(self, state_dim):
        super(ValueNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.fc(x)

# Hyperparameters
env = gym.make("Pendulum-v1", new_step_api=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
lr = 3e-4
gamma = 0.99
alpha = 0.2
tau = 0.005

# Initialize networks and optimizer
policy_net = PolicyNet(state_dim, action_dim)
value_net1 = ValueNet(state_dim + action_dim)
value_net2 = ValueNet(state_dim + action_dim)
target_value_net1 = ValueNet(state_dim + action_dim)
target_value_net2 = ValueNet(state_dim + action_dim)
value_net = ValueNet(state_dim)
target_value_net = ValueNet(state_dim)

optimizer_policy = optim.Adam(policy_net.parameters(), lr=lr)
optimizer_value1 = optim.Adam(value_net1.parameters(), lr=lr)
optimizer_value2 = optim.Adam(value_net2.parameters(), lr=lr)
optimizer_value = optim.Adam(value_net.parameters(), lr=lr)

def soft_update(target, source):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(tau * source_param.data + (1 - tau) * target_param.data)

# Training loop
for episode in range(1000):
    state = env.reset()
    total_reward = 0
    done = False

    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32)
        action, log_prob = policy_net.sample(state_tensor)
        action_np = action.detach().numpy()
        next_state, reward, done, truncated, info = env.step(action_np)
        total_reward += reward

        reward_tensor = torch.tensor(reward, dtype=torch.float32)
        next_state_tensor = torch.tensor(next_state, dtype=torch.float32)

        with torch.no_grad():
            next_action, next_log_prob = policy_net.sample(next_state_tensor)
            target_q_value = torch.min(
                target_value_net1(torch.cat([next_state_tensor, next_action], dim=-1)),
                target_value_net2(torch.cat([next_state_tensor, next_action], dim=-1))
            ) - alpha * next_log_prob
            target_value = reward_tensor + gamma * target_q_value

        q_value1 = value_net1(torch.cat([state_tensor, action], dim=-1))
        q_value2 = value_net2(torch.cat([state_tensor, action], dim=-1))
        value1_loss = nn.functional.mse_loss(q_value1, target_value)
        value2_loss = nn.functional.mse_loss(q_value2, target_value)

        optimizer_value1.zero_grad()
        value1_loss.backward(retain_graph=True)
        optimizer_value1.step()

        optimizer_value2.zero_grad()
        value2_loss.backward(retain_graph=True)
        optimizer_value2.step()

        q_value = torch.min(
            value_net1(torch.cat([state_tensor, action], dim=-1)),
            value_net2(torch.cat([state_tensor, action], dim=-1))
        )
        value_loss = nn.functional.mse_loss(value_net(state_tensor), q_value.detach())

        optimizer_value.zero_grad()
        value_loss.backward()
        optimizer_value.step()

        policy_loss = (alpha * log_prob - q_value).mean()

        optimizer_policy.zero_grad()
        policy_loss.backward()
        optimizer_policy.step()

        soft_update(target_value_net1, value_net1)
        soft_update(target_value_net2, value_net2)
        soft_update(target_value_net, value_net)

        state = next_state

        if done or truncated:
            break

    print(f"Episode {episode}, Total Reward: {total_reward}")

print("Training complete.")