In [9]:
import json
import numpy as np
import torch
from torch_geometric.data import Data
from job_shop_lib.dispatching.feature_observers import (
    feature_observer_factory,
    FeatureObserverType,
)
from job_shop_lib.graphs import build_agent_task_graph
from job_shop_lib.reinforcement_learning import (
    SingleJobShopGraphEnv,
    ObservationDict,
    ObservationSpaceKey,
)
from job_shop_lib import Schedule, Operation
from gnn_scheduler.utils import get_data_path

In [6]:
DATA_PATH = get_data_path()
schedules_json = json.load(open(DATA_PATH / "small_random_instances_0.json"))

In [7]:
len(schedules_json)

100000

In [10]:
features_observers_types = [
    FeatureObserverType.DURATION,
    FeatureObserverType.EARLIEST_START_TIME,
    FeatureObserverType.IS_COMPLETED,
    FeatureObserverType.IS_SCHEDULED,
    FeatureObserverType.POSITION_IN_JOB,
    FeatureObserverType.REMAINING_OPERATIONS,
]

In [8]:
for schedule_dict in schedules_json:
    schedule = Schedule.from_dict(**schedule_dict)
    graph = build_agent_task_graph(schedule.instance)
    env = SingleJobShopGraphEnv(
        schedule, feature_observer_configs=features_observers_types
    )

    break

[[S-Op(operation=O(m=0, d=40, j=2, p=0), start_time=0, machine_id=0), S-Op(operation=O(m=0, d=91, j=4, p=1), start_time=40, machine_id=0), S-Op(operation=O(m=0, d=97, j=1, p=2), start_time=131, machine_id=0), S-Op(operation=O(m=0, d=91, j=5, p=0), start_time=228, machine_id=0), S-Op(operation=O(m=0, d=75, j=0, p=4), start_time=319, machine_id=0), S-Op(operation=O(m=0, d=8, j=3, p=4), start_time=472, machine_id=0)], [S-Op(operation=O(m=1, d=65, j=1, p=0), start_time=0, machine_id=1), S-Op(operation=O(m=1, d=94, j=2, p=1), start_time=65, machine_id=1), S-Op(operation=O(m=1, d=52, j=0, p=2), start_time=159, machine_id=1), S-Op(operation=O(m=1, d=43, j=4, p=3), start_time=211, machine_id=1), S-Op(operation=O(m=1, d=82, j=3, p=1), start_time=254, machine_id=1), S-Op(operation=O(m=1, d=31, j=5, p=2), start_time=344, machine_id=1)], [S-Op(operation=O(m=2, d=6, j=0, p=0), start_time=0, machine_id=2), S-Op(operation=O(m=2, d=37, j=1, p=1), start_time=65, machine_id=2), S-Op(operation=O(m=2, d=7

In [None]:
def create_action_probabilities(
    env: SingleJobShopGraphEnv,
    schedule: Schedule,
    observation: ObservationDict,
) -> torch.Tensor:
    """
    Creates a probability distribution over actions based on the optimal schedule.

    Args:
        env: The job shop environment
        schedule: The solved schedule containing optimal actions
        observation: Current observation from the environment

    Returns:
        torch.Tensor: Probability distribution over jobs (one-hot if single optimal action)
    """
    # Get currently available operations
    ready_jobs = []
    for job_id, is_ready in enumerate(
        observation[ObservationSpaceKey.JOBS]
    ):
        if is_ready[0] == 1.0:
            ready_jobs.append(job_id)

    if not ready_jobs:
        # No legal actions - shouldn't happen but handle gracefully
        print("No legal actions")
        probs = torch.zeros(env.instance.num_jobs)
        return probs

    # Find which of the ready operations appear next in the optimal schedule
    optimal_jobs = []
    current_time = env.dispatcher.current_time()

    # Check each machine's schedule
    for machine_schedule in schedule.schedule:
        # Look for the next operation after current_time
        for scheduled_op in machine_schedule:
            if scheduled_op.start_time >= current_time:
                # If this operation is from a job that's currently ready
                if scheduled_op.job_id in ready_jobs:
                    optimal_jobs.append(scheduled_op.job_id)
                break

    if not optimal_jobs:
        # If no optimal jobs found (shouldn't happen with valid schedule)
        # Distribute probability equally among ready jobs
        probs = torch.zeros(env.instance.num_jobs)
        prob_value = 1.0 / len(ready_jobs)
        for job_id in ready_jobs:
            probs[job_id] = prob_value
        return probs

    # Create probability distribution
    probs = torch.zeros(env.instance.num_jobs)
    prob_value = 1.0 / len(optimal_jobs)
    for job_id in optimal_jobs:
        probs[job_id] = prob_value

    return probs