<a href="https://colab.research.google.com/github/atharvanaik10/CryptoSAC/blob/main/soft_actor_critic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Soft Actor Critic RL for discrete action spaces and continous state spaces

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pickle
from google.colab import drive
from collections import namedtuple, deque
import random
from scipy import signal

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.modules.activation import ReLU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Experience Replay Memory Buffer

In [None]:
Transition = namedtuple('Transition', 
                        ('state', 'action', 'reward', 'next_state', 'done'))

class ReplayMemory(object):
  def __init__(self, capacity):
    self.capacity = capacity
    self.memory = deque([], maxlen=capacity)

  def push(self, *args):
    self.memory.append(Transition(*args))

  def sample(self, batch_size):
    return random.sample(self.memory, batch_size)

  def __len__(self):
    return len(self.memory)

## Critic Network
With action space $A$, the critic is a function approximator defined by Q space
$$ Q : \mathcal{S} \rightarrow \mathbb{R}^{|A|} $$

In [None]:
class Critic(nn.Module):
  def __init__(self, state_dims, action_dims, learning_rate, 
               layer1_size, layer2_size):
    super(Critic, self).__init__()

    self.state_dims = state_dims
    self.action_dims = action_dims

    self.model = nn.Sequential(nn.Linear(state_dims, layer1_size),
                               nn.ReLU(),
                               nn.Linear(layer1_size, layer2_size),
                               nn.ReLU(),
                               nn.Linear(layer2_size, 1))

    # self.layer1 = nn.Linear(self.state_dims + action_dims, layer1_size)
    # self.layer2 = nn.Linear(layer1_size, layer2_size)
    # self.outlayer = nn.Linear(layer2_size, 1)

    self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
    self.to(device)
    
  def forward(self, state, action):
    # Calculate the quality of the state and action Q(s,a)
    quality = self.model(state)
    return quality

## Actor Network
With action space $A$, the actor is a function approximator defined by pi space
$$ \pi : \mathcal{S} \rightarrow [0,1]^{|A|} $$

In [None]:
class Actor(nn.Module):
    def __init__(self, state_dims, action_dims, learning_rate, 
                 layer1_size, layer2_size):
        super(Actor, self).__init__()

        self.state_dims = state_dims
        self.action_dims = action_dims
    
        self.model = nn.Sequential(nn.Linear(state_dims, layer1_size),
                                   nn.ReLU(),
                                   nn.Linear(layer1_size, layer2_size),
                                   nn.ReLU(),
                                   nn.Linear(layer2_size, action_dims),
                                   nn.Softmax(dim=-1))

        # self.layer1 = nn.Linear(self.state_dims, layer1_size)
        # self.layer2 = nn.Linear(layer1_size, layer2_size)
        # self.outlayer = nn.Linear(layer2_size, action_dims)

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.to(device)

    def forward(self, state):
        # Calculate the policy probablity pi_a(s)
        policy = self.model(state)
        return policy

## Agent


