<a href="https://colab.research.google.com/github/PATELOM925/Data-Mining/blob/main/Project%20%20-%201st_Copy_(OM)_of_Idea_GNN_for_Structured_Problem_Representation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup
Split into two cells, as Colab needs to restart after installing requirements from Starjob repo.
- Downloads the Starjob codebase from GitHub: https://github.com/starjob42/Starjob.
- Installs pip requirements
    - First, requirements from the Starjob repo.
    - Then, extra requirements so that the workflow works on Google Colab.
- Downloads the Starjob dataset from HuggingFace, move to the expected folder for the finetuning script (`train_llama_3.py`).

In [None]:
!git clone https://github.com/starjob42/Starjob.git

# seems to conflict with requirements of unsloth on Colab
#!pip install -r Starjob/requirements.txt

Cloning into 'Starjob'...
remote: Enumerating objects: 40, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 40 (delta 14), reused 8 (delta 1), pack-reused 0 (from 0)[K
Receiving objects: 100% (40/40), 15.08 KiB | 5.03 MiB/s, done.
Resolving deltas: 100% (14/14), done.
Downloading starjob130k.json (1.9 GB)
Error downloading object: starjob130k.json (c9ec14e): Smudge error: Error downloading starjob130k.json (c9ec14ef703a5aa4cc360fe82d8f806fbe03cbee5846a8ff34c503baf6290dd1): batch response: This repository exceeded its LFS budget. The account responsible for the budget should increase it to restore access.

Errors logged to /content/Starjob/.git/lfs/logs/20251217T132207.107276431.log
Use `git lfs logs last` to view the log.
error: external filter 'git-lfs filter-process' failed
fatal: starjob130k.json: smudge filter lfs failed
You can inspect what was checked out with 'git status'
and retry with 'git restore --so

In [None]:
%%capture
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes

# Manually including packages from `requirements.txt`, since the versions conflict with unsloth
!pip install wandb datasets torch_geometric

!mkdir data/
!hf download mideavalwisard/Starjob --repo-type dataset --local-dir data/

# Fine-tuning
Manually pasting the content of `train_llama_3.py` below, splitting the cells at key points. Any other changes are highlighted.

## Initial Configuration

- Modifying to use static variables instead of `argparse` (which doesn't work with Colab).
- Also changed `dtype` to `float16`, since the Tesla T4 does not support `bfloat16`.

In [None]:
# =========================
# Imports
# =========================

import argparse
import torch
import wandb
from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import load_dataset

import re
import numpy as np
from typing import Tuple, List, Dict

from trl import SFTTrainer
from transformers import TrainingArguments

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


In [None]:
# =========================
# Arguments
# =========================
# NOTE: Original script uses argparse, which doesn't seem to work with Colab
# Converting to plain variables instead

# Model and data parameters
max_seq_length = 50000
dtype = 'float16'           # NOTE: Changed from "bfloat16", which is not supported by T4 GPU
load_in_4bit = True

# LoRA hyperparameters
lora_r = 64
lora_alpha = 64
lora_dropout = 0.0
bias = 'none'

# Additional configurations
use_gradient_checkpointing = 'unsloth'
random_state = 42
use_rslora = False
loftq_config = None

# Training hyperparameters
per_device_train_batch_size = 4
gradient_accumulation_steps = 4
warmup_steps = 5
num_train_epochs = 2
learning_rate = 2e-4
logging_steps = 1
optim = 'adamw_8bit'
weight_decay = 0.01
lr_scheduler_type = 'linear'
seed = 42
save_total_limit = 50
save_step = 200
per_device_eval_batch_size = 2
train_lm_head = False
train_embed_tokens = False

# Output directory
output_dir = None

In [None]:
# =========================
# Generate Output Directory Name
# =========================

# Create an output directory name based on hyperparameters
if output_dir is None:
    dir_out = f"output_alpha{lora_alpha}_r{lora_r}_train_lm_head{train_lm_head}_train_embed_tok_{train_embed_tokens}_seq{max_seq_length}_b{per_device_train_batch_size}_ep{num_train_epochs}"
else:
    dir_out = output_dir

# =========================
# Initialize WandB
# =========================

# Initialize Weights & Biases for experiment tracking
wandb.init(
    project="llama3-jssp-clean",  # Change the project name if needed
    name=dir_out,
)

## Initialising model

In [None]:
# =========================
# Load Model and Tokenizer
# =========================

# Set dtype
dtype = torch.bfloat16 if dtype == 'bfloat16' else torch.float16

# Load the pre-trained model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

target_modules =[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ]
if train_lm_head:
    target_modules.append('lm_head')
if train_embed_tokens:
    target_modules.append('embed_tokens')

# Configure the model with PEFT (Parameter-Efficient Fine-Tuning)
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_r,
    target_modules=target_modules,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias=bias,
    use_gradient_checkpointing=use_gradient_checkpointing,
    random_state=random_state,
    use_rslora=use_rslora,
    loftq_config=loftq_config,
)

==((====))==  Unsloth 2025.12.5: Fast Llama patching. Transformers: 4.57.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

Unsloth 2025.12.5 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


## Preparing dataset
- Moved the `alpaca_prompt` template string and `formatting_prompts_func` function into this area, for cleanliness.
- Added extra comments for clarity.

In [None]:
# TEMP (Only while GPU unavailable)
from datasets import load_dataset
from pprint import pp

import re
import numpy as np
from typing import Tuple, List, Dict

import networkx as nx
from collections import defaultdict

import torch
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx

In [None]:
# load the data from the ./data folder, obtain train/test splits
dataset = load_dataset('./data/', split="train")
print(dataset[:5])
print(dataset.column_names)

{'num_jobs': [40, 40, 30, 50, 30], 'num_machines': [40, 20, 20, 20, 15], 'prompt_jobs_first': ['J0:\nM10:122 M29:26 M15:178 M33:126 M17:174 M22:20 M3:98 M38:99 M6:111 M24:104 M34:89 M1:71 M25:184 M27:19 M11:182 M0:19 M8:9 M32:151 M36:12 M12:65 M39:175 M14:139 M23:139 M30:92 M21:66 M28:62 M19:100 M13:26 M5:137 M16:158 M9:52 M37:14 M26:113 M31:92 M2:126 M20:38 M35:81 M18:80 M7:113 M4:63 \nJ1:\nM27:113 M2:125 M22:44 M36:183 M30:131 M13:154 M26:183 M23:126 M12:186 M1:157 M17:141 M6:191 M4:192 M38:68 M29:79 M35:92 M18:17 M3:148 M32:79 M21:54 M16:139 M25:151 M33:20 M10:136 M5:117 M14:112 M9:100 M8:131 M24:53 M11:78 M31:103 M39:154 M19:160 M15:181 M37:67 M0:184 M20:160 M34:73 M7:120 M28:125 \nJ2:\nM6:139 M24:158 M34:178 M30:137 M13:72 M12:16 M25:38 M9:186 M35:112 M37:143 M29:93 M8:170 M3:86 M38:79 M31:155 M36:117 M0:180 M16:93 M32:51 M27:23 M18:162 M22:18 M21:144 M4:98 M23:54 M20:189 M11:142 M28:50 M15:41 M2:117 M10:180 M39:40 M5:115 M7:31 M33:75 M26:182 M17:113 M19:120 M1:169 M14:21 \nJ3:\nM

In [None]:
# Preprocessing the Dataset
intruct = dataset.filter(lambda x: x['instruction'] is None)
output = dataset.filter(lambda x: x['output'] is None)
preprocessed_ds = dataset.filter(lambda x: x['instruction'] is not None and x['output'] is not None)
# df = preprocessed_ds.to_pandas()
# print(len(df))
# df[:5]

Filter:   0%|          | 0/129720 [00:00<?, ? examples/s]

Filter:   0%|          | 0/129720 [00:00<?, ? examples/s]

Filter:   0%|          | 0/129720 [00:00<?, ? examples/s]

In [None]:
# =========================
# Load and Prepare Dataset
# =========================

# load the data from the ./data folder, obtain train/test splits
# dataset = load_dataset('./data/', split="train")
seed = 42
split_dataset = preprocessed_ds.train_test_split(test_size=0.02, seed=seed)
print(len(split_dataset))


# Define the Alpaca-style prompt template
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""
# EOS_TOKEN = tokenizer.eos_token
EOS_TOKEN = "END"       # Only while GPU unavailable

# Define function to create prompts with this template, using columns from the dataset
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]
    texts = []
    for instruction, input_text, output in zip(instructions, inputs, outputs):
        text = alpaca_prompt.format(instruction, input_text, output) + EOS_TOKEN
        texts.append(text)
    return {"text": texts}


# Map the train/test splits through the above function, to create prompts for the LLM
train_dataset = split_dataset['train'].map(formatting_prompts_func, batched=True)
eval_dataset = split_dataset['test'].map(formatting_prompts_func, batched=True)

2


Map:   0%|          | 0/9525 [00:00<?, ? examples/s]

Map:   0%|          | 0/195 [00:00<?, ? examples/s]

In [None]:
print(split_dataset)
print("\n\n",split_dataset["train"][0])
print("\n\n",split_dataset["test"][0])

DatasetDict({
    train: Dataset({
        features: ['num_jobs', 'num_machines', 'prompt_jobs_first', 'instruction', 'prompt_machines_first', 'input', 'output', 'path', 'matrix'],
        num_rows: 9525
    })
    test: Dataset({
        features: ['num_jobs', 'num_machines', 'prompt_jobs_first', 'instruction', 'prompt_machines_first', 'input', 'output', 'path', 'matrix'],
        num_rows: 195
    })
})


 {'num_jobs': 50, 'num_machines': 20, 'prompt_jobs_first': 'J0:\nM18:138 M9:117 M4:112 M8:101 M3:64 M19:60 M12:99 M13:52 M16:57 M15:114 M10:65 M14:80 M1:138 M11:92 M5:63 M7:80 M17:98 M2:79 M6:108 M0:79 \nJ1:\nM6:144 M12:59 M11:95 M7:28 M8:39 M10:117 M14:53 M3:32 M17:87 M9:15 M0:124 M1:143 M15:96 M4:69 M18:69 M13:34 M19:53 M16:112 M5:69 M2:49 \nJ2:\nM8:125 M13:123 M2:50 M7:138 M3:113 M19:117 M6:75 M18:120 M5:142 M9:41 M4:126 M1:17 M17:121 M15:124 M14:66 M11:128 M16:47 M10:107 M12:81 M0:41 \nJ3:\nM10:128 M12:144 M9:125 M8:92 M17:67 M3:9 M11:21 M7:30 M14:18 M16:142 M0:60 M13:121 M15:13

### Issue with dataset
NOTE: Seems to be a significant issue with this script. We can see below that the "text" column of the old-style instances (i.e., from the original "LLMs can Schedule" paper) has _no_ instruction or input in its "text" column, only a response. This means that the LLM is being trained on useless prompts.


#### Screenshot


![link text](https://drive.google.com/uc?id=1EteiynGyp7X5KDPh0kqnmDJWbBg4aYaF)

#### Code

In [None]:
# Print out one of each for comparison
# New-style: Has a "matrix" column
new_sample = None
i = 0
while True:
    i += 1
    if train_dataset[i]['matrix'] is not None:
        new_sample = train_dataset[i]
        # print(new_sample)         # All columns
        print(new_sample["text"]) # Only the "text" column (i.e., the one used by the LLM)
        break

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Optimize schedule for 50 Jobs (denoted as J) across 15 Machines (denoted as M) to minimize makespan. The makespan is the completion time of the last operation in the schedule. Each M can process only one J at a time, and once started, J cannot be interrupted.



### Input:
J0:
M3:69 M6:19 M9:5 M8:10 M12:66 M4:25 M2:42 M1:44 M7:25 M13:41 M10:41 M14:5 M11:28 M0:52 M5:38 
J1:
M13:18 M5:19 M12:17 M7:8 M10:43 M9:63 M4:42 M0:16 M14:43 M2:22 M8:69 M1:26 M6:11 M3:75 M11:21 
J2:
M9:42 M7:69 M13:16 M2:55 M0:15 M12:36 M4:64 M1:57 M8:7 M14:62 M11:21 M5:34 M3:10 M6:59 M10:41 
J3:
M6:20 M10:33 M9:40 M0:24 M14:12 M7:62 M13:19 M3:24 M12:19 M2:29 M1:50 M8:62 M5:53 M4:24 M11:61 
J4:
M4:36 M5:47 M9:73 M14:71 M11:46 M0:54 M13:17 M6:58 M10:74 M3:68 M1:66 M2:6 M8:59 M7:72 M12:13 
J5:
M7:29 M4:30 M11:34 M6:77 M3:30 M2:43 M10:6 M14:14 M

In [None]:
# Old-style: No "matrix" column
old_sample = None
i = 0
while True:
    i += 1
    if train_dataset[i]['matrix'] is None:
        old_sample = train_dataset[i]
        # print(old_sample)         # All columns
        print(old_sample["text"]) # Only the "text" column (i.e., the one used by the LLM)
        break

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
None

### Input:
None

### Response:
Solution:

 Job 0 Operation 0 on Machine 0 : 0 + 3 -> 3 
 Job 5 Operation 0 on Machine 1 : 0 + 2 -> 2 
 Job 7 Operation 0 on Machine 2 : 0 + 16 -> 16 
 Job 3 Operation 0 on Machine 3 : 0 + 2 -> 2 
 Job 6 Operation 0 on Machine 1 : 2 + 2 -> 4 
 Job 1 Operation 0 on Machine 3 : 2 + 4 -> 6 
 Job 5 Operation 1 on Machine 4 : 2 + 29 -> 31 
 Job 3 Operation 1 on Machine 0 : 3 + 21 -> 24 
 Job 2 Operation 0 on Machine 1 : 4 + 6 -> 10 
 Job 8 Operation 0 on Machine 3 : 6 + 28 -> 34 
 Job 1 Operation 1 on Machine 1 : 10 + 18 -> 28 
 Job 0 Operation 1 on Machine 2 : 16 + 5 -> 21 
 Job 6 Operation 1 on Machine 2 : 21 + 19 -> 40 
 Job 7 Operation 1 on Machine 0 : 24 + 9 -> 33 
 Job 3 Operation 2 on Machine 1 : 28 + 3 -> 31 
 Job 0 Operation 2 on Machine 4 : 31 + 5 -> 36 
 Job 8 Operation 

## GNN Encoding of Instances

In [None]:
# =========================
# Convert Dataset to GNN Encoding
# =========================
# NOTE: Assuming that we only want to work with "new"-style instances (pending decision)

# Function to parse JSSP instances (from Starjob paper) into structured format.
def parse_jssp_instance(instance: Dict) -> Dict:
    """Returns:
    {
        num_jobs: number of jobs in problem
        num_machines: number of machines in problem
        num_operations: number of operations in problem
        job_operations: dictionary of job -> operations, where operations is a list of machine-duration pairs
    }
    """
    # Number of jobs and machines are already stored in the instance, just need to parse operations
    job_operations = {}
    lines = instance['input'].strip().split('\n')
    for i in range(len(lines)):
        # Assumes 'jobs' specified every second line, with 'machine-operations' every other line
        if i % 2 == 0:
            job_id = int(re.search(r'J(\d+):', lines[i]).group(1))
            operations = re.findall(r'M(\d+):(\d+)', lines[i+1])
            job_operations[job_id] = [(int(m), int(d)) for m, d in operations]
        # Skip every other line, as we already handle them above
        else:
            continue

    # Also extract number of operations
    # NOTE: We assume that each job has the same number of operations, so we just multiply
    # the number of operations in the first job by the number of jobs
    num_operations = instance['num_jobs'] * len(next(iter(job_operations.values())))

    # Return everything as structured dictionary
    return {
        'num_jobs': instance['num_jobs'],
        'num_machines': instance['num_machines'],
        'num_operations': num_operations,
        'job_operations': job_operations,
    }


# If we need to use old-style, we can use a different parsing function
def parse_jssp_instance_with_old(nl_description: str) -> Dict:
    """
    Parse natural language JSSP description (from Starjob paper) into dictionary format.

    Returns:
    {
        num_jobs: number of jobs in problem
        num_machines: number of machines in problem
        num_operations: number of operations in problem
        job_operations: dictionary of job -> operations, where operations is a list of machine-duration pairs
    }
    """
    lines = nl_description.strip().split('\n')

    # Extract number of jobs and machines
    problem_line = lines[0]
    num_jobs = int(re.search(r'(\d+)\s+Jobs', problem_line).group(1))
    num_machines = int(re.search(r'(\d+)\s+Machines', problem_line).group(1))

    # Parse operations
    job_operations = {}
    for line in lines:
        if line.startswith('J'):
            job_id = int(re.search(r'J(\d+):', line).group(1))
            operations = re.findall(r'M(\d+):(\d+)', line)
            job_operations[job_id] = [(int(m), int(d)) for m, d in operations]

    # Extract number of operations
    # NOTE: We assume that each job has the same number of operations, so we just multiply
    # the number of operations in the first job by the number of jobs
    num_operations = num_jobs * len(next(iter(job_operations.values())))

    return {
        'num_jobs': num_jobs,
        'num_machines': num_machines,
        'num_operations': num_operations,
        'job_operations': job_operations,
    }


In [None]:
# Function to build disjunctive graph representation of the parsed (structured) JSSP dictionary
# Uses networkx.DiGraph to store the directed graph
def build_disjunctive_graph(parsed_problem: Dict) -> Tuple[nx.DiGraph, Dict]:
    """
    Return graph with structure:
    - Nodes: (job_id, operation_id) tuples + source and sink nodes (i.e., len = num_operations + 2)
    - Directed edges: Job precedence constraints
    - Undirected edges: Machine conflict constraints
    - Node attributes: machine_id, processing_time
    """

    # Initialise the graph
    G = nx.DiGraph()

    # Extract the operations for this JSSP instance
    job_ops = parsed_problem['job_operations']

    # Add source and sink nodes
    source = ('source', 0)
    sink = ('sink', 0)
    G.add_node(source, node_type='source', processing_time=0)
    G.add_node(sink, node_type='sink', processing_time=0)

    # Initialise reference dictionariesMap operation nodes and create node info dictionary
    node_to_operation = {}  # node -> (job_id, op_id, machine_id, proc_time)
    operation_to_node = {}  # (job_id, op_id) -> node
    operations_by_machine = defaultdict(list)   # machine_id -> [nodes]

    # Iterate over each job to create nodes of each operation
    for job_id, operations in job_ops.items():
        # Then iterate over each operation of that job
        for op_id, (machine_id, proc_time) in enumerate(operations):
            # Each node stores the current job and operation
            node = (job_id, op_id)

            # Add the node, with relevant info for this operation
            G.add_node(node,
                      node_type='operation',
                      job_id=job_id,
                      operation_id=op_id,
                      machine_id=machine_id,
                      processing_time=proc_time)

            # Add same info to our reference dictionaries
            node_to_operation[node] = (job_id, op_id, machine_id, proc_time)
            operation_to_node[(job_id, op_id)] = node
            operations_by_machine[machine_id].append(node)

    # pp(operation_to_node)

    # Add conjunctive edges (precedence constraints within jobs)
    for job_id, operations in job_ops.items():
        # Source â†’ first operation of each job
        # first_op = operation_to_node[(job_id, 0)]
        first_op = (job_id, 0)
        G.add_edge(source, first_op, edge_type='conjunctive', weight=0)

        # Operations in sequence within a job
        for op_idx in range(len(operations) - 1):
            # from_node = operation_to_node[(job_id, op_idx)]
            # to_node = operation_to_node[(job_id, op_idx + 1)]
            from_node = (job_id, op_idx)
            to_node = (job_id, op_idx + 1)
            proc_time = operations[op_idx][1]
            G.add_edge(from_node, to_node,
                      edge_type='conjunctive',
                     weight=proc_time)

        # Last operation of each job â†’ sink
        # last_op = operation_to_node[(job_id, len(operations) - 1)]
        last_op = (job_id, len(operations) - 1)
        last_proc_time = operations[-1][1]
        G.add_edge(last_op, sink,
                  edge_type='conjunctive',
                  weight=last_proc_time)

    # Add disjunctive edges (machine constraints)
    # These will be undirected initially; orientation happens during scheduling
    for machine_id, ops_on_machine in operations_by_machine.items():
        # Create complete bipartite graph (all pairs can conflict)
        for i in range(len(ops_on_machine)):
            for j in range(i + 1, len(ops_on_machine)):
                op1, op2 = ops_on_machine[i], ops_on_machine[j]
                # Add as undirected edge (will be oriented during solution)
                G.add_edge(op1, op2,
                          edge_type='disjunctive_undirected',
                          machine_id=machine_id)
                G.add_edge(op2, op1,
                          edge_type='disjunctive_undirected',
                          machine_id=machine_id)

    return G, node_to_operation


In [None]:
def graph_to_pygdata(G: nx.DiGraph,
                     node_to_operation: Dict) -> Data:
    """
    Convert disjunctive graph to PyTorch Geometric Data object.

    Node features:
    - machine_id (one-hot encoded)
    - processing_time (normalized)
    - job_id (one-hot encoded)
    - operation_index (normalized)
    - node_type (source/sink/operation)

    Edge features:
    - edge_type (conjunctive / disjunctive)
    - weight (processing time for conjunctive edges)
    """
    # Extract all operation nodes (exclude source/sink initially)
    operation_nodes = [n for n in G.nodes() if G.nodes[n]['node_type'] == 'operation']
    num_jobs = max([G.nodes[n]['job_id'] for n in operation_nodes]) + 1
    num_machines = max([G.nodes[n]['machine_id'] for n in operation_nodes]) + 1
    max_ops_per_job = max([G.nodes[n]['operation_id'] for n in operation_nodes]) + 1

    # Create node index mapping
    node_list = [('source', 0)] + operation_nodes + [('sink', 0)]
    node_idx = {node: i for i, node in enumerate(node_list)}

    # Create node feature matrix
    num_nodes = len(node_list)
    node_features = []

    for node in node_list:
        if G.nodes[node]['node_type'] == 'source':
            # Source node: all zeros
            features = [0] * (num_machines + num_jobs + 2 + 2)
            node_features.append(features)
        elif G.nodes[node]['node_type'] == 'sink':
            # Sink node: all zeros
            features = [0] * (num_machines + num_jobs + 2 + 2)
            node_features.append(features)
        else:
            # Operation node
            machine_id = G.nodes[node]['machine_id']
            job_id = G.nodes[node]['job_id']
            op_idx = G.nodes[node]['operation_id']
            proc_time = G.nodes[node]['processing_time']

            # One-hot machine encoding
            machine_one_hot = [1 if i == machine_id else 0 for i in range(num_machines)]

            # One-hot job encoding
            job_one_hot = [1 if i == job_id else 0 for i in range(num_jobs)]

            # Normalized operation index and processing time
            op_idx_norm = op_idx / max(1, max_ops_per_job - 1)
            proc_time_norm = proc_time / 500.0  # Normalize by max typical processing time

            features = machine_one_hot + job_one_hot + [op_idx_norm, proc_time_norm]
            node_features.append(features)

    x = torch.tensor(node_features, dtype=torch.float32)

    # Create edge list
    edge_index = [[], []]
    edge_attr = []

    for u, v, data in G.edges(data=True):
        u_idx = node_idx[u]
        v_idx = node_idx[v]
        edge_index[0].append(u_idx)
        edge_index[1].append(v_idx)

        # Edge type: conjunctive=1, disjunctive=0
        edge_type = 1 if data['edge_type'] == 'conjunctive' else 0
        weight = data.get('weight', 0) / 500.0  # Normalize weight

        edge_attr.append([edge_type, weight])

    edge_index = torch.tensor(edge_index, dtype=torch.long)
    edge_attr = torch.tensor(edge_attr, dtype=torch.float32) if edge_attr else torch.zeros((0, 2))

    # Create PyTorch Geometric Data object
    pyg_data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        num_nodes=num_nodes,
        num_jobs=num_jobs,
        num_machines=num_machines,
        node_mapping=node_idx
    )

    return pyg_data


In [None]:
# Use above utility functions to convert natural language grompt into PyG Data
# def convert_starjob_prompt_to_gnn(nl_prompt: str) -> Data:
def convert_starjob_instance_to_gnn(instance: Dict):
    """
    Full pipeline: Instance â†’ Disjunctive Graph â†’ PyG Data
    """
    # Step 1: Parse
    parsed = parse_jssp_instance(instance)
    # pp(parsed, compact=True, width=100)
    # return parsed

    # Step 2: Build graph
    G, node_to_op = build_disjunctive_graph(parsed)
    print(G)
    # return G, node_to_op

    # Step 3: Convert to PyG
    pyg_data = graph_to_pygdata(G, node_to_op)

    return pyg_data

# Example usage with a Starjob dataset entry
# nl_prompt = """Optimize schedule for 3 Jobs (denoted as J) across 3 Machines (denoted as M) to inimize makespan.
# The makespan is the completion time of the last operation in the schedule.
# Each M can process only one J at a time, and once started, J cannot be interrupted.
# J0: M0:105 M1:29 M2:213
# J1: M0:193 M1:18 M2:213
# J2: M0:78 M1:74 M2:221"""
# pyg_data = convert_starjob_prompt_to_gnn(nl_prompt)

# print(new_sample['input'])
# print(type(new_sample))
pyg_data = convert_starjob_instance_to_gnn(new_sample)

# print(f"Nodes: {pyg_data.num_nodes}")
# print(f"Edges: {pyg_data.edge_index.shape[1]}")
# print(f"Node features shape: {pyg_data.x.shape}")
# print(f"Edge features shape: {pyg_data.edge_attr.shape}")


DiGraph with 752 nodes and 37550 edges


ValueError: expected sequence of length 69 at dim 1 (got 67)

# Training

In [None]:
# =========================
# Initialize the Trainer
# =========================

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=20,
    packing=False,
    args=TrainingArguments(
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=warmup_steps,
        num_train_epochs=num_train_epochs,
        learning_rate=learning_rate,
        # fp16=True,
        bf16=is_bfloat16_supported(),
        logging_steps=logging_steps,
        optim=optim,
        weight_decay=weight_decay,
        lr_scheduler_type=lr_scheduler_type,
        seed=seed,
        output_dir=dir_out,
        report_to="wandb",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        save_total_limit=save_total_limit,
        save_steps=save_step,
        eval_strategy="steps",
        eval_steps=save_step,
        per_device_eval_batch_size=per_device_eval_batch_size,
    ),
)


Unsloth: Tokenizing ["text"] (num_proc=6):   0%|          | 0/127125 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=6):   0%|          | 0/2595 [00:00<?, ? examples/s]

In [None]:
# =========================
# Monitor GPU Memory Usage
# =========================

gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

# =========================
# Start Training
# =========================

trainer_stats = trainer.train()

The model is already on multiple devices. Skipping the move to device specified in `args`.


GPU = Tesla T4. Max memory = 14.741 GB.
6.883 GB of memory reserved.


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 127,125 | Num Epochs = 2 | Total steps = 15,892
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 167,772,160 of 8,198,033,408 (2.05% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
