In [None]:
!pip install torch_geometric --quiet

In [None]:
import torch
import pandas as pd


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
INPUT_PATH = '/data/processed'
OUTPUT_PATH = None

In [None]:
# Load synth data

df_train = pd.read_csv(INPUT_PATH+'synt-df_train.csv')
df_val = pd.read_csv(INPUT_PATH+'synt-df_val.csv')
df_test = pd.read_csv(INPUT_PATH+'synt-df_test.csv')
df_train.drop('Unnamed: 0', axis=1, inplace=True)
df_val.drop('Unnamed: 0', axis=1, inplace=True)
df_test.drop('Unnamed: 0', axis=1, inplace=True)

# fix types

for t in ['card_id', 'merchant_id']:
    df_train[t] = df_train[t].astype('category')
    df_val[t] = df_val[t].astype('category')
    df_test[t] = df_test[t].astype('category')

tasks = ['unique_merchant_id', 'count_gt_12_card_id',
       'count_eq_3_card_id', 'double_card_id_merchant_city_ONLINE',
       'duplicate_card_id_merchant_city']

# Build the graph

In [None]:
from typing import Any, Collection, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import HeteroData

def df_to_hetero_for_task(
    df: pd.DataFrame,
    *,
    features: List[str],
    task_col: str,
    row_encoder=None,  # keep your type if you have it
    row_node_type: str = "row",
    feature_node_prefix: str = "",
    include_missing: bool = True,
    missing_token: str = "__MISSING__",
    oov_token: str = "__OOV__",
    add_reverse_edges: bool = True,  # IMPORTANT: default True for row prediction
    store_value_strings: bool = False,
    x_dtype: torch.dtype = torch.float32,
    y_dtype: torch.dtype = torch.float32,
    value2idx_in: Optional[Dict[str, Dict[Any, int]]] = None,
    use_one_hot: Optional[Union[bool, Collection[str]]] = None,
    one_hot_max_size: int = 5000,  # safety guard
) -> Tuple[HeteroData, Dict[str, Dict[Any, int]]]:

    if task_col not in df.columns:
        raise KeyError(f"task_col='{task_col}' not in df.columns.")
    for f in features:
        if f not in df.columns:
            raise KeyError(f"Feature '{f}' not in df.columns.")

    # Normalize one-hot selection
    if use_one_hot is True:
        one_hot_feats = set(features)
    elif use_one_hot:
        one_hot_feats = set(use_one_hot)
    else:
        one_hot_feats = set()

    data = HeteroData()
    n = len(df)

    # Row nodes
    data[row_node_type].num_nodes = n
    data[row_node_type].y = torch.tensor(df[task_col].to_numpy(), dtype=y_dtype).view(-1)

    if row_encoder is not None:
        X, _names = row_encoder.transform(df)
        data[row_node_type].x = torch.tensor(X, dtype=x_dtype)
    # else: leave row.x unset; model should initialize (embedding or constant + type emb)

    row_idx = np.arange(n, dtype=np.int64)

    # Build or reuse vocab
    value2idx: Dict[str, Dict[Any, int]] = {} if value2idx_in is None else value2idx_in

    for feat in features:
        feat_type = f"{feature_node_prefix}{feat}"
        rel_type = f"has_{feat}"
        rev_rel_type = f"rev_{rel_type}"

        col = df[feat]
        if include_missing:
            col = col.where(~col.isna(), other=missing_token)

        if value2idx_in is None:
            # TRAIN: build vocab
            uniques = list(pd.unique(col if include_missing else df[feat].dropna()))
            if include_missing and missing_token not in uniques:
                uniques.append(missing_token)
            if oov_token not in uniques:
                uniques.append(oov_token)
            v2i = {v: i for i, v in enumerate(uniques)}
            value2idx[feat] = v2i
        else:
            # VAL/TEST: reuse vocab
            v2i = value2idx_in[feat]
            if oov_token not in v2i:
                raise ValueError(f"oov_token='{oov_token}' missing in value2idx_in[{feat!r}]")

        num_vals = len(v2i)
        data[feat_type].num_nodes = num_vals

        # Optional one-hot for small vocab only
        if (feat in one_hot_feats) and (num_vals <= one_hot_max_size):
            data[feat_type].x = torch.eye(num_vals, dtype=x_dtype)

        if store_value_strings and value2idx_in is None:
            data[feat_type].value = list(v2i.keys())

        # Edges row -> value (updates value nodes)
        oov_idx = v2i[oov_token]
        vals = col.to_numpy()

        dst = np.fromiter((v2i.get(v, oov_idx) for v in vals), dtype=np.int64, count=len(vals))
        edge_index = torch.tensor(np.vstack([row_idx, dst]), dtype=torch.long)
        data[(row_node_type, rel_type, feat_type)].edge_index = edge_index

        # Reverse edges value -> row (updates row nodes)  ***CRITICAL***
        if add_reverse_edges:
            data[(feat_type, rev_rel_type, row_node_type)].edge_index = edge_index.flip(0).contiguous()

    return data, value2idx


