<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Soft_Actor_Critic_(SAC).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
from collections import deque
import random
import numpy as np

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        action_probs = torch.softmax(self.fc3(x), dim=-1)
        return action_probs

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.q1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, state, action):
        xu = torch.cat([state, action], dim=1)
        q1_value = self.q1(xu)
        q2_value = self.q2(xu)
        return q1_value, q2_value

def one_hot_encoding(action, action_dim):
    action_one_hot = torch.zeros(action_dim)
    action_one_hot[action] = 1
    return action_one_hot

# Instantiate the actor and critic networks
env = gym.make('CartPole-v1', new_step_api=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n  # Use the number of discrete actions
actor = Actor(state_dim, action_dim)
critic = Critic(state_dim, action_dim)
critic_target = Critic(state_dim, action_dim)
critic_target.load_state_dict(critic.state_dict())

# Define optimizers
actor_optimizer = optim.Adam(actor.parameters(), lr=1e-3)
critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

# Define replay buffer
replay_buffer = deque(maxlen=1000000)
batch_size = 256
gamma = 0.99
tau = 0.005

def update_networks():
    if len(replay_buffer) < batch_size:
        return

    batch = random.sample(replay_buffer, batch_size)
    states, actions, rewards, next_states, dones = zip(*batch)

    states = torch.tensor(states, dtype=torch.float32)
    actions = torch.tensor(actions, dtype=torch.int64)
    actions_one_hot = torch.stack([one_hot_encoding(a, action_dim) for a in actions])
    rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)
    next_states = torch.tensor(next_states, dtype=torch.float32)
    dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)

    with torch.no_grad():
        next_action_probs = actor(next_states)
        next_actions = next_action_probs.multinomial(1).squeeze(1)
        next_actions_one_hot = torch.stack([one_hot_encoding(a, action_dim) for a in next_actions])
        next_q1_value, next_q2_value = critic_target(next_states, next_actions_one_hot)
        next_q_value = rewards + gamma * (1 - dones) * torch.min(next_q1_value, next_q2_value)

    current_q1_value, current_q2_value = critic(states, actions_one_hot)
    critic_loss = nn.MSELoss()(current_q1_value, next_q_value) + nn.MSELoss()(current_q2_value, next_q_value)

    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()

    for param, target_param in zip(critic.parameters(), critic_target.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    action_probs = actor(states)
    sampled_actions = action_probs.multinomial(1).squeeze(1)
    sampled_actions_one_hot = torch.stack([one_hot_encoding(a, action_dim) for a in sampled_actions])
    q1_value, q2_value = critic(states, sampled_actions_one_hot)
    actor_loss = -torch.min(q1_value, q2_value).mean()

    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()

# Training loop
num_episodes = 500
max_timesteps = 1000

for episode in range(num_episodes):
    state = env.reset()
    total_reward = 0

    for t in range(max_timesteps):
        state = torch.tensor(state, dtype=torch.float32)
        action_probs = actor(state)
        action = action_probs.multinomial(1).item()

        next_state, reward, done, truncated, _ = env.step(action)
        done = done or truncated
        replay_buffer.append((state.numpy(), action, reward, next_state, done))

        state = next_state
        total_reward += reward

        update_networks()

        if done:
            break

    print(f"Episode {episode}, total reward: {total_reward}")

# Save the actor and critic models
torch.save(actor.state_dict(), 'actor_sac.pth')
torch.save(critic.state_dict(), 'critic_sac.pth')