## **Vanila SAC**

#### **Imports**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import warnings
from collections import deque
import copy

import numpy as np

import gymnasium as gym
from gymnasium.wrappers import RescaleAction
import matplotlib.pyplot as plt

In [None]:
warnings.filterwarnings(action = 'ignore', category = UserWarning)

#### **Device Setup**

In [None]:
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device working on: {device}')

#### **Env Setup**

In [None]:
env = gym.make('Humanoid-v5', max_episode_steps = 1000)

env = RescaleAction(env, min_action = -1.0, max_action = 1.0)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = env.action_space.high[0]

print(f'State dim: {state_dim}\nAction dim: {action_dim} | Max Action range: {max_action}')

## **Architecture**

In [None]:
class Feature_extractor(nn.Module):
    
    def __init__(self, input_dim, output_dim, h):
        super(Feature_extractor, self).__init__()
        
        self.feature = nn.Sequential(
            
            nn.Linear(input_dim, h),
            nn.SiLU(),
            
            nn.LayerNorm(h),
            nn.Linear(h,h),
            nn.SiLU(),
            
            nn.Linear(h,h),
            nn.SiLU(),
            
            
            nn.Linear(h,h),
            nn.SiLU(),
            
            nn.LayerNorm(h),
            nn.Linear(h,h),
            nn.SiLU(),
            
            
            nn.Linear(h,h),
            nn.SiLU(),
            
            nn.LayerNorm(h),
            nn.Linear(h,h),
            nn.SiLU(),
            
            nn.Linear(h,output_dim),
            nn.SiLU()
        )
        
    def forward(self, x):
        
            x = self.feature(x)
            
            return x

##### **Actor Design**

In [None]:
class Actor_network(nn.Module):
    
    def __init__(self, state_dim, action_dim, n1, n2, n3, n4, h1, max_action = max_action):
        super(Actor_network, self).__init__()
        
        self.max_action = max_action
        
        # Send to Feature Extractor
        
        self.feature = Feature_extractor(input_dim = state_dim, output_dim = n1, h = h1)
        
        # Introduction of MHA
        
        self.norm = nn.LayerNorm(n1)
        self.mha = nn.MultiheadAttention(embed_dim = n1, num_heads = 8, batch_first = True)
        
        # Actor network
        
        self.actor = nn.Sequential(
            
            nn.Linear(n1, n2),
            nn.SiLU(),
            
            nn.LayerNorm(n2),
            nn.Linear(n2, n3),
            nn.SiLU(),
            
            nn.Linear(n3, n4)
        )
        
        # Mean
        
        self.mu = nn.Linear(n4, action_dim)
        
        # Log Std
        
        self.log_std = nn.Linear(n4, action_dim)
        
    def forward(self, state):
        
        # Pass to feature
        
        feature = self.feature(state)
        
        # pass to norm
        
        x = self.norm(feature)
        
        # pass to MHA
        x = x.unsqueeze(1)
        x_2, _  = self.mha(x, x, x)
        x_2 = x_2.squeeze(1)
        
        # pass to actor
        
        x_3 = self.actor(x_2)
        
        # Pass to mu
        
        mu = self.mu(x_3)
        
        # pass to log std
        
        log_std = self.log_std(x_3)
        
        # smooth scaling
        
        mu = torch.tanh_(mu)                # range[-1.0, 1.0]
        log_std = torch.tanh_(log_std)      # range[-1.0, 1.0]
        log_std = log_std.clamp(min = -5, max = 2)
        std = torch.exp(log_std)
        
        # Reparameterization trick
        
        normal = torch.distributions.Normal(mu, std)
        z = normal.rsample()
        tanh_z = torch.tanh_(z)
        log_prob = normal.log_prob(z)
        action = tanh_z * self.max_action
        
        # Squashing
        squash = 2 * (torch.log(torch.tensor(2.0, device=z.device)) - z - F.softplus(-2 * z))
        log_prob = log_prob - squash
        log_prob = torch.sum(log_prob, dim = 1, keepdim = True)
        
        return action, log_prob

##### **Critic Design**

