In [7]:
# gnn_link_predictor_pipeline.py
import random
import itertools
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv, Linear
from torch_geometric.loader import LinkNeighborLoader

# -------------------------
# Utilities
# -------------------------
def edit_distance(a: str, b: str) -> int:
    """Simple Levenshtein distance (dynamic programming)."""
    la, lb = len(a), len(b)
    dp = [[0] * (lb + 1) for _ in range(la + 1)]
    for i in range(la + 1):
        dp[i][0] = i
    for j in range(lb + 1):
        dp[0][j] = j
    for i in range(1, la + 1):
        for j in range(1, lb + 1):
            cost = 0 if a[i-1] == b[j-1] else 1
            dp[i][j] = min(dp[i-1][j] + 1,     # deletion
                           dp[i][j-1] + 1,     # insertion
                           dp[i-1][j-1] + cost) # substitution
    return dp[la][lb]

def jaccard(a: set, b: set) -> float:
    if not a and not b: return 1.0
    return len(a & b) / len(a | b)

# -------------------------
# Synthetic data
# -------------------------
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

model_dict = {
    "MX100": {"motor", "frame", "sensorA", "battery"},
    "MX200": {"motor", "frame", "sensorB", "battery", "shield"},
    "AX1":   {"motor", "frame", "sensorA", "battery", "antenna"},
    "AX2":   {"motor", "frame", "sensorC", "battery"},
    "BX9":   {"motor", "frame", "battery"}
}

# Build synthetic historical orders for training (noisy)
history_orders = []
component_pool = set().union(*model_dict.values()).union({"cable","extraPart","sensorD"})
model_keys = list(model_dict.keys())
for i in range(500):
    model = random.choice(model_keys)
    base = set(model_dict[model])
    # randomly drop or add
    if random.random() < 0.12 and len(base) > 1:
        base = set(random.sample(list(base), len(base)-1))
    if random.random() < 0.12:
        base = base.union({random.choice(list(component_pool))})
    history_orders.append({"order_id": f"H-{i:03d}", "model": model, "components": base})

# Build a few example query orders to run inference on later
example_orders = [
    {"order_id": "O-001", "model_number": "MX100", "components": {"motor", "frame", "sensorA", "battery"}},  # exact -> ACCEPT
    {"order_id": "O-002", "model_number": "MX100", "components": {"motor", "frame", "sensorA"}},  # exact model exists but missing -> REJECT
    {"order_id": "O-003", "model_number": "MX1X", "components": {"motor", "frame", "sensorA", "antenna"}},  # non-exact -> bootstrap predict
    {"order_id": "O-004", "model_number": "AXX", "components": {"motor", "frame", "battery", "shield"}},  # non-exact -> bootstrap
    {"order_id": "O-005", "model_number": "BX9",  "components": {"motor", "frame", "battery", "extraPart"}}   # exact model exists but extra -> REJECT
]

# -------------------------
# Build HeteroData graph (order nodes and component nodes)
# -------------------------
# Create index mapping for components and orders
all_components = sorted(list(component_pool.union(*[o["components"] for o in history_orders])))
comp_to_idx = {c: i for i, c in enumerate(all_components)}

order_nodes = []  # will be the historical orders
for i, ho in enumerate(history_orders):
    order_nodes.append(ho["order_id"])

order_to_idx = {oid: i for i, oid in enumerate(order_nodes)}

data = HeteroData()

# Node features: simple learnable embeddings (indices -> embedding)
num_order_nodes = len(order_nodes)
num_comp_nodes = len(all_components)
order_feat_dim = 32
comp_feat_dim = 32

# initialize node features as embeddings (learnable)
data['order'].x = torch.randn((num_order_nodes, order_feat_dim), dtype=torch.float)
data['component'].x = torch.randn((num_comp_nodes, comp_feat_dim), dtype=torch.float)

# Build edges: order -> has -> component
src_orders = []
dst_components = []
for ho in history_orders:
    oid = ho['order_id']
    oidx = order_to_idx[oid]
    for c in ho['components']:
        if c not in comp_to_idx:
            continue
        cidx = comp_to_idx[c]
        src_orders.append(oidx)
        dst_components.append(cidx)

data['order', 'has', 'component'].edge_index = torch.tensor([src_orders, dst_components], dtype=torch.long)
# add reverse edges for message passing
data['component', 'rev_has', 'order'].edge_index = torch.tensor([dst_components, src_orders], dtype=torch.long)

