# DQN for Inventory Management — A2C Drop-in Comparison

**Architecture matches A2C Mod** (`training_220.ipynb`) exactly:
- Same state space: `[inventory, sales, waste]` per product → flat `[660]`
- Same action space: 14 discrete values
- Same reward function: `r = 1 - z - overstock - q - quan`
- Same data parsers and TFRecord files
- Same 600 episodes × 900 timesteps

**DQN Additions over A2C:**
- Experience Replay Buffer (100 000 capacity)
- Target Network (updated every 10 episodes)
- Double-DQN update rule
- Epsilon-Greedy exploration
- GroupNormalization(groups=1) after each hidden layer (matches Critic)

In [1]:
# ============================================================
# 0. DEPENDENCIES
# ============================================================
import sys
# Uncomment if needed:
!{sys.executable} -m pip install tensorflow==2.14 tensorflow-addons==0.22.0 numpy pandas matplotlib wandb




[notice] A new release of pip is available: 25.3 -> 26.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
# ============================================================
# 1. IMPORTS
# ============================================================
import os
import sys
import random
from collections import deque
from datetime import datetime

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

np.set_printoptions(edgeitems=25, linewidth=10000, precision=8, suppress=True)

print(f"TensorFlow : {tf.__version__}")
print(f"GPU devices: {tf.config.list_physical_devices('GPU')}")


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



TensorFlow : 2.14.0
GPU devices: []


In [3]:
# ============================================================
# 2. CONFIGURATION  (mirrors training_220.ipynb FLAGS)
# ============================================================

class Config:
    # ── Environment ─────────────────────────────────────────
    num_products          = 220
    num_features_per_prod = 3        # [inventory x, sales, waste q]
    num_features          = 220 * 3  # flat state: 660
    num_actions           = 14
    num_timesteps         = 900

    # ── Shared training knobs ────────────────────────────────
    train_episodes = 600
    batch_size     = 32       # same as A2C
    gamma          = 0.99     # same as A2C
    waste          = 0.025    # same as A2C
    zero_inventory = 1e-5     # same as A2C

    # ── Network architecture ─────────────────────────────────
    hidden_size  = 128        # matches DQN hidden_size request
    dropout_prob = 0.1        # same as A2C actor/critic
    use_group_norm = True     # GroupNormalization(groups=1) like A2C Critic

    # ── DQN-specific ─────────────────────────────────────────
    learning_rate          = 0.001   # same as A2C actor/critic lr for fairness
    replay_buffer_size     = 100_000
    min_replay_size        = 1_000
    epsilon_start          = 1.0
    epsilon_end            = 0.01
    epsilon_decay_episodes = 400     # decay over first 400 episodes
    target_update_freq     = 10      # episodes between target-network syncs

    # ── Action space (same 14 values as A2C) ─────────────────
    action_space = [
        0, 0.005, 0.01, 0.0125, 0.015, 0.0175,
        0.02, 0.03, 0.04, 0.08, 0.12, 0.2, 0.5, 1
    ]

    # ── File paths (same as A2C 220-product setup) ────────────
    train_file    = 'data220/train.tfrecords'
    capacity_file = 'data220/capacity.tfrecords'
    stock_file    = 'data220/stock.tfrecords'
    predict_file  = 'data220/test.tfrecords'
    output_dir    = 'checkpoints_dqn_comparison3'
    output_file   = './output_dqn_comparison.csv3'

    # ── W&B ──────────────────────────────────────────────────
    use_wandb    = False          # set True to enable
    wandb_project = 'inventory-dqn-vs-a2c'


FLAGS = Config()
os.makedirs(FLAGS.output_dir, exist_ok=True)

print("Configuration:")
print(f"  Products       : {FLAGS.num_products}")
print(f"  State features : {FLAGS.num_features}  ({FLAGS.num_products} × {FLAGS.num_features_per_prod})")
print(f"  Episodes       : {FLAGS.train_episodes}")
print(f"  Timesteps/ep   : {FLAGS.num_timesteps}")
print(f"  Batch size     : {FLAGS.batch_size}")
print(f"  Learning rate  : {FLAGS.learning_rate}")
print(f"  Hidden size    : {FLAGS.hidden_size}")
print(f"  GroupNorm      : {FLAGS.use_group_norm}")
print(f"  Replay buffer  : {FLAGS.replay_buffer_size}")
print(f"  Gamma          : {FLAGS.gamma}")

Configuration:
  Products       : 220
  State features : 660  (220 × 3)
  Episodes       : 600
  Timesteps/ep   : 900
  Batch size     : 32
  Learning rate  : 0.001
  Hidden size    : 128
  GroupNorm      : True
  Replay buffer  : 100000
  Gamma          : 0.99


In [4]:
# ============================================================
# 3. DATA PARSERS  (identical to training_220.ipynb)
# ============================================================

def sales_parser(serialized_example):
    """Parse a single sales record from TFRecordDataset."""
    example = tf.io.parse_single_example(
        serialized_example,
        features={"sales": tf.io.FixedLenFeature([FLAGS.num_products], tf.float32)}
    )
    return example


def capacity_parser(serialized_example):
    """Parse a single capacity record from TFRecordDataset."""
    example = tf.io.parse_single_example(
        serialized_example,
        features={"capacity": tf.io.FixedLenFeature([FLAGS.num_products], tf.float32)}
    )
    return example


