In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m62.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


In [5]:
import torch
import pandas as pd
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [22]:
INPUT_PATH = '/data/processed'
OUTPUT_PATH = '/runs/'

In [6]:
# Load retail data

df_train = pd.read_csv(INPUT_PATH+'synt-df_train.csv')
df_val = pd.read_csv(INPUT_PATH+'synt-df_val.csv')
df_test = pd.read_csv(INPUT_PATH+'synt-df_test.csv')
df_train.drop('Unnamed: 0', axis=1, inplace=True)
df_val.drop('Unnamed: 0', axis=1, inplace=True)
df_test.drop('Unnamed: 0', axis=1, inplace=True)

# fix types

for t in ['card_id', 'merchant_id']:
    df_train[t] = df_train[t].astype('category')
    df_val[t] = df_val[t].astype('category')
    df_test[t] = df_test[t].astype('category')

tasks = ['unique_merchant_id', 'count_gt_12_card_id',
       'count_eq_3_card_id', 'double_card_id_merchant_city_ONLINE',
       'duplicate_card_id_merchant_city']

# Build the graph

In [7]:
from typing import Any, Collection, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import HeteroData

def df_to_hetero_for_task(
    df: pd.DataFrame,
    *,
    features: List[str],
    task_col: str,
    row_encoder=None,  # keep your type if you have it
    row_node_type: str = "row",
    feature_node_prefix: str = "",
    include_missing: bool = True,
    missing_token: str = "__MISSING__",
    oov_token: str = "__OOV__",
    add_reverse_edges: bool = True,  # IMPORTANT: default True for row prediction
    store_value_strings: bool = False,
    x_dtype: torch.dtype = torch.float32,
    y_dtype: torch.dtype = torch.float32,
    value2idx_in: Optional[Dict[str, Dict[Any, int]]] = None,
    use_one_hot: Optional[Union[bool, Collection[str]]] = None,
    one_hot_max_size: int = 5000,  # safety guard
) -> Tuple[HeteroData, Dict[str, Dict[Any, int]]]:

    if task_col not in df.columns:
        raise KeyError(f"task_col='{task_col}' not in df.columns.")
    for f in features:
        if f not in df.columns:
            raise KeyError(f"Feature '{f}' not in df.columns.")

    # Normalize one-hot selection
    if use_one_hot is True:
        one_hot_feats = set(features)
    elif use_one_hot:
        one_hot_feats = set(use_one_hot)
    else:
        one_hot_feats = set()

    data = HeteroData()
    n = len(df)

    # Row nodes
    data[row_node_type].num_nodes = n
    data[row_node_type].y = torch.tensor(df[task_col].to_numpy(), dtype=y_dtype).view(-1)

    if row_encoder is not None:
        X, _names = row_encoder.transform(df)
        data[row_node_type].x = torch.tensor(X, dtype=x_dtype)
    # else: leave row.x unset; model should initialize (embedding or constant + type emb)

    row_idx = np.arange(n, dtype=np.int64)

    # Build or reuse vocab
    value2idx: Dict[str, Dict[Any, int]] = {} if value2idx_in is None else value2idx_in

    for feat in features:
        feat_type = f"{feature_node_prefix}{feat}"
        rel_type = f"has_{feat}"
        rev_rel_type = f"rev_{rel_type}"

        col = df[feat]
        if include_missing:
            col = col.where(~col.isna(), other=missing_token)

        if value2idx_in is None:
            # TRAIN: build vocab
            uniques = list(pd.unique(col if include_missing else df[feat].dropna()))
            if include_missing and missing_token not in uniques:
                uniques.append(missing_token)
            if oov_token not in uniques:
                uniques.append(oov_token)
            v2i = {v: i for i, v in enumerate(uniques)}
            value2idx[feat] = v2i
        else:
            # VAL/TEST: reuse vocab
            v2i = value2idx_in[feat]
            if oov_token not in v2i:
                raise ValueError(f"oov_token='{oov_token}' missing in value2idx_in[{feat!r}]")

        num_vals = len(v2i)
        data[feat_type].num_nodes = num_vals

        # Optional one-hot for small vocab only
        if (feat in one_hot_feats) and (num_vals <= one_hot_max_size):
            data[feat_type].x = torch.eye(num_vals, dtype=x_dtype)

        if store_value_strings and value2idx_in is None:
            data[feat_type].value = list(v2i.keys())

        # Edges row -> value (updates value nodes)
        oov_idx = v2i[oov_token]
        vals = col.to_numpy()

        dst = np.fromiter((v2i.get(v, oov_idx) for v in vals), dtype=np.int64, count=len(vals))
        edge_index = torch.tensor(np.vstack([row_idx, dst]), dtype=torch.long)
        data[(row_node_type, rel_type, feat_type)].edge_index = edge_index

        # Reverse edges value -> row (updates row nodes)  ***CRITICAL***
        if add_reverse_edges:
            data[(feat_type, rev_rel_type, row_node_type)].edge_index = edge_index.flip(0).contiguous()

    return data, value2idx


