In [1]:
# Import relevant libraries
import numpy as np
import gymnasium as gym
from gymnasium import spaces

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

In [2]:
# Warehouse simulator with a moving agent
class WarehouseEnv(gym.Env):
    """
    A warehouse environment conforming to the modern Gymnasium API.
    """
    
    # Add metadata, which is standard practice in Gymnasium
    metadata = {"render_modes": ["rgb_array", "human"], "render_fps": 4}

    def __init__(self, grid_size=5, num_items=5, max_steps=100, render_mode=None):
        super().__init__()

        self.grid_size = grid_size
        self.num_items = num_items
        self.max_steps = max_steps
        self.agent_pos = [0, 0]
        self.agent_carries_item = 0
        self.steps_taken = 0
        self.warehouse = np.zeros((grid_size, grid_size), dtype=int)
        self.item_locations = {}
        self.item_demand = np.random.randint(1, 10, size=num_items + 1)
        self.item_demand[0] = 0
        
        # Store the render_mode and assert it's valid
        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        self.action_space = spaces.Discrete(5)
        self.observation_space = spaces.Dict({
            "agent_pos": spaces.Box(low=0, high=grid_size - 1, shape=(2,), dtype=int),
            "agent_carries": spaces.Discrete(num_items + 1),
            "warehouse_layout": spaces.Box(low=0, high=num_items, shape=(grid_size, grid_size), dtype=int)
        })

    def _get_obs(self):
        return {
            "agent_pos": np.array(self.agent_pos),
            "agent_carries": self.agent_carries_item,
            "warehouse_layout": self.warehouse
        }

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.agent_pos = [0, 0]
        self.agent_carries_item = 0
        self.steps_taken = 0
        self.warehouse.fill(0)
        self.item_locations.clear()

        for item in range(1, self.num_items + 1):
            while True:
                row, col = self.np_random.integers(0, self.grid_size, size=2)
                if self.warehouse[row, col] == 0:
                    self.warehouse[row, col] = item
                    self.item_locations[item] = (row, col)
                    break
        obs = self._get_obs()
        return obs, {}

    def step(self, action):
        self.steps_taken += 1
        reward = -0.01

        if action < 4:
            new_pos = self.agent_pos.copy()
            if action == 0: new_pos[0] -= 1
            elif action == 1: new_pos[0] += 1
            elif action == 2: new_pos[1] -= 1
            elif action == 3: new_pos[1] += 1
            if 0 <= new_pos[0] < self.grid_size and 0 <= new_pos[1] < self.grid_size:
                self.agent_pos = new_pos
            else:
                reward -= 0.1
        elif action == 4:
            agent_r, agent_c = self.agent_pos
            if self.agent_carries_item > 0:
                if self.warehouse[agent_r, agent_c] == 0:
                    item_id = self.agent_carries_item
                    self.warehouse[agent_r, agent_c] = item_id
                    self.item_locations[item_id] = (agent_r, agent_c)
                    self.agent_carries_item = 0
                    distance = agent_r + agent_c
                    reward += self.item_demand[item_id] / (distance + 1) * 5
                else:
                    reward -= 0.2
            else:
                item_on_floor = self.warehouse[agent_r, agent_c]
                if item_on_floor > 0:
                    self.agent_carries_item = item_on_floor
                    self.warehouse[agent_r, agent_c] = 0
                    del self.item_locations[item_on_floor]
                    reward += 0.1
                else:
                    reward -= 0.2

        terminated = self.steps_taken >= self.max_steps
        truncated = False
        obs = self._get_obs()
        return obs, reward, terminated, truncated, {}

    
    def render(self):
        """ Render the environment with a legend for each color.
            - Red: Agent
            - Yellow: Agent carrying an item
            - Viridis colors: Items on the floor
            - White: Empty spaces"""
        
        if self.render_mode == "rgb_array":
            fig, ax = plt.subplots(figsize=(7,7))
            render_grid = np.copy(self.warehouse)
            agent_r, agent_c = self.agent_pos

            agent_render_value = self.num_items + 2
            if self.agent_carries_item > 0:
                agent_render_value += 1

            render_grid[agent_r, agent_c] = agent_render_value

            cmap_colors = [(1, 1, 1, 1)]
            viridis = plt.colormaps.get('viridis')
            item_colors = viridis(np.linspace(0, 1, self.num_items))
            cmap_colors.extend(item_colors)
            cmap_colors.append((1, 0, 0, 1)) # agent = red
            cmap_colors.append((1, 1, 0, 1)) # agent carrying = yellow

            custom_map = colors.ListedColormap(cmap_colors)
            bounds = np.arange(self.num_items + 4) - 0.5
            norm = colors.BoundaryNorm(bounds, custom_map.N)

            ax.imshow(render_grid, cmap=custom_map, norm=norm)

            ax.set_xticks(np.arange(-0.5, self.grid_size, 1), minor=True)
            ax.set_yticks(np.arange(-0.5, self.grid_size, 1), minor=True)
            ax.grid(which="minor", color="black", linewidth=0.5)
            ax.tick_params(which="both", bottom=False, left=False, labelbottom=False, labelleft=False)

            for r in range(self.grid_size):
                for c in range(self.grid_size):
                    item_id = self.warehouse[r,c]
                    if item_id > 0:
                        ax.text(c, r, str(item_id), ha='center', va='center', color='white', weight='bold')
            
            from matplotlib.patches import Patch
            legend_elements = [
                Patch(facecolor='red', edgecolor='black', label='Agent (Empty)'),
                Patch(facecolor='yellow', edgecolor='black', label='Agent (Carrying)'),
                Patch(facecolor=viridis(0.5), edgecolor='black', label='Item'),
                Patch(facecolor='white', edgecolor='black', label='Empty Space')]

            ax.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.05), loc='upper center', ncol=2)
            plt.tight_layout(rect=[0, 0.05, 1, 1])

            canvas = FigureCanvas(fig)
            canvas.draw()
            width, height = fig.canvas.get_width_height()
            image = np.frombuffer(canvas.buffer_rgba(), dtype='uint8').reshape(height, width, 4)[:, :, :3]
            plt.close(fig)
            return image

