In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import random
import numpy as np
import torch
import tqdm
from job_shop_lib.dispatching.feature_observers import (
    FeatureObserverType,
)
from job_shop_lib.graphs import build_resource_task_graph
from job_shop_lib.reinforcement_learning import (
    SingleJobShopGraphEnv,
    ResourceTaskGraphObservation,
    get_optimal_actions,
)
from job_shop_lib.dispatching import OptimalOperationsObserver
from job_shop_lib import Schedule
from gnn_scheduler.utils import get_data_path
from gnn_scheduler.data import JobShopData

In [3]:
DATA_PATH = get_data_path()
schedules_json = json.load(open(DATA_PATH / "small_random_instances_0.json"))

In [4]:
# from job_shop_lib.constraint_programming import ORToolsSolver
# from job_shop_lib.generation import GeneralInstanceGenerator
# import tqdm

# instance_generator = GeneralInstanceGenerator(
#     duration_range=(1, 10),
#     num_jobs=4,
#     num_machines=3,
#     iteration_limit=1000,
#     seed=42,
# )
# schedules_json = []
# for instance in tqdm.tqdm(instance_generator):
#     solver = ORToolsSolver()
#     schedule = solver.solve(instance)
#     schedules_json.append(schedule.to_dict())

In [5]:
len(schedules_json)

100000

In [6]:
features_observers_types = [
    FeatureObserverType.DURATION,
    FeatureObserverType.EARLIEST_START_TIME,
    FeatureObserverType.IS_SCHEDULED,
    FeatureObserverType.POSITION_IN_JOB,
    FeatureObserverType.REMAINING_OPERATIONS,
]

In [7]:
OPERATION_FEATURES_TO_NORMALIZE = [
    0,  # Duration
    1,  # EarliestStartTime
    4,  # PositionInJob
    5,  # Job duration
    6,  # Job earliest start time
    9,  # Job remaining operations
]
MACHINE_FEATURES_TO_NORMALIZE = [
    0,  # Duration
    1,  # EarliestStartTime
    4,  # RemainingOperations
]

features_to_normalize = {
    "operation": OPERATION_FEATURES_TO_NORMALIZE,
    "machine": MACHINE_FEATURES_TO_NORMALIZE,
}

In [8]:
def _normalize_features(
    node_features_dict: dict[str, np.ndarray],
    indices_to_normalize: dict[str, list[int]] | None = None,
):
    if indices_to_normalize is None:
        indices_to_normalize = {
            "operation": np.arange(8),
            "machine": np.arange(4),
        }
    for key, indices in indices_to_normalize.items():
        # Divide by the maximum value checking for division by zero
        max_values = np.max(node_features_dict[key], axis=0)
        max_values[max_values == 0] = 1
        node_features_dict[key][:, indices] /= max_values[indices]

    return node_features_dict

In [9]:
def _map_available_operations_with_ids_to_original_ids(
    available_operations_with_ids, original_ids: dict[str, np.ndarray]
):
    new_ids = []
    for operation_id, machine_id, job_id in available_operations_with_ids:
        original_operation_id = original_ids["operation"][operation_id]
        original_machine_id = original_ids["machine"][machine_id]
        new_ids.append((original_operation_id, original_machine_id, job_id))
    return new_ids

In [25]:
dataset = {}
for schedule_dict in tqdm.tqdm(schedules_json):
    observations = []
    action_probabilities_sequence = []
    schedule = Schedule.from_dict(**schedule_dict)
    graph = build_resource_task_graph(schedule.instance)
    env = SingleJobShopGraphEnv(
        graph,
        feature_observer_configs=features_observers_types,
        ready_operations_filter=None,
    )
    env = ResourceTaskGraphObservation(env)
    optimal_ops_observer = OptimalOperationsObserver(
        env.unwrapped.dispatcher, schedule
    )
    correct_schedule = False
    attempts = 0
    while not correct_schedule:
        obs, info = env.reset()
        done = False
        while not done:

            action_probs = get_optimal_actions(
                optimal_ops_observer,
                _map_available_operations_with_ids_to_original_ids(
                    info["available_operations_with_ids"],
                    obs["original_ids_dict"],
                ),
            )
            if len(action_probs) > 1:
                obs["node_features_dict"] = _normalize_features(
                    obs["node_features_dict"]
                )
                observations.append(obs)
                action_probs_adjusted = {}
                assert len(info["available_operations_with_ids"]) == len(
                    action_probs
                )
                for key, value in zip(
                    info["available_operations_with_ids"], action_probs.values()
                ):
                    action_probs_adjusted[key] = value
                action_probabilities_sequence.append(action_probs_adjusted)
            if max(action_probs.values()) != 1.0:
                correct_schedule = False
                break
            optimal_actions = [
                action
                for action, value in action_probs.items()
                if value == 1.0
            ]
            action_choice = random.choice(optimal_actions)
            _, machine_id, job_id = action_choice
            # machine_id = obs["original_ids_dict"]["machine"][machine_id]
            obs, reward, done, _, info = env.step((job_id, machine_id))
        makespan = env.unwrapped.dispatcher.schedule.makespan()
        correct_schedule = makespan == schedule.makespan()
        if not correct_schedule:
            attempts += 1
            name = schedule.instance.name
            print(f"Failed to generate correct schedule for {name}")
            print(f"Attempt {attempts}")
            if attempts >= 100:
                raise ValueError("Failed to generate correct schedule")
    dataset[schedule.instance.name] = (
        observations,
        action_probabilities_sequence,
    )

  0%|          | 152/100000 [00:06<1:05:47, 25.29it/s]


