In [None]:
import sys

sys.path.insert(0, "..")

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm.notebook import tqdm

# Our modules
from src.agents.dqn_agent import DQNAgent
from src.mdp import RewardFunction
from src.agents.callbacks.learning_curve_callback import LearningCurveCallback
from src.agents.metrics import TrainingMetrics, EvaluationMetrics
from src.seeds import generate_seeds

# Plotting style
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (12, 4)

## 2. Environment Setup

In [None]:
# Configuration
EPISODE_LENGTH = 1000  # days

TRAINING_SEED = generate_seeds(1, start_index=20)[0]
EVAL_SEED = generate_seeds(1, start_index=21)[0]
EVAL_SEEDS = generate_seeds(1000, start_index=0)

print("🎲 DQN Training Seeds:")
print(f"   Training: {TRAINING_SEED}")
print(f"   Validation: {EVAL_SEED}")
print(f"   Final evaluation: {len(EVAL_SEEDS)} seeds starting at {EVAL_SEEDS[0]}")

In [None]:
from src.environment.gym_env import InventoryEnvironment


def make_env(
    episode_length: int = EPISODE_LENGTH,
    random_seed: int = TRAINING_SEED,
) -> InventoryEnvironment:
    return InventoryEnvironment(
        k=30,
        Q_max=30,
        episode_length=episode_length,
        random_seed=random_seed,
    )


env = make_env(random_seed=TRAINING_SEED)
print(env)

## 3. Create and Train DQN Agent

In [None]:
agent = DQNAgent(
    env=env,
    learning_rate=1e-4,
    gamma=0.99,
    buffer_size=100_000,
    batch_size=64,
    exploration_fraction=0.3,
    exploration_final_eps=0.05,
    target_update_interval=1000,
    learning_starts=1000,
    train_freq=4,
    tensorboard_log=None,
    policy_kwargs=dict(net_arch=[256, 256]),
    seed=TRAINING_SEED,
    verbose=0,
)

print(f"DQN Agent created with seed {TRAINING_SEED}")

## 5. Training

Train the DQN agent with periodic evaluation.

**Monitor with TensorBoard:**

```bash
tensorboard --logdir=./logs
```


In [None]:
TOTAL_TIMESTEPS = 5_000_000

print(f"🚀 Starting training for {TOTAL_TIMESTEPS:,} timesteps...")
print(f"   ≈ {TOTAL_TIMESTEPS // EPISODE_LENGTH:,} episodes\n")

# Initialize plot classes
training_plots = TrainingMetrics()
evaluation_plots = EvaluationMetrics()

# Callbacks
learning_curve_callback = LearningCurveCallback()

agent.train(
    total_timesteps=TOTAL_TIMESTEPS,
    progress_bar=True,
    callbacks=learning_curve_callback,
)

print(f"Training complete!")

## 6. Save the Trained Model

In [None]:
agent.save()

---

# Phase 1: Training Analysis


### 1. Learning Curve (Training Stability)

Average Reward per Episode vs. Timesteps with baseline comparison.