In [3]:
import torch
import stable_baselines3
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, BaseCallback
import imageio
import os
from gymnasium.wrappers import FlattenObservation

# Check PyTorch GPU availability
if torch.cuda.is_available():
    print("CUDA available, Stable Baselines3 will use GPU.")
    device = 'cuda'
else:
    print("CUDA not available, Stable Baselines3 will use CPU.")
    device = 'cpu'

# SaveAnimationCallback class (no changes needed here)
class SaveAnimationCallback(BaseCallback):
    def __init__(self, save_freq: int, output_path: str = './animation_frames/', verbose: int = 0):
        super(SaveAnimationCallback, self).__init__(verbose)
        self.save_freq = save_freq
        self.output_path = output_path
        os.makedirs(self.output_path, exist_ok=True)

    def _on_step(self) -> bool:
        if self.n_calls % self.save_freq == 0:
            # self.training_env.render() is vectorized, so we access the underlying env
            img = self.training_env.get_images()[0]
            filename = os.path.join(self.output_path, f"frame_{self.num_timesteps:06d}.png")
            imageio.imwrite(filename, img)
        return True

# 1. Create the base environment
env = WarehouseEnv(render_mode="rgb_array")
# 2. Apply the FlattenObservation wrapper
env = FlattenObservation(env)
# 3. Wrap with Monitor for statistics
env = Monitor(env)
# 4. Vectorize the environment for Stable Baselines3
env = DummyVecEnv([lambda: env])
# ------------------------------------

# dqn_kwargs dictionary
dqn_kwargs = {
    "policy":"MlpPolicy",
    "env":env,
    "learning_rate": 5e-4, # Low learning rate for more stable learning
    "buffer_size": 50000,
    "learning_starts": 5000,
    "batch_size": 128,
    "gamma": 0.99,
    "train_freq": (4, "step"),
    "target_update_interval": 5000,
    "exploration_fraction": 0.2,
    "exploration_final_eps": 0.02,
    "verbose": 1,
    "tensorboard_log": "./logs/dqn_warehouse/",
    "device": device
}

model = DQN(**dqn_kwargs)

# Callbacks
checkpoint_callback = CheckpointCallback(save_freq=5000, save_path="./models/dqn/", name_prefix="warehouse_dqn")