In [8]:
tasks = ['unique_StockCode',
       'count_gt_75_InvoiceNo', 'count_eq_15_InvoiceNo',
       'double_InvoiceNo_StockCode_23084', 'duplicate_CustomerID_StockCode']

features = ['InvoiceNo', 'StockCode', 'CustomerID']

task_col = 'unique_StockCode'
data_train, v2i_train = df_to_hetero_for_task(
    df_train,
    features=features,
    task_col=task_col,
    row_encoder=None,          # we will attach row.x ourselves
    add_reverse_edges=True,
    include_missing=True,
    missing_token="__MISSING__",
    oov_token="__OOV__",       # still useful for missing values
    value2idx_in=None,         # IMPORTANT: per-split vocab
    use_one_hot=None,
)

data_val, v2i_val = df_to_hetero_for_task(
    df_val,
    features=features,
    task_col=task_col,
    row_encoder=None,
    add_reverse_edges=True,
    include_missing=True,
    missing_token="__MISSING__",
    oov_token="__OOV__",
    value2idx_in=None,         # IMPORTANT: per-split vocab
    use_one_hot=None,
)

data_test, v2i_test = df_to_hetero_for_task(
    df_test,
    features=features,
    task_col=task_col,
    row_encoder=None,
    add_reverse_edges=True,
    include_missing=True,
    missing_token="__MISSING__",
    oov_token="__OOV__",
    value2idx_in=None,         # IMPORTANT: per-split vocab
    use_one_hot=None,
)


# Preprocess node level row fearures

In [9]:
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler
from typing import List


class RowFeaturePreprocessor:
    def __init__(
        self,
        *,
        text_col: str,
        numeric_cols: List[str],
        max_tfidf_features: int = 128,
        min_df: int = 2,
        lowercase: bool = True,
    ):
        self.text_col = text_col
        self.numeric_cols = numeric_cols
        self.max_tfidf_features = max_tfidf_features
        self.min_df = min_df
        self.lowercase = lowercase

        self._tfidf = None
        self._scaler = None
        self._fitted = False

    def fit(self, df: pd.DataFrame):
        # ---- text ----
        text = (
            df[self.text_col]
            .fillna("")
            .astype(str)
        )

        self._tfidf = TfidfVectorizer(
            max_features=self.max_tfidf_features,
            min_df=self.min_df,
            lowercase=self.lowercase,
            ngram_range=(1, 2),
        )
        self._tfidf.fit(text)

        # ---- numeric ----
        X_num = df[self.numeric_cols].astype(float).fillna(0.0).to_numpy()
        self._scaler = StandardScaler()
        self._scaler.fit(X_num)

        self._fitted = True
        return self

    def transform(self, df: pd.DataFrame) -> np.ndarray:
        if not self._fitted:
            raise RuntimeError("RowFeaturePreprocessor must be fitted first")

        # text features
        text = (
            df[self.text_col]
            .fillna("")
            .astype(str)
        )
        X_text = self._tfidf.transform(text).toarray()

        # numeric features
        X_num = df[self.numeric_cols].astype(float).fillna(0.0).to_numpy()
        X_num = self._scaler.transform(X_num)

        # concat
        return np.hstack([X_num, X_text])

    @property
    def output_dim(self) -> int:
        if not self._fitted:
            raise RuntimeError("Not fitted")
        return len(self.numeric_cols) + len(self._tfidf.get_feature_names_out())


