In [5]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
import pandas as pd
import torch  # Add this import

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [6]:
class StockAllocationEnv(gym.Env):
    def __init__(self, location_data, article_to_idx, max_steps=50):
        super(StockAllocationEnv, self).__init__()
        self.location_data = location_data
        self.article_to_idx = article_to_idx
        self.idx_to_article = {v: k for k, v in article_to_idx.items()}
        self.num_articles = len(set(location_data["group_article"]))
        
        # Create location ID to index mapping
        self.locations = sorted(set(location_data["location_id"]))
        self.location_to_idx = {loc: idx for idx, loc in enumerate(self.locations)}
        self.idx_to_location = {idx: loc for loc, idx in self.location_to_idx.items()}
        self.num_locations = len(self.locations)
        
        self.max_steps = max_steps

        # Create index mapping for OTS classes
        unique_ots_classes = sorted(set(location_data["ots_main_class"]))
        self.ots_class_to_idx = {ots: idx for idx, ots in enumerate(unique_ots_classes)}
        self.idx_to_ots_class = {v: k for k, v in self.ots_class_to_idx.items()}

        # Extract unique OTS classes and their open_to_ship values per location
        self.ots_classes = {}
        for _, row in location_data.iterrows():
            ots_class = row["ots_main_class"]
            location = row["location_id"]
            key = (self.ots_class_to_idx[ots_class], location)
            if key not in self.ots_classes:
                self.ots_classes[key] = {
                    "open_to_ship": row["open_to_ship"],
                    "remaining_open_to_ship": row["open_to_ship"],
                    "articles": []
                }
            self.ots_classes[key]["articles"].append(self.article_to_idx[row["group_article"]])

        # Extract per-article-location info
        self.article_location_info = {}
        for _, row in location_data.iterrows():
            article_idx = self.article_to_idx[row["group_article"]]
            location = row["location_id"]
            if article_idx not in self.article_location_info:
                self.article_location_info[article_idx] = {
                    "allocatable_stock": row["allocatable_stock"],
                    "remaining_stock": row["allocatable_stock"],
                    "locations": {}
                }
            
            self.article_location_info[article_idx]["locations"][location] = {
                "article_price": row["article_price"],
                "predicted_sales": row["predicted_sales"],
                "predicted_sales_upper_bound": row["predicted_sales_upper_bound"],
                "ots_main_class": self.ots_class_to_idx[row["ots_main_class"]],
                "min_allocation": row["min_allocation"],
                "max_allocation": row["max_allocation"],
                "allocated": 0
            }

        # Define observation space
        self.observation_space = spaces.Dict({
            # Article-location features
            "article_location_features": spaces.Box(
                low=-np.inf,
                high=np.inf,
                shape=(self.num_articles, self.num_locations, 6),  # [remaining_stock, price, predicted_sales, min_alloc, max_alloc, allocated]
                dtype=np.float32
            ),
            # OTS class features per location
            "ots_features": spaces.Box(
                low=0,
                high=np.inf,
                shape=(len(unique_ots_classes), self.num_locations, 2),  # [total_open_to_ship, remaining_open_to_ship]
                dtype=np.float32
            ),
            # Current phase, article, and location
            "current_phase": spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
            "current_article": spaces.Box(low=-1, high=self.num_articles, shape=(1,), dtype=np.float32),
            "current_location": spaces.Box(low=-1, high=self.num_locations, shape=(1,), dtype=np.float32)
        })

        # Action space for both article/location selection and quantity allocation
        self.action_space = spaces.Box(
            low=-1,
            high=1,
            shape=(1,),
            dtype=np.float32
        )

        # Track available article-location combinations
        self.available_combinations = set()
        for article_idx in range(self.num_articles):
            for location_idx in range(self.num_locations):
                self.available_combinations.add((article_idx, location_idx))
        
        self.initial_combinations = self.available_combinations.copy()
        self.reset()

    def _get_observation(self):
        # Prepare article-location features
        article_location_features = np.zeros((self.num_articles, self.num_locations, 6), dtype=np.float32)
        for art_idx, art_info in self.article_location_info.items():
            for loc, loc_info in art_info["locations"].items():
                loc_idx = self.location_to_idx[loc]
                article_location_features[art_idx, loc_idx] = [
                    art_info["remaining_stock"],
                    loc_info["article_price"],
                    loc_info["predicted_sales"],
                    loc_info["min_allocation"],
                    loc_info["max_allocation"],
                    loc_info["allocated"]
                ]

        # Prepare OTS features
        ots_features = np.zeros((len(self.ots_class_to_idx), self.num_locations, 2), dtype=np.float32)
        for (ots_class_idx, location), ots_info in self.ots_classes.items():
            loc_idx = self.location_to_idx[location]
            ots_features[ots_class_idx, loc_idx] = [
                ots_info["open_to_ship"],
                ots_info["remaining_open_to_ship"]
            ]
            
        # Add mask for available combinations to the observation
        available_mask = np.zeros((self.num_articles, self.num_locations), dtype=np.float32)
        for article_idx, location_idx in self.available_combinations:
            available_mask[article_idx, location_idx] = 1.0

        return {
            "article_location_features": article_location_features,
            "ots_features": ots_features,
            "current_phase": np.array([self.phase], dtype=np.float32),
            "current_article": np.array([self.current_article], dtype=np.float32),
            "current_location": np.array([self.current_location], dtype=np.float32),
            "available_mask": available_mask  # Added mask for available combinations
        }

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.current_step = 0
        self.current_article = -1
        self.current_location = -1
        
        # Reset article stocks
        for art_info in self.article_location_info.values():
            art_info["remaining_stock"] = art_info["allocatable_stock"]
            for loc_info in art_info["locations"].values():
                loc_info["allocated"] = 0
        
        # Reset OTS class remaining open_to_ship
        for ots_info in self.ots_classes.values():
            ots_info["remaining_open_to_ship"] = ots_info["open_to_ship"]
            
        # Reset available combinations
        self.available_combinations = self.initial_combinations.copy()
        
        self.phase = 0  # 0: select article, 1: select location, 2: allocate quantity
        return self._get_observation(), {}

    def step(self, action):
        reward = 0
        done = False
        info = {}
        
        if self.phase == 0:  # Article selection phase
            article_idx = int((action + 1) * (self.num_articles / 2))
            article_idx = np.clip(article_idx, 0, self.num_articles - 1)
            
            # Check if there are any valid locations for this article
            valid_locations = [loc_idx for loc_idx in range(self.num_locations) 
                             if (article_idx, loc_idx) in self.available_combinations]
            
            if not valid_locations:
                return self._get_observation(), 0, False, False, info  # Skip invalid article
            
            self.current_article = article_idx
            self.phase = 1
            return self._get_observation(), 0, False, False, info

        elif self.phase == 1:  # Location selection phase
            location_idx = int((action + 1) * (self.num_locations / 2))
            location_idx = np.clip(location_idx, 0, self.num_locations - 1)
            
            # Check if this combination is valid
            if not self._is_combination_valid(self.current_article, location_idx):
                self.phase = 0  # Go back to article selection
                self.current_article = -1
                return self._get_observation(), 0, False, False, info
            
            self.current_location = location_idx
            self.phase = 2
            return self._get_observation(), 0, False, False, info

        else:  # Quantity allocation phase
            art_info = self.article_location_info[self.current_article]
            location = self.idx_to_location[self.current_location]
            loc_info = art_info["locations"][location]
            ots_key = (loc_info["ots_main_class"], location)
            ots_info = self.ots_classes[ots_key]

            # Calculate maximum possible allocation considering all constraints
            max_possible = min(
                loc_info["max_allocation"],
                art_info["remaining_stock"],
                ots_info["remaining_open_to_ship"]
            )
            min_possible = loc_info["min_allocation"]
            
            # Map action to quantity
            quantity = ((action + 1) / 2) * (max_possible - min_possible) + min_possible
            quantity = float(np.clip(quantity, min_possible, max_possible))

            # Update states
            art_info["remaining_stock"] -= quantity
            loc_info["allocated"] += quantity
            ots_info["remaining_open_to_ship"] -= quantity

            # Remove this combination from available ones
            self.available_combinations.remove((self.current_article, self.current_location))
            
            # Update available combinations based on new constraints
            self._remove_invalid_combinations()

            # Calculate reward based on predicted sales and price
            effective_sold = min(quantity, loc_info["predicted_sales"])
            reward = effective_sold * loc_info["article_price"]

            # Check done conditions
            self.current_step += 1
            done = (
                self.current_step >= self.max_steps or
                len(self.available_combinations) == 0 or  # No more valid combinations
                all(ots_info["remaining_open_to_ship"] <= 0 for ots_info in self.ots_classes.values()) or
                all(art_info["remaining_stock"] <= 0 for art_info in self.article_location_info.values())
            )

            if not done:
                self.phase = 0
                self.current_article = -1
                self.current_location = -1

            return self._get_observation(), reward, done, done, info
        
    def _is_combination_valid(self, article_idx, location_idx):
        """Check if an article-location combination is still valid for allocation"""
        if (article_idx, location_idx) not in self.available_combinations:
            return False
            
        art_info = self.article_location_info[article_idx]
        location = self.idx_to_location[location_idx]
        loc_info = art_info["locations"][location]
        ots_key = (loc_info["ots_main_class"], location)
        ots_info = self.ots_classes[ots_key]

        # Check if there's enough stock and OTS capacity for at least minimum allocation
        min_alloc = loc_info["min_allocation"]
        return (art_info["remaining_stock"] >= min_alloc and 
                ots_info["remaining_open_to_ship"] >= min_alloc)

    def _remove_invalid_combinations(self):
        """Remove combinations that are no longer valid"""
        invalid_combinations = set()
        for article_idx, location_idx in self.available_combinations:
            if not self._is_combination_valid(article_idx, location_idx):
                invalid_combinations.add((article_idx, location_idx))
        
        self.available_combinations -= invalid_combinations

