## SWARM Framework for Supply Chain Management with communication and 2 step decision making

In [1]:
import os
import re
import sys
import time
import numpy as np

from typing import List
from tqdm.notebook import tqdm

# CHANGED: Possibly import new config or reuse existing
from config import env_configs
np.random.seed(42)


Variable demand for t=0: 4


In [2]:
for name, config in env_configs.items():
    # Evaluate demand function at a specific time, say t=0.
    demand_value = config['demand_fn'](0)
    print(f"{name} demand for t=0: {demand_value}")

constant_demand demand for t=0: 4
variable_demand demand for t=0: 3
larger_demand demand for t=0: 7
seasonal_demand demand for t=0: 4
normal_demand demand for t=0: 1
increasing_demand demand for t=0: 5
cyclical_demand demand for t=0: 5
demand_shock demand for t=0: 5
stochastic_demand demand for t=0: 0


## Initializing the environement

In [3]:
os.environ["OPENAI_API_KEY"] = "YOUR_OPENAI_KEY"

In [4]:
# setting up the communication with the API 
import requests

headers = {
    "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}",
    "Content-Type": "application/json"
}


In [5]:
# 1) Build environment
from env import  DecentralizedInventoryEnvWithComm
from config import env_configs
env_config_name = "seasonal_demand"
env_config = env_configs[env_config_name]
print(env_config)  


im_env =  DecentralizedInventoryEnvWithComm(**env_config)
im_env.config = env_config
print(im_env)  

{'num_stages': 4, 'num_periods': 12, 'init_inventories': [12, 12, 12, 12], 'lead_times': [2, 2, 2, 2], 'demand_fn': <function <lambda> at 0x0000021EF952D8A0>, 'prod_capacities': [20, 20, 20, 20], 'sale_prices': [5, 5, 5, 5], 'order_costs': [5, 5, 5, 5], 'backlog_costs': [1, 1, 1, 1], 'holding_costs': [1, 1, 1, 1], 'stage_names': ['retailer', 'wholesaler', 'distributor', 'manufacturer'], 'comm_size': 4}
<DecentralizedInventoryEnvWithComm instance>


## Function definition 

In [6]:
def parse_observation(obs, stage_index, max_lead_time, num_stages, comm_size):
    """
    Converting a single row from the environment's observation matrix into a user-friendly
    text description for that stage.

    The observation for each stage is structured as follows:
      - Base state: first 9 + 2 * max_lead_time elements, where:
          [0]   : production capacity
          [1]   : sale price
          [2]   : order cost
          [3]   : backlog cost
          [4]   : holding cost
          [5]   : stage-specific lead time (but base state is allocated with max_lead_time)
          [6]   : inventory
          [7]   : backlog
          [8]   : upstream backlog (backlog of the next stage; for the last stage this might be unused)
          [9 : 9+max_lead_time]       : sales (using fixed max lead time slicing)
          [9+max_lead_time : 9+2*max_lead_time] : deliveries
      - Communication messages: the remaining elements (a flattened vector of shape (num_stages * comm_size,))

    Parameters:
      obs (np.array): The observation, which could be a single row or a 2D array (multi-stage).
      stage_index (int): The stage index to parse (only used if obs is multi-stage).
      max_lead_time (int): The maximum lead time used in the environment.
      num_stages (int): The total number of stages.
      comm_size (int): The size of the communication vector for each stage.
      
    Returns:
      str: A string summarizing the parsed observation for the given stage.
    """
    # Handle both single-stage (1D) and multi-stage (2D) observations.
    if obs.ndim == 1:
        row = obs
    else:
        row = obs[stage_index]
    
    # Computing the length of the base state from max lead time.
    base_dim = 9 + 2 * max_lead_time
    base = row[:base_dim]
    communications = row[base_dim:]
    
    # Extracting base state values.
    prod_capacity = base[0]
    sale_price = base[1]
    order_cost = base[2]
    backlog_cost = base[3]
    holding_cost = base[4]
    lead_time = int(base[5])
    inventory = base[6]
    backlog = base[7]
    next_stage_backlog = base[8]
    
    # Dynamically extracting only the relevant number of past sales & deliveries based on lead_time
    if lead_time > 0:
        sales = base[-2 * max_lead_time : -max_lead_time][-lead_time:].tolist()
        deliveries = base[-max_lead_time:][-lead_time:].tolist()
    else:
        sales = []
        deliveries = []
    
    # The communications vector is a flattened array of length num_stages * comm_size.
    communications_list = communications.tolist()
    
    info_str = (
        f"Production Capacity = {prod_capacity}, Sale Price = {sale_price}, Order Cost = {order_cost}, "
        f"Backlog Cost = {backlog_cost}, Holding Cost = {holding_cost}, Stage Lead Time = {lead_time}, "
        f"Inventory = {inventory}, Current Backlog (you owing to the downstream) = {backlog}, "
        f"Upstream Backlog (your upstream owing to you) = {next_stage_backlog}, "
        f"Previous Sales (in the recent period(s), from old to new)={sales}, "
        f"Arriving Deliveries (in this and the next period(s), from near to far) = {deliveries}"
    )
    
    return info_str