eval_env = WarehouseEnv()
eval_env = FlattenObservation(eval_env) # Eval env also needs to be flattened
eval_env = DummyVecEnv([lambda: eval_env])
eval_callback = EvalCallback(eval_env, best_model_save_path="./models/dqn_best",
                             log_path="./logs/dqn_eval/", eval_freq=10000, deterministic=True)

animation_callback = SaveAnimationCallback(save_freq=5000)

# Learn
total_timesteps = 100000 # Increased more complex problem
model.learn(total_timesteps=total_timesteps, callback=[checkpoint_callback, eval_callback, animation_callback])

# Close environments
env.close()
eval_env.close()

model.save("./models/dqn_warehouse_final")
print("Training Complete. Model saved successfully.")

CUDA available, Stable Baselines3 will use GPU.
Using cuda device
Logging to ./logs/dqn_warehouse/DQN_3
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 100      |
|    ep_rew_mean      | 13.1     |
|    exploration_rate | 0.98     |
| time/               |          |
|    episodes         | 4        |
|    fps              | 3901     |
|    time_elapsed     | 0        |
|    total_timesteps  | 400      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 100      |
|    ep_rew_mean      | 40.2     |
|    exploration_rate | 0.961    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 4477     |
|    time_elapsed     | 0        |
|    total_timesteps  | 800      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 100      |
|    ep_rew_mean     



Eval num_timesteps=10000, episode_reward=-9.00 +/- 9.80
Episode length: 100.00 +/- 0.00
----------------------------------
| eval/               |          |
|    mean_ep_length   | 100      |
|    mean_reward      | -9       |
| rollout/            |          |
|    exploration_rate | 0.51     |
| time/               |          |
|    total_timesteps  | 10000    |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.0189   |
|    n_updates        | 1249     |
----------------------------------
New best mean reward!
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 100      |
|    ep_rew_mean      | 22.7     |
|    exploration_rate | 0.51     |
| time/               |          |
|    episodes         | 100      |
|    fps              | 496      |
|    time_elapsed     | 20       |
|    total_timesteps  | 10000    |
----------------------------------
----------------------------------
| rollout/     

In [4]:
import imageio
from stable_baselines3 import DQN
from gymnasium.wrappers import FlattenObservation
import os

print("--- Starting Cell 4: Evaluation ---")

# --- Delete the frames directory if it exists ---
frames_path = "./evaluation_frames/"
if os.path.exists(frames_path):
    import shutil
    shutil.rmtree(frames_path)
os.makedirs(frames_path, exist_ok=True)
print(f"Cleaned and created directory: {frames_path}")
# --------------------------------------------------------------------

try:
    # Load the trained model
    model_path = './models/dqn_warehouse_final.zip'
    model = DQN.load(model_path)
    print(f"Successfully loaded model from: {model_path}")

    # Create the evaluation environment
    eval_env = WarehouseEnv(render_mode="rgb_array")
    eval_env_wrapped = FlattenObservation(eval_env)
    print("Evaluation environment created and wrapped.")

    output_gif_path = "warehouse_evaluation.gif"
    frames = []

    # Reset the wrapped environment
    obs, info = eval_env_wrapped.reset()
    print("Environment reset. Initial observation received.")

    # Limit the evaluation episode length
    for step in range(200):
        print(f"--- Step {step} ---")

        # Render the ORIGINAL unwrapped environment
        img = eval_env.render()
        if img is not None:
            frames.append(img)
            print(f"Frame {step} rendered and saved.")
        else:
            print(f"Warning: Frame {step} rendering returned None.")


        # Predict the action
        action, _ = model.predict(obs, deterministic=True)
        print(f"Action predicted: {action}")

        # Step the wrapped environment
        obs, reward, terminated, truncated, info = eval_env_wrapped.step(action)
        print(f"Environment stepped. Reward: {reward:.2f}, Terminated: {terminated}, Truncated: {truncated}")

        # The episode ends if done (terminated) or truncated is True
        if terminated or truncated:
            print(f"Episode finished at step {step}. Reason: Terminated={terminated}, Truncated={truncated}")
            # Add a final frame to show the end state
            img = eval_env.render()
            frames.append(img)
            break

    eval_env_wrapped.close()
    print("\n--- Loop Finished ---")

    # Check if frames were captured before saving
    if frames:
        print(f"Saving {len(frames)} frames to {output_gif_path}...")
        imageio.mimsave(output_gif_path, frames, fps=10)
        print("Evaluation GIF saved successfully.")
    else:
        print("Warning: No frames were captured, so no GIF was created.")