def stock_parser(serialized_example):
    """Parse a single stock record from TFRecordDataset."""
    example = tf.io.parse_single_example(
        serialized_example,
        features={"stock": tf.io.FixedLenFeature([FLAGS.num_products], tf.float32)}
    )
    return example


print("Data parsers defined (identical to training_220.ipynb)")

Data parsers defined (identical to training_220.ipynb)


In [5]:
# ============================================================
# 4. ENVIRONMENT HELPERS  (identical to training_220.ipynb)
# ============================================================

def waste(x):
    """Waste fraction q̂ = waste_rate * inventory."""
    return FLAGS.waste * x


# def calc_reward(x_clip, sales_now, overstock):
    """
    Reward per product (NumPy arrays) — CORRECTED LOGIC.

    r = 1 - z - overstock - q - quan

    where:
      z    = stockout indicator (1 if demand exceeds available stock)
      q    = waste of inventory AFTER replenishment (x_clip)
      quan = quantile spread (95th - 5th percentile of x_clip)

    FIXED ISSUES:
      ❌ Old: Checked stockout on OLD inventory (before action)
      ✅ New: Check stockout based on unfulfilled demand
      
      ❌ Old: Calculated waste on OLD inventory (before action)
      ✅ New: Calculate waste on inventory AFTER replenishment
      
      ❌ Old: Quantile spread on OLD inventory
      ✅ New: Quantile spread on inventory AFTER replenishment
    
    Timeline:
      t:   inventory x
      ↓    +action (replenishment)
      t+:  x_clip (after replenishment & capacity constraint)
      ↓    waste happens here: q = waste(x_clip)
      ↓    sales fulfillment    
      t+1: x_next (remaining inventory)
      ↓    stockout if sales_now > x_clip
    """
    # # CORRECTED: Check stockout based on unfulfilled demand
    # stockout_amount = np.maximum(0.0, sales_now - x_clip)        # [P] amount of unmet demand
    # z = (stockout_amount > FLAGS.zero_inventory).astype(np.float32)  # [P] binary indicator
    
    # # CORRECTED: Waste happens on inventory AFTER replenishment
    # q = waste(x_clip)                                            # [P] waste on post-action inventory
    
    # # CORRECTED: Quantile spread on inventory distribution AFTER replenishment
    # quan = float(np.quantile(x_clip, 0.95) - np.quantile(x_clip, 0.05))  # scalar
    # quan_vec = np.full(FLAGS.num_products, quan, dtype=np.float32)       # [P]
    
    # r = (1.0 - z - overstock - q - quan_vec).astype(np.float32)
    # return r, z, quan
def calc_reward(x_old, overstock):
    """
    Reward per product — identical to A2C_mod in training1.py (lines 399-404).

    All penalty terms are computed from x_old (inventory BEFORE action),
    exactly matching:
        z    = tf.cast(x < FLAGS.zero_inventory, tf.float32)
        quan = tf.repeat(quantile(x, 0.95) - quantile(x, 0.05), num_products)
        r    = 1 - z - overstock - q - quan

    Parameters
    ----------
    x_old     : np.ndarray [P]  — inventory BEFORE action (same as 'x' in A2C_mod)
    overstock : np.ndarray [P]  — max(0, x_old + u - 1)
    """
    # z: stockout indicator on OLD inventory (before action)
    z = (x_old < FLAGS.zero_inventory).astype(np.float32)            # [P]

    # q: waste on OLD inventory (before action)
    q = waste(x_old)                                                  # [P]

    # quan: quantile spread on OLD inventory, broadcast to all products
    quan = float(np.quantile(x_old, 0.95) - np.quantile(x_old, 0.05))  # scalar
    quan_vec = np.full(FLAGS.num_products, quan, dtype=np.float32)      # [P]

    r = (1.0 - z - overstock - q - quan_vec).astype(np.float32)      # [P]
    return r, z, quan

print("Environment helpers defined")

Environment helpers defined


In [7]:
# ============================================================
# 5. Q-NETWORK  (Per-Product, matches A2C cloned-agent design)
#
# OLD (commented below): Global network [B, 660] -> [B, 220, 14]
#   All products share info through hidden layers — unfair vs A2C
#
# NEW: Per-product network [B*P, 3] -> [B*P, 14] -> [B, P, 14]
#   Each product only sees its own (x_i, sales_i, q_i)
#   Same weights applied to every product (cloned agent)
#   Matches A2C Actor architecture for fair comparison
# ============================================================

