In [50]:
import torch
from torch.nn import Embedding
import torch_geometric
from torch_geometric.utils import from_networkx
from typing import List, Optional, Dict, Any, Callable
import networkx as nx
from gnn_scheduler.jssp import JobShopInstance, load_from_benchmark, load_all_from_benchmark


class JobShopGraphData:
    def __init__(self, job_shop_instances: List[JobShopInstance], use_embeddings: bool = True):
        self.job_shop_instances = job_shop_instances
        self.use_embeddings = use_embeddings
        self.embedding_layers = self._init_embedding_layers() if use_embeddings else None

    def _init_embedding_layers(self) -> Dict[str, Embedding]:
        # Initialize embedding layers for different features
        return {
            'machine_id': Embedding(num_embeddings=11, embedding_dim=100),
            'job_id': Embedding(num_embeddings=21, embedding_dim=100),
            # Add more embeddings as needed
        }

    def construct_graph_data(self, extra_features: Optional[Dict[str, Any]] = None) -> List[torch_geometric.data.Data]:
        graph_data_list = []
        for instance in self.job_shop_instances:
            graph = instance.disjunctive_graph
            self._add_node_features(graph, extra_features)
            data = from_networkx(graph)
            data.y = torch.tensor([instance.optimum], dtype=torch.float) if instance.optimum is not None else None
            graph_data_list.append(data)

        return graph_data_list

    def _add_node_features(self, 
                           graph: nx.DiGraph, 
                           extra_features: list[Callable[[str, dict[str, Any], nx.DiGraph], list]]):
        for node, node_data in graph.nodes(data=True):
            features = []

            if self.use_embeddings:
                # Extract features from node_data and add embeddings
                machine_id = node_data.get('machine_id')
                job_id = node_data.get('job_id')
                features.extend([
                    self.embedding_layers['machine_id'](torch.tensor([machine_id])),
                    self.embedding_layers['job_id'](torch.tensor([job_id])),
                ])

            if extra_features:
                for feature_func in extra_features:
                    feature_value = feature_func(node, node_data, graph)
                    features.append(torch.tensor(feature_value, dtype=torch.float))

            graph.nodes[node]['x'] = torch.cat(features, dim=-1) if features else None

# Example usage:
# job_shop_instances = [instance1, instance2, ...]
# job_shop_graph_data = JobShopGraphData(job_shop_instances, use_embeddings=False)
# extra_features = {'degree': lambda op, pos: ...}
# graph_data_list = job_shop_graph_data.construct_graph_data(extra_features)


In [51]:
from functools import partial


ft06_instance = load_from_benchmark('ft06')
ft10_instance = load_from_benchmark('ft10')
job_shop_instances = [ft06_instance, ft10_instance]
job_shop_graph_data = JobShopGraphData(job_shop_instances, use_embeddings=True)
# We add as extra feature one hot encoding of machine_id
def one_hot_encoding(node, node_data, graph, feature_name, n_values=100):
    zeros = [0] * n_values
    machine_id = node_data.get(feature_name)
    if machine_id >= 0:
        zeros[machine_id] = 1
    return zeros

def get_degree(node, node_data, graph: nx.DiGraph):
    in_degree = graph.in_degree(node) / (graph.number_of_nodes() - 1)
    out_degree = graph.out_degree(node) / (graph.number_of_nodes() - 1)
    return [in_degree, out_degree]

extra_features = [
    # partial(one_hot_encoding, feature_name='machine_id'), 
                  get_degree]

graph_data_list = job_shop_graph_data.construct_graph_data(extra_features)
graph_data_list

IndexError: index out of range in self

In [None]:
data_ft06 = graph_data_list[0]
data_ft06.y

tensor([55.])

In [None]:
data_ft06.x

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.1622],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.1622, 0.0000],
        [0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.1622, 0.1622],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.1622, 0.1622],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.1622, 0.1622],
        [0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.1622, 0.1622],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.1622, 0.1622],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.1622, 0.1622],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,