except Exception as e:
    print("\n--- AN ERROR OCCURRED ---")
    import traceback
    traceback.print_exc()

--- Starting Cell 4: Evaluation ---
Cleaned and created directory: ./evaluation_frames/
Successfully loaded model from: ./models/dqn_warehouse_final.zip
Evaluation environment created and wrapped.
Environment reset. Initial observation received.
--- Step 0 ---
Frame 0 rendered and saved.
Action predicted: 1
Environment stepped. Reward: -0.01, Terminated: False, Truncated: False
--- Step 1 ---
Frame 1 rendered and saved.
Action predicted: 1
Environment stepped. Reward: -0.01, Terminated: False, Truncated: False
--- Step 2 ---
Frame 2 rendered and saved.
Action predicted: 1
Environment stepped. Reward: -0.01, Terminated: False, Truncated: False
--- Step 3 ---
Frame 3 rendered and saved.
Action predicted: 1
Environment stepped. Reward: -0.01, Terminated: False, Truncated: False
--- Step 4 ---
Frame 4 rendered and saved.
Action predicted: 1
Environment stepped. Reward: -0.11, Terminated: False, Truncated: False
--- Step 5 ---
Frame 5 rendered and saved.
Action predicted: 1
Environment step

In [5]:
import os
import imageio
import re

input_folder = "./animation_frames/"
output_gif_path = "warehouse_training.gif"
fps = 2

try:
    image_files = [f for f in os.listdir(input_folder) if f.endswith('.png')]

    def sort_key(filename):
        match = re.search(r'(\d+)', filename)
        return int(match.group(1)) if match else 0

    image_files.sort(key=sort_key)

    if not image_files:
        print(f"Error: No image files found in '{input_folder}'. Cannot create GIF.")
    else:
        print(f"Found {len(image_files)} frames to animate.")

        frames = []
        for filename in image_files:
            file_path = os.path.join(input_folder, filename)
            frames.append(imageio.imread(file_path))

        print(f"Saving GIF to '{output_gif_path}'...")
        imageio.mimsave(output_gif_path, frames, fps=fps, loop=0)
        print("--- GIF created successfully! ---")


except FileNotFoundError:
    print(f"Error: The directory '{input_folder}' was not found.")
except Exception as e:
    print(f"An error occurred: {e}")      

Found 20 frames to animate.


  frames.append(imageio.imread(file_path))


Saving GIF to 'warehouse_training.gif'...
--- GIF created successfully! ---


In [6]:
pip install imageio[pyav]

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [10]:
import imageio.v3 as iio
import os

# --- Configuration ---
gif_path = "warehouse_training.gif"
mp4_path = "warehouse_training.mp4"
fps = 2 

print(f"--- Converting {gif_path} to {mp4_path} ---")

if not os.path.exists(gif_path):
    print(f"Error: The file '{gif_path}' was not found. Please ensure it is in the correct directory.")
else:
    try:
        print("Reading frames from GIF...")
        
        # Read all frames from the GIF into a list in memory.
        # index=None tells it to read all frames. This is a robust way to handle it.
        frames = iio.imread(gif_path, index=None)
        
        num_frames = len(frames)
        print(f"Successfully read {num_frames} frames.")

        print(f"Writing {num_frames} frames to MP4 with {fps} FPS...")
        
        # Write the list of frames to the new MP4 video file.
        iio.imwrite(mp4_path, frames, fps=fps, codec='libx264')

        print("--- Conversion successful! ---")

    except Exception as e:
        import traceback
        print(f"\nAn error occurred during conversion:")
        traceback.print_exc()

--- Converting warehouse_training.gif to warehouse_training.mp4 ---
Reading frames from GIF...
Successfully read 20 frames.
Writing 20 frames to MP4 with 2 FPS...
--- Conversion successful! ---
You can now find 'warehouse_training.mp4' in your file browser and upload it to LinkedIn.
