In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import numpy as np
import torch
from torch_geometric.data import HeteroData
from job_shop_lib.dispatching.feature_observers import (
    FeatureObserverType,
)
from job_shop_lib.graphs import build_resource_task_graph
from job_shop_lib.reinforcement_learning import (
    SingleJobShopGraphEnv,
    ResourceTaskGraphObservation,
    get_optimal_actions,
)
from job_shop_lib.dispatching import OptimalOperationsObserver
from job_shop_lib import Schedule
from gnn_scheduler.utils import get_data_path

In [3]:
DATA_PATH = get_data_path()
schedules_json = json.load(open(DATA_PATH / "small_random_instances_0.json"))

In [4]:
len(schedules_json)

100000

In [5]:
features_observers_types = [
    FeatureObserverType.DURATION,
    FeatureObserverType.EARLIEST_START_TIME,
    FeatureObserverType.IS_SCHEDULED,
    FeatureObserverType.POSITION_IN_JOB,
    FeatureObserverType.REMAINING_OPERATIONS,
]

In [6]:
OPERATION_FEATURES_TO_NORMALIZE = [
    0,  # Duration
    1,  # EarliestStartTime
    4,  # PositionInJob
    5,  # Job duration
    6,  # Job earliest start time
    9,  # Job remaining operations
]
MACHINE_FEATURES_TO_NORMALIZE = [
    0,  # Duration
    1,  # EarliestStartTime
    4,  # RemainingOperations
]

features_to_normalize = {
    "operation": OPERATION_FEATURES_TO_NORMALIZE,
    "machine": MACHINE_FEATURES_TO_NORMALIZE,
}

In [7]:
def _normalize_features(
    node_features_dict: dict[str, np.ndarray],
    indices_to_normalize: dict[str, list[int]] | None = None,
):
    if indices_to_normalize is None:
        indices_to_normalize = {
            "operation": np.arange(8),
            "machine": np.arange(4),
        }
    for key, indices in indices_to_normalize.items():
        # Divide by the maximum value checking for division by zero
        max_values = np.max(node_features_dict[key], axis=0)
        max_values[max_values == 0] = 1
        node_features_dict[key][:, indices] /= max_values[indices]

    return node_features_dict

In [None]:
dataset = {}
for schedule_dict in schedules_json:
    observations = []
    action_probabilities_sequence = []
    schedule = Schedule.from_dict(**schedule_dict)
    graph = build_resource_task_graph(schedule.instance)
    env = SingleJobShopGraphEnv(
        graph, feature_observer_configs=features_observers_types
    )
    env = ResourceTaskGraphObservation(env)
    optimal_ops_observer = OptimalOperationsObserver(
        env.unwrapped.dispatcher, schedule
    )
    obs, info = env.reset()
    done = False
    while not done:
        action_probs = get_optimal_actions(
            optimal_ops_observer, info["available_operations_with_ids"]
        )
        if len(action_probs) > 1:
            obs["node_features_dict"] = _normalize_features(
                obs["node_features_dict"]
            )
            observations.append(obs)
            action_probabilities_sequence.append(action_probs)
        _, machine_id, job_id = max(action_probs, key=action_probs.get)
        obs, reward, done, _, info = env.step((job_id, machine_id))
    makespan = env.unwrapped.dispatcher.schedule.makespan()
    assert makespan == schedule.makespan()
    dataset[schedule.instance.name] = (
        observations,
        action_probabilities_sequence,
    )

In [21]:
info

{'feature_names': defaultdict(list,
             {<FeatureType.OPERATIONS: 'operations'>: ['Duration',
               'EarliestStartTime',
               'IsScheduled',
               'PositionInJob'],
              <FeatureType.MACHINES: 'machines'>: ['Duration',
               'EarliestStartTime',
               'IsScheduled',
               'RemainingOperations'],
              <FeatureType.JOBS: 'jobs'>: ['Duration',
               'EarliestStartTime',
               'IsScheduled',
               'RemainingOperations']}),
 'available_operations_with_ids': []}

In [9]:
observations, action_probabilities_sequence = dataset[schedule.instance.name]
assert len(observations) == len(action_probabilities_sequence)

In [20]:
hetero_dataset = []
for obs, action_probs in zip(observations, action_probabilities_sequence):
    hetero_data = HeteroData()
    for key, value in obs.items():
        for subkey, subvalue in value.items():
            if key == "node_features_dict":
                hetero_data[subkey].x = torch.from_numpy(subvalue)
            elif key == "edge_index_dict":
                hetero_data[subkey].edge_index = torch.from_numpy(subvalue)
    hetero_data["y"] = torch.tensor(list(action_probs.values()))
    hetero_data["valid_pairs"] = torch.tensor(list(action_probs.keys()))
    hetero_dataset.append(hetero_data)

In [11]:
hetero_dataset[0].edge_index_dict

{('operation',
  'to',
  'operation'): tensor([[ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,  4,  4,
           4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,  7,  7,  8,  8,  8,  8,
           9,  9,  9,  9, 10, 10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 13, 13,
          13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 17, 17, 17, 17,
          18, 18, 18, 18, 19, 19, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 22, 22,
          22, 22, 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26,
          27, 27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29],
         [ 1,  2,  3,  4,  0,  2,  3,  4,  0,  1,  3,  4,  0,  1,  2,  4,  0,  1,
           2,  3,  6,  7,  8,  9,  5,  7,  8,  9,  5,  6,  8,  9,  5,  6,  7,  9,
           5,  6,  7,  8, 11, 12, 13, 14, 10, 12, 13, 14, 10, 11, 13, 14, 10, 11,
          12, 14, 10, 11, 12, 13, 16, 17, 18, 19, 15, 17, 18, 19, 15, 16, 18, 19,
          15, 16, 17, 19, 15, 16, 17, 18, 21, 22, 23, 24, 20, 22, 23, 24, 20, 21,
 

In [22]:
from gnn_scheduler.model import ResidualSchedulingGNN, HeteroMetadata

metadata = HeteroMetadata(node_types=["operation", "machine"])


model = ResidualSchedulingGNN(
    metadata=metadata, in_channels_dict={"operation": 8, "machine": 4}
)

a = model(hetero_dataset[0].x_dict, hetero_dataset[0].edge_index_dict, hetero_dataset[0].valid_pairs)

In [14]:
hetero_dataset[0]

HeteroData(
  y=[6],
  valid_pairs=[6, 3],
  operation={ x=[30, 8] },
  machine={ x=[5, 4] },
  (operation, to, operation)={ edge_index=[2, 120] },
  (operation, to, machine)={ edge_index=[2, 30] },
  (machine, to, operation)={ edge_index=[2, 30] },
  (machine, to, machine)={ edge_index=[2, 20] }
)

In [24]:
torch.cat([a, a], dim=0)

tensor([[ 0.0851],
        [-0.1582],
        [-0.0196],
        [-0.3940],
        [ 0.0272],
        [-0.2747],
        [ 0.0851],
        [-0.1582],
        [-0.0196],
        [-0.3940],
        [ 0.0272],
        [-0.2747]], grad_fn=<CatBackward0>)