In [8]:
# Read forecast data
forecast_data = pd.read_csv('./forecasts.csv')

# filter plant == 9540   and remove columns plant 
forecast_data = forecast_data[forecast_data['plant'] == 9540]
# remove open_to_shiop <5
forecast_data = forecast_data.drop(columns=['plant'])

# Prepare training data
train_data = {
    "group_article": [],
    "location_id": [],
    "allocatable_stock": [],
    "article_price": [],
    "predicted_sales": [],
    "predicted_sales_upper_bound": [],
    "ots_main_class": [],
    "open_to_ship": [],
    "min_allocation": [],
    "max_allocation": []
}

# Process forecast data into required format
for _, row in forecast_data.iterrows():
    train_data["group_article"].append(row["group_article"])
    train_data["location_id"].append(row["location_id"]) 
    train_data["allocatable_stock"].append(row["allocatable_stock"])
    train_data["article_price"].append(row["article_price"])
    train_data["predicted_sales"].append(row["predicted_sales"])
    train_data["predicted_sales_upper_bound"].append(row["predicted_sales_upper_bound"]) # 20% buffer
    train_data["ots_main_class"].append(row["ots_main_class"])
    train_data["open_to_ship"].append(row["open_to_ship"])
    train_data["min_allocation"].append(row["min_allocation"])
    train_data["max_allocation"].append(500.0)