In [None]:
# Example usage: build hetero graphs for ONE binary task column
# Assumes df_train/df_val/df_test are pandas DataFrames


features = ["card_id", "merchant_id", "merchant_city"]  # features as nodes
task_col = "duplicate_card_id_merchant_city"  # pick one task

# 1) Build TRAIN graph and vocab (value2idx) from train only
data_train, value2idx = df_to_hetero_for_task(
    df_train,
    features=features,
    task_col=task_col,
    row_encoder=None,            # skip for your schema
    add_reverse_edges=True,      # critical for row prediction
    include_missing=True,
    missing_token="__MISSING__",
    oov_token="__OOV__",
    use_one_hot=None,            # recommended: distinguish values via embeddings later
)

# 2) Build VAL graph using train vocab (unseen values -> __OOV__)
data_val, _ = df_to_hetero_for_task(
    df_val,
    features=features,
    task_col=task_col,
    row_encoder=None,
    add_reverse_edges=True,
    include_missing=True,
    missing_token="__MISSING__",
    oov_token="__OOV__",
    value2idx_in=value2idx,      # reuse train vocab
    use_one_hot=None,
)

# 3) Build TEST graph using train vocab
data_test, _ = df_to_hetero_for_task(
    df_test,
    features=features,
    task_col=task_col,
    row_encoder=None,
    add_reverse_edges=True,
    include_missing=True,
    missing_token="__MISSING__",
    oov_token="__OOV__",
    value2idx_in=value2idx,      # reuse train vocab
    use_one_hot=None,
)

# Optional sanity checks
print("Train row nodes:", data_train["row"].num_nodes)
print("Val row nodes:  ", data_val["row"].num_nodes)
print("Test row nodes: ", data_test["row"].num_nodes)

print("Node types:", data_train.node_types)
print("Edge types:", data_train.edge_types)

# y is on row nodes
print("Train labels shape:", data_train["row"].y.shape)
print("Val labels shape:  ", data_val["row"].y.shape)
print("Test labels shape: ", data_test["row"].y.shape)


# Define GNN and training loop

In [None]:
from __future__ import annotations

from typing import Dict, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv


class StructuralInductiveHeteroSAGE(nn.Module):
    """
    Inductive hetero GraphSAGE designed for structural/degree-like logical tasks.

    Key properties:
      - No learned ID embeddings for value nodes (avoids memorization across splits).
      - Constant type embeddings per node type.
      - Tiny noise added at init to break exact symmetry (critical to avoid constant outputs).
      - SAGEConv with SUM aggregation (preserves count information).
      - Predicts binary logits on 'row' nodes.

    Recommended:
      - num_layers=2 for "row connected to value with degree k" style tasks.
    """

    def __init__(
        self,
        metadata: Tuple[List[str], List[Tuple[str, str, str]]],
        row_node_type: str = "row",
        hidden_channels: int = 64,
        num_layers: int = 2,
        dropout: float = 0.02,
        sage_aggr: str = "sum",
        init_noise_std: float = 1e-3,  # symmetry breaker
        head_hidden: int = 64,
    ):
        super().__init__()
        self.node_types, self.edge_types = metadata
        self.row_node_type = row_node_type
        self.hidden = hidden_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.init_noise_std = init_noise_std

        # One learnable type vector per node type (broadcast across nodes of that type)
        self.type_vec = nn.ParameterDict()
        for ntype in self.node_types:
            p = nn.Parameter(torch.zeros(hidden_channels))
            nn.init.normal_(p, mean=0.0, std=0.02)
            self.type_vec[ntype] = p

        # Hetero GraphSAGE layers (SUM across relations)
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(
                HeteroConv(
                    {etype: SAGEConv((-1, -1), hidden_channels, aggr=sage_aggr) for etype in self.edge_types},
                    aggr="sum",
                )
            )

        # Small head for binary logits
        self.head = nn.Sequential(
            nn.Linear(hidden_channels, head_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(head_hidden, 1),
        )

    def _init_x_dict(self, data: HeteroData) -> Dict[str, torch.Tensor]:
        device = next(self.parameters()).device
        x_dict: Dict[str, torch.Tensor] = {}
        for ntype in self.node_types:
            n = int(data[ntype].num_nodes)
            base = self.type_vec[ntype].to(device).unsqueeze(0).expand(n, -1)

            # Add tiny noise to break symmetry (same distribution each forward pass)
            if self.init_noise_std > 0:
                base = base + self.init_noise_std * torch.randn((n, self.hidden), device=device)

            x_dict[ntype] = base
        return x_dict

    def _edge_index_dict(self, data: HeteroData) -> Dict[Tuple[str, str, str], torch.Tensor]:
        device = next(self.parameters()).device
        out = {}
        for etype in self.edge_types:
            store = data[etype]
            if "edge_index" in store and store["edge_index"] is not None and store["edge_index"].numel() > 0:
                out[etype] = store["edge_index"].to(device)
        return out

    def forward(self, data: HeteroData) -> torch.Tensor:
        device = next(self.parameters()).device
        data = data.to(device)

        x_dict = self._init_x_dict(data)
        edge_index_dict = self._edge_index_dict(data)

        for li, conv in enumerate(self.convs):
            x_dict = conv(x_dict, edge_index_dict)

            if li < self.num_layers - 1:
                for ntype in x_dict:
                    x_dict[ntype] = F.relu(x_dict[ntype])
                    x_dict[ntype] = F.dropout(x_dict[ntype], p=self.dropout, training=self.training)

        logits = self.head(x_dict[self.row_node_type])
        return logits


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

metadata = (data_train.node_types, data_train.edge_types)

model = StructuralInductiveHeteroSAGE(
    metadata=metadata,
    row_node_type="row",
    hidden_channels=64,
    num_layers=4,
    dropout=0.0,
    sage_aggr="sum",
    init_noise_std=1e-3,
).to(device)


In [None]:
y = data_train["row"].y.cpu().numpy()
pos = (y == 1).sum()
neg = (y == 0).sum()
pos_weight = neg / max(pos, 1)

criterion = torch.nn.BCEWithLogitsLoss(
    pos_weight=torch.tensor([pos_weight], device=device)
)

model.eval()
with torch.no_grad():
    logits = model(data_train.to(device)).view(-1)
print("logits std:", logits.std().item())


In [None]:
from __future__ import annotations

import copy
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import f1_score, roc_auc_score


@torch.no_grad()
def _predict_probs(model: nn.Module, data, device: torch.device):
    model.eval()
    d = data.to(device)
    logits = model(d).view(-1)
    probs = torch.sigmoid(logits).detach().cpu().numpy()
    y = d["row"].y.view(-1).detach().cpu().numpy().astype(np.int32)
    return probs, y


def _roc_auc_safe(y: np.ndarray, probs: np.ndarray) -> float:
    if len(np.unique(y)) < 2:
        return float("nan")
    return float(roc_auc_score(y, probs))


def _best_f1_threshold(y: np.ndarray, probs: np.ndarray, grid: int = 200) -> tuple[float, float]:
    # threshold grid in (0,1)
    ts = np.linspace(0.001, 0.999, grid)
    best_t, best_f1 = 0.5, -1.0
    for t in ts:
        f1 = f1_score(y, (probs >= t).astype(np.int32))
        if f1 > best_f1:
            best_f1, best_t = float(f1), float(t)
    return best_t, best_f1


@torch.no_grad()
def evaluate(model: nn.Module, data, device: torch.device, threshold: float) -> dict:
    probs, y = _predict_probs(model, data, device)
    y_pred = (probs >= threshold).astype(np.int32)
    return {
        "f1": float(f1_score(y, y_pred)),
        "roc_auc": _roc_auc_safe(y, probs),
    }


def train_binary_best_f1(
    model: nn.Module,
    data_train,
    data_val,
    data_test,
    *,
    device: torch.device,
    lr: float = 1e-3,
    weight_decay: float = 1e-4,
    epochs: int = 50,
    grad_clip: float | None = 1.0,
    threshold_grid: int = 200,
):
    model = model.to(device)

    # pos_weight from TRAIN
    y_train_np = data_train["row"].y.detach().cpu().numpy()
    pos = float((y_train_np == 1).sum())
    neg = float((y_train_np == 0).sum())
    pos_weight = neg / max(pos, 1.0)

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    best = {
        "val_f1": -1.0,
        "threshold": 0.5,
        "state_dict": None,
        "epoch": -1,
        "val_auc": float("nan"),
    }

    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()

        dtr = data_train.to(device)
        logits = model(dtr).view(-1)
        y = dtr["row"].y.view(-1).float()

        loss = criterion(logits, y)
        loss.backward()

        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        optimizer.step()

        # Pick threshold on VAL to maximize F1
        val_probs, val_y = _predict_probs(model, data_val, device)
        t_star, val_f1 = _best_f1_threshold(val_y, val_probs, grid=threshold_grid)
        val_auc = _roc_auc_safe(val_y, val_probs)

        # Compute train metrics at the same threshold (for interpretability)
        train_metrics = evaluate(model, data_train, device, threshold=t_star)
        val_metrics = {"f1": float(val_f1), "roc_auc": float(val_auc)}

        # Checkpoint by val F1
        if val_f1 > best["val_f1"]:
            best["val_f1"] = float(val_f1)
            best["val_auc"] = float(val_auc)
            best["threshold"] = float(t_star)
            best["epoch"] = epoch
            best["state_dict"] = copy.deepcopy({k: v.detach().cpu() for k, v in model.state_dict().items()})

        print(
            f"Epoch {epoch:03d} | loss={loss.item():.4f} | "
            f"thr*={t_star:.3f} | "
            f"train_f1={train_metrics['f1']:.4f} train_auc={train_metrics['roc_auc']:.4f} | "
            f"val_f1={val_metrics['f1']:.4f} val_auc={val_metrics['roc_auc']:.4f}"
        )

    # Restore best model
    if best["state_dict"] is not None:
        model.load_state_dict(best["state_dict"])

    # Final report on TEST using best threshold selected on VAL
    test_metrics = evaluate(model, data_test, device, threshold=best["threshold"])

    report = {
        "best_epoch": best["epoch"],
        "best_val_f1": best["val_f1"],
        "best_val_auc": best["val_auc"],
        "best_threshold": best["threshold"],
        "test_f1": test_metrics["f1"],
        "test_auc": test_metrics["roc_auc"],
        "pos_weight": float(pos_weight),
    }
    return model, report


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(task_col)

model, report = train_binary_best_f1(
    model,
    data_train,
    data_val,
    data_test,
    device=device,
    lr=1e-3,
    weight_decay=1e-4,
    epochs=100,
)

print(report)
