In [1]:
# Import relevant libraries
import numpy as np
import gymnasium as gym
from gymnasium import spaces

In [2]:
# Define the warehouse simulator
class WarehouseEnv(gym.Env):
    def __init__(self, grid_size=5, num_items=5, max_steps=50):
        super(WarehouseEnv, self).__init__()

        self.grid_size = grid_size
        self.num_items = num_items
        self.max_steps = max_steps
        self.steps_taken = 0

        # Initialize warehouse with Item IDs
        self.warehouse = np.zeros((grid_size, grid_size), dtype=int)
        self.item_locations = {} # Stores (row, col) of each item as a tuple.

        # Demand levels for each item (higher means more frequently accessed)
        self.item_demand = np.random.randint(1, 10, size=num_items+1)

        # Action space: (move item_id, target_row, target_col)
        self.action_space = spaces.Discrete((num_items+1) * grid_size * grid_size) # Instead of MultiDiscrete[num_items+1, grid_size, grid_size].Changed 
                                                                                   # because of the policy chosen in the next cell. Discrete is used
                                                                                   # when an agent needs to choose one option from a list. MultiDiscrete is used when
                                                                                   # the agent has to make multiple, independent discrete decisions concurrently. 

        # Observation space: Flattend grid representation
        self.observation_space = spaces.Box(low=0, high=num_items, shape=(grid_size * grid_size,), dtype=int)

        self.reset()


    def reset(self, seed=None, options=None):
        """Reset warehouse with random item placement."""
        super().reset(seed=seed)


        self.warehouse.fill(0)
        self.item_locations.clear()
        self.steps_taken = 0

        for item in range(1, self.num_items + 1):
            row, col = np.random.randint(0, self.grid_size, size=2) # generate 2 integers in the range 0 to grid_size where grid_size is not included.
            while self.warehouse[row, col] != 0: # Pick a unique position in the warehouse since all locations/positions were initialized as 0.
              row, col = np.random.randint(0, self.grid_size, size=2)
            self.warehouse[row, col] = item
            self.item_locations[item] = (row, col)

        return self._get_obs(), {}


    def _get_obs(self):
        """Return the warehouse state as a flattened array"""
        return self.warehouse.flatten()


    def step(self, action):
      self.steps_taken += 1


      item_id = action // (self.grid_size * self.grid_size) # Isolate the item id from the full action space definition above.
      target_row = (action // self.grid_size) % self.grid_size # Isolate the target row
      target_col = action % self.grid_size


      if item_id == 0 or item_id not in self.item_locations:
          return self._get_obs(), -1, False, False, {} # Invalid move penalty

      # Get current item location
      current_row, current_col = self.item_locations[item_id]

      # If target position is empty, move item
      if self.warehouse[target_row, target_col] == 0:
          self.warehouse[current_row, current_col] = 0 # Clear old position
          self.warehouse[target_row, target_col] = item_id # Place item
          self.item_locations[item_id] = (target_row, target_col) # Update location dictionary

      # Compute reward: Encourage high-demand items to be near (0,0)
      distance = abs(target_row) + abs(target_col) # Manhattan distance from exit
      reward = self.item_demand[item_id] / (distance + 1)

      # done = False
      truncated = False

      # Define termination condition
      done = self.steps_taken >= self.max_steps or self._goal_reached()

      return self._get_obs(), reward, done, truncated, {}

    def _goal_reached(self):
      """Check if high-demand items are near (0,0)."""
      for item_id, (row, col) in self.item_locations.items():
        if self.item_demand[item_id] >= 7 and (row + col) > 2:
          return False
      return True

    def render(self):
        print(self.warehouse)

In [3]:
import torch
import stable_baselines3
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor # To automatically track and log statistics about the agent's performance during training or evaluation. 
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback

# Check PyTorch GPU availability
if torch.cuda.is_available():
    print("CUDA available, Stable Baselines3 will use GPU.")
    device = 'cuda'
else:
    print("CUDA not available, Stable Baselines3 will use CPU.")
    device = 'cpu'

env = WarehouseEnv()
env = Monitor(env)
env = DummyVecEnv([lambda: env]) # To make the environment compatible with stable_baselines3 algorithms, which expect a vectorized environment interface, even if there's only a single environment.

dqn_kwargs = {
    "policy":"MlpPolicy", # Multi-Layer Perceptron (MLP)
    "env":env,
    "learning_rate": 1e-3,
    "buffer_size": 10000,
    "learning_starts": 10,
    "batch_size": 64,
    "gamma": 0.99,
    "train_freq": (4, "step"),
    "target_update_interval": 10000,
    "exploration_fraction": 0.1,
    "exploration_final_eps": 0.01,
    "verbose": 1,
    "tensorboard_log": "./logs/dqn_warehouse/",
    "device": "cuda"
}

model = DQN(**dqn_kwargs)

checkpoint_callback = CheckpointCallback(save_freq=5000, save_path="./models/dqn/", name_prefix="warehouse_dqn")
eval_env = DummyVecEnv([lambda: WarehouseEnv()])
eval_callback = EvalCallback(eval_env, best_model_save_path="./models/dqn_best",
                             log_path="./logs/dqn_eval/", eval_freq=5000, deterministic=True)

total_timesteps = 20000
model.learn(total_timesteps=total_timesteps) # , callback=[checkpoint_callback, eval_callback]

env.close()
eval_env.close()

model.save("./models/dqn_warehouse_final")

print("Training Complete. Model saved successfully.")

CUDA available, Stable Baselines3 will use GPU.
Using cuda device
Logging to ./logs/dqn_warehouse/DQN_4
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 43.2     |
|    ep_rew_mean      | 73.4     |
|    exploration_rate | 0.914    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 23       |
|    time_elapsed     | 7        |
|    total_timesteps  | 173      |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 1.83     |
|    n_updates        | 41       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 46.8     |
|    ep_rew_mean      | 84.7     |
|    exploration_rate | 0.815    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 47       |
|    time_elapsed     | 7        |
|    total_timesteps  | 374      |
| train/             