In [7]:
def get_demand_description(env_config_name):
    """
    Specifying a description for the demand scenario based on the environment configuration name.
    """
    demand_description = {
        "constant_demand": "The expected demand at the retailer(stage 1) is a constant value of 4 units for all 12 periods.",
        "variable_demand": "The expected demand at the retailer (stage 1) is a discrete uniform distribution U{0, 4} for all 12 periods.",
        "larger_demand": "The expected demand at the retailer (stage 1) is a discrete uniform distribution U{0, 9} for all 12 periods.",
        "seasonal_demand": "For the first four periods, demand at the retailer(stage 1) follows a discrete uniform distribution over {0, 1, 2, 3, 4}, and for the following eight periods, it follows a discrete uniform distribution over {5, 6, 7, 8}.",
        "normal_demand": "The expected demand at the retailer (stage 1) is a normal distribution N(4, 2^2), " \
            "truncated at 0, for all 12 periods.",
        "increasing_demand": "The expected demand at the retailer (stage 1) is a linearly increasing demand by starting with an initial value 5 and growing by 1 unit every period",
        "cyclical_demand": " The expected demand at the retailer (stage 1) is computed as a seasonal sine wave—with a 12-round period, a 5-unit amplitude, and a 5-unit upward shift—whose value is rounded to yield an integer.",
        "demand_shock": "The expected demand at the retailer (stage 1) is normally 5, but it jumps by 8 units to 13 during periods 8 through 10, capturing a temporary demand shock. ",
        "stochastic_demand": "The expected demand at the retailer (stage 1) is an Integer-Valued Autoregressive INAR(1) process with a thinning probability of 0.5, meaning 50% of the previous period's demand carries over. New demand is added as Poisson arrivals with a mean of 2, ensuring the overall demand remains an integer count.",
    }
    return demand_description.get(env_config_name, "Unknown demand configuration.")


In [8]:
def parse_order_and_comm(response_str: str, comm_size: int):
    
    import logging
    logger = logging.getLogger(__name__)

    # 1) Parse order quantity strictly, in case of inadequate response from the LLM. Using regex pattern.
    order_pattern = r"\[Order\s*quantity:\s*(\d+)\]"
    order_match = re.search(order_pattern, response_str, re.IGNORECASE)
    if order_match:
        order_qty = int(order_match.group(1))
    else:
        logger.warning("Order quantity not found. Defaulting to 0.")
        order_qty = 0

   #  2) Parse the communication vector
    comm_pattern = r"\[Comm\s*vector:\s*([^\]]+)\]"
    comm_match = re.search(comm_pattern, response_str, re.IGNORECASE)
    if comm_match:
        values_str = comm_match.group(1)
        try:
           # Parsing the comma-separated values.
            values = [float(x.strip()) for x in values_str.split(",")]
        except ValueError:
            logger.warning("Error parsing communication vector. Defaulting to zeros.")
            values = []

        # Ensure the communication vector has exactly comm_size elements.
        if len(values) < comm_size:
            logger.warning(f"Communication vector has fewer than expected {comm_size} values; padding with zeros.")
            values += [0.0] * (comm_size - len(values))
        elif len(values) > comm_size:
            logger.warning(f"Communication vector has more than expected {comm_size} values; truncating.")
            values = values[:comm_size]

        comm_array = np.array(values, dtype=np.float32)
    else:
        logger.warning("Communication vector not found. Defaulting to zeros.")
        comm_array = np.zeros(comm_size, dtype=np.float32)

    return order_qty, comm_array


## Agent creation and simulation definition

In [9]:
from typing import List, Tuple, Dict, Any
from dataclasses import dataclass
from swarm import Agent, Swarm
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Definition of a dataclass for agent actions
@dataclass
class Action:
    order_quantity: float
    comm_vector: np.ndarray