# ┌──────────────────────────────────────────────────────────┐
# │  OLD: Global Q-Network (COMMENTED OUT)                  │
# └──────────────────────────────────────────────────────────┘
# # ============================================================
# # 5. Q-NETWORK
# #
# # Architecture mirrors the A2C Actor/Critic:
# #   Dense → GroupNorm(groups=1) → ReLU → Dropout   (×3 hidden layers)
# #   Dense → reshape to [B, P, A]
# # ============================================================
#
# class MultiProductQNetwork(tf.keras.Model):
#     """
#     Q-Network for multi-product inventory management.
#
#     Input  : [B, num_features]         e.g. [B, 660]
#     Output : [B, num_products, num_actions]  e.g. [B, 220, 14]
#
#     Each hidden layer uses:
#         Dense → GroupNormalization(groups=1) → ReLU → Dropout
#     This matches the GroupNorm usage in the A2C Critic.
#     """
#
#     def __init__(
#         self,
#         num_features: int,
#         num_products: int,
#         num_actions: int,
#         hidden_size: int,
#         dropout_prob: float = 0.1,
#         use_group_norm: bool = True,
#         name: str | None = None,
#     ):
#         super().__init__(name=name)
#
#         self.num_products = num_products
#         self.num_actions  = num_actions
#
#         # ── Shared trunk (same depth as A2C Actor: 3 hidden layers) ──
#         self.dense1 = tf.keras.layers.Dense(hidden_size, activation=None, name="dense1")
#         self.dense2 = tf.keras.layers.Dense(hidden_size, activation=None, name="dense2")
#         self.dense3 = tf.keras.layers.Dense(hidden_size, activation=None, name="dense3")
#
#         # ── Output: one Q-value per (product, action) ──────────────
#         self.out = tf.keras.layers.Dense(num_products * num_actions, activation=None, name="output")
#
#         # ── Normalisation & regularisation ─────────────────────────
#         self._use_gn = use_group_norm
#         if use_group_norm:
#             # GroupNormalization(groups=1) == LayerNorm on the channel axis
#             # Matches tfa.layers.GroupNormalization(groups=1) from A2C Critic
#             self.gn1 = tfa.layers.GroupNormalization(groups=1, name="gn1")
#             self.gn2 = tfa.layers.GroupNormalization(groups=1, name="gn2")
#             self.gn3 = tfa.layers.GroupNormalization(groups=1, name="gn3")
#
#         self.drop1 = tf.keras.layers.Dropout(dropout_prob)
#         self.drop2 = tf.keras.layers.Dropout(dropout_prob)
#         self.drop3 = tf.keras.layers.Dropout(dropout_prob)
#
#     def call(self, state, training: bool = False):
#         # state : [B, num_features]
#         x = self.dense1(state)
#         if self._use_gn:
#             x = self.gn1(x, training=training)
#         x = tf.nn.relu(x)
#         x = self.drop1(x, training=training)
#
#         x = self.dense2(x)
#         if self._use_gn:
#             x = self.gn2(x, training=training)
#         x = tf.nn.relu(x)
#         x = self.drop2(x, training=training)
#
#         x = self.dense3(x)
#         if self._use_gn:
#             x = self.gn3(x, training=training)
#         x = tf.nn.relu(x)
#         x = self.drop3(x, training=training)
#
#         q = self.out(x)                          # [B, P*A]
#         bsz = tf.shape(state)[0]
#         q = tf.reshape(q, [bsz, self.num_products, self.num_actions])  # [B, P, A]
#         return q
#
#
# # Quick sanity-check
# _dummy = tf.zeros([2, FLAGS.num_features])
# _net   = MultiProductQNetwork(
#     FLAGS.num_features, FLAGS.num_products, FLAGS.num_actions,
#     FLAGS.hidden_size, FLAGS.dropout_prob, FLAGS.use_group_norm, name="test"
# )
# _out   = _net(_dummy, training=False)
# print(f"Q-Network output shape: {_out.shape}")
# assert _out.shape == (2, FLAGS.num_products, FLAGS.num_actions), "Shape mismatch!"
# del _dummy, _net, _out
# print("MultiProductQNetwork defined ✓")


# ┌──────────────────────────────────────────────────────────┐
# │  NEW: Per-Product Q-Network (matches A2C cloned-agent)  │
# └──────────────────────────────────────────────────────────┘

