# MARL Training Demo â€“ Actor-Critic on FJSP

This notebook runs a short training session of the simple actor-critic policy on the toy Brandimarte-style FJSP instance and visualizes learning curves and policy behaviour.

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from src.seed_utils import SeedConfig, set_global_seeds
from src.ac_training import TrainingConfig, run_training

# Global seed for notebook reproducibility
set_global_seeds(SeedConfig(base_seed=2026))

instance_path = Path("data/brandimarte_mk_toy.txt")

train_config = TrainingConfig(
    instance_path=instance_path,
    seed_config=SeedConfig(base_seed=2026),
    num_episodes=100,
    max_steps_per_episode=64,
    log_interval=20,
    device="cpu",
)
train_config

In [None]:
# Quick smoke test: run a single episode with the untrained policy via run_training on a tiny run
small_config = TrainingConfig(
    instance_path=instance_path,
    seed_config=SeedConfig(base_seed=2027),
    num_episodes=5,
    max_steps_per_episode=32,
    log_interval=5,
    device="cpu",
)
small_metrics = run_training(small_config)
small_metrics["episode_rewards"]

In [None]:
# Main training run (short)
metrics = run_training(train_config)

episode_rewards = np.array(metrics["episode_rewards"])  # negative makespan sums
episode_makespans = np.array(metrics["episode_makespans"])
losses = np.array(metrics["losses"])
value_losses = np.array(metrics["value_losses"])

episode_rewards[:5], episode_makespans[:5]

In [None]:
# Plot learning curves
fig, axes = plt.subplots(3, 1, figsize=(10, 10), sharex=True)

episodes = np.arange(1, len(episode_rewards) + 1)

axes[0].plot(episodes, episode_rewards, label="Episode reward (sum of step rewards)")
axes[0].set_ylabel("Reward")
axes[0].legend()

axes[1].plot(episodes, episode_makespans, label="Final makespan", color="tab:orange")
axes[1].set_ylabel("Makespan")
axes[1].legend()

axes[2].plot(episodes, losses, label="Total loss")
axes[2].plot(episodes, value_losses, label="Value loss", linestyle="--")
axes[2].set_xlabel("Episode")
axes[2].set_ylabel("Loss")
axes[2].legend()

fig.tight_layout()
plt.show()

## Policy behaviour visualisation

Load the last checkpoint and run a few evaluation episodes with a deterministic policy to visualise schedules.

In [None]:
import torch

from src.fjsp_env import FJSPEnv, FJSPEnvConfig
from src.graph_builder import build_graph_from_env_state
from src.marl_policy import FJSPActorCritic

# Load last checkpoint from training run above
ckpt_path = train_config.checkpoint_dir / f"{train_config.checkpoint_prefix}_ep{train_config.num_episodes}.pt"
checkpoint = torch.load(ckpt_path, map_location="cpu")

policy_eval = FJSPActorCritic()
policy_eval.load_state_dict(checkpoint["model_state_dict"])
policy_eval.eval()

env_cfg = FJSPEnvConfig(instance_path=instance_path, seed_config=SeedConfig(base_seed=3030))
env = FJSPEnv(env_cfg)

def run_eval_episode(env: FJSPEnv, policy: FJSPActorCritic):
    obs = env.reset()
    done = False
    step = 0
    trajectory = []

    while not done and step < 64:
        assert env._step_jobs is not None
        assert env._step_machines is not None

        graph = build_graph_from_env_state(env._step_jobs, env._step_machines)
        feasible_actions = obs["feasible_actions"]

        out = policy.get_action_and_value(graph, feasible_actions, deterministic=True)
        action_idx, logits, value = out

        obs, reward, done, info = env.step(action_idx)
        trajectory.append((step, feasible_actions, action_idx, reward, done))
        step += 1

    schedule = env.last_schedule
    makespan = env.last_makespan
    return trajectory, schedule, makespan

traj, schedule, makespan = run_eval_episode(env, policy_eval)
len(traj), makespan

In [None]:
# Simple Gantt-like plot of the resulting schedule
fig, ax = plt.subplots(figsize=(8, 4))

machine_ids = sorted(schedule.keys())
yticks = []
yticklabels = []

for i, m_id in enumerate(machine_ids):
    ops = schedule[m_id]
    yticks.append(i)
    y = i
    yticklabels.append(f"M{m_id}")
    for (job_id, op_idx, start, end) in ops:
        ax.barh(y, end - start, left=start, edgecolor="black")
        ax.text(start + 0.1, y, f"J{job_id}-O{op_idx}", va="center", ha="left", fontsize=8)

ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels)
ax.set_xlabel("Time")
ax.set_title(f"Schedule Gantt chart (makespan={makespan:.2f})")
fig.tight_layout()
plt.show()