In [2]:
import json
import numpy as np
import torch
from torch_geometric.data import Data
from job_shop_lib.dispatching.feature_observers import (
    feature_observer_factory,
    FeatureObserverType,
)
from job_shop_lib.graphs import build_resource_task_graph
from job_shop_lib.reinforcement_learning import (
    SingleJobShopGraphEnv,
    ObservationDict,
    ObservationSpaceKey,
)
from job_shop_lib.dispatching import Dispatcher, OptimalOperationsObserver
from job_shop_lib import Schedule, Operation
from gnn_scheduler.utils import get_data_path

In [3]:
DATA_PATH = get_data_path()
schedules_json = json.load(open(DATA_PATH / "small_random_instances_0.json"))

In [4]:
len(schedules_json)

100000

In [5]:
features_observers_types = [
    FeatureObserverType.DURATION,
    FeatureObserverType.EARLIEST_START_TIME,
    FeatureObserverType.IS_COMPLETED,
    FeatureObserverType.IS_SCHEDULED,
    FeatureObserverType.POSITION_IN_JOB,
    FeatureObserverType.REMAINING_OPERATIONS,
]

In [6]:
def get_optimal_actions(
    optimal_ops_observer: OptimalOperationsObserver,
    available_operations_with_ids: list[tuple[int, int, int]],
):
    """Creates a probability distribution over actions based on the optimal
    schedule.

    Args:
        optimal_ops_observer: The observer that provides optimal operations.
        available_operations_with_ids: List of available operations with their
        IDs (operation_id, machine_id, job_id).

    Returns:
        dict: A dictionary containing the action probabilities of each tuple
        (operation_id, machine_id, job_id) in the available actions
    """
    optimal_actions = {}
    optimal_ops = optimal_ops_observer.optimal_available
    optimal_ops_ids = [
        (op.operation_id, op.machine_id, op.job_id) for op in optimal_ops
    ]
    for operation_id, machine_id, job_id in available_operations_with_ids:
        is_optimal = (operation_id, machine_id, job_id) in optimal_ops_ids
        optimal_actions[(operation_id, machine_id, job_id)] = int(is_optimal)
    return optimal_actions

In [10]:
dataset = {}
for schedule_dict in schedules_json:
    observations = []
    action_probabilities_sequence = []
    schedule = Schedule.from_dict(**schedule_dict)
    graph = build_resource_task_graph(schedule.instance)
    env = SingleJobShopGraphEnv(
        graph, feature_observer_configs=features_observers_types
    )
    optimal_ops_observer = OptimalOperationsObserver(env.dispatcher, schedule)
    obs, info = env.reset()
    done = False
    while not done:
        action_probs = get_optimal_actions(
            optimal_ops_observer, info["available_operations_with_ids"]
        )
        if len(action_probs) > 1:
            observations.append(obs)
            action_probabilities_sequence.append(action_probs)
        _, machine_id, job_id = max(action_probs, key=action_probs.get)
        obs, reward, done, _, info = env.step((job_id, machine_id))
    makespan = env.dispatcher.schedule.makespan()
    assert makespan == schedule.makespan()
    dataset[schedule.instance.name] = (
        observations,
        action_probabilities_sequence,
        makespan,
    )
    print(len(env.reward_function.rewards))

    break

30


In [11]:
dataset["small_random_instance_1"]

([{'removed_nodes': array([False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False]),
   'edge_index': array([[ 0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  3,
            3,  3,  3,  3,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  6,  6,
            6,  6,  6,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  9,  9,  9,
            9,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 12,
           12, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15,
           16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 19,
           19, 19, 19, 19, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 22, 22,
           22, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 25, 25, 25,
           25, 25, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28,
       

In [13]:
len(dataset["small_random_instance_1"][0])

29