class MultiProductQNetwork(tf.keras.Model):
    """
    Per-Product Q-Network — each product is processed INDEPENDENTLY.

    External interface unchanged:
      Input  : [B, num_features]              e.g. [B, 660]
      Output : [B, num_products, num_actions]  e.g. [B, 220, 14]

    Internally:
      1. Reshape  [B, 660] -> [B, 3, 220] -> [B, 220, 3]  (split per product)
      2. Flatten  [B, 220, 3] -> [B*220, 3]
      3. Forward  [B*220, 3] -> Dense(3->H)->GN->ReLU->Drop x3 -> Dense(H->14)
      4. Reshape  [B*220, 14] -> [B, 220, 14]

    This matches A2C Actor exactly:
      - Same input per product: [x_i, sales_i, q_i]  (3 features)
      - Same network depth: 3 hidden layers
      - Same hidden size (configurable)
      - Same GroupNorm(groups=1) + ReLU + Dropout
      - Product i has NO access to product j's features
    """

    def __init__(
        self,
        num_features,
        num_products,
        num_actions,
        hidden_size,
        dropout_prob=0.1,
        use_group_norm=True,
        name=None,
    ):
        super().__init__(name=name)

        self.num_products      = num_products
        self.num_actions        = num_actions
        self.features_per_prod  = num_features // num_products  # 660 // 220 = 3

        # ── Per-product trunk (3 hidden layers, same as A2C Actor) ──
        self.dense1 = tf.keras.layers.Dense(hidden_size, activation=None, name="dense1")
        self.dense2 = tf.keras.layers.Dense(hidden_size, activation=None, name="dense2")
        self.dense3 = tf.keras.layers.Dense(hidden_size, activation=None, name="dense3")

        # ── Output: num_actions Q-values per product ───────────
        self.out = tf.keras.layers.Dense(num_actions, activation=None, name="output")

        # ── Normalisation & regularisation ─────────────────────
        self._use_gn = use_group_norm
        if use_group_norm:
            self.gn1 = tfa.layers.GroupNormalization(groups=1, name="gn1")
            self.gn2 = tfa.layers.GroupNormalization(groups=1, name="gn2")
            self.gn3 = tfa.layers.GroupNormalization(groups=1, name="gn3")

        self.drop1 = tf.keras.layers.Dropout(dropout_prob)
        self.drop2 = tf.keras.layers.Dropout(dropout_prob)
        self.drop3 = tf.keras.layers.Dropout(dropout_prob)

    def call(self, state, training=False):
        """
        state: [B, F]  where F = num_products * features_per_prod  (e.g. 660)

        State layout: [x_0..x_P, sales_0..sales_P, q_0..q_P]
        Rearrange to [B, P, 3] where dim 2 = [x_i, sales_i, q_i]
        """
        B = tf.shape(state)[0]
        P = self.num_products
        F = self.features_per_prod  # 3

        # [B, 660] -> [B, 3, 220] -> [B, 220, 3]
        state_3d = tf.reshape(state, [B, F, P])       # [B, 3, 220]
        state_3d = tf.transpose(state_3d, [0, 2, 1])  # [B, 220, 3]

        # Flatten products into batch dim: [B*220, 3]
        x = tf.reshape(state_3d, [B * P, F])          # [B*P, 3]

        # ── Per-product forward (same weights for every product) ──
        x = self.dense1(x)                             # [B*P, H]
        if self._use_gn:
            x = self.gn1(x, training=training)
        x = tf.nn.relu(x)
        x = self.drop1(x, training=training)

        x = self.dense2(x)                             # [B*P, H]
        if self._use_gn:
            x = self.gn2(x, training=training)
        x = tf.nn.relu(x)
        x = self.drop2(x, training=training)

        x = self.dense3(x)                             # [B*P, H]
        if self._use_gn:
            x = self.gn3(x, training=training)
        x = tf.nn.relu(x)
        x = self.drop3(x, training=training)

        q = self.out(x)                                # [B*P, A]

        # Reshape back: [B*P, A] -> [B, P, A]
        q = tf.reshape(q, [B, P, self.num_actions])    # [B, 220, 14]
        return q


# Quick sanity-check
_dummy = tf.zeros([2, FLAGS.num_features])
_net   = MultiProductQNetwork(
    FLAGS.num_features, FLAGS.num_products, FLAGS.num_actions,
    FLAGS.hidden_size, FLAGS.dropout_prob, FLAGS.use_group_norm, name="test"
)
_out   = _net(_dummy, training=False)
print(f"Q-Network output shape:  {_out.shape}")
assert _out.shape == (2, FLAGS.num_products, FLAGS.num_actions)
print(f"Parameters: {sum(v.numpy().size for v in _net.trainable_variables):,}")
del _dummy, _net, _out
print("PerProduct MultiProductQNetwork defined ✓")

Q-Network output shape:  (2, 220, 14)
Parameters: 36,110
PerProduct MultiProductQNetwork defined ✓


In [8]:
# ============================================================
# 6. EXPERIENCE REPLAY BUFFER
# ============================================================

class ReplayBuffer:
    """
    Circular experience replay buffer.

    Stores transitions: (state, action_indices, reward_vector, next_state, done)
      state / next_state : np.float32 [num_features]   (660,)
      action_indices     : np.int32   [num_products]   (220,) — index into action_space
      reward_vector      : np.float32 [num_products]   (220,) — per-product reward
      done               : float scalar
    """

    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)

    def add(self, state, action_indices, reward_vector, next_state, done):
        self.buffer.append((state, action_indices, reward_vector, next_state, done))

    def sample(self, batch_size: int):
        """
        Returns:
          states      : [B, F]   float32
          actions     : [B, P]   int32
          rewards     : [B, P]   float32
          next_states : [B, F]   float32
          dones       : [B]      float32
        """
        batch = random.sample(self.buffer, batch_size)
        states      = np.array([e[0] for e in batch], dtype=np.float32)
        actions     = np.array([e[1] for e in batch], dtype=np.int32)
        rewards     = np.array([e[2] for e in batch], dtype=np.float32)
        next_states = np.array([e[3] for e in batch], dtype=np.float32)
        dones       = np.array([e[4] for e in batch], dtype=np.float32)
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)


print("ReplayBuffer defined ✓")

ReplayBuffer defined ✓


In [9]:
# ============================================================
# 7. DQN AGENT
# ============================================================