KeyboardInterrupt: 

In [26]:
observations = []
action_probabilities_sequence = []
for obs, action_probs in dataset.values():
    observations.extend(obs)
    action_probabilities_sequence.extend(action_probs)

In [27]:
hetero_dataset = []
for obs, action_probs in tqdm.tqdm(
    zip(observations, action_probabilities_sequence)
):
    job_shop_data = JobShopData()
    for key, value in obs.items():
        for subkey, subvalue in value.items():
            if key == "node_features_dict":
                job_shop_data[subkey].x = torch.from_numpy(subvalue)
            elif key == "edge_index_dict":
                job_shop_data[subkey].edge_index = torch.from_numpy(subvalue)
    job_shop_data["y"] = torch.tensor(
        list(action_probs.values()), dtype=torch.float32
    )
    job_shop_data["valid_pairs"] = torch.tensor(list(action_probs.keys()))
    hetero_dataset.append(job_shop_data)

3093it [00:00, 3325.14it/s]


In [28]:
def print_dtypes(data):
    """Print dtypes of all tensors in a HeteroData object"""
    # Print global attributes
    for key in data.keys():
        if not isinstance(data[key], (dict, list)):
            print(f"{key}: {type(data[key])}")
            if isinstance(data[key], torch.Tensor):
                print(f"  Shape: {data[key].shape}, Dtype: {data[key].dtype}")

    # Print node attributes
    for node_type in data.node_types:
        print(f"\nNode type: {node_type}")
        for key, value in data[node_type].items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: Shape {value.shape}, Dtype: {value.dtype}")

    # Print edge attributes
    for edge_type in data.edge_types:
        print(f"\nEdge type: {edge_type}")
        for key, value in data[edge_type].items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: Shape {value.shape}, Dtype: {value.dtype}")


# Usage
print_dtypes(hetero_dataset[0])

x: <class 'torch_geometric.data.storage.NodeStorage'>
y: <class 'torch.Tensor'>
  Shape: torch.Size([6]), Dtype: torch.float32
edge_index: <class 'torch_geometric.data.storage.NodeStorage'>
valid_pairs: <class 'torch.Tensor'>
  Shape: torch.Size([6, 3]), Dtype: torch.int64

Node type: operation
  x: Shape torch.Size([30, 8]), Dtype: torch.float32

Node type: machine
  x: Shape torch.Size([5, 4]), Dtype: torch.float32

Node type: x

Node type: edge_index

Edge type: ('operation', 'to', 'operation')
  edge_index: Shape torch.Size([2, 120]), Dtype: torch.int64

Edge type: ('operation', 'to', 'machine')
  edge_index: Shape torch.Size([2, 30]), Dtype: torch.int64

Edge type: ('machine', 'to', 'operation')
  edge_index: Shape torch.Size([2, 30]), Dtype: torch.int64

Edge type: ('machine', 'to', 'machine')
  edge_index: Shape torch.Size([2, 20]), Dtype: torch.int64


In [29]:
from gnn_scheduler.model import ResidualSchedulingGNN, HeteroMetadata

metadata = HeteroMetadata(node_types=["operation", "machine"])


model = ResidualSchedulingGNN(
    metadata=metadata, in_channels_dict={"operation": 8, "machine": 4}
)

model(
    hetero_dataset[0].x_dict,
    hetero_dataset[0].edge_index_dict,
    hetero_dataset[0].valid_pairs,
)

tensor([[ 0.6649],
        [-0.0698],
        [ 0.3180],
        [ 0.6859],
        [-0.5058],
        [ 0.5251]], grad_fn=<AddmmBackward0>)

