In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from collections import deque
import random

class ReplayBuffer:
    """Experience replay buffer for SAC - identical structure to DDPG"""
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states), np.array(actions), np.array(rewards),
                np.array(next_states), np.array(dones))
    
    def __len__(self):
        return len(self.buffer)

class SoftQNetwork(nn.Module):
    """Soft Q-Network for SAC
    Source: SAC paper (Haarnoja et al., 2018) Section 4.2
    Architecture: Dynamic Pricing paper specifies 3 layers with 128 nodes"""
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(SoftQNetwork, self).__init__()
        # Dynamic Pricing paper: "three chained layers with 128 nodes"
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 1)
    
    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)

class PolicyNetwork(nn.Module):
    """Stochastic Actor (Policy) Network for SAC
    Source: SAC paper Section 4.2, Equation 11-13
    Outputs mean and log_std for Gaussian policy with tanh squashing"""
    def __init__(self, state_dim, action_dim, hidden_dim=128, log_std_min=-20, log_std_max=2):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)
        
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        
    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        
        mean = self.mean(x)
        log_std = self.log_std(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        
        return mean, log_std
    
    def sample(self, state, epsilon=1e-6):
        """Reparameterization trick for sampling actions
        Source: SAC paper Equation 11 and Appendix C"""
        mean, log_std = self.forward(state)
        std = log_std.exp()
        
        normal = torch.distributions.Normal(mean, std)
        z = normal.rsample()  # Reparameterization trick
        action = torch.tanh(z)
        
        # Log probability computation with tanh correction (Appendix C, Eq 21)
        log_prob = normal.log_prob(z)
        log_prob -= torch.log(1 - action.pow(2) + epsilon)
        log_prob = log_prob.sum(1, keepdim=True)
        
        return action, log_prob, torch.tanh(mean)

class SACAgent:
    """Soft Actor-Critic Agent for pricing competition
    Based on: Haarnoja et al. (2018) "Soft Actor-Critic"
    Hyperparameters from: Dynamic Pricing paper Section 4 & SAC paper Table 1"""
    
    def __init__(
        self,
        agent_id,
        state_dim,
        action_dim,
        hidden_dim=128,  # Dynamic Pricing paper: "128 nodes"
        actor_lr=3e-4,   # SAC paper Table 1
        critic_lr=3e-4,  # SAC paper Table 1
        alpha_lr=3e-4,   # For automatic entropy tuning
        gamma=0.99,      # SAC paper Table 1
        tau=0.005,       # SAC paper Table 1: "target smoothing coefficient"
        alpha=0.2,       # Initial entropy coefficient
        automatic_entropy_tuning=True,
        buffer_size=1000000,  # SAC paper: 10^6
        batch_size=256,       # SAC paper Table 1
        seed=None,
        price_min=0.0,
        price_max=2.0
    ):
        self.agent_id = agent_id
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.price_min = price_min
        self.price_max = price_max
        self.automatic_entropy_tuning = automatic_entropy_tuning
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
        
        # Actor network (Policy π_φ)
        self.actor = PolicyNetwork(state_dim, action_dim, hidden_dim).to(self.device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        
        # Two Q-networks (SAC uses two Q-functions to mitigate positive bias)
        # Source: SAC paper Section 4.2, following Fujimoto et al. (2018)
        self.critic_1 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(self.device)
        self.critic_2 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(self.device)
        
        self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr)
        self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr)
        
        # Target Q-networks (for stable learning)
        self.critic_1_target = SoftQNetwork(state_dim, action_dim, hidden_dim).to(self.device)
        self.critic_2_target = SoftQNetwork(state_dim, action_dim, hidden_dim).to(self.device)
        
        self.critic_1_target.load_state_dict(self.critic_1.state_dict())
        self.critic_2_target.load_state_dict(self.critic_2.state_dict())
        
        # Entropy coefficient (alpha) - can be learned automatically
        if automatic_entropy_tuning:
            self.target_entropy = -torch.prod(torch.Tensor([action_dim]).to(self.device)).item()
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
            self.alpha = self.log_alpha.exp().item()
        else:
            self.alpha = alpha
        
        # Replay buffer
        self.replay_buffer = ReplayBuffer(buffer_size)
        
    def select_action(self, state, evaluate=False):
        """Select action using the current policy
        During training: sample from the stochastic policy
        During evaluation: use the mean action (deterministic)"""
        
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        
        if evaluate:
            # Deterministic action (mean of the policy)
            _, _, action = self.actor.sample(state)
        else:
            # Stochastic action (sample from distribution)
            action, _, _ = self.actor.sample(state)
        
        # Convert from [-1, 1] to [price_min, price_max]
        action = action.cpu().data.numpy().flatten()
        scaled_price = self.price_min + (self.price_max - self.price_min) * (action[0] + 1) / 2
        
        return scaled_price, action
    
    def remember(self, state, action, reward, next_state, done):
        """Store experience in replay buffer"""
        self.replay_buffer.push(state, action, reward, next_state, done)
    
    def update_parameters(self):
        """Update actor and critic networks
        Source: SAC paper Algorithm 1"""
        
        if len(self.replay_buffer) < self.batch_size:
            return
        
        # Sample batch from replay buffer
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
        
        with torch.no_grad():
            # Sample actions from current policy for next states
            next_actions, next_log_probs, _ = self.actor.sample(next_states)
            
            # Compute target Q-values (minimum of two Q-networks for stability)
            target_q1 = self.critic_1_target(next_states, next_actions)
            target_q2 = self.critic_2_target(next_states, next_actions)
            target_q = torch.min(target_q1, target_q2) - self.alpha * next_log_probs
            
            # Bellman backup: r + γ * (1 - done) * V(s')
            target_q = rewards + (1 - dones) * self.gamma * target_q
        
        # Update critics
        current_q1 = self.critic_1(states, actions)
        current_q2 = self.critic_2(states, actions)
        
        critic_1_loss = F.mse_loss(current_q1, target_q)
        critic_2_loss = F.mse_loss(current_q2, target_q)
        
        self.critic_1_optimizer.zero_grad()
        critic_1_loss.backward()
        self.critic_1_optimizer.step()
        
        self.critic_2_optimizer.zero_grad()
        critic_2_loss.backward()
        self.critic_2_optimizer.step()
        
        # Update actor
        new_actions, log_probs, _ = self.actor.sample(states)
        q1_new = self.critic_1(states, new_actions)
        q2_new = self.critic_2(states, new_actions)
        q_new = torch.min(q1_new, q2_new)
        
        # Actor loss: maximize expected return while maximizing entropy
        # J_π = E[α * log π(a|s) - Q(s,a)]
        actor_loss = (self.alpha * log_probs - q_new).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # Update entropy coefficient (if automatic tuning is enabled)
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
            
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            
            self.alpha = self.log_alpha.exp().item()
        
        # Soft update target networks
        self._soft_update(self.critic_1, self.critic_1_target)
        self._soft_update(self.critic_2, self.critic_2_target)
    
    def _soft_update(self, local_model, target_model):
        """Soft update: θ_target = τ*θ_local + (1 - τ)*θ_target"""
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)