def create_agents(stage_names: List[str]) -> List[Agent]:
    """
    Creating agents for each stage with instructions tailored to their role.
    """
    agents = []
    num_stages = len(stage_names)
    demand_description = get_demand_description(env_config_name)  # Assume global definition

    for stage, stage_name in enumerate(stage_names):
        instructions = (
            f"You represent stage {stage + 1} ('{stage_name}') in a {num_stages}-stage supply chain.\n\n"
            "You are an expert in inventory management and optimization"
            "Your Objective:\n"
            "Your goal is to minimize the total cost (order, holding, backlog) by making ordering decisions and effectively communicating with neighboring stages.\n"
            f"{demand_description}\n\n"
            "Decision Process per Period:\n"
            "1. Initial Decision: Provide your initial order quantity and initial communication vector before receiving neighbor inputs.\n"
            "2. Updated Decision: After receiving upstream and downstream neighbor communications, revise and provide your final order quantity and final communication vector.\n\n"
            "Your communication vector has four dimensions:\n"
            "- v1: Current inventory level\n"
            "- v2: Order quantity placed this round\n"
            "- v3: Urgency level or risk indicator\n\n"
            "- v4: Current lead time\n\n"
        )
        agent = Agent(
            name=f"{stage_name.capitalize()}_Agent",
            instructions=instructions,
            model="o3-mini"
        )
        agents.append(agent)
    
    return agents


In [10]:
# encapsulating API call within a class method, ensuring that the setup is handled in one place
class Swarm:
    def __init__(self):
        # existing initialization
        pass

    def run(self, agent: Agent, messages: List[Dict[str, str]], temperature: float = 0.7) -> Any:
        data = {
            "model": os.getenv("OPENAI_MODEL", "o3-mini"),
            "messages": messages,
            "temperature": temperature,
            "reasoning effort": high
        }
        # Now make your API call using requests or your preferred HTTP client
        response = requests.post("https://api.openai.com/v1/chat/completions", json=data, headers=your_headers)
        # Process and return the response as needed
        return response.json()  # or your wrapped response object

In [11]:
from swarm import Swarm


# Configure the main logger
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# Reducing verbosity for external modules
logging.getLogger("httpx").setLevel(logging.WARNING)

logger = logging.getLogger(__name__)