class MultiProductDQNAgent:
    """
    Double-DQN agent for multi-product inventory management.

    Design mirrors the A2C for fair comparison:
      - Same network depth / GroupNorm as A2C Critic
      - Same learning rate (0.001) and optimizer (Adam)
      - Huber loss (more stable than MSE for large action spaces)
      - Double-DQN: online net selects actions, target net evaluates them
      - Per-product epsilon-greedy exploration
      - Vector reward: one reward per product (matches A2C reward shape)
    """

    def __init__(self, config: Config):
        self.config = config

        # ── Online Q-network (trained every step) ────────────
        self.q_network = MultiProductQNetwork(
            config.num_features, config.num_products, config.num_actions,
            config.hidden_size, config.dropout_prob, config.use_group_norm,
            name="q_network",
        )
        # ── Target Q-network (frozen copy, synced every N episodes) ──
        self.target_network = MultiProductQNetwork(
            config.num_features, config.num_products, config.num_actions,
            config.hidden_size, config.dropout_prob, config.use_group_norm,
            name="target_network",
        )

        # ── Optimizer — same as A2C ───────────────────────────
        self.optimizer = tf.optimizers.Adam(config.learning_rate)

        # ── Replay buffer ─────────────────────────────────────
        self.replay_buffer = ReplayBuffer(config.replay_buffer_size)

        # ── Epsilon-greedy parameters ─────────────────────────
        self.epsilon = config.epsilon_start
        self.epsilon_decay = (
            (config.epsilon_start - config.epsilon_end)
            / config.epsilon_decay_episodes
        )

        # ── Step counter (used for checkpoint resume) ─────────
        self.global_step = tf.Variable(0, dtype=tf.int64)

        # ── Build networks so weights exist before set_weights ─
        _dummy = tf.zeros([1, config.num_features], dtype=tf.float32)
        _ = self.q_network(_dummy, training=False)
        _ = self.target_network(_dummy, training=False)
        self.sync_target_network()

        # ── Huber loss (stable for large Q-value magnitudes) ──
        self._huber = tf.keras.losses.Huber(
            reduction=tf.keras.losses.Reduction.NONE
        )

        self.action_space_arr = np.array(config.action_space, dtype=np.float32)

    # ── Target network sync ────────────────────────────────────────
    def sync_target_network(self):
        """Hard-copy weights from online → target network."""
        self.target_network.set_weights(self.q_network.get_weights())

    # ── Action selection ────────────────────────────────────────────
    def select_actions(self, state, training: bool = True):
        """
        Vectorised epsilon-greedy action selection.

        Input  : state [num_features]     (660,)
        Output : action_indices [num_products]  — index into action_space
        """
        state_batch = tf.expand_dims(
            tf.convert_to_tensor(state, dtype=tf.float32), axis=0
        )                                              # [1, F]
        q_vals = self.q_network(state_batch, training=False)[0]  # [P, A]
        greedy = tf.argmax(q_vals, axis=1, output_type=tf.int32)  # [P]

        if not training:
            return greedy.numpy()

        # Per-product random exploration
        explore = (
            tf.random.uniform([self.config.num_products]) < self.epsilon
        )
        rand_acts = tf.random.uniform(
            [self.config.num_products], 0, self.config.num_actions, dtype=tf.int32
        )
        return tf.where(explore, rand_acts, greedy).numpy()

    # ── Double-DQN train step (compiled with tf.function) ──────────
    @tf.function
    def _train_step_tf(
        self, states, actions, rewards, next_states, dones
    ):
        """
        Double-DQN update with per-product (vector) rewards.

        Tensors:
          states      [B, F]
          actions     [B, P]   int32 — action indices
          rewards     [B, P]   float32
          next_states [B, F]
          dones       [B]      float32

        Loss:
          Huber( r + γ · Q_target(s', argmax_a Q_online(s',a)),  Q_online(s,a) )
        """
        gamma       = tf.cast(self.config.gamma, tf.float32)
        states      = tf.cast(states,      tf.float32)
        next_states = tf.cast(next_states, tf.float32)
        actions     = tf.cast(actions,     tf.int32)
        rewards     = tf.cast(rewards,     tf.float32)
        dones       = tf.cast(dones,       tf.float32)

        B = tf.shape(states)[0]
        P = self.config.num_products

        # Index helpers: [B, P]
        b_idx = tf.repeat(tf.range(B)[:, tf.newaxis], P, axis=1)  # [B, P]
        p_idx = tf.repeat(tf.range(P)[tf.newaxis, :], B, axis=0)  # [B, P]

        with tf.GradientTape() as tape:
            # Q(s, a) from online network  →  [B, P]
            q_all = self.q_network(states, training=True)          # [B, P, A]
            g_idx  = tf.stack([b_idx, p_idx, actions], axis=-1)   # [B, P, 3]
            q_sa   = tf.gather_nd(q_all, g_idx)                   # [B, P]

            # Double-DQN: online net picks best next action
            nq_online  = self.q_network(next_states, training=False)       # [B, P, A]
            best_next  = tf.argmax(nq_online, axis=2, output_type=tf.int32) # [B, P]

            # Target net evaluates that action
            nq_target  = self.target_network(next_states, training=False)  # [B, P, A]
            g_next_idx = tf.stack([b_idx, p_idx, best_next], axis=-1)      # [B, P, 3]
            next_q     = tf.gather_nd(nq_target, g_next_idx)               # [B, P]

            td_target = rewards + (1.0 - dones[:, tf.newaxis]) * gamma * next_q  # [B, P]

            # Huber loss, mean across batch × products
            loss = tf.reduce_mean(self._huber(td_target, q_sa))

        grads = tape.gradient(loss, self.q_network.trainable_variables)
        grads, _ = tf.clip_by_global_norm(grads, 10.0)  # gradient clipping
        self.optimizer.apply_gradients(
            zip(grads, self.q_network.trainable_variables)
        )
        return loss

    # ── Public train-step wrapper ──────────────────────────────────
    def train_step(self):
        """Sample a mini-batch and perform one gradient update. Returns loss or None."""
        if len(self.replay_buffer) < self.config.min_replay_size:
            return None
        batch = self.replay_buffer.sample(self.config.batch_size)
        loss  = self._train_step_tf(*batch)
        return float(loss.numpy())

    # ── Epsilon decay ──────────────────────────────────────────────
    def decay_epsilon(self):
        self.epsilon = max(self.config.epsilon_end, self.epsilon - self.epsilon_decay)