In [None]:
class Critic_network(nn.Module):
    
    def __init__(self, state_dim, action_dim, n1, n2, n3, n4, h1):
        super().__init__()
        
        # Pass to feature
        
        self.feature = Feature_extractor(state_dim + action_dim, n1, h1)
        
        # pass to mha
        
        self.norm = nn.LayerNorm(n1)
        self.mha = nn.MultiheadAttention(n1, 8, batch_first = True)
        
        # pass to critic
        
        # critic 1
        
        self.critic_1 = nn.Sequential(
            
            nn.Linear(n1, n2),
            nn.SiLU(),
            
            nn.LayerNorm(n2),
            nn.Linear(n2, n3),
            nn.SiLU(),
            
            nn.Linear(n3, n4),
            nn.SiLU(),
            
            nn.Linear(n4, 1)
        )
        
        # critic 2
        
        self.critic_2 = nn.Sequential(
            
            nn.Linear(n1, n2),
            nn.SiLU(),
            
            nn.LayerNorm(n2),
            nn.Linear(n2, n3),
            nn.SiLU(),
            
            nn.Linear(n3, n4),
            nn.SiLU(),
            
            nn.Linear(n4, 1)
        )
        
    def forward(self, state, action):
        
        # shape check
        
        if state.dim() == 3:
            state = state.squeeze(-1)
            
        if action.dim() == 3:
            action = action.squeeze(-1)
            
        x = torch.cat([state, action], dim = -1)
        
        # Pass to feature
        
        feature = self.feature(x)
        
        # Pass to MHA
        
        x_2 = self.norm(feature)
        x_2 = x_2.unsqueeze(1)
        x_3, _ = self.mha(x_2, x_2, x_2)
        x_3 = x_3.squeeze(1)
        
        # Pass to critic
        
        q_1 = self.critic_1(x_3)
        q_2 = self.critic_2(x_3)
        
        return q_1, q_2

#### **Set up**

In [None]:
# Assembly of neurons

n1 = 128
h1 = 128
n2 = 256
n3 = 512
n4 = 256

# Initialize

actor_network = Actor_network(state_dim, action_dim, n1, n2, n3, n4, h1).to(device)

print(actor_network)

print('-------------------------------------------------------------------------------')

critic_network = Critic_network(state_dim, action_dim, n1, n2, n3, n4, h1).to(device)

print(critic_network)


# target network
target_critic = copy.deepcopy(critic_network).to(device)

### **Soft Update**

In [None]:
def soft_update(source, target, tau = 0.067):
    
    for param , target_param in zip(source.parameters(), target.parameters()):
        
        target_param.data.copy_(tau * param.data + (1-tau) * target_param.data)

## **Agent Setup**

