In [None]:
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import StandardScaler
import importlib
from collections import defaultdict
import torch
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
%matplotlib inline
import GnnScheduleDataset as GnnScheduleDataset_Module
import MultiCriteriaGNNModel as MultiCriteriaGNNModel_Module

importlib.reload(GnnScheduleDataset_Module) # in case of updates
importlib.reload(MultiCriteriaGNNModel_Module) # in case of updates

from GnnScheduleDataset import GnnScheduleDataset
from MultiCriteriaGNNModel import MultiCriteriaGNNModel
import subprocess
import sys

#manual installing
def install_package(package_name, use_index_url=True):
    print(f"Installing {package_name}...")
    #run: python.exe -m pip install [package_name]

    subprocess.run(
        [sys.executable, "-m", "pip", "uninstall", "-y", package_name],
        check=True,
        text=True
    )
    if use_index_url:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", package_name, "--index-url", "https://download.pytorch.org/whl/cu126"],
            check=True,
            text=True
        )
    else:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", package_name],
            check=True,
            text=True
        )
    print(f"Successfully installed {package_name}!")

# Try to import, if it fails, install it
# try:
#     import torch
#     print("Torch is already available.")
# except: 
#     #install_package('torch')
#     # After installing, you must use importlib to refresh or restart the script
#     import torch
#     print("Torch imported successfully after installation.")

#torch-scatter torch-sparse torch-cluster torch-spline-conv pyg-lib
#torch-scatter torch-sparse torch-cluster torch-spline-conv
# install_package('torch-scatter')
# install_package('torch-sparse')
# install_package('torch-cluster')
# install_package('torch-spline-conv')

# try:
#     import torch_geometric
#     print("Torch is already available.")
# except: 
#     install_package('torch_geometric', False)
#     # After installing, you must use importlib to refresh or restart the script
#     import torch_geometric
#     print("Torch imported successfully after installation.")


#file paths
MISSION_BATCH_DIR = "./datasets/mini-batch/Batch10M_distanced.csv"
UDC_TYPES_DIR = "./datasets/WM_UDC_TYPE.csv"
MISSION_BATCH_TRAVEL_DIR = "./datasets/mini-batch/Batch10M_travel_distanced.csv"
FORK_LIFTS_DIR = "./datasets/ForkLifts10W.csv"
#MISSION_TYPES_DIR = "./datasets/MissionTypes.csv"
SCHEDULE_DIR = "./schedules/mini-batch/"
NUM_EPOCHS = 10
BATCH_SIZE = 8 #nice to be equal to the number of mini-batch instances
LEARNING_RATE = 0.001


### Loss Definition 

In [8]:
def weighted_loss(predictions, ground_truth, u_batch):
    """
    computes weighted BCE loss for activation, assignment, and sequence heads.
    total Loss = Beta * act_loss + alpha * (assign_loss + seq_loss)
    """
    pred_act = predictions['activation']
    pred_assign = predictions['assignment']
    pred_seq = predictions['sequence']
    
    #ground truth (should be in [N, 1] shape)
    true_act = ground_truth['operator'].y.view(-1, 1)
    true_assign = ground_truth['operator', 'assign', 'order'].y.view(-1, 1)
    true_seq = ground_truth['order', 'to', 'order'].y.view(-1, 1)
    
    #BCE losses
    loss_act = F.binary_cross_entropy(pred_act, true_act)
    loss_assign = F.binary_cross_entropy(pred_assign, true_assign)
    loss_seq = F.binary_cross_entropy(pred_seq, true_seq)
    
    #extract alpha/beta (Mean over batch)
    alpha = u_batch[:, 0].mean()
    beta = u_batch[:, 1].mean()
    
    #weighted Sum
    #Note that alpha/beta need to be scaled down if they are large (e.g. 100) to prevent explosion
    #or rely on the optimizer (Adam) to handle scaling.
    total_loss = (beta * loss_act) + (alpha * (loss_assign + loss_seq))
    
    return total_loss, loss_act.item(), loss_assign.item(), loss_seq.item()

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

#init dataset
dataset = GnnScheduleDataset(
    schedule_dir=SCHEDULE_DIR,
    mission_base_path=MISSION_BATCH_DIR,
    edge_base_path=MISSION_BATCH_TRAVEL_DIR,
    pallet_types_file_path=UDC_TYPES_DIR,
    fork_path=FORK_LIFTS_DIR
)

print(f"found {len(dataset)} valid schedule instances.")

#create DataLoader using the dataset
#batch_size can be > 1 to train on multiple graphs at once
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