In [None]:
window = int(0.1 * (TOTAL_TIMESTEPS // EPISODE_LENGTH))

# Learning Curve: Use data from callback
training_plots.plot_learning_curve(
    episode_timesteps=learning_curve_callback.episode_timesteps,
    episode_rewards=learning_curve_callback.episode_rewards,
    window=max(window, 5),
    title="Learning Curve: DQN Agent Training Progress",
)

### 2. Exploration vs Exploitation

Epsilon decay over training timesteps.


In [None]:
training_plots.plot_epsilon_decay(
    total_timesteps=TOTAL_TIMESTEPS,
    exploration_fraction=agent.hyperparams["exploration_fraction"],
    exploration_final_eps=agent.hyperparams["exploration_final_eps"],
    title="Exploration vs Exploitation: Epsilon Decay",
)

---

# Phase 2: Evaluation & Testing

Run N=100 independent test episodes with deterministic policy (ε=0).


In [None]:
N_TEST_EPISODES = len(EVAL_SEEDS)

reward_fn = RewardFunction()
test_episodes_dqn = []

print(f"🧪 Running {N_TEST_EPISODES} test episodes with deterministic seeds...")

for ep, seed in tqdm(
    enumerate(EVAL_SEEDS), total=N_TEST_EPISODES, desc="Evaluating DQN"
):
    dqn_env = make_env(random_seed=seed)
    obs, _ = dqn_env.reset()

    dqn_data = {
        "net_inv_0": [],
        "net_inv_1": [],
        "q0": [],
        "q1": [],
        "demand_0": [],
        "demand_1": [],  # Daily demand per product
        "ordering_cost": [],
        "holding_cost": [],
        "shortage_cost": [],
        "total_daily_cost": [],
    }

    done = False
    while not done:
        action = agent.select_action(obs, deterministic=True)
        action_obj = dqn_env.action_space_config.get_action(action)
        obs, reward, terminated, truncated, info = dqn_env.step(action)

        # Log Inventory & Actions
        dqn_data["net_inv_0"].append(info["net_inventory"][0])
        dqn_data["net_inv_1"].append(info["net_inventory"][1])
        dqn_data["q0"].append(action_obj.order_quantities[0])
        dqn_data["q1"].append(action_obj.order_quantities[1])

        # Log Daily Demand per product
        dqn_data["demand_0"].append(info["total_demand"][0])
        dqn_data["demand_1"].append(info["total_demand"][1])

        # Compute Daily Cost Components using RewardFunction
        state = dqn_env.get_current_state()
        costs = reward_fn.calculate_costs(state, action_obj)

        # Append to lists
        dqn_data["ordering_cost"].append(costs.ordering_cost)
        dqn_data["holding_cost"].append(costs.holding_cost)
        dqn_data["shortage_cost"].append(costs.shortage_cost)
        dqn_data["total_daily_cost"].append(costs.total_cost)

        done = terminated or truncated

    test_episodes_dqn.append(dqn_data)

print(f"Collected {N_TEST_EPISODES} test episodes for DQN")

### Warm-up Period Analysis (Welch's Graphical Procedure)

Since the simulation starts with initial inventory conditions, the early data may be biased (transient phase).  
We use Welch's procedure to identify when steady-state begins.


In [None]:
# Welch's Graphical Procedure for Warm-up Detection
n_days, n_reps, WARMUP_LENGTH = evaluation_plots.plot_welch_procedure(
    test_episodes=test_episodes_dqn,
    window_size=25,
    title="Warm-up Period Analysis",
)

## Cost Component Breakdown (Economic Analysis)

In [None]:
# Compute and print evaluation statistics
stats = evaluation_plots.print_evaluation_statistics(
    test_episodes=test_episodes_dqn,
    warmup_length=WARMUP_LENGTH,
)

### Daily Cost Evolution

Aggregated daily cost statistics across all test episodes.


In [None]:
# Daily Cost Analysis
evaluation_plots.plot_daily_cost_analysis(
    test_episodes=test_episodes_dqn,
    warmup_length=WARMUP_LENGTH,
    title="Daily Cost Evolution - DQN Agent",
)

Grouped bar chart
decomposed into Ordering, Holding, and Shortage costs.


In [None]:
evaluation_plots.plot_cost_breakdown_by_product(
    test_episodes=test_episodes_dqn,
    warmup_length=WARMUP_LENGTH,
    n_days=n_days,
    title="Cost Component Breakdown by Product",
)

### Operational Time Series (Behavioral Analysis)

Snapshot of inventory levels, orders, and demand over time for selected episodes.


In [None]:
# Select representative episode based on STEADY-STATE costs (excluding warmup)
episode_ss_costs = [
    np.mean(ep["total_daily_cost"][WARMUP_LENGTH:]) for ep in test_episodes_dqn
]
global_ss_mean = np.mean(episode_ss_costs)

# Find episode closest to mean
representative_idx = np.argmin(np.abs(np.array(episode_ss_costs) - global_ss_mean))

# Also find best and worst episodes
best_idx = np.argmin(episode_ss_costs)
worst_idx = np.argmax(episode_ss_costs)

print(f"📊 Episode Selection (Based on Steady-State Costs):")
print(
    f"   Representative: Episode {representative_idx} (cost: ${episode_ss_costs[representative_idx]:.2f})"
)
print(
    f"   Best:           Episode {best_idx} (cost: ${episode_ss_costs[best_idx]:.2f})"
)
print(
    f"   Worst:          Episode {worst_idx} (cost: ${episode_ss_costs[worst_idx]:.2f})"
)
print(f"   Mean across all episodes: ${global_ss_mean:.2f}")

In [None]:
print("\n📈 Plotting Representative Episode (Steady-State Period)...")
evaluation_plots.plot_operational_timeseries(
    episode_data=test_episodes_dqn[452],
    title="Operational Time Series - DQN Agent (Representative Episode)",
    start_day=WARMUP_LENGTH,
    max_days=200,  # Show only 200 days for clarity
)

### Inventory Distribution Histogram (Risk Profile)

Distribution of Net Inventory levels over all test episodes. Red = Backlog (I < 0), Green = On-Hand (I > 0).


In [None]:
# Inventory Distribution Histogram (Steady-State only)
evaluation_plots.plot_inventory_histogram(
    test_episodes=test_episodes_dqn,
    warmup_length=WARMUP_LENGTH,
    title="Inventory Distribution - Risk Profile (DQN Agent, Steady-State)",
)