# -------------------------
# Model: Hetero GNN + link predictor head
# -------------------------
class HeteroGNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        # two-layer HeteroConv with SAGEConv for each relation
        self.conv1 = HeteroConv({
            ('order','has','component'): SAGEConv((-1, -1), hidden_channels),
            ('component','rev_has','order'): SAGEConv((-1, -1), hidden_channels)
        }, aggr='mean')
        self.conv2 = HeteroConv({
            ('order','has','component'): SAGEConv((-1, -1), out_channels),
            ('component','rev_has','order'): SAGEConv((-1, -1), out_channels)
        }, aggr='mean')
        # small MLP to transform each node-type output into a common embedding dim
        self.order_lin = Linear(out_channels, out_channels)
        self.comp_lin  = Linear(out_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {k: F.relu(v) for k,v in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        # unify dims per node-type
        x_order = self.order_lin(x_dict['order'])
        x_comp  = self.comp_lin(x_dict['component'])
        return {'order': x_order, 'component': x_comp}

class LinkPredictor(torch.nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        # MLP on concatenation of order_emb and comp_emb => probability
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim*2, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, 1)
        )
    def forward(self, order_emb, comp_emb):
        x = torch.cat([order_emb, comp_emb], dim=-1)
        return torch.sigmoid(self.mlp(x)).squeeze(-1)

# -------------------------
# Negative sampling helpers and dataset preparation
# -------------------------
# Build adjacency set for fast negative sampling
adj = defaultdict(set)
for s, t in zip(src_orders, dst_components):
    adj[s].add(t)

# create all positive pairs (order_idx, comp_idx)
positive_pairs = [(s, t) for s, t in zip(src_orders, dst_components)]

def sample_negative_for_order(o_idx, k=1):
    # sample k component indices not connected to order o_idx
    negatives = []
    while len(negatives) < k:
        c = random.randrange(num_comp_nodes)
        if c not in adj[o_idx]:
            negatives.append(c)
    return negatives

# Build training triples: for each positive sample, add a negative sample
train_pairs = []
for (o, c) in positive_pairs:
    train_pairs.append((o, c, 1))
    negs = sample_negative_for_order(o, k=1)
    for nc in negs:
        train_pairs.append((o, nc, 0))

# Convert to tensors (train on all, small-scale)
pairs_tensor = torch.tensor([[o,c,label] for (o,c,label) in train_pairs], dtype=torch.long)

# -------------------------
# Training loop
# -------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
enc = HeteroGNNEncoder(hidden_channels=64, out_channels=64).to(device)
pred = LinkPredictor(emb_dim=64).to(device)

opt = torch.optim.Adam(list(enc.parameters()) + list(pred.parameters()), lr=0.01, weight_decay=1e-4)
bce_loss = nn.BCELoss()

data = data.to(device)

# We'll precompute edge_index_dict for conv usage
edge_index_dict = {
    ('order','has','component'): data['order','has','component'].edge_index,
    ('component','rev_has','order'): data['component','rev_has','order'].edge_index
}

# Flatten features dict into x_dict for forward
x_dict = {'order': data['order'].x, 'component': data['component'].x}

pairs = pairs_tensor.to(device)

def evaluate_model():
    enc.eval(); pred.eval()
    with torch.no_grad():
        emb = enc(x_dict, edge_index_dict)
        order_embs = emb['order']
        comp_embs = emb['component']
        y_true = []
        y_score = []
        for (o,c,label) in train_pairs:
            o = int(o); c = int(c); label = int(label)
            y_true.append(label)
            score = pred(order_embs[o:o+1], comp_embs[c:c+1]).cpu().item()
            y_score.append(score)
        auc = roc_auc_score(y_true, y_score)
    return auc

# Training
epochs = 40
for epoch in range(1, epochs+1):
    enc.train(); pred.train()
    opt.zero_grad()
    emb = enc(x_dict, edge_index_dict)
    order_embs = emb['order']
    comp_embs = emb['component']
    o_idx = pairs[:,0]
    c_idx = pairs[:,1]
    labels = pairs[:,2].float().to(device)
    o_emb = order_embs[o_idx]
    c_emb = comp_embs[c_idx]
    scores = pred(o_emb, c_emb)
    loss = bce_loss(scores, labels)
    loss.backward()
    opt.step()
    if epoch % 5 == 0 or epoch==1:
        auc = evaluate_model()
        print(f"Epoch {epoch:03d} - Loss {loss.item():.4f} - AUC {auc:.4f}")

# final evaluation on training-like pairs
final_auc = evaluate_model()
print("Final AUC on training pairs:", final_auc)

# -------------------------
# Inference logic: nearest-match with edit distance + exact/non-exact policy
# -------------------------
# helper to find nearest model by edit distance
def nearest_model_by_edit_distance(model_number: str, model_dict: dict):
    best = None
    best_dist = None
    for m in model_dict.keys():
        d = edit_distance(model_number, m)
        if best is None or d < best_dist:
            best = m
            best_dist = d
    return best, best_dist

# link scoring function using trained model (returns probability)
def score_link_for_order_component(order_id: str, comp_name: str):
    """
    For inference we may have orders not present in the training set.
    For simplicity we will create a temporary order embedding by:
     - if the order exists as a training order: use its embedding
     - else: create an embedding from average of bootstrap component embeddings or zeros fallback
    """
    enc.eval(); pred.eval()
    with torch.no_grad():
        emb = enc(x_dict, edge_index_dict)
        order_embs = emb['order']
        comp_embs = emb['component']

        # if order exists in training orders, use its embedding
        if order_id in order_to_idx:
            oidx = order_to_idx[order_id]
            o_vec = order_embs[oidx:oidx+1]
        else:
            # fallback: zeros vector (could be replaced with learned prior or embedding generated by
            # averaging embeddings of contextual components; caller can pass bootstrap context)
            o_vec = torch.zeros((1, order_embs.shape[1]), device=device)

        if comp_name not in comp_to_idx:
            # unknown component -> return low probability
            return 0.0
        cvec = comp_embs[comp_to_idx[comp_name]: comp_to_idx[comp_name]+1]
        return float(pred(o_vec, cvec).cpu().item())

# A helper that scores components given a bootstrap set: create a temporary order embedding as mean of bootstrap components embeddings
def score_component_given_bootstrap(bootstrap_components: set, candidate: str):
    enc.eval(); pred.eval()
    with torch.no_grad():
        emb = enc(x_dict, edge_index_dict)
        comp_embs = emb['component']
        if candidate not in comp_to_idx:
            return 0.0
        cvec = comp_embs[comp_to_idx[candidate]: comp_to_idx[candidate]+1]
        if len(bootstrap_components) == 0:
            # no context -> use global prior: mean order embedding (mean of all order embeddings)
            order_emb_mean = emb['order'].mean(dim=0, keepdim=True)
            return float(pred(order_emb_mean, cvec).cpu().item())
        else:
            idxs = [comp_to_idx[c] for c in bootstrap_components if c in comp_to_idx]
            if not idxs:
                order_emb_mean = emb['order'].mean(dim=0, keepdim=True)
                return float(pred(order_emb_mean, cvec).cpu().item())
            mean_vec = comp_embs[idxs].mean(dim=0, keepdim=True)
            # Optionally transform mean_vec to order-space; here we directly use as order embedding
            return float(pred(mean_vec, cvec).cpu().item())

# Now run through example orders applying the decision rules
print("\n--- Inference decisions for example orders ---\n")
for qo in example_orders:
    qid = qo['order_id']
    q_model = qo['model_number']
    q_comps = set(qo['components'])
    nearest, dist = nearest_model_by_edit_distance(q_model, model_dict)
    model_comps = model_dict[nearest]

    print(f"Order {qid}: model_number='{q_model}', nearest_model='{nearest}' (edit_dist={dist})")
    # exact model-number match path
    if q_model in model_dict and q_model == nearest and q_comps == model_comps:
        print(" -> DECISION: ACCEPT (exact model number & exact component set match)\n")
        continue
    if q_model in model_dict and q_model == nearest and q_comps != model_comps:
        print(" -> DECISION: REJECT (model number exists but component sets differ)\n")
        continue

    # non-exact nearest: bootstrap & use GNN link predictor
    bootstrap = q_comps & model_comps
    query_uniques = q_comps - model_comps
    model_uniques = model_comps - q_comps

    print(" Bootstrap (intersection):", bootstrap)
    # score query unique components (are they likely given bootstrap?)
    if query_uniques:
        print(" Query-unique components (scores):")
        for c in sorted(query_uniques):
            p = score_component_given_bootstrap(bootstrap, c)
            flag = "suspect_extra" if p < 0.5 else "ok_extra"
            print(f"   {c}: p={p:.3f} -> {flag}")
    else:
        print(" Query has no unique components (no extras)")

    # score model unique components (are they likely missing from query?)
    if model_uniques:
        print(" Model-unique components (scores):")
        for c in sorted(model_uniques):
            p = score_component_given_bootstrap(bootstrap, c)
            flag = "should_be_included" if p > 0.5 else "could_be_missing"
            print(f"   {c}: p={p:.3f} -> {flag}")
    else:
        print(" Model has no unique components (query covers model components)")

    print("")  # newline between orders

# End of script

Epoch 001 - Loss 0.6914 - AUC 0.9179
Epoch 005 - Loss 0.3606 - AUC 0.9275
Epoch 010 - Loss 0.3244 - AUC 0.9213
Epoch 015 - Loss 0.2945 - AUC 0.9443
Epoch 020 - Loss 0.2329 - AUC 0.9719
Epoch 025 - Loss 0.1716 - AUC 0.9872
Epoch 030 - Loss 0.1416 - AUC 0.9873
Epoch 035 - Loss 0.1164 - AUC 0.9892
Epoch 040 - Loss 0.1035 - AUC 0.9904
Final AUC on training pairs: 0.9904479091895186

--- Inference decisions for example orders ---

Order O-001: model_number='MX100', nearest_model='MX100' (edit_dist=0)
 -> DECISION: ACCEPT (exact model number & exact component set match)

Order O-002: model_number='MX100', nearest_model='MX100' (edit_dist=0)
 -> DECISION: REJECT (model number exists but component sets differ)

Order O-003: model_number='MX1X', nearest_model='MX100' (edit_dist=2)
 Bootstrap (intersection): {'motor', 'frame', 'sensorA'}
 Query-unique components (scores):
   antenna: p=0.482 -> suspect_extra
 Model-unique components (scores):
   battery: p=0.994 -> should_be_included

Order O-00

In [2]:
!pip install torch torch_geometric scikit-learn
# NOTE: torch_geometric installation often requires following their install instructions:
# https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m55.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [3]:
# extended_gnn_pipeline_with_validation.py
# Full script: builds the hetero graph, trains a GNN link predictor, then validates
# on synthetic "good" and "bad" holdout orders producing a confusion matrix.

import random
import itertools
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, confusion_matrix
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv, Linear

# -------------------------
# Utilities
# -------------------------
def edit_distance(a: str, b: str) -> int:
    """Levenshtein distance (dynamic programming)."""
    la, lb = len(a), len(b)
    dp = [[0] * (lb + 1) for _ in range(la + 1)]
    for i in range(la + 1):
        dp[i][0] = i
    for j in range(lb + 1):
        dp[0][j] = j
    for i in range(1, la + 1):
        for j in range(1, lb + 1):
            cost = 0 if a[i-1] == b[j-1] else 1
            dp[i][j] = min(dp[i-1][j] + 1, dp[i][j-1] + 1, dp[i-1][j-1] + cost)
    return dp[la][lb]

def jaccard(a: set, b: set) -> float:
    if not a and not b: return 1.0
    return len(a & b) / len(a | b)

# -------------------------
# Synthetic data (same as earlier)
# -------------------------
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

model_dict = {
    "MX100": {"motor", "frame", "sensorA", "battery"},
    "MX200": {"motor", "frame", "sensorB", "battery", "shield"},
    "AX1":   {"motor", "frame", "sensorA", "battery", "antenna"},
    "AX2":   {"motor", "frame", "sensorC", "battery"},
    "BX9":   {"motor", "frame", "battery"}
}

# Historical orders used for training
history_orders = []
component_pool = set().union(*model_dict.values()).union({"cable","extraPart","sensorD"})
model_keys = list(model_dict.keys())
for i in range(500):
    model = random.choice(model_keys)
    base = set(model_dict[model])
    # randomly drop or add noise
    if random.random() < 0.12 and len(base) > 1:
        base = set(random.sample(list(base), len(base)-1))
    if random.random() < 0.12:
        base = base.union({random.choice(list(component_pool))})
    history_orders.append({"order_id": f"H-{i:03d}", "model": model, "components": base})

# Example queries (for manual inspection later)
example_orders = [
    {"order_id": "O-001", "model_number": "MX100", "components": {"motor", "frame", "sensorA", "battery"}},  # exact -> ACCEPT
    {"order_id": "O-002", "model_number": "MX100", "components": {"motor", "frame", "sensorA"}},  # exact model exists but missing -> REJECT
    {"order_id": "O-003", "model_number": "MX1X", "components": {"motor", "frame", "sensorA", "antenna"}},  # non-exact -> bootstrap predict
    {"order_id": "O-004", "model_number": "AXX", "components": {"motor", "frame", "battery", "shield"}},  # non-exact -> bootstrap
    {"order_id": "O-005", "model_number": "BX9",  "components": {"motor", "frame", "battery", "extraPart"}}   # exact model exists but extra -> REJECT
]

# -------------------------
# Build HeteroData graph (order nodes and component nodes)
# -------------------------
all_components = sorted(list(component_pool.union(*[o["components"] for o in history_orders])))
comp_to_idx = {c: i for i, c in enumerate(all_components)}

order_nodes = []
for i, ho in enumerate(history_orders):
    order_nodes.append(ho["order_id"])
order_to_idx = {oid: i for i, oid in enumerate(order_nodes)}

data = HeteroData()

# Node features: simple seed embeddings (learnable)
num_order_nodes = len(order_nodes)
num_comp_nodes = len(all_components)
order_feat_dim = 32
comp_feat_dim = 32

data['order'].x = torch.randn((num_order_nodes, order_feat_dim), dtype=torch.float)
data['component'].x = torch.randn((num_comp_nodes, comp_feat_dim), dtype=torch.float)

# Build order->component edges
src_orders = []
dst_components = []
for ho in history_orders:
    oidx = order_to_idx[ho['order_id']]
    for c in ho['components']:
        if c not in comp_to_idx: continue
        cidx = comp_to_idx[c]
        src_orders.append(oidx)
        dst_components.append(cidx)

data['order', 'has', 'component'].edge_index = torch.tensor([src_orders, dst_components], dtype=torch.long)
data['component', 'rev_has', 'order'].edge_index = torch.tensor([dst_components, src_orders], dtype=torch.long)

# -------------------------
# Hetero GNN + Link Predictor
# -------------------------
class HeteroGNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = HeteroConv({
            ('order','has','component'): SAGEConv((-1, -1), hidden_channels),
            ('component','rev_has','order'): SAGEConv((-1, -1), hidden_channels)
        }, aggr='mean')
        self.conv2 = HeteroConv({
            ('order','has','component'): SAGEConv((-1, -1), out_channels),
            ('component','rev_has','order'): SAGEConv((-1, -1), out_channels)
        }, aggr='mean')
        self.order_lin = Linear(out_channels, out_channels)
        self.comp_lin  = Linear(out_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x = self.conv1(x_dict, edge_index_dict)
        x = {k: F.relu(v) for k,v in x.items()}
        x = self.conv2(x, edge_index_dict)
        x_order = self.order_lin(x['order'])
        x_comp  = self.comp_lin(x['component'])
        return {'order': x_order, 'component': x_comp}

class LinkPredictor(torch.nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim*2, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, 1)
        )
    def forward(self, order_emb, comp_emb):
        x = torch.cat([order_emb, comp_emb], dim=-1)
        return torch.sigmoid(self.mlp(x)).squeeze(-1)

# -------------------------
# Prepare training pairs (positive + negative)
# -------------------------
adj = defaultdict(set)
for s, t in zip(src_orders, dst_components):
    adj[s].add(t)

positive_pairs = [(s, t) for s, t in zip(src_orders, dst_components)]

def sample_negative_for_order(o_idx, k=1):
    negatives = []
    while len(negatives) < k:
        c = random.randrange(num_comp_nodes)
        if c not in adj[o_idx]:
            negatives.append(c)
    return negatives

train_pairs = []
for (o, c) in positive_pairs:
    train_pairs.append((o, c, 1))
    negs = sample_negative_for_order(o, k=1)
    for nc in negs:
        train_pairs.append((o, nc, 0))

pairs_tensor = torch.tensor([[o,c,label] for (o,c,label) in train_pairs], dtype=torch.long)

# -------------------------
# Training
# -------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
enc = HeteroGNNEncoder(hidden_channels=64, out_channels=64).to(device)
pred = LinkPredictor(emb_dim=64).to(device)

opt = torch.optim.Adam(list(enc.parameters()) + list(pred.parameters()), lr=0.01, weight_decay=1e-4)
bce_loss = nn.BCELoss()

data = data.to(device)
edge_index_dict = {
    ('order','has','component'): data['order','has','component'].edge_index,
    ('component','rev_has','order'): data['component','rev_has','order'].edge_index
}
x_dict = {'order': data['order'].x, 'component': data['component'].x}
pairs = pairs_tensor.to(device)

def evaluate_model_auc(pairs_list):
    enc.eval(); pred.eval()
    with torch.no_grad():
        emb = enc(x_dict, edge_index_dict)
        order_embs = emb['order']
        comp_embs = emb['component']
        y_true = []
        y_score = []
        for (o,c,label) in pairs_list:
            o = int(o); c = int(c); label = int(label)
            y_true.append(label)
            score = pred(order_embs[o:o+1], comp_embs[c:c+1]).cpu().item()
            y_score.append(score)
        return roc_auc_score(y_true, y_score)

epochs = 40
for epoch in range(1, epochs+1):
    enc.train(); pred.train()
    opt.zero_grad()
    emb = enc(x_dict, edge_index_dict)
    order_embs = emb['order']
    comp_embs = emb['component']
    o_idx = pairs[:,0]
    c_idx = pairs[:,1]
    labels = pairs[:,2].float().to(device)
    o_emb = order_embs[o_idx]
    c_emb = comp_embs[c_idx]
    scores = pred(o_emb, c_emb)
    loss = bce_loss(scores, labels)
    loss.backward()
    opt.step()
    if epoch % 5 == 0 or epoch==1:
        auc = evaluate_model_auc(train_pairs)
        print(f"Epoch {epoch:03d} - Loss {loss.item():.4f} - AUC {auc:.4f}")

final_auc = evaluate_model_auc(train_pairs)
print("Final AUC on training-like pairs:", final_auc)

# -------------------------
# Inference helper functions (reuse encoder & predictor)
# -------------------------
def nearest_model_by_edit_distance(model_number: str, model_dict: dict):
    best = None
    best_dist = None
    for m in model_dict.keys():
        d = edit_distance(model_number, m)
        if best is None or d < best_dist:
            best = m
            best_dist = d
    return best, best_dist

# Score candidate component given bootstrap context using trained models
def score_component_given_bootstrap(bootstrap_components: set, candidate: str, enc_model, pred_model, x_dict_local, edge_index_local, comp_to_idx_local):
    enc_model.eval(); pred_model.eval()
    with torch.no_grad():
        emb = enc_model(x_dict_local, edge_index_local)
        comp_embs = emb['component']
        if candidate not in comp_to_idx_local:
            return 0.0
        cvec = comp_embs[comp_to_idx_local[candidate]: comp_to_idx_local[candidate]+1]
        if len(bootstrap_components) == 0:
            order_emb_mean = emb['order'].mean(dim=0, keepdim=True)
            return float(pred_model(order_emb_mean, cvec).cpu().item())
        else:
            idxs = [comp_to_idx_local[c] for c in bootstrap_components if c in comp_to_idx_local]
            if not idxs:
                order_emb_mean = emb['order'].mean(dim=0, keepdim=True)
                return float(pred_model(order_emb_mean, cvec).cpu().item())
            mean_vec = comp_embs[idxs].mean(dim=0, keepdim=True)
            # use mean_vec as a proxy order embedding
            return float(pred_model(mean_vec, cvec).cpu().item())

# Decision function: returns True if accept, False if reject
def decide_accept_or_reject(query_order, model_dict, comp_to_idx_local, enc_model, pred_model, x_dict_local, edge_index_local, threshold=0.5):
    """
    Query order: dict with keys 'order_id', 'model_number', 'components'
    Returns: (accept_bool, reason_dict)
    """
    qid = query_order['order_id']
    qmodel = query_order['model_number']
    qcomps = set(query_order['components'])
    nearest, dist = nearest_model_by_edit_distance(qmodel, model_dict)
    model_comps = model_dict[nearest]

    # Exact model-number exists path
    if qmodel in model_dict and qmodel == nearest:
        if qcomps == model_comps:
            return True, {"route": "exact_match_components_equal", "nearest": nearest, "dist": dist}
        else:
            return False, {"route": "exact_model_components_differ", "nearest": nearest, "dist": dist,
                           "model_components": model_comps}

    # Non-exact: bootstrap + link predictor scoring
    bootstrap = qcomps & model_comps
    query_uniques = qcomps - model_comps
    model_uniques = model_comps - qcomps

    # If bootstrap is empty we still proceed; scores will rely on global priors
    # Check query-unique components: if any is unlikely given bootstrap -> reject (extra)
    for c in query_uniques:
        p = score_component_given_bootstrap(bootstrap, c, enc_model, pred_model, x_dict_local, edge_index_local, comp_to_idx_local)
        if p < threshold:
            return False, {"route": "reject_extra", "nearest": nearest, "dist": dist, "bootstrap": bootstrap, "offending_component": c, "prob": p}

    # Check model-unique components: if any is likely given bootstrap -> reject (missing)
    for c in model_uniques:
        p = score_component_given_bootstrap(bootstrap, c, enc_model, pred_model, x_dict_local, edge_index_local, comp_to_idx_local)
        if p > threshold:
            return False, {"route": "reject_missing", "nearest": nearest, "dist": dist, "bootstrap": bootstrap, "missing_component": c, "prob": p}

    # If passed both checks -> accept
    return True, {"route": "non_exact_accept", "nearest": nearest, "dist": dist, "bootstrap": bootstrap}

# -------------------------
# Generate holdout "good" and "bad" orders
# -------------------------
# Good orders: should be accepted by policy. Create:
#  - exact matches for some models
#  - near variants that include plausible optional parts (co-occurring)
# Bad orders:
#  - missing core component(s)
#  - include unlikely extras (random extras that don't co-occur)
good_holdout = []
bad_holdout = []

# 1) exact-good: exact component sets
for i, m in enumerate(model_keys):
    good_holdout.append({"order_id": f"G-exact-{i}", "model_number": m, "components": set(model_dict[m])})

# 2) plausible-variant-good: bootstrap will support optional parts (we create extras that co-occur often)
# Build frequency co-occurrence from training history to pick plausible extras
cooc_counts = defaultdict(int)
for ho in history_orders:
    comps = list(ho['components'])
    for a in comps:
        for b in comps:
            if a != b:
                cooc_counts[(a,b)] += 1

# For each model, pick a common co-occurring part (if exists) and add to make it still "good"
for i, m in enumerate(model_keys):
    base = set(model_dict[m])
    # find candidate extra with high co-occurrence with many base parts
    candidates = []
    for cand in all_components:
        if cand in base: continue
        # score = sum cooc with base parts
        score = sum(cooc_counts.get((b, cand), 0) for b in base)
        if score > 0:
            candidates.append((score, cand))
    if candidates:
        candidates.sort(reverse=True)
        chosen = candidates[0][1]
        # create a good order variant that includes plausible extra
        good_holdout.append({"order_id": f"G-variant-{i}", "model_number": m + "X", "components": base.union({chosen})})

# 3) bad orders: missing required components (drop core parts)
for i, m in enumerate(model_keys):
    base = set(model_dict[m])
    # drop a core component (prefer non-optional: here just drop element)
    if len(base) > 1:
        dropped = random.choice(list(base))
        bad_holdout.append({"order_id": f"B-missing-{i}", "model_number": m + "BAD", "components": base - {dropped}})

# 4) bad orders: add unlikely extras (random), pick parts with low co-occurrence
rare_candidates = [c for c in all_components if all(cooc_counts.get((b,c),0) < 3 for b in all_components)]
for i in range(len(model_keys)):
    m = random.choice(model_keys)
    base = set(model_dict[m])
    rare = random.choice(rare_candidates) if rare_candidates else "extraPart"
    bad_holdout.append({"order_id": f"B-extra-{i}", "model_number": m + "BADX", "components": base.union({rare})})

# Ensure holdouts are reasonably sized
print("Generated holdout sizes:", len(good_holdout), "good |", len(bad_holdout), "bad")

# -------------------------
# Validation: run decision for each holdout and produce confusion matrix
# -------------------------
y_true = []
y_pred = []

# Threshold for link predictor decisions
threshold = 0.5

# Evaluate good_holdout: expected label = 1 (good)
for qo in good_holdout:
    accept, info = decide_accept_or_reject(qo, model_dict, comp_to_idx, enc, pred, x_dict, edge_index_dict, threshold=threshold)
    y_true.append(1)
    y_pred.append(1 if accept else 0)
    # optional: you can print per-order info for debugging
    # print("GOOD", qo['order_id'], "->", "ACCEPT" if accept else "REJECT", info)

# Evaluate bad_holdout: expected label = 0 (bad)
for qo in bad_holdout:
    accept, info = decide_accept_or_reject(qo, model_dict, comp_to_idx, enc, pred, x_dict, edge_index_dict, threshold=threshold)
    y_true.append(0)
    y_pred.append(1 if accept else 0)
    # optional debug print
    # print("BAD ", qo['order_id'], "->", "ACCEPT" if accept else "REJECT", info)

# Compute confusion matrix: order of labels default (0,1) -> we'll map to standard
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
print("\nConfusion matrix counts:")
print(f"TP (accept good): {tp}")
print(f"FN (reject good): {fn}")
print(f"TN (reject bad): {tn}")
print(f"FP (accept bad): {fp}")

# Print simple metrics
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
accuracy = (tp + tn) / (tp + tn + fp + fn)
print(f"\nValidation metrics: precision={precision:.3f}, recall={recall:.3f}, accuracy={accuracy:.3f}")

# Optionally, show per-case decisions for inspection
print("\nDetailed decisions for holdouts (first 10 shown):")
combined = [("GOOD", qo) for qo in good_holdout] + [("BAD", qo) for qo in bad_holdout]
for label, qo in combined[:10]:
    accept, info = decide_accept_or_reject(qo, model_dict, comp_to_idx, enc, pred, x_dict, edge_index_dict, threshold=threshold)
    print(f"{label} {qo['order_id']} -> {'ACCEPT' if accept else 'REJECT'} via {info['route']} (nearest={info.get('nearest')}, dist={info.get('dist')})")

# -------------------------
# Notes about training paradigm updates for this application
# -------------------------
notes = """
Notes / Recommended updates to training paradigm for better validation fidelity:

1) Context sampling during training:
   - During training we must simulate the exact inference 'bootstrap' situation: randomly sub-sample each training order's components
     to create partial contexts and train the link predictor to predict missing/extra components given the partial context. This was
     approximated by sampling modified historical orders earlier, but you can make it explicit: for each positive example, sample a
     bootstrap subset B subset-of C_order and use (B, candidate) as training context.

2) Hard-negative mining:
   - For validation, the model must handle subtle plausible negatives (e.g., components that often co-occur but are wrong for this variant).
     Hard-negative mining will improve discrimination.

3) Threshold calibration:
   - We used 0.5 as default. In production calibrate threshold on a validation set (ROC/Precision-Recall or cost-sensitive loss depending on FP/FN costs).

4) Deterministic features & order encoder:
   - Replace random node initializations with deterministic features (one-hot model family, product type, tokenized model number embeddings).
     Also add an order-encoder (MLP that consumes a binary bag-of-components or component embeddings) to create consistent inference-time
     order embeddings from bootstrap components.

5) Larger/representative holdout:
   - Validate across multiple families, time-splits (simulate new model numbers), and noisy patterns present in production.

6) Logging and human-in-the-loop:
   - For ambiguous cases (probability near threshold or tiny bootstrap), route to review and collect labelled corrections to improve training data.
"""
print(notes)


Epoch 001 - Loss 0.6920 - AUC 0.9212
Epoch 005 - Loss 0.3649 - AUC 0.9285
Epoch 010 - Loss 0.3114 - AUC 0.9240
Epoch 015 - Loss 0.2828 - AUC 0.9446
Epoch 020 - Loss 0.2121 - AUC 0.9773
Epoch 025 - Loss 0.1635 - AUC 0.9863
Epoch 030 - Loss 0.1372 - AUC 0.9860
Epoch 035 - Loss 0.1171 - AUC 0.9886
Epoch 040 - Loss 0.1192 - AUC 0.9895
Final AUC on training-like pairs: 0.9894556508624753
Generated holdout sizes: 10 good | 10 bad

Confusion matrix counts:
TP (accept good): 8
FN (reject good): 2
TN (reject bad): 9
FP (accept bad): 1

Validation metrics: precision=0.889, recall=0.800, accuracy=0.850

Detailed decisions for holdouts (first 10 shown):
GOOD G-exact-0 -> ACCEPT via exact_match_components_equal (nearest=MX100, dist=0)
GOOD G-exact-1 -> ACCEPT via exact_match_components_equal (nearest=MX200, dist=0)
GOOD G-exact-2 -> ACCEPT via exact_match_components_equal (nearest=AX1, dist=0)
GOOD G-exact-3 -> ACCEPT via exact_match_components_equal (nearest=AX2, dist=0)
GOOD G-exact-4 -> ACCEPT v