In [None]:
class Agent():
    # alpha = entropy temperature factor
    # gamma = discount factor
    # tau = soft update interpolation factor
    def __init__(self, env, alpha=1, gamma=0.99, learning_rate=10e-4, 
                 tau=0.01, buffer_capacity=1000000, batch_size=100):
        self.env = env
        self.state_dims = 5
        self.action_dims = 3

        self.alpha = alpha
        self.gamma = gamma
        self.tau = tau
        self.learning_rate = learning rate

        self.buffer_capacity = buffer_capacity
        self.batch_size = batch_size

        # initialize replay buffer
        self.replay_buffer = ReplayMemory(self.buffer_capacity)

        # initialise 2x critics (Q-local) and 2x target critics (Q)
        self.local_critic1 = Critic(self.state_dims, self.action_dims, 
                                    self.learning_rate, 256, 256)
        self.local_critic2 = Critic(self.state_dims, self.action_dims, 
                                    self.learning_rate, 256, 256)
        
        self.target_critic1 = Critic(self.state_dims, self.action_dims, 
                                    self.learning_rate, 256, 256)
        self.target_critic2 = Critic(self.state_dims, self.action_dims, 
                                    self.learning_rate, 256, 256)
        
        # TODO soft = hard update target and local
        self.soft_update(1)

        # initialize actor (pi)
        self.actor = Actor(self.state_dims, self.action_dims, 
                           self.learning_rate, 256, 256)
    
        self.target_entropy = 0.98 * -np.log(1/action_dims)

        # initalize alpha learning network
        self.log_alpha = torch.tensor(np.log(self.alpha), requires_grad=True)
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], 
                                                lr=self.learning_rate)
    
        
    # Actor helper methods (all inputs are pytorch tensors)
    def get_next_action(self, state, on_policy=False):
        policy, _ = self.get_action_probs(state)
        policy = policy.squeeze(0).detach().numpy()
        if on_policy:
            return np.argmax(policy)
        else:
            return np.random.choice(range(self.action_dims, p=policy))

    def get_action_probs(self, state):
        action_probs = self.actor.forward(state)
        log_action_probs = torch.log(action_probs + 1e-8 if action_probs == 0.0 
                            else torch.log(action_probs))    
        
        return action_probs, log_action_probs

    def get_actor_loss(self, state):
        action_probs, log_action_probs = self.get_action_probs(state)
        
        q_locals = get_quality(state)

        actor_loss = (action_probs * 
                    (self.alpha * 
                     (log_action_probs - q_locals))).sum(dim=1).mean()
        
        return actor_loss, log_action_probs

    # Alpha helper
    def get_log_alpha_loss(self, log_action_probs):
        return -(self.log_alpha * 
                (log_action_probs + self.target_entropy).detach()).mean()

    # Critic helper methods (all inputs are pytorch tensors)
    def get_quality(self, state):
        return torch.min(self.local_critic1(state), self.local_critic2(state))

    def get_target_quality(self, state):
        return torch.min(self.target_critic1(state), self.target_critic2(state))

    def get_critic_loss(self, state, action, reward, next_state, done):
        with torch.no_grad():
            action_probs, log_action_probs = self.get_action_probs(next_state)

            q_next_target = self.get_target_quality(next_state)

            soft_state_values = (action_probs * 
                                (q_next_target - 
                                (self.alpha * log_action_probs))).sum(dim=1)
                
            q_next = reward + ~done * self.gamma * soft_state_values

        q_soft1 = self.local_critic1(state).gather(1, action.unsqueeze(-1)).squeezed(-1)
        q_soft2 = self.local_critic1(state).gather(1, action.unsqueeze(-1)).squeezed(-1)

        critic1_loss = (nn.MSELoss(reduction="none")(q_soft1, q_next)).mean()
        critic2_loss = (nn.MSELoss(reduction="none")(q_soft2, q_next)).mean()

        return critic1_loss, critic2_loss

    def soft_update(self, tau):
        for target_param, local_param in zip(target_critic1.parameters(), 
                                            local_critic1.parameters()):
            target_param.data.copy_(tau * local_param.data + 
                                    (1-tau) * target_param.data)
        
        for target_param, local_param in zip(target_critic2.parameters(), 
                                            local_critic2.parameters()):
            target_param.data.copy_(tau * local_param.data + 
                                    (1-tau) * target_param.data)

    # train network on one state
    def train(self, state, action, reward, next_state, done):
        self.local_critic1.optimizer.zero_grad()
        self.local_critic2.optimizer.zero_grad()

        self.actor.optimizer.zero_grad()

        self.alpha_optimizer.zero_grad()

        self.replay_buffer.push(state, action, reward, next_state, done)

        if len(self.replay_buffer) >= self.batch_size:
            batch = self.replay_buffer.sample(self.batch_size)
            batch = list(map(list, zip(*batch)))

            state_tensor = torch.tensor(np.array(batch[0]))
            action_tensor = torch.tensor(np.array(batch[1]))
            reward_tensor = torch.tensor(np.array(batch[2])).float()
            next_state_tensor = torch.tensor(np.array(batch[3]))
            done_tensor = torch.tensor(np.array(batch[4]))

            critic1_loss, critic2_loss = self.get_critic_loss(state_tensor,
                                                        action_tensor,
                                                        reward_tensor,
                                                        next_state_tensor,
                                                        done_tensor)
            
            critic1_loss.backward()
            critic2_loss.backward()
            self.local_critic1.optimizer.step()
            self.local_critic2.optimizer.step()

            actor_loss, log_action_probs = self.get_actor_loss(state_tensor)
            actor_loss.backward()
            self.actor.optimizer.step()

            log_alpha_loss = self.get_log_alpha_loss(log_action_probs)
            log_alpha_loss.backward()
            self.log_alpha_optimizer.step()
            self.alpha = self.log_alpha.exp()

            self.soft_update(self.tau)

## Environment Setup

We create a custom environment fetching data from the Coinbase Exchange API. 

We define the state space as

$$ \mathcal{S}: \mathbb{R}^5$$

Where state $s_t$ = `[current holding, value(t), value(t-1), value(t-2), value(t-3)]`

We define the action space as discrete
$$ \mathcal{A}: \{-1,0,1\} $$

Where $a_t = -1 \implies$ sell all, $a_t = 0 \implies$ hold all, $a_t = 1 \implies$ buy all.

In the future, this action space can be modified from discrete to continuous [-1,1] for more precise calls.

