In [2]:
import gymnasium
import numpy as np
from gymnasium import spaces
import matplotlib.pyplot as plt
from utils import visualise_pricing_strategy, visualise_episode_rewards,visualise_demand_data, external_demand_function
from Pricing_Environment import demand_calculator, action_strategy,pricing_env

import random
random.seed(42)
np.random.seed(42)

product_config = {
    "min_price": 10,
    "max_price": 100,
    "initial_demand": 0.5,
}

demand_calculator_config = {
    "price_probability_ranges": {
        (0, 50): 0.8,   # 80% demand probability for prices between $0 and $50
        (51, 100): 0.6, # 60% demand probability for prices between $51 and $100
        # Add more ranges a
        # nd probabilities as needed
    },
    'low':0,
    'high':100,
    'seasonality':True
}

action_strategy_config = {
    "action_probabilities": {
        0: 0.1,  # Decrease price significantly
        1: 0.2,  # Decrease price slightly
        2: 0.4,  # Keep price
        3: 0.2,  # Increase price slightly
        4: 0.1,  # Increase price significantly
    },
    "price_change_map": {
        0: -10,  # Decrease significantly
        1: -5,   # Decrease slightly
        2: 0,    # Keep price
        3: 5,    # Increase slightly
        4: 10    # Increase significantly
    }
}



In [3]:
import torch

# Check if CUDA (GPU support) is available and choose accordingly
device = torch.device("cpu")


import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal


In [4]:


class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.mean = nn.Linear(256, action_dim)
        self.log_std = nn.Linear(256, action_dim)
        self.max_action = max_action

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = self.log_std(x)
        std = torch.exp(log_std)  # Standard deviation must be positive
        return mean, std

    def sample(self, state):
        mean, std = self(state)
        normal = torch.distributions.Normal(mean, std)
        z = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        action = torch.tanh(z) * self.max_action
        return action

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
    
    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        return self.network(sa)

class SACAgent:
    def __init__(self, state_dim, action_dim, max_action, device):
        self.device = device
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.critic1 = Critic(state_dim, action_dim).to(device)
        self.critic2 = Critic(state_dim, action_dim).to(device)
        self.critic1_target = Critic(state_dim, action_dim).to(device)
        self.critic2_target = Critic(state_dim, action_dim).to(device)
        self.critic1_target.load_state_dict(self.critic1.state_dict())
        self.critic2_target.load_state_dict(self.critic2.state_dict())

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=3e-4)
        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=3e-4)
        self.discount = 0.99
        self.tau = 0.005
        self.policy_delay = 2

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        action = self.actor.sample(state)
        return action.cpu().data.numpy().flatten()

    def train(self, replay_buffer, batch_size=256):
        for it in range(batch_size):
            # Sample a batch of transitions from the replay buffer
            state, action, next_state, reward, done = replay_buffer.sample(batch_size)
            state = torch.FloatTensor(state).to(self.device)
            action = torch.FloatTensor(action).to(self.device)
            next_state = torch.FloatTensor(next_state).to(self.device)
            reward = torch.FloatTensor(reward).to(self.device)
            done = torch.FloatTensor(done).to(self.device)

            # Compute the target Q value
            with torch.no_grad():
                next_action = self.actor.sample(next_state)
                target_Q1 = self.critic1_target(next_state, next_action)
                target_Q2 = self.critic2_target(next_state, next_action)
                target_Q = torch.min(target_Q1, target_Q2)
                target_Q = reward + ((1 - done) * self.discount * target_Q)

            # Get current Q estimates
            current_Q1 = self.critic1(state, action)
            current_Q2 = self.critic2(state, action)

            # Compute critic loss
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

            # Optimize the critic
            self.critic1_optimizer.zero_grad()
            self.critic2_optimizer.zero_grad()
            critic_loss.backward()
            self.critic1_optimizer.step()
            self.critic2_optimizer.step()

            # Delayed policy updates
            if it % self.policy_delay == 0:
                # Compute actor loss
                actor_loss = -self.critic1(state, self.actor.sample(state)).mean()

                # Optimize the actor
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()

                # Soft update the target networks
                for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

                for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


In [5]:
# Assuming the environment is already imported and initialized
env = pricing_env.PricingEnvironment(render_mode="text", is_continuous=True, product_config=product_config, demand_calculator_config=demand_calculator_config, action_strategy_config=action_strategy_config)


In [6]:

# Initialize the SAC Agent
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0] if env.is_continuous else env.action_space.n
max_action = env.action_space.high[0] if env.is_continuous else 1


In [7]:
SACAgent(state_dim=3, action_dim=1, max_action=1,device=device)

: 