<a href="https://colab.research.google.com/github/Cosmox999/SOC-RL/blob/main/sac.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install numpy==1.23.5 gym==0.25.2

Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m67.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
scikit-image 0.25.2 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.
chex 0.1.89 requires numpy>=1.24.1, but you have numpy 1.23.5 which is incompatible.
pymc 5.20.1 requires numpy>=1.25.0, but you have numpy 1.23.5 which is incompatible.
imbalanced-learn 0.13.0

In [3]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import matplotlib.pyplot as plt
import os
from gym.wrappers import RecordVideo

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = np.random.choice(len(self.buffer), batch_size)
        states, actions, rewards, next_states, dones = zip(*[self.buffer[idx] for idx in batch])
        return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)

    def __len__(self):
        return len(self.buffer)

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)

        self.mean = nn.Linear(256, action_dim)
        self.log_std = nn.Linear(256, action_dim)

        self.action_scale = torch.tensor(2.0)
        self.action_bias = torch.tensor(0.0)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))

        mean = self.mean(x)
        log_std = self.log_std(x)
        log_std = torch.clamp(log_std, min=-20, max=2)
        std = torch.exp(log_std)

        dist = Normal(mean, std)
        x_t = dist.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = dist.log_prob(x_t) - torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(-1, keepdim=True)

        return action, log_prob

def sac_train(env, episodes=150, batch_size=256, gamma=0.99, tau=0.005, alpha=0.2, lr=3e-4):
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    # Networks
    q_net1 = QNetwork(state_dim, action_dim)
    q_net2 = QNetwork(state_dim, action_dim)
    target_q_net1 = QNetwork(state_dim, action_dim)
    target_q_net2 = QNetwork(state_dim, action_dim)
    policy_net = PolicyNetwork(state_dim, action_dim)

    # Copy parameters to target networks
    target_q_net1.load_state_dict(q_net1.state_dict())
    target_q_net2.load_state_dict(q_net2.state_dict())

    # Optimizers
    q_optimizer1 = optim.Adam(q_net1.parameters(), lr=lr)
    q_optimizer2 = optim.Adam(q_net2.parameters(), lr=lr)
    policy_optimizer = optim.Adam(policy_net.parameters(), lr=lr)

    # Replay buffer
    replay_buffer = ReplayBuffer(100000)

    rewards = []

    for episode in range(episodes):
        state = env.reset()
        done = False
        total_reward = 0

        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action, _ = policy_net(state_tensor)
            action = action.detach().numpy()[0]

            next_state, reward, done, _ = env.step(action)
            replay_buffer.push(state, action, reward, next_state, done)

            state = next_state
            total_reward += reward

            if len(replay_buffer) > batch_size:
                # Sample batch
                states, actions, rewards_batch, next_states, dones = replay_buffer.sample(batch_size)

                states = torch.FloatTensor(states)
                actions = torch.FloatTensor(actions)
                rewards_batch = torch.FloatTensor(rewards_batch).unsqueeze(1)
                next_states = torch.FloatTensor(next_states)
                dones = torch.FloatTensor(1 - dones).unsqueeze(1)

                # Q targets
                with torch.no_grad():
                    next_actions, next_log_probs = policy_net(next_states)
                    q1_target = target_q_net1(next_states, next_actions)
                    q2_target = target_q_net2(next_states, next_actions)
                    q_target = torch.min(q1_target, q2_target) - alpha * next_log_probs
                    q_target = rewards_batch + gamma * dones * q_target

                # Update Q networks
                q1 = q_net1(states, actions)
                q2 = q_net2(states, actions)

                q1_loss = (q1 - q_target).pow(2).mean()
                q2_loss = (q2 - q_target).pow(2).mean()

                q_optimizer1.zero_grad()
                q1_loss.backward()
                q_optimizer1.step()

                q_optimizer2.zero_grad()
                q2_loss.backward()
                q_optimizer2.step()

                # Update policy network
                actions_pi, log_probs = policy_net(states)
                q1_pi = q_net1(states, actions_pi)
                q2_pi = q_net2(states, actions_pi)
                q_pi = torch.min(q1_pi, q2_pi)

                policy_loss = (alpha * log_probs - q_pi).mean()

                policy_optimizer.zero_grad()
                policy_loss.backward()
                policy_optimizer.step()

                # Update target networks
                for target_param, param in zip(target_q_net1.parameters(), q_net1.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

                for target_param, param in zip(target_q_net2.parameters(), q_net2.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        rewards.append(total_reward)

        if episode % 10 == 0:
            print(f"Episode: {episode}, Reward: {total_reward}")

    # Save the policy network
    torch.save(policy_net.state_dict(), "sac_policy.pth")

    return rewards

def plot_rewards(rewards, title="SAC Rewards"):
    plt.figure(figsize=(10, 5))
    plt.plot(rewards)
    plt.title(title)
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")
    plt.savefig("sac_rewards.png")
    plt.close()

if __name__ == "__main__":
    # Create a directory for videos
    if not os.path.exists("videos"):
        os.makedirs("videos")

    # First train the agent
    env = gym.make('Pendulum-v1')
    rewards = sac_train(env, episodes=200)  # Reduced episodes for Colab demo
    plot_rewards(rewards)

    # Then record a video of the trained agent
    env = RecordVideo(gym.make('Pendulum-v1'), "videos", name_prefix="sac")
    state = env.reset()
    done = False

    policy_net = PolicyNetwork(env.observation_space.shape[0], env.action_space.shape[0])
    policy_net.load_state_dict(torch.load("sac_policy.pth"))

    while not done:
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        action, _ = policy_net(state_tensor)
        state, _, done, _ = env.step(action.detach().numpy()[0])

    env.close()

    # In Colab, you can view the video with:
    # from IPython.display import HTML
    # import base64
    #
    # def show_video(video_path):
    #     video_file = open(video_path, "r+b").read()
    #     video_url = f"data:video/mp4;base64,{base64.b64encode(video_file).decode()}"
    #     return HTML(f"""<video width="400" controls><source src="{video_url}"></video>""")
    #
    # video_path = "videos/sac-episode-0.mp4"
    # show_video(video_path)

Episode: 0, Reward: -1315.5934731809123
Episode: 10, Reward: -1178.662697110516
Episode: 20, Reward: -868.9250862488867
Episode: 30, Reward: -1040.6500903699769
Episode: 40, Reward: -647.0165651645594
Episode: 50, Reward: -648.9576955947495
Episode: 60, Reward: -1023.5574159953212
Episode: 70, Reward: -735.725062390191
Episode: 80, Reward: -792.0655109856966
Episode: 90, Reward: -126.63417408262525
Episode: 100, Reward: -123.71334244519775
Episode: 110, Reward: -371.95395358083374
Episode: 120, Reward: -247.07317151519567
Episode: 130, Reward: -4.8028083386638185
Episode: 140, Reward: -358.61009266040793
Episode: 150, Reward: -123.4839473142723
Episode: 160, Reward: -360.05168393455335
Episode: 170, Reward: -123.53525797267784
Episode: 180, Reward: -119.41148606879798
Episode: 190, Reward: -240.90747643913244


  logger.warn(
  logger.deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
  from pkg_resources import resource_stream, resource_exists
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
  policy_net.load_state_dict(torch.load("sac_policy.pth"))
See here for more information: https://www.gym