In [2]:
import numpy as np
import torch
from networks import ActorNetwork, ValueNetwork, CriticNetwork
from memory_buffer import ReplyBuffer
from sac_agent import SAC_Agent

class InventoryAITraderSAC:
    def __init__(self, sigma, bid_k, bid_a, ask_k, ask_a, dt, memory_size=1000000):
        self.sigma = sigma
        self.bid_k = bid_k
        self.bid_a = bid_a
        self.ask_k = ask_k
        self.ask_a = ask_a
        self.dt = dt

        # Create SAC agent
        self.agent = SAC_Agent(input_dims=[2], max_action=1, alpha=0.003, beta=0.003, gamma=0.99, n_actions=1,
                               memory_size=memory_size, tau=0.005, layer1=256, layer2=256, batch_size=256, reward_scale=2)

        # Set initial values for SAC agent
        self.agent.scale = 1.0  # Adjust reward scaling if needed

    def get_quotes(self, x, q, s, rt, train_mode=False):
        cur_state = self.scale_state(q, rt)

        if train_mode:
            action = self.agent.pick_action([q, rt])
            self.agent.remember(cur_state, action, 0, cur_state, False)
            self.agent.learn()

        action, _ = self.agent.actor.normal_sample(torch.FloatTensor([q, rt]))
        gamma = self.scale_gamma(action.numpy())

        # Pre-computed values
        two_div_gamma = 2 / gamma
        g_ss_rt = gamma * self.sigma**2 * rt

        # Calculate reservation price and spreads
        reservation_price = s - q * g_ss_rt
        spread_bid = (g_ss_rt + two_div_gamma * np.log(1 + gamma / self.bid_k)) / 2
        spread_ask = (g_ss_rt + two_div_gamma * np.log(1 + gamma / self.ask_k)) / 2

        return reservation_price - spread_bid, reservation_price + spread_ask  # Return bid quote, ask quote

    def scale_state(self, q, rt):
        """
        Scales state, assumes max inventory is 100
        """
        return [abs(q) / 100, rt]

    def scale_gamma(self, action):
        """
        Scales action to gamma
        """
        return max(1e-5, ((action + 1) / 2))


# Usage example:
# trader = InventoryAITraderSAC(sigma=..., bid_k=..., bid_a=..., ask_k=..., ask_a=..., dt=...)
# bid, ask = trader.get_quotes(x=..., q=..., s=..., rt=..., train_mode=True)