In [None]:
# Setting CUDA_LAUNCH_BLOCKING=1 for debugging
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [39]:
from torch_geometric.data import DataLoader
from torch import nn

# BCEWithLogitsLoss is used because the model outputs logits

loss_function = nn.BCEWithLogitsLoss()

loader = DataLoader(hetero_dataset, batch_size=32, shuffle=True)


def train(
    model: nn.Module,
    loader,
    loss_function: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for data in tqdm.tqdm(loader, desc="Training"):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x_dict, data.edge_index_dict, data.valid_pairs)
        loss = loss_function(output.squeeze(1), data.y)
        loss.backward()
        optimizer.step()
        
        # Calculate accuracy
        pred = torch.sigmoid(output.squeeze(1)) > 0.5
        correct += (pred == (data.y > 0.5)).sum().item()
        total += len(data.y)
        
        total_loss += loss.item()
    
    accuracy = correct / total if total > 0 else 0
    return total_loss / len(loader), accuracy


device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1000):
    loss, accuracy = train(model, loader, loss_function, optimizer, device)
    print(f"Epoch {epoch} | Loss: {loss:.4f} | Accuracy: {accuracy:.4f}")
    
    # Early stopping could be added here
    if accuracy > 0.99:
        print(f"Reached high accuracy. Stopping at epoch {epoch}")
        break

Training: 100%|██████████| 97/97 [00:04<00:00, 21.52it/s]


Epoch 0 | Loss: 0.1827 | Accuracy: 0.9192


Training: 100%|██████████| 97/97 [00:04<00:00, 21.84it/s]


Epoch 1 | Loss: 0.1606 | Accuracy: 0.9296


Training: 100%|██████████| 97/97 [00:04<00:00, 23.69it/s]


Epoch 2 | Loss: 0.1454 | Accuracy: 0.9375


Training: 100%|██████████| 97/97 [00:04<00:00, 23.37it/s]


Epoch 3 | Loss: 0.1331 | Accuracy: 0.9462


Training: 100%|██████████| 97/97 [00:04<00:00, 22.56it/s]


Epoch 4 | Loss: 0.1209 | Accuracy: 0.9477


Training: 100%|██████████| 97/97 [00:05<00:00, 18.02it/s]


Epoch 5 | Loss: 0.1166 | Accuracy: 0.9517


Training: 100%|██████████| 97/97 [00:03<00:00, 24.39it/s]


Epoch 6 | Loss: 0.1076 | Accuracy: 0.9570


Training: 100%|██████████| 97/97 [00:04<00:00, 21.91it/s]


Epoch 7 | Loss: 0.0905 | Accuracy: 0.9624


Training: 100%|██████████| 97/97 [00:11<00:00,  8.19it/s]


Epoch 8 | Loss: 0.0874 | Accuracy: 0.9642


Training: 100%|██████████| 97/97 [00:11<00:00,  8.28it/s]


Epoch 9 | Loss: 0.0906 | Accuracy: 0.9624


Training: 100%|██████████| 97/97 [00:04<00:00, 20.97it/s]


Epoch 10 | Loss: 0.0830 | Accuracy: 0.9682


Training: 100%|██████████| 97/97 [00:03<00:00, 25.89it/s]


Epoch 11 | Loss: 0.0753 | Accuracy: 0.9707


Training: 100%|██████████| 97/97 [00:03<00:00, 24.93it/s]


Epoch 12 | Loss: 0.0781 | Accuracy: 0.9699


Training: 100%|██████████| 97/97 [00:04<00:00, 23.12it/s]


Epoch 13 | Loss: 0.0657 | Accuracy: 0.9733


Training: 100%|██████████| 97/97 [00:04<00:00, 22.49it/s]


Epoch 14 | Loss: 0.0662 | Accuracy: 0.9764


Training: 100%|██████████| 97/97 [00:03<00:00, 25.84it/s]


Epoch 15 | Loss: 0.0547 | Accuracy: 0.9782


Training: 100%|██████████| 97/97 [00:03<00:00, 25.09it/s]


Epoch 16 | Loss: 0.0551 | Accuracy: 0.9787


Training: 100%|██████████| 97/97 [00:03<00:00, 24.49it/s]


Epoch 17 | Loss: 0.0527 | Accuracy: 0.9797


Training: 100%|██████████| 97/97 [00:03<00:00, 25.20it/s]


Epoch 18 | Loss: 0.0555 | Accuracy: 0.9782


Training: 100%|██████████| 97/97 [00:03<00:00, 25.02it/s]