In [10]:
pp = RowFeaturePreprocessor(
    text_col="Description",
    numeric_cols=["Quantity", "UnitPrice"],
    max_tfidf_features=128,   # adjust as needed
).fit(df_train)

data_train["row"].x = torch.tensor(pp.transform(df_train), dtype=torch.float32)
data_val["row"].x   = torch.tensor(pp.transform(df_val),   dtype=torch.float32)
data_test["row"].x  = torch.tensor(pp.transform(df_test),  dtype=torch.float32)

print("row.x dim:", data_train["row"].x.shape[1])


row.x dim: 130


# Define GNN and training loop

In [11]:
from __future__ import annotations

from typing import Dict, List, Tuple
import hashlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv


def _stable_hash_to_bin(s: str, dim: int) -> int:
    h = hashlib.md5(s.encode("utf-8")).hexdigest()
    return int(h[:8], 16) % dim


# Option 1: HeteroSage
class HeteroSAGE_RowX_Structural(nn.Module):
    """
    HeteroGraphSAGE for:
      - row nodes with real features: data['row'].x may be None
      - categorical/value nodes without x: initialized structurally (type vector + fixed hash)
    Intended for inductive (separate-graph) evaluation.
    """

    def __init__(
        self,
        metadata: Tuple[List[str], List[Tuple[str, str, str]]],
        *,
        max_num_nodes_by_type: Dict[str, int],
        row_node_type: str = "row",
        row_in_dim: int = 0,
        hidden_channels: int = 64,
        num_layers: int = 2,
        dropout: float = 0.1,
        sage_aggr: str = "sum",
        hash_scale: float = 1e-3,
        hash_seed: int = 12345,
        head_hidden: int = 64,
    ):
        super().__init__()
        self.node_types, self.edge_types = metadata
        self.row_node_type = row_node_type
        self.hidden = hidden_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.hash_scale = hash_scale

        # --- Type vectors (learnable) ---
        self.type_vec = nn.ParameterDict()
        for ntype in self.node_types:
            p = nn.Parameter(torch.zeros(hidden_channels))
            nn.init.normal_(p, mean=0.0, std=0.02)
            self.type_vec[ntype] = p

        # --- Row feature projection (OPTIONAL) ---
        self.row_in_dim = int(row_in_dim)
        self.row_proj = nn.Linear(self.row_in_dim, hidden_channels) if self.row_in_dim > 0 else None

        # --- Fixed hash buffers per node type (deterministic symmetry breaker) ---
        g = torch.Generator()
        g.manual_seed(hash_seed)

        self.hash_buf = nn.ModuleDict()
        for ntype in self.node_types:
            if ntype not in max_num_nodes_by_type:
                raise KeyError(f"max_num_nodes_by_type missing node type: {ntype!r}")
            nmax = int(max_num_nodes_by_type[ntype])
            m = nn.Module()
            H = torch.randn((nmax, hidden_channels), generator=g) * hash_scale
            m.register_buffer("H", H)
            self.hash_buf[ntype] = m

        # --- Hetero GraphSAGE layers ---
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(
                HeteroConv(
                    {etype: SAGEConv((-1, -1), hidden_channels, aggr=sage_aggr) for etype in self.edge_types},
                    aggr="sum",
                )
            )

        # --- Head (binary logits on row nodes) ---
        self.head = nn.Sequential(
            nn.Linear(hidden_channels, head_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(head_hidden, 1),
        )

    def _edge_index_dict(self, data: HeteroData):
        device = next(self.parameters()).device
        out = {}
        for etype in self.edge_types:
            store = data[etype]
            if "edge_index" in store and store["edge_index"] is not None and store["edge_index"].numel() > 0:
                out[etype] = store["edge_index"].to(device)
        return out

    def _init_x_dict(self, data: HeteroData) -> Dict[str, torch.Tensor]:
        device = next(self.parameters()).device
        x_dict: Dict[str, torch.Tensor] = {}

        # --- Row nodes: type vec + fixed hash (+ optional projected row.x) ---
        n_row = int(data[self.row_node_type].num_nodes)

        base_row = self.type_vec[self.row_node_type].to(device).unsqueeze(0).expand(n_row, -1)
        H_row = self.hash_buf[self.row_node_type].H[:n_row].to(device)

        has_row_x = hasattr(data[self.row_node_type], "x") and data[self.row_node_type].x is not None

        if has_row_x:
            if self.row_proj is None:
                raise ValueError("row_in_dim=0, but data['row'].x was provided.")
            row_x = data[self.row_node_type].x.to(device)
            if row_x.size(-1) != self.row_proj.in_features:
                raise ValueError(
                    f"row.x dim mismatch: got {row_x.size(-1)} but model expects {self.row_proj.in_features}."
                )
            x_row = base_row + H_row + self.row_proj(row_x)
        else:
            # structural-only
            x_row = base_row + H_row

        x_dict[self.row_node_type] = x_row

        # --- Other node types: type vec + fixed hash (no .x assumed) ---
        for ntype in self.node_types:
            if ntype == self.row_node_type:
                continue
            n = int(data[ntype].num_nodes)
            base = self.type_vec[ntype].to(device).unsqueeze(0).expand(n, -1)
            H = self.hash_buf[ntype].H[:n].to(device)
            x_dict[ntype] = base + H

        return x_dict

    def forward(self, data: HeteroData) -> torch.Tensor:
        data = data.to(next(self.parameters()).device)

        x_dict = self._init_x_dict(data)
        edge_index_dict = self._edge_index_dict(data)

        for li, conv in enumerate(self.convs):
            x_dict = conv(x_dict, edge_index_dict)
            if li < self.num_layers - 1:
                for ntype in x_dict:
                    x_dict[ntype] = F.relu(x_dict[ntype])
                    x_dict[ntype] = F.dropout(x_dict[ntype], p=self.dropout, training=self.training)

        return self.head(x_dict[self.row_node_type])


In [13]:
# Option 2: Sage with Relation Gate

from __future__ import annotations

from collections import defaultdict
from typing import Dict, List, Tuple
import hashlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv


class HeteroSAGE_RowX_RelGate(nn.Module):
    """
    Full-graph hetero GraphSAGE with:
      - row nodes: real features (row.x) projected to hidden
      - other nodes: type vec + fixed hash init
      - SAGEConv(sum) per relation
      - learned relation gating across incoming relations per dst node type
    """

    def __init__(
        self,
        metadata: Tuple[List[str], List[Tuple[str, str, str]]],
        *,
        max_num_nodes_by_type: Dict[str, int],
        row_node_type: str = "row",
        row_in_dim: int = 1,
        hidden_channels: int = 64,
        num_layers: int = 3,
        dropout: float = 0.01,
        sage_aggr: str = "sum",
        hash_scale: float = 1e-3,
        hash_seed: int = 12345,
        rel_gate_hidden: int = 64,
        temperature: float = 1.0,
        head_hidden: int = 64,
    ):
        super().__init__()
        self.node_types, self.edge_types = metadata
        self.row_node_type = row_node_type
        self.hidden = hidden_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.temperature = temperature

        # Type vectors
        self.type_vec = nn.ParameterDict()
        for ntype in self.node_types:
            p = nn.Parameter(torch.zeros(hidden_channels))
            nn.init.normal_(p, mean=0.0, std=0.02)
            self.type_vec[ntype] = p

        # Row feature projection
        self.row_proj = nn.Linear(row_in_dim, hidden_channels)

        # Fixed hash buffers sized by max nodes per type
        g = torch.Generator()
        g.manual_seed(hash_seed)
        self.hash_buf = nn.ModuleDict()
        for ntype in self.node_types:
            nmax = int(max_num_nodes_by_type[ntype])
            m = nn.Module()
            H = torch.randn((nmax, hidden_channels), generator=g) * hash_scale
            m.register_buffer("H", H)
            self.hash_buf[ntype] = m

        # Convs with aggr=None so we can gate relations ourselves
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(
                HeteroConv(
                    {etype: SAGEConv((-1, -1), hidden_channels, aggr=sage_aggr) for etype in self.edge_types},
                    aggr=None,
                )
            )

        # Relation gate per layer and destination type
        self.rel_gate = nn.ModuleList()
        for _ in range(num_layers):
            md = nn.ModuleDict()
            for dst_type in self.node_types:
                md[dst_type] = nn.Sequential(
                    nn.Linear(2 * hidden_channels, rel_gate_hidden),
                    nn.ReLU(),
                    nn.Linear(rel_gate_hidden, 1),
                )
            self.rel_gate.append(md)

        # Head
        self.head = nn.Sequential(
            nn.Linear(hidden_channels, head_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(head_hidden, 1),
        )

    def _edge_index_dict(self, data: HeteroData):
        device = next(self.parameters()).device
        out = {}
        for etype in self.edge_types:
            store = data[etype]
            if "edge_index" in store and store["edge_index"] is not None and store["edge_index"].numel() > 0:
                out[etype] = store["edge_index"].to(device)
        return out

    def _init_x_dict(self, data: HeteroData) -> Dict[str, torch.Tensor]:
        device = next(self.parameters()).device
        x_dict: Dict[str, torch.Tensor] = {}

        for ntype in self.node_types:
            n = int(data[ntype].num_nodes)
            base = self.type_vec[ntype].to(device).unsqueeze(0).expand(n, -1)
            H = self.hash_buf[ntype].H[:n].to(device)

            if ntype == self.row_node_type:
                if not (hasattr(data[ntype], "x") and data[ntype].x is not None):
                    raise ValueError("Expected data['row'].x for row features.")
                row_x = data[ntype].x.to(device)
                x_dict[ntype] = base + H + self.row_proj(row_x)
            else:
                x_dict[ntype] = base + H

        return x_dict

    def _rel_gated_aggregate(
        self,
        layer_i: int,
        out_by_rel: Dict[Tuple[str, str, str], torch.Tensor],
        x_prev: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        # Group by destination type
        grouped = defaultdict(list)
        for etype, h in out_by_rel.items():
            dst = etype[2]
            grouped[dst].append(h)

        x_new: Dict[str, torch.Tensor] = {}

        for dst_type in self.node_types:
            if dst_type not in grouped:
                x_new[dst_type] = x_prev[dst_type]
                continue

            rel_hs = grouped[dst_type]
            if len(rel_hs) == 1:
                x_new[dst_type] = rel_hs[0]
                continue

            h_prev = x_prev[dst_type]
            gate_net = self.rel_gate[layer_i][dst_type]

            scores = []
            for h_rel in rel_hs:
                scores.append(gate_net(torch.cat([h_rel, h_prev], dim=-1)))  # (N,1)

            score_stack = torch.stack(scores, dim=0)  # (R,N,1)
            alpha = F.softmax(score_stack / self.temperature, dim=0)
            rel_stack = torch.stack(rel_hs, dim=0)  # (R,N,H)
            x_new[dst_type] = torch.sum(alpha * rel_stack, dim=0)

        return x_new

    def forward(self, data: HeteroData) -> torch.Tensor:
        data = data.to(next(self.parameters()).device)

        x_dict = self._init_x_dict(data)
        edge_index_dict = self._edge_index_dict(data)

        for li, conv in enumerate(self.convs):
            x_prev = x_dict
            out_by_rel = conv(x_prev, edge_index_dict)  # etype -> dst tensor
            x_dict = self._rel_gated_aggregate(li, out_by_rel, x_prev)

            if li < self.num_layers - 1:
                for ntype in x_dict:
                    x_dict[ntype] = F.relu(x_dict[ntype])
                    x_dict[ntype] = F.dropout(x_dict[ntype], p=self.dropout, training=self.training)

        return self.head(x_dict[self.row_node_type])


## training loop

In [14]:
from __future__ import annotations

import copy
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import f1_score, roc_auc_score


@torch.no_grad()
def _predict_probs(model: nn.Module, data, device: torch.device):
    model.eval()
    d = data.to(device)
    logits = model(d).view(-1)
    probs = torch.sigmoid(logits).detach().cpu().numpy()
    y = d["row"].y.view(-1).detach().cpu().numpy().astype(np.int32)
    return probs, y


def _roc_auc_safe(y: np.ndarray, probs: np.ndarray) -> float:
    if len(np.unique(y)) < 2:
        return float("nan")
    return float(roc_auc_score(y, probs))


def _best_f1_threshold(y: np.ndarray, probs: np.ndarray, grid: int = 400) -> tuple[float, float]:
    ts = np.linspace(0.001, 0.999, grid)
    best_t, best_f1 = 0.5, -1.0
    for t in ts:
        f1 = f1_score(y, (probs >= t).astype(np.int32))
        if f1 > best_f1:
            best_f1, best_t = float(f1), float(t)
    return best_t, best_f1


@torch.no_grad()
def evaluate_at_threshold(model: nn.Module, data, device: torch.device, threshold: float) -> dict:
    probs, y = _predict_probs(model, data, device)
    y_pred = (probs >= threshold).astype(np.int32)
    return {"f1": float(f1_score(y, y_pred)), "roc_auc": _roc_auc_safe(y, probs)}


def train_fullbatch_binary_best_f1(
    model: nn.Module,
    data_train,
    data_val,
    data_test=None,
    *,
    device: torch.device,
    lr: float = 5e-4,
    weight_decay: float = 1e-6,
    epochs: int = 100,
    patience: int = 15,
    grad_clip: float | None = 1.0,
    threshold_grid: int = 400,
    print_every: int = 1,
):
    model = model.to(device)

    # pos_weight from TRAIN
    y_train_np = data_train["row"].y.detach().cpu().numpy()
    pos = float((y_train_np == 1).sum())
    neg = float((y_train_np == 0).sum())
    pos_weight = neg / max(pos, 1.0)

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    best = {
        "val_f1": -1.0,
        "val_auc": float("nan"),
        "threshold": 0.5,
        "epoch": -1,
        "state_dict": None,
    }
    bad_epochs = 0

    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()

        dtr = data_train.to(device)
        logits = model(dtr).view(-1)
        y = dtr["row"].y.view(-1).float()

        loss = criterion(logits, y)
        loss.backward()

        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        optimizer.step()

        # choose threshold on VAL
        val_probs, val_y = _predict_probs(model, data_val, device)
        t_star, val_f1 = _best_f1_threshold(val_y, val_probs, grid=threshold_grid)
        val_auc = _roc_auc_safe(val_y, val_probs)

        improved = val_f1 > best["val_f1"] + 1e-6
        if improved:
            best["val_f1"] = float(val_f1)
            best["val_auc"] = float(val_auc)
            best["threshold"] = float(t_star)
            best["epoch"] = epoch
            best["state_dict"] = copy.deepcopy({k: v.detach().cpu() for k, v in model.state_dict().items()})
            bad_epochs = 0
        else:
            bad_epochs += 1

        if epoch % print_every == 0:
            train_metrics = evaluate_at_threshold(model, data_train, device, threshold=t_star)
            print(
                f"Epoch {epoch:03d} | loss={loss.item():.4f} | thr*={t_star:.3f} | "
                f"train_f1={train_metrics['f1']:.4f} train_auc={train_metrics['roc_auc']:.4f} | "
                f"val_f1={val_f1:.4f} val_auc={val_auc:.4f} | bad={bad_epochs}/{patience}"
            )

        if bad_epochs >= patience:
            break

    if best["state_dict"] is not None:
        model.load_state_dict(best["state_dict"])

    report = {
        "best_epoch": best["epoch"],
        "best_val_f1": best["val_f1"],
        "best_val_auc": best["val_auc"],
        "best_threshold": best["threshold"],
        "pos_weight": float(pos_weight),
    }

    if data_test is not None:
        test_metrics = evaluate_at_threshold(model, data_test, device, threshold=best["threshold"])
        report.update({"test_f1": test_metrics["f1"], "test_auc": test_metrics["roc_auc"]})

    return model, report


In [16]:
metadata = (data_train.node_types, data_train.edge_types)
row_in_dim = data_train["row"].x.shape[1]

max_num_nodes_by_type = {}
for ntype in data_train.node_types:
    max_num_nodes_by_type[ntype] = max(
        int(data_train[ntype].num_nodes),
        int(data_val[ntype].num_nodes),
        int(data_test[ntype].num_nodes),
    )


model = HeteroSAGE_RowX_Structural(
    metadata,
    max_num_nodes_by_type=max_num_nodes_by_type,
    row_node_type="row",
    row_in_dim=row_in_dim,
    hidden_channels=64,
    num_layers=3,
    dropout=0.0,
    sage_aggr="sum",
    hash_scale=1e-3,
    head_hidden=64,
).to(device)

"""model = HeteroSAGE_RowX_RelGate(
    metadata,
    max_num_nodes_by_type=max_num_nodes_by_type,
    row_node_type="row",
    row_in_dim=row_in_dim,
    hidden_channels=128,
    num_layers=2,
    dropout=0.0,
    sage_aggr="sum",
    hash_scale=1e-3,
    head_hidden=64,
).to(device)"""

'model = HeteroSAGE_RowX_RelGate(\n    metadata,\n    max_num_nodes_by_type=max_num_nodes_by_type,\n    row_node_type="row",\n    row_in_dim=row_in_dim,\n    hidden_channels=128,\n    num_layers=2,\n    dropout=0.0,\n    sage_aggr="sum",\n    hash_scale=1e-3,\n    head_hidden=64,\n).to(device)'

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model, report = train_fullbatch_binary_best_f1(
    model,
    data_train,
    data_val,
    data_test,
    device=device,
    lr=1e-3,
    weight_decay=1e-6,
    epochs=200,
)

print(report)


Epoch 001 | loss=94.4181 | thr*=0.864 | train_f1=0.2719 train_auc=0.5228 | val_f1=0.1314 val_auc=0.4197 | bad=0/15
Epoch 002 | loss=15946.8330 | thr*=0.864 | train_f1=0.2716 train_auc=0.5255 | val_f1=0.1332 val_auc=0.4255 | bad=0/15
Epoch 003 | loss=10149.8193 | thr*=0.884 | train_f1=0.2717 train_auc=0.5057 | val_f1=0.1317 val_auc=0.4174 | bad=1/15
Epoch 004 | loss=637.9578 | thr*=0.879 | train_f1=0.3128 train_auc=0.6236 | val_f1=0.1283 val_auc=0.4085 | bad=2/15
Epoch 005 | loss=57.0777 | thr*=0.776 | train_f1=0.3127 train_auc=0.6325 | val_f1=0.1414 val_auc=0.4646 | bad=0/15
Epoch 006 | loss=42.6914 | thr*=0.634 | train_f1=0.3261 train_auc=0.6294 | val_f1=0.1321 val_auc=0.4233 | bad=1/15
Epoch 007 | loss=16.3681 | thr*=0.636 | train_f1=0.2380 train_auc=0.4832 | val_f1=0.1349 val_auc=0.4249 | bad=2/15
Epoch 008 | loss=1671.0797 | thr*=0.574 | train_f1=0.2390 train_auc=0.4881 | val_f1=0.1361 val_auc=0.4292 | bad=3/15
Epoch 009 | loss=1444.7683 | thr*=0.534 | train_f1=0.4510 train_auc=0.7

In [21]:
df = pd.DataFrame([report])
df["model"] = "Sage"
df["task"] = task_col

df.to_csv(OUTPUT_PATH +f"Retail-GNN-{task_col}.csv")