def run_simulation_swarm(env_config_name: str, env: Any, stage_agents: List[Agent]) -> float:
    obs, _ = env.reset()
    
    if env_config_name == 'stochastic_demand':
        logger.info("Generating a new stochastic demand series...")
        env.config["demand_fn"].generate_new_series()

    client = Swarm()
    done = False
    total_reward = 0.0
    period = 0

    logger.info(f"Simulation started: {env_config_name}\n{'=' * 80}")

    while not done and period < env.num_periods:
        logger.info(f"--- PERIOD {period + 1} ---")

        # Preliminary Decisions (sequential)
        preliminary_actions: List[Action] = []
        for stage_idx in range(env.num_stages):
            stage_name = env.stage_names[stage_idx].capitalize()
            local_obs = obs[f"stage_{stage_idx}"]

            input_message = (
                f"--- {stage_name} (Stage {stage_idx + 1}) ---\n"
                f"Current period: {period + 1}\n"
                f"Overview: {parse_observation(local_obs, stage_idx, env.max_lead_time, env.num_stages, env.comm_size)}\n\n"
                "Decision Task: Make an initial order decision.\n"
                "Format: [Order quantity: X] [Comm vector: v1, v2, v3, v4]\n"
            )

            response = client.run(
                agent=stage_agents[stage_idx],
                messages=[{"role": "user", "content": input_message}]
            )

            # Log a response summary instead of the full-text
            assistant_message = response.messages[-1]["content"]
            try:
                order_qty, comm_vec = parse_order_and_comm(assistant_message, env.comm_size)
            except Exception as e:
                logger.warning(f"Parsing failed for {stage_name}, defaulting to zero. Error: {e}")
                order_qty, comm_vec = 0, np.zeros(env.comm_size, dtype=np.float32)
            
            preliminary_actions.append(Action(order_quantity=order_qty, comm_vector=comm_vec))
            logger.info(f"{stage_name} preliminary decision: [Order quantity: {order_qty}] [Comm vector: {comm_vec.tolist()}]")

        # Build communication messages after preliminary decisions
        downstream_messages = [None] * env.num_stages
        upstream_messages = [None] * env.num_stages
        for stage_idx in range(env.num_stages):
            stage_name = env.stage_names[stage_idx].capitalize()
            order_qty = preliminary_actions[stage_idx].order_quantity
            comm_vec = preliminary_actions[stage_idx].comm_vector

            if stage_idx < env.num_stages - 1:
                downstream_messages[stage_idx + 1] = f"[From {stage_name}]: Ordered {order_qty} units. Comm vector: {comm_vec.tolist()}"
            if stage_idx > 0:
                upstream_messages[stage_idx - 1] = f"[Notification from {stage_name}]: Order {order_qty} units planned."

        # Updated Decisions (sequential)
        final_actions: List[Action] = []
        for stage_idx in range(env.num_stages):
            stage_name = env.stage_names[stage_idx].capitalize()
            local_obs = obs[f"stage_{stage_idx}"]

            downstream_msg = downstream_messages[stage_idx]
            upstream_msg = upstream_messages[stage_idx]

            input_message = (
                f"--- {stage_name} (Stage {stage_idx + 1}) UPDATE ---\n"
                f"Current period: {period + 1}\n"
                f"Local state: {parse_observation(local_obs, stage_idx, env.max_lead_time, env.num_stages, env.comm_size)}\n\n"
            )
            if downstream_msg:
                input_message += f"Downstream message: {downstream_msg}\n"
            if upstream_msg:
                input_message += f"Upstream message: {upstream_msg}\n"
            input_message += (
                "\nUpdated Decision Task: Adjust your order decision.\n"
                "Format: [Order quantity: X] [Comm vector: v1, v2, v3, v4]\n"
            )

            response = client.run(
                agent=stage_agents[stage_idx],
                messages=[{"role": "user", "content": input_message}]
            )

            assistant_message = response.messages[-1]["content"]
            try:
                order_qty, comm_vec = parse_order_and_comm(assistant_message, env.comm_size)
            except Exception as e:
                logger.warning(f"Parsing failed for {stage_name} on update, defaulting to zero. Error: {e}")
                order_qty, comm_vec = 0, np.zeros(env.comm_size, dtype=np.float32)
            
            final_actions.append(Action(order_quantity=order_qty, comm_vector=comm_vec))
            logger.info(f"{stage_name} updated decision: [Order quantity: {order_qty}] [Comm vector: {comm_vec.tolist()}]")

        # Log final actions and environment update
        action_dict: Dict[str, Tuple[float, np.ndarray]] = {
            f"stage_{m}": (action.order_quantity, action.comm_vector)
            for m, action in enumerate(final_actions)
        }
        next_obs, reward, terminations, truncations, info = env.step(action_dict)
        step_reward = sum(reward.values())
        total_reward += step_reward
        done = terminations["__all__"]

        logger.info(f"Reward: {reward}, Total Reward: {total_reward}")
        logger.info("=" * 80)
        obs = next_obs
        period += 1

    logger.info("Simulation finished.")
    return total_reward


## Running the simulation

In [12]:

num_iterations = 5  # or however many iterations you want
all_rewards = []

for i in tqdm(range(num_iterations)):
    im_env.reset()
    
    # Create agents using the Swarm-based agent creation function.
    stage_agents = create_agents(env_config["stage_names"])
    
    #  Run the simulation 
    total_reward = run_simulation_swarm(env_config_name, im_env, stage_agents)
    
    # Collect the reward from this iteration.
    all_rewards.append(total_reward)
    print(f"Iteration {i+1} finished, total reward: {total_reward}")

# Print out summary statistics.
print("Simulation finished.")
print("All rewards:", all_rewards)
print("Average reward:", np.mean(all_rewards))
print("Standard deviation reward:", np.std(all_rewards))


  0%|          | 0/5 [00:00<?, ?it/s]

INFO:__main__:Simulation started: seasonal_demand
INFO:__main__:--- PERIOD 1 ---
INFO:__main__:Retailer preliminary decision: [Order quantity: 0] [Comm vector: [12.0, 0.0, 0.0, 2.0]]
INFO:__main__:Wholesaler preliminary decision: [Order quantity: 0] [Comm vector: [12.0, 0.0, 0.0, 2.0]]
INFO:__main__:Distributor preliminary decision: [Order quantity: 4] [Comm vector: [12.0, 4.0, 0.0, 2.0]]
INFO:__main__:Manufacturer preliminary decision: [Order quantity: 8] [Comm vector: [12.0, 8.0, 1.0, 2.0]]
INFO:__main__:Retailer updated decision: [Order quantity: 0] [Comm vector: [12.0, 0.0, 0.0, 2.0]]
INFO:__main__:Wholesaler updated decision: [Order quantity: 0] [Comm vector: [12.0, 0.0, 0.0, 2.0]]
INFO:__main__:Distributor updated decision: [Order quantity: 0] [Comm vector: [12.0, 0.0, 0.0, 2.0]]


KeyboardInterrupt: 