Epoch 19 | Loss: 0.0483 | Accuracy: 0.9805


Training: 100%|██████████| 97/97 [00:03<00:00, 24.88it/s]


Epoch 20 | Loss: 0.0483 | Accuracy: 0.9820


Training: 100%|██████████| 97/97 [00:03<00:00, 24.57it/s]


Epoch 21 | Loss: 0.0442 | Accuracy: 0.9823


Training: 100%|██████████| 97/97 [00:03<00:00, 25.56it/s]


Epoch 22 | Loss: 0.0417 | Accuracy: 0.9842


Training: 100%|██████████| 97/97 [00:03<00:00, 25.95it/s]


Epoch 23 | Loss: 0.0419 | Accuracy: 0.9847


Training: 100%|██████████| 97/97 [00:04<00:00, 23.55it/s]


Epoch 24 | Loss: 0.0387 | Accuracy: 0.9855


Training: 100%|██████████| 97/97 [00:04<00:00, 23.89it/s]


Epoch 25 | Loss: 0.0341 | Accuracy: 0.9873


Training: 100%|██████████| 97/97 [00:04<00:00, 23.74it/s]


Epoch 26 | Loss: 0.0443 | Accuracy: 0.9828


Training: 100%|██████████| 97/97 [00:03<00:00, 24.99it/s]


Epoch 27 | Loss: 0.0346 | Accuracy: 0.9869


Training: 100%|██████████| 97/97 [00:03<00:00, 24.49it/s]


Epoch 28 | Loss: 0.0330 | Accuracy: 0.9879


Training: 100%|██████████| 97/97 [00:04<00:00, 23.47it/s]


Epoch 29 | Loss: 0.0376 | Accuracy: 0.9862


Training: 100%|██████████| 97/97 [00:03<00:00, 25.46it/s]


Epoch 30 | Loss: 0.0335 | Accuracy: 0.9879


Training: 100%|██████████| 97/97 [00:03<00:00, 24.88it/s]


Epoch 31 | Loss: 0.0354 | Accuracy: 0.9864


Training: 100%|██████████| 97/97 [00:03<00:00, 25.17it/s]


Epoch 32 | Loss: 0.0303 | Accuracy: 0.9890


Training: 100%|██████████| 97/97 [00:03<00:00, 24.70it/s]


Epoch 33 | Loss: 0.0285 | Accuracy: 0.9893


Training: 100%|██████████| 97/97 [00:03<00:00, 25.32it/s]


Epoch 34 | Loss: 0.0309 | Accuracy: 0.9881


Training: 100%|██████████| 97/97 [00:03<00:00, 25.32it/s]


Epoch 35 | Loss: 0.0384 | Accuracy: 0.9855


Training: 100%|██████████| 97/97 [00:03<00:00, 24.98it/s]

Epoch 36 | Loss: 0.0247 | Accuracy: 0.9915
Reached high accuracy. Stopping at epoch 36





In [40]:
len(hetero_dataset)

3093

In [32]:
action_probs

{(2, 1, 1): 1, (3, 0, 2): 1}

In [33]:
data["operation"].x.shape

torch.Size([36, 8])

In [34]:
data["valid_pairs"]

tensor([[ 0,  1,  0],
        [ 6,  0,  1],
        [10,  4,  2],
        [17,  4,  3],
        [20,  2,  4],
        [26,  4,  5],
        [33,  5,  2],
        [34,  9,  3]])

In [23]:
obs

{'edge_index_dict': {('operation',
   'to',
   'operation'): array([], shape=(2, 0), dtype=int32),
  ('operation',
   'to',
   'machine'): array([[0, 1, 2],
         [2, 1, 0]]),
  ('machine',
   'to',
   'operation'): array([[0, 1, 2],
         [2, 1, 0]]),
  ('machine',
   'to',
   'machine'): array([[0, 0, 1, 1, 2, 2],
         [1, 2, 0, 2, 0, 1]])},
 'node_features_dict': {'operation': array([[  0.85915494,   0.        ,   0.        ,   0.        ,
            1.        ,   0.        ,   0.        ,   1.        ],
         [  0.7605634 ,   0.        ,   0.        ,   0.        ,
            0.8852459 ,   0.        ,   0.        ,   1.        ],
         [  1.        , -63.        ,   1.        ,   0.        ,
            0.        ,   0.        ,   1.        ,   0.        ]],
        dtype=float32),
  'machine': array([[   0.       , -302.       ,    1.       ,    0.       ],
         [   0.8852459,    0.       ,    0.       ,    1.       ],
         [   1.       ,    0.       ,   