#init model
if len(dataset) > 0:
    sample_data = dataset[0]
    model = MultiCriteriaGNNModel(
        metadata=sample_data.metadata(),
        hidden_dim=64,
        num_layers=3,
        heads=4
    ).to(device)

    #adam optimizer is a standard for GNNs
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        total_epoch_loss = 0.0
        
        for batch_idx, batch in tqdm(enumerate(loader), total=len(loader), desc=f"epoch {epoch}/{NUM_EPOCHS}"):
            batch = batch.to(device)
            optimizer.zero_grad()
            
            #construct batch_dict
            batch_dict_arg = {
                'operator': batch['operator'].batch,
                'order': batch['order'].batch
            }
            
            #forward pass
            preds = model(
                batch.x_dict, 
                batch.edge_index_dict, 
                batch.edge_attr_dict,
                batch.u,
                batch_dict=batch_dict_arg
            )
            
            #backward step and optimization
            loss, l_act, l_assign, l_seq = weighted_loss(preds, batch, batch.u)
            loss.backward()
            optimizer.step()
            
            total_epoch_loss += loss.item()
            
            #print mini-batch progress (every 2 batches)
            if batch_idx % 2 == 0:
                print(f"[Batch {batch_idx}] loss: {loss.item():.4f} (act_loss: {l_act:.3f}, assign_loss: {l_assign:.3f}, seq_loss: {l_seq:.3f})")
        
        avg_loss = total_epoch_loss / len(loader)
        print(f"Epoch {epoch} complete. average loss: {avg_loss:.4f}")
            
        # print(f"Batch {batch_idx}:")
        # print(f"Batch Size: {batch.num_graphs}")
        # print(f"Activation Probs: {out['activation']}")
        # print(f"Assignment Probs: {out['assignment']}")
        # print(f"Sequence Probs: {out['sequence']}")
        
        #if batch_idx >= 1: break #limit to 2 batches, just for demo

Using device: cuda
found 9 valid schedule instances.


epoch 1/10: 100%|██████████| 2/2 [00:00<00:00,  7.53it/s]


[Batch 0] loss: 75.3580 (act_loss: 0.727, assign_loss: 2.331, seq_loss: 0.316)
Epoch 1 complete. average loss: 69.9799


epoch 2/10: 100%|██████████| 2/2 [00:00<00:00,  8.29it/s]


[Batch 0] loss: 74.0735 (act_loss: 0.724, assign_loss: 1.389, seq_loss: 0.253)
Epoch 2 complete. average loss: 82.4990


epoch 3/10: 100%|██████████| 2/2 [00:00<00:00,  8.18it/s]


[Batch 0] loss: 70.0956 (act_loss: 0.691, assign_loss: 0.749, seq_loss: 0.266)
Epoch 3 complete. average loss: 73.4632


epoch 4/10: 100%|██████████| 2/2 [00:00<00:00,  8.45it/s]


[Batch 0] loss: 68.3060 (act_loss: 0.676, assign_loss: 0.442, seq_loss: 0.262)
Epoch 4 complete. average loss: 67.8510


epoch 5/10: 100%|██████████| 2/2 [00:00<00:00,  8.47it/s]


[Batch 0] loss: 68.6943 (act_loss: 0.682, assign_loss: 0.241, seq_loss: 0.271)
Epoch 5 complete. average loss: 69.2531


epoch 6/10: 100%|██████████| 2/2 [00:00<00:00,  7.05it/s]


[Batch 0] loss: 66.5053 (act_loss: 0.661, assign_loss: 0.181, seq_loss: 0.258)
Epoch 6 complete. average loss: 68.2081


epoch 7/10: 100%|██████████| 2/2 [00:00<00:00,  6.21it/s]


[Batch 0] loss: 65.5304 (act_loss: 0.651, assign_loss: 0.171, seq_loss: 0.245)
Epoch 7 complete. average loss: 63.9690


epoch 8/10: 100%|██████████| 2/2 [00:00<00:00,  8.45it/s]


[Batch 0] loss: 66.4515 (act_loss: 0.660, assign_loss: 0.172, seq_loss: 0.246)
Epoch 8 complete. average loss: 58.7624


epoch 9/10: 100%|██████████| 2/2 [00:00<00:00,  7.92it/s]


[Batch 0] loss: 66.0396 (act_loss: 0.656, assign_loss: 0.174, seq_loss: 0.253)
Epoch 9 complete. average loss: 63.8811


epoch 10/10: 100%|██████████| 2/2 [00:00<00:00,  8.58it/s]

[Batch 0] loss: 68.1502 (act_loss: 0.677, assign_loss: 0.183, seq_loss: 0.244)
Epoch 10 complete. average loss: 59.3330