print("MultiProductDQNAgent defined ✓")

MultiProductDQNAgent defined ✓


In [10]:
# ============================================================
# 8. TRAINING LOOP
#
# Follows training_220.ipynb A2C loop structure:
#   - Load all sales from TFRecordDataset using sales_parser
#   - Normalise by capacity (same as A2C)
#   - 600 episodes × 900 timesteps
#   - Same reward formula: r = 1 - z - overstock - q - quan
#   - Checkpoint every 10 episodes
# ============================================================

def train_dqn():
    """Main DQN training loop — drop-in replacement for A2C train()."""

    # ── Optional W&B init ──────────────────────────────────────────
    if FLAGS.use_wandb:
        import wandb
        wandb.init(
            project=FLAGS.wandb_project,
            config=vars(FLAGS),
        )

    # ── Load all sales data (same as A2C: TFRecordDataset → sales_parser) ──
    print("Loading data...")
    all_sales_raw = []
    for rec in tf.data.TFRecordDataset(FLAGS.train_file).map(sales_parser):
        all_sales_raw.append(rec["sales"].numpy())
    all_sales_raw = np.array(all_sales_raw, dtype=np.float32)  # [T, P]

    # Capacity — single record (same as A2C)
    capacity = next(
        iter(tf.data.TFRecordDataset(FLAGS.capacity_file).map(capacity_parser))
    )["capacity"].numpy()  # [P]

    # Normalise sales by capacity (same as A2C)
    all_sales = all_sales_raw / capacity[np.newaxis, :]  # [T, P]
    print(f"Sales loaded: {all_sales.shape}  (timesteps × products)")

    # ── Initialise agent ────────────────────────────────────────────
    print("Initialising DQN agent...")
    agent = MultiProductDQNAgent(FLAGS)

    # ── Checkpoint setup ────────────────────────────────────────────
    checkpoint = tf.train.Checkpoint(
        optimizer     = agent.optimizer,
        q_network     = agent.q_network,
        target_network= agent.target_network,
        step          = agent.global_step,
    )
    ckpt_manager = tf.train.CheckpointManager(
        checkpoint, FLAGS.output_dir, max_to_keep=5
    )

    start_episode = 0
    if ckpt_manager.latest_checkpoint:
        checkpoint.restore(ckpt_manager.latest_checkpoint)
        start_episode = int(agent.global_step.numpy())
        # Restore epsilon to where it would be after start_episode decays
        agent.epsilon = max(
            FLAGS.epsilon_end,
            FLAGS.epsilon_start - agent.epsilon_decay * start_episode
        )
        print(f"✓ Restored checkpoint from episode {start_episode}")
    else:
        print("Starting fresh training")

    T = all_sales.shape[0]  # total timesteps in dataset

    print("=" * 60)
    print("DQN Training — A2C Fair Comparison")
    print(f"{'Episodes':>12} : {FLAGS.train_episodes}")
    print(f"{'Timesteps':>12} : {FLAGS.num_timesteps}")
    print(f"{'LR':>12} : {FLAGS.learning_rate}")
    print(f"{'Hidden':>12} : {FLAGS.hidden_size}")
    print(f"{'GroupNorm':>12} : {FLAGS.use_group_norm}")
    print("=" * 60)

    # ── Episode loop ────────────────────────────────────────────────
    for episode in range(start_episode, FLAGS.train_episodes):

        # Random initial inventory in [0, 1], same as A2C
        x = np.random.uniform(0, 1, size=FLAGS.num_products).astype(np.float32)

        # Random start index in the time-series (A2C does a window slide)
        max_start = max(0, T - FLAGS.num_timesteps - 1)
        start_idx = np.random.randint(0, max_start + 1) if episode > 0 else 0
        ep_len    = min(FLAGS.num_timesteps, T - start_idx - 1)

        # Episode-level metric accumulators
        ep_rewards   = []
        ep_losses    = []
        ep_stockouts = []
        ep_waste     = []
        ep_overstock = []
        ep_quantile  = []

        # ── Timestep loop ──────────────────────────────────────────
        for t in range(ep_len):
            idx          = start_idx + t
            sales_now    = all_sales[idx]       # [P]  current period
            sales_next   = all_sales[idx + 1]   # [P]  next-period forecast

            # ── Build state (matches A2C: [x, sales, q] flat) ──────
            q_now = waste(x)                    # [P]  waste estimate
            state = np.concatenate(
                [x, sales_now, q_now], axis=0
            ).astype(np.float32)                # [660]

            # ── Action selection ────────────────────────────────────
            action_idx = agent.select_actions(state, training=True)  # [P]  int
            actions    = agent.action_space_arr[action_idx]          # [P]  float

            # ── Environment step (same dynamics as A2C) ────────────
            x_rep  = x + actions                           # add replenishment
            over   = np.maximum(0.0, x_rep - 1.0)         # overstock before clip
            x_clip = np.minimum(1.0, x_rep)               # clip to capacity
            x_next = np.maximum(0.0, x_clip - sales_now)  # fulfill demand

            # ── Reward (CORRECTED: evaluate consequences of action) ──
            # Pass x_clip (inventory AFTER replenishment) instead of x (OLD inventory)
            # This correctly evaluates: waste on new stock, stockout from unmet demand
            # r, z, quan = calc_reward(x_clip, sales_now, over)  # r: [P]
            r, z, quan = calc_reward(x, over)
            done = 1.0 if (t == ep_len - 1) else 0.0  # terminal flag at episode end
           
            # ── Build next-state ────────────────────────────────────
            q_next     = waste(x_next)
            next_state = np.concatenate(
                [x_next, sales_next, q_next], axis=0
            ).astype(np.float32)                           # [660]

            # ── Store transition ────────────────────────────────────
            agent.replay_buffer.add(state, action_idx, r, next_state, done)

            # ── Collect metrics ─────────────────────────────────────
            # CORRECTED: Track waste on x_clip (consistent with reward calculation)
            # q_clip = waste(x_clip)
            ep_rewards.append(float(np.mean(r)))
            ep_stockouts.append(float(np.mean(z)))
            # ep_waste.append(float(np.mean(q_clip)))  # waste after replenishment
            ep_waste.append(float(np.mean(waste(x))))  # waste on old inventory (matches reward)
            ep_overstock.append(float(np.mean(over)))
            ep_quantile.append(float(quan))

            # ── Gradient update ─────────────────────────────────────
            loss = agent.train_step()
            if loss is not None:
                ep_losses.append(loss)

            # ── Advance state ───────────────────────────────────────
            x = x_next

        # ── End of episode: epsilon decay, target sync, logging ────
        agent.decay_epsilon()

        if (episode + 1) % FLAGS.target_update_freq == 0:
            agent.sync_target_network()

        avg_r  = float(np.mean(ep_rewards))
        avg_l  = float(np.mean(ep_losses)) if ep_losses else 0.0
        avg_so = float(np.mean(ep_stockouts))
        avg_w  = float(np.mean(ep_waste))
        avg_o  = float(np.mean(ep_overstock))
        avg_q  = float(np.mean(ep_quantile))

        print(
            f"Ep {episode+1:4d}/{FLAGS.train_episodes} | "
            f"R={avg_r:+.4f}  L={avg_l:.4f}  "
            f"SO={avg_so:.4f}  W={avg_w:.4f}  "
            f"O={avg_o:.4f}  Q={avg_q:.4f}  "
            f"ε={agent.epsilon:.4f}  buf={len(agent.replay_buffer)}"
        )

        if FLAGS.use_wandb:
            import wandb
            wandb.log({
                "episode": episode + 1,
                "reward":  avg_r, "loss": avg_l,
                "stockout": avg_so, "waste": avg_w,
                "overstock": avg_o, "quantile": avg_q,
                "epsilon": agent.epsilon,
                "buffer_size": len(agent.replay_buffer),
            })

        # Checkpoint every 10 episodes (same as A2C)
        if (episode + 1) % 10 == 0:
            agent.global_step.assign(episode + 1)
            ckpt_manager.save()
            print(f"  ✓ Checkpoint saved at episode {episode+1}")

    # ── Final checkpoint ────────────────────────────────────────────
    agent.global_step.assign(FLAGS.train_episodes)
    ckpt_manager.save()

    if FLAGS.use_wandb:
        import wandb
        wandb.finish()

    print("=" * 60)
    print(f"Training complete! Checkpoints in: {FLAGS.output_dir}")
    print("=" * 60)

    return agent   # return agent for immediate use / evaluation