# Create environments
train_df = pd.DataFrame(train_data)

# check the unique locations count for each group_article
# get unique locations to list
unique_locations = train_df['location_id'].unique().tolist()

# Create new rows for missing location combinations
new_rows = []
for article in train_df['group_article'].unique():
    article_df = train_df[train_df['group_article'] == article]
    missing_locations = set(unique_locations) - set(article_df['location_id'])
    
    for loc in missing_locations:
        # Copy the first row for this article as template
        new_row = article_df.iloc[0].copy()
        new_row['location_id'] = loc
        # Set default values for the new location
        new_row['predicted_sales'] = 0
        new_row['predicted_sales_upper_bound'] = 0
        new_row['open_to_ship'] = 0
        new_rows.append(new_row)

# Append new rows to train_df
if new_rows:
    train_df = pd.concat([train_df, pd.DataFrame(new_rows)], ignore_index=True)



In [9]:

# Create unique article mapping
unique_articles = sorted(set(train_data["group_article"]))
article_to_idx = {art: idx for idx, art in enumerate(unique_articles)}

# Create environments
train_env = StockAllocationEnv(train_df, article_to_idx)

# Create and train model
model = PPO(
    "MultiInputPolicy",
    train_env,
    verbose=1,
    device=device,
    policy_kwargs=dict(
        net_arch=dict(
            pi=[256, 128, 64],
            vf=[256, 128, 64]
        )
    )
)

model.learn(total_timesteps=100)


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 553      |
|    ep_rew_mean     | 5.81e+05 |
| time/              |          |
|    fps             | 24       |
|    iterations      | 1        |
|    time_elapsed    | 81       |
|    total_timesteps | 2048     |
---------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x2381410bc40>

In [None]:

# # Testing data
# test_data = {
#     "group_article": ["D", "D", "D", "E", "E", "E", "F", "F", "F"],  # Articles repeated for each location
#     "location_id": [1, 2, 3, 1, 2, 3, 1, 2, 3],  # Multiple locations
#     "allocatable_stock": [120, 120, 120, 180, 180, 180, 130, 130, 130],  # Same stock per article
#     "article_price": [12, 14, 13, 18, 20, 19, 13, 15, 14],  # Different prices per location
#     "predicted_sales": [90, 95, 92, 140, 145, 142, 110, 115, 112],  # Different predictions per location
#     "predicted_sales_upper_bound": [110, 115, 112, 170, 175, 172, 150, 155, 152],
#     "ots_main_class": [2, 2, 2, 2, 2, 2, 2, 2, 2],
#     "open_to_ship": [220, 250, 230, 220, 250, 230, 220, 250, 230],  # Different per location
#     "min_allocation": [12, 14, 13, 14, 16, 15, 11, 13, 12],
#     "max_allocation": [55, 60, 57, 95, 100, 97, 65, 70, 67],
# }
# test_df = pd.DataFrame(test_data)
# unique_articles = sorted(set(test_df["group_article"]))
# article_to_idx = {art: idx for idx, art in enumerate(unique_articles)}
# test_env = StockAllocationEnv(test_df, article_to_idx)
# # Test the model
# obs, info = test_env.reset()
# for _ in range(10):
#     action, _states = model.predict(obs)
#     action = action.item()
#     obs, rewards, dones, __, info = test_env.step(action)
#     if dones:
#         break