In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gym
import numpy as np
import random
from collections import deque

#define replay buffer
default_buffer_size = 100000
class ReplayBuffer:
    #deque with max length to automatically drop old data
    def __init__(self, max_size=default_buffer_size):
        self.buffer = deque(maxlen=max_size)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    #sampling random minibatches
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

#clamp for std (-20,2)
LOG_STD_MIN = -20
LOG_STD_MAX = 2
#define actor network (Gaussian distribution output)
class Actor(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size=256):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(num_inputs, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )
        #output layers for mean and std
        self.mean = nn.Linear(hidden_size, num_actions)
        self.log_std = nn.Linear(hidden_size, num_actions)

    #use clamp to make sure distribution is in range
    def forward(self, state):
        x = self.net(state)
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(LOG_STD_MIN, LOG_STD_MAX)
        std = log_std.exp() #convert log(std) to std
        return mean, std

    #tanh squashing: produce log prob within (-1,1)
    def sample(self, state):
        mean, std = self(state)
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample() #rsample allows us to backpropagate
        y_t = torch.tanh(x_t) #squash
        action = y_t
        #compute log probability with correction for tanh
        log_prob = normal.log_prob(x_t) - torch.log(1 - y_t.pow(2) + 1e-6) #add small number to prevent 0
        log_prob = log_prob.sum(1, keepdim=True)
        return action, log_prob

#define critic network
class Critic(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size=256):
        super(Critic, self).__init__()
        self.q1_net = nn.Sequential(
            nn.Linear(num_inputs + num_actions, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )
        self.q2_net = nn.Sequential(
            nn.Linear(num_inputs + num_actions, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, state, action):
        sa = torch.cat([state, action], dim=1)
        q1 = self.q1_net(sa)
        q2 = self.q2_net(sa)
        return q1, q2

#define agent to instantiate actor/critic networks and other hyperparameters
class SACAgent:
    def __init__(self, env, gamma=0.99, tau=0.005, target_entropy=None,
                 lr=3e-4, buffer_size=default_buffer_size, batch_size=64):
        self.env = env
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size

        #get observation and action space dimensions
        obs_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]

        #we instantiate two actor/critic networks each to prevent overestimation
        self.actor = Actor(obs_dim, action_dim).to(device)
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr)

        self.critic = Critic(obs_dim, action_dim).to(device)
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=lr)
        self.critic_target = Critic(obs_dim, action_dim).to(device)
        #copy initial weights to target network
        self.critic_target.load_state_dict(self.critic.state_dict())

        if target_entropy is None:
            self.target_entropy = -action_dim #entropy has neg dimension of the action space
        else:
            self.target_entropy = target_entropy
        #log alpha param for entropy coefficient
        self.log_alpha = torch.tensor(np.log(0.2), requires_grad=True, device=device)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=lr)

        self.buffer = ReplayBuffer(buffer_size)

    #getter method for alpha
    @property
    def alpha(self):
        return self.log_alpha.exp() #exp to ensure >0

    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        if evaluate:
            mean, _ = self.actor(state)
            action = torch.tanh(mean) #also tanh squash whenever we sample an action
            return action.cpu().detach().numpy()[0]
        else:
            action, _ = self.actor.sample(state)
            return action.cpu().detach().numpy()[0]

    #soft actor/critic updates
    def update(self):
        #only update if we have enough samples
        if len(self.buffer) < self.batch_size:
            return
        state, action, reward, next_state, done = self.buffer.sample(self.batch_size)
        state = torch.FloatTensor(state).to(device)
        action = torch.FloatTensor(action).to(device)
        reward = torch.FloatTensor(reward).unsqueeze(1).to(device)
        next_state = torch.FloatTensor(next_state).to(device)
        done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device)

        #compute target q-value
        with torch.no_grad():
            next_action, next_log_prob = self.actor.sample(next_state)
            q1_next, q2_next = self.critic_target(next_state, next_action)
            q_next = torch.min(q1_next, q2_next) - self.alpha.detach() * next_log_prob #take min of estimated q-values
            q_target = reward + (1 - done) * self.gamma * q_next #standard update with chosen q-value

        #critic update using MSE loss
        q1, q2 = self.critic(state, action)
        critic_loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target) #minimize loss of both function at once
        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        #actor update: maximize total reward and expected entropy
        action_new, log_prob_new = self.actor.sample(state)
        q1_new, q2_new = self.critic(state, action_new)
        q_new = torch.min(q1_new, q2_new) #take min of estimated q-values
        actor_loss = (self.alpha * log_prob_new - q_new).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        #alpha loss to tune entropy coefficient
        alpha_loss = -(self.log_alpha * (log_prob_new + self.target_entropy).detach()).mean()
        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()

        #soft updating target critic parameters
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    #add everything to buffer
    def store_transition(self, state, action, reward, next_state, done):
        self.buffer.push(state, action, reward, next_state, done)

#main training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train(env_name='Pendulum-v1', max_episodes=200, max_steps=200):
    #instantiate environment and agent
    env = gym.make(env_name)
    agent = SACAgent(env)
    total_rewards = []

    for episode in range(max_episodes):
        state = env.reset()[0]
        episode_reward = 0
        for step in range(max_steps):
            #select and execute an action
            action = agent.select_action(state)
            next_state, reward, done, truncated, _ = env.step(action) #take step in environment
            #store transition and update networks
            agent.store_transition(state, action, reward, next_state, done)
            agent.update()
            state = next_state
            episode_reward += reward
            if done:
                break
        #periodic reward reporting
        total_rewards.append(episode_reward)
        if (episode + 1) % 10 == 0:
            avg_reward = np.mean(total_rewards[-10:])
            print(f"Episode {episode+1}, Average Reward: {avg_reward:.2f}")

    env.close()

if __name__ == "__main__":
    train()