In [None]:
class SAC_Agent:
    
    def __init__(self, actor_network, critic_network, target_critic, actor_optimizer, actor_scheduler, critic_optimizer, critic_scheduler, gamma):
        
        # Define networks
        
        self.actor = actor_network
        self.critic = critic_network
        self.target_critic = target_critic
        
        # Define hyperparams
        
        self.gamma = gamma
        
        # Define optimizer and scheduler
        
        self.actor_optimizer = actor_optimizer
        self.actor_scheduler = actor_scheduler
        self.critic_optimizer = critic_optimizer
        self.critic_scheduler = critic_scheduler
        
        # Auto tuning
        
        self.target_entropy = - action_dim * 1.5
        self.log_alpha = torch.tensor(np.log(0.2), requires_grad = True, device = device)
        self.alpha_min = 0.1
        self.alpha_optimizer = optim.AdamW([self.log_alpha], lr = 1e-4, weight_decay = 0.001)
        
    def compute_alpha(self):
        
        self.alpha = self.log_alpha.exp().detach()
        self.alpha = self.alpha.clamp(min = self.alpha_min, max = 0.2)
        
        return self.alpha
    
    def update_target(self):
        
        return soft_update(self.critic, self.target_critic)
    
    def select_action(self, state):
        
        state = torch.tensor(state, dtype = torch.float32).to(device)
        
        if state.dim() == 1:
            state = state.unsqueeze(0)
        
        action, log_prob = self.actor(state)
        
        return action, log_prob
        
    def update(self, replay_buffer, batch_size):
        
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)
               
        # Move to device
        
        state = state.to(device)
        action = action.to(device)
        reward = reward.to(device)
        next_state = next_state.to(device)
        done = done.to(device)
        
        # compute target q value
        
        with torch.no_grad():
            
            next_action, _ = self.select_action(next_state)
            
            if next_action.dim() == 1:
                next_action.unsqueeze(0)
            
            target_1, target_2 = self.target_critic(next_state, next_action)
            target_Q = ((0.75 * torch.min(target_1, target_2)) + (0.25 * torch.max(target_1, target_2)))
            target_value = reward + self.gamma * (1 - done) * target_Q
            
            
        # Compute current q value
        
        current_q1, current_q2 = self.critic(state, action)
        critic_loss_1 = F.mse_loss(current_q1, target_value)
        critic_loss_2 = F.mse_loss(current_q2, target_value)
        
        critic_loss = critic_loss_1 + critic_loss_2
        
        # Update critic
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm = 0.5)
        self.critic_optimizer.step() 
        self.critic_scheduler.step()
        
        
        # compute actor loss
        
        new_action, new_log_prob = self.actor(state)
        old_log_prob = new_log_prob.detach()
        
        q_1, q_2 = self.critic(state, new_action)
        q_pi = ((0.75 * torch.min(q_1, q_2)) + (0.25 * torch.max(q_1, q_2)))
        
        #Kl_Div =  torch.sum(torch.exp(old_log_prob) * (old_log_prob - new_log_prob), dim = -1).mean()
        
        actor_loss = ((self.compute_alpha() * new_log_prob - q_pi)).mean()

        # Update Actor network
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm = 0.5)
        self.actor_optimizer.step()
        self.actor_scheduler.step()
        
        
        self.update_target()
        
        # Alpha optimizer
        
        alpha_loss = - (self.log_alpha * (old_log_prob + self.target_entropy)).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm = 1.0)
        self.alpha_optimizer.step()
        
        return actor_loss.item(), critic_loss.item(), self.alpha.item()

#### **Tensor Conversion**

In [None]:
def safe_tensor(x):
    
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32).to(device)

### **Replay Buffer**

In [None]:
class SAC_Buffer:
    
    def __init__(self, capacity, batch_size):
        
        self.capacity = capacity
        self.batch_size = batch_size
        self.sac_buffer = deque(maxlen = capacity)
        self.ptr = 0
        
    def add(self, state,  action, reward, next_state, done):
        
        # Converting to tensor
        
        state = safe_tensor(state)
        action = safe_tensor(action)
        reward = safe_tensor(reward)
        next_state = safe_tensor(next_state)
        done = safe_tensor(done)
        
        experience = (state, action, reward, next_state, done)
        
        if len(self.sac_buffer) < self.capacity:
            
            self.sac_buffer.append(experience)
            
        else:
            
            self.sac_buffer[self.ptr] = experience
            self.ptr = (1 + self.ptr) % self.capacity
            
    def __len__(self):
        return len(self.sac_buffer)
    
    def sample(self, batch_size):
        
        indices = np.random.choice(len(self.sac_buffer), batch_size)
        
        states, actions, rewards, next_states, dones = zip(*[self.sac_buffer[ind] for ind in indices])
        
        # convert to tensor
        
        states = torch.stack(states).to(device)
        actions = torch.stack(actions).to(device)
        rewards = torch.stack(rewards).to(device)
        next_states = torch.stack(next_states).to(device)
        dones = torch.stack(dones).to(device)
        
        return (states, actions, rewards, next_states, dones)

#### **Set up**

In [None]:
sac_buffer = SAC_Buffer(capacity = 500_000, batch_size = 512)

### **Hyper Params**

In [None]:
# Hyper params

gamma = 0.997
max_iter = 10_000

# Actor params

actor_optimizer = optim.AdamW(actor_network.parameters(), lr = 3e-4, weight_decay = 0.001)
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(actor_optimizer, T_max = max_iter)

# Critic params

critic_optimizer = optim.AdamW(critic_network.parameters(), lr = 1e-4, weight_decay = 0.001)
critic_scheduler = optim.lr_scheduler.CosineAnnealingLR(critic_optimizer, T_max = max_iter)


# SAC agent
agent = SAC_Agent(actor_network, critic_network, target_critic, actor_optimizer, actor_scheduler, critic_optimizer, critic_scheduler, gamma)