print("train_dqn() defined ✓")

train_dqn() defined ✓


In [11]:
# ============================================================
# 9. PREDICTION / EVALUATION
#
# Same output format as A2C predict() in training_220.ipynb:
#   stock, action, overstock, sales, stockout, capacity  (one line each per timestep)
# ============================================================

def predict_dqn(checkpoint_dir=None):
    """
    Run the trained DQN on test data and write results to FLAGS.output_file.
    Output format is identical to the A2C predict() for fair metric comparison.
    """
    checkpoint_dir = checkpoint_dir or FLAGS.output_dir

    # ── Load test-period sales ──────────────────────────────────────
    sales_dataset    = tf.data.TFRecordDataset(FLAGS.predict_file).map(sales_parser)
    capacity_dataset = tf.data.TFRecordDataset(FLAGS.capacity_file).map(capacity_parser)
    stock_dataset    = tf.data.TFRecordDataset(FLAGS.stock_file).map(stock_parser)

    capacity = next(iter(capacity_dataset))["capacity"]  # [P]
    x        = next(iter(stock_dataset))["stock"]        # [P]  initial stock

    # ── Load agent ──────────────────────────────────────────────────
    print("Initialising agent for prediction...")
    agent    = MultiProductDQNAgent(FLAGS)
    checkpoint = tf.train.Checkpoint(
        q_network=agent.q_network,
        step=agent.global_step,
    )
    ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)

    if ckpt_manager.latest_checkpoint:
        checkpoint.restore(ckpt_manager.latest_checkpoint).expect_partial()
        print(f"✓ Loaded: {ckpt_manager.latest_checkpoint}")
    else:
        print("✗ No checkpoint found. Aborting.")
        return

    # ── Predict & write ─────────────────────────────────────────────
    print(f"Writing predictions to {FLAGS.output_file}...")
    with open(FLAGS.output_file, "w") as writer:
        for rec in sales_dataset:
            sales = tf.divide(rec["sales"], capacity)  # normalise by capacity
            q     = waste(x.numpy())

            state       = np.concatenate([x.numpy(), sales.numpy(), q], axis=0)
            action_idx  = agent.select_actions(state, training=False)
            actions     = agent.action_space_arr[action_idx]
            u           = tf.constant(actions, dtype=tf.float32)

            overstock = tf.maximum(0.0, (x + u) - 1.0)
            x_u       = tf.minimum(1.0, x + u)
            stockout  = tf.minimum(0.0, x_u - sales)

            # Same line format as A2C predict()
            writer.write("stock:"     + ",".join(map(str, x.numpy()))                    + "\n")
            writer.write("action:"    + ",".join(map(str, u.numpy()))                    + "\n")
            writer.write("overstock:" + ",".join(map(str, overstock.numpy()))            + "\n")
            writer.write("sales:"     + ",".join(map(str, sales.numpy()))                + "\n")
            writer.write("stockout:"  + ",".join(map(str, stockout.numpy()))             + "\n")
            writer.write("capacity:"  + ",".join(map(str, (capacity / capacity).numpy())) + "\n")

            x = tf.maximum(0.0, x_u - sales)

    print(f"✓ Prediction complete — {FLAGS.output_file}")


print("predict_dqn() defined ✓")

predict_dqn() defined ✓


In [12]:
# ============================================================
# 10. DATA FILE CHECK
# ============================================================

files = [
    FLAGS.train_file, FLAGS.capacity_file,
    FLAGS.stock_file, FLAGS.predict_file,
]

print("Checking data files:")
print("=" * 50)
all_ok = True
for fp in files:
    ok = os.path.exists(fp)
    print(f"  {'✓' if ok else '✗'}  {fp}")
    if not ok:
        all_ok = False
print("=" * 50)
if all_ok:
    print("✓ All data files present — ready to train!")
else:
    print(
        "✗ Missing files. Run prepare_data.py first:\n"
        "  python prepare_data.py --number_of_products 220 --middle_time_period 900 \\\n"
        "    --train_tfrecords_file data220/train.tfrecords \\\n"
        "    --test_tfrecords_file data220/test.tfrecords \\\n"
        "    --capacity_tfrecords_file data220/capacity.tfrecords \\\n"
        "    --stock_tfrecords_file data220/stock.tfrecords"
    )

Checking data files:
  ✓  data220/train.tfrecords
  ✓  data220/capacity.tfrecords
  ✓  data220/stock.tfrecords
  ✓  data220/test.tfrecords
✓ All data files present — ready to train!


In [13]:
# ============================================================
# 11. RUN TRAINING
# ============================================================

# Set FLAGS.use_wandb = True above if you want W&B tracking.

print("=" * 60)
print("STARTING DQN TRAINING (A2C fair-comparison mode)")
print("=" * 60)

try:
    trained_agent = train_dqn()
except KeyboardInterrupt:
    print("\n" + "=" * 60)
    print("Training interrupted — checkpoints saved.")
    print("=" * 60)
except Exception as exc:
    import traceback
    print("\n" + "=" * 60)
    print(f"ERROR: {exc}")
    traceback.print_exc()
    print("=" * 60)

STARTING DQN TRAINING (A2C fair-comparison mode)
Loading data...
Sales loaded: (900, 220)  (timesteps × products)
Initialising DQN agent...
✓ Restored checkpoint from episode 260
DQN Training — A2C Fair Comparison
    Episodes : 600
   Timesteps : 900
          LR : 0.001
      Hidden : 128
   GroupNorm : True
Ep  261/600 | R=+0.4661  L=0.0000  SO=0.0053  W=0.0207  O=0.0372  Q=0.4708  ε=0.3540  buf=899
Ep  262/600 | R=+0.4719  L=0.0302  SO=0.0054  W=0.0206  O=0.0376  Q=0.4645  ε=0.3516  buf=1798
Ep  263/600 | R=+0.4751  L=0.0261  SO=0.0051  W=0.0207  O=0.0377  Q=0.4614  ε=0.3491  buf=2697
Ep  264/600 | R=+0.4718  L=0.0263  SO=0.0054  W=0.0207  O=0.0374  Q=0.4647  ε=0.3466  buf=3596
Ep  265/600 | R=+0.4746  L=0.0261  SO=0.0054  W=0.0207  O=0.0375  Q=0.4618  ε=0.3441  buf=4495
Ep  266/600 | R=+0.4779  L=0.0269  SO=0.0052  W=0.0207  O=0.0380  Q=0.4583  ε=0.3417  buf=5394
Ep  267/600 | R=+0.4772  L=0.0274  SO=0.0052  W=0.0207  O=0.0375  Q=0.4594  ε=0.3392  buf=6293
Ep  268/600 | R=+0.4791 

In [14]:
# ============================================================
# 12. RUN PREDICTION
#     (run this cell AFTER training, or after restoring a checkpoint)
# ============================================================

try:
    predict_dqn()
except Exception as exc:
    import traceback
    print(f"ERROR: {exc}")
    traceback.print_exc()

Initialising agent for prediction...
✓ Loaded: checkpoints_dqn_comparison3\ckpt-50
Writing predictions to ./output_dqn_comparison.csv3...
✓ Prediction complete — ./output_dqn_comparison.csv3
