# Polymer Property Predictions 



In [None]:
# # general 
import pandas as pd
import numpy as np
import math 
from typing import Dict, Optional
import os 
import random 
from copy import deepcopy
import gc

# TensorFlow
import tensorflow as tf

# PyTorch
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Sampler
from torch.utils.data import Dataset
from torch.optim import AdamW, RMSprop
from torch.amp import GradScaler, autocast

# # PyTorch Geometric
from torch_geometric.nn import GINEConv
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GeoDataLoader
from torch_geometric.utils import to_dense_adj
from torch_geometric.utils import to_dense_batch

# # OGB dataset 
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

# SKlearn 
from sklearn.model_selection import train_test_split

# Local imports 
from dataset_polymer_fixed import LMDBDataset

In [3]:
print("TensorFlow version:", tf.__version__)
print("Built with CUDA:", tf.test.is_built_with_cuda())
print("CUDA available:", tf.test.is_built_with_gpu_support())
print(tf.config.list_physical_devices('GPU'))
# list all GPUs
gpus = tf.config.list_physical_devices('GPU')

# check compute capability if GPU available
if gpus:
    for gpu in gpus:
        details = tf.config.experimental.get_device_details(gpu)
        print(f"Device: {gpu.name}")
        print(f"Compute Capability: {details.get('compute_capability')}")
else:
    print("No GPU found.")

TensorFlow version: 2.10.0
Built with CUDA: True
CUDA available: True
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Device: /physical_device:GPU:0
Compute Capability: (8, 6)


In [None]:
if os.path.exists('/kaggle'):
    DATA_ROOT = '/kaggle/input/neurips-open-polymer-prediction-2025'
    CHUNK_DIR = '/kaggle/working/processed_chunks'  # Writable directory
    BACKBONE_PATH = '/kaggle/input/polymer/hlgap-gnn3d-transformer-pcqm4mv2-v1.pt'
else:
    DATA_ROOT = 'data'
    CHUNK_DIR = os.path.join(DATA_ROOT, 'processed_chunks')
    BACKBONE_PATH = 'hlgap-gnn3d-transformer-pcqm4mv2-v1.pt'

TRAIN_LMDB = os.path.join(CHUNK_DIR, 'polymer_train3d_dist.lmdb')
TEST_LMDB = os.path.join(CHUNK_DIR, 'polymer_test3d_dist.lmdb')

print(f"Data root: {DATA_ROOT}")
print(f"LMDB directory: {CHUNK_DIR}")
print(f"Train LMDB: {TRAIN_LMDB}")
print(f"Test LMDB: {TEST_LMDB}")

# Create LMDBs if they don't exist
if not os.path.exists(TRAIN_LMDB) or not os.path.exists(TEST_LMDB):
    print('Building LMDBs...')
    os.makedirs(CHUNK_DIR, exist_ok=True)
    # Run the LMDB builders
    !python build_polymer_lmdb_fixed.py train
    !python build_polymer_lmdb_fixed.py test
    print('LMDB creation complete.')
else:
    print('LMDBs already exist.')

Data root: data
LMDB directory: data\processed_chunks
Train LMDB: data\processed_chunks\polymer_train3d_dist.lmdb
Test LMDB: data\processed_chunks\polymer_test3d_dist.lmdb
LMDBs already exist.


# Graph Transformer

In [None]:
# -------------------- Parent-aware wiring (CSV parents -> augmented LMDB key_ids) --------------------
label_cols = ['Tg','FFV','Tc','Density','Rg']
task2idx   = {k:i for i,k in enumerate(label_cols)}
AUG_KEY_MULT = 1000  

train_csv = pd.read_csv(os.path.join(DATA_ROOT, "train.csv"))
train_csv["id"] = train_csv["id"].astype(int)

# LMDB ids (augmented key_ids)
lmdb_ids_path = TRAIN_LMDB + ".ids.txt"
lmdb_ids = np.loadtxt(lmdb_ids_path, dtype=np.int64)
if lmdb_ids.ndim == 0: lmdb_ids = lmdb_ids.reshape(1)

# Parent map (preferred) -> key_id list per parent
pmap_path = TRAIN_LMDB + ".parent_map.tsv"
if os.path.exists(pmap_path):
    pmap = pd.read_csv(pmap_path, sep="\t") # cols: key_id, parent_id, aug_idx, seed
    pmap["key_id"] = pmap["key_id"].astype(np.int64)
    pmap["parent_id"] = pmap["parent_id"].astype(np.int64)
else:
    # fallback if parent_map missing: derive parent by integer division
    pmap = pd.DataFrame({
        "key_id": lmdb_ids.astype(np.int64),
        "parent_id": (lmdb_ids // AUG_KEY_MULT).astype(np.int64),
    })
parents_in_lmdb = np.sort(pmap["parent_id"].unique().astype(np.int64))

def parents_with_label(task: str) -> np.ndarray:
    m = ~train_csv[task].isna()
    have = train_csv.loc[m, "id"].astype(int).values # parents that have this label
    return np.intersect1d(have, parents_in_lmdb, assume_unique=False)

# Split by parents so we have no leakage and then we can expand to augmented key_ids 
def task_parent_split_keys(task: str, test_size=0.2, seed=42):
    parents_labeled = parents_with_label(task)
    if parents_labeled.size == 0:
        raise ValueError(f"No parents with labels for {task}")
    p_tr, p_va = train_test_split(parents_labeled, test_size=test_size, random_state=seed)
    tr_keys = pmap.loc[pmap.parent_id.isin(p_tr), "key_id"].astype(np.int64).values
    va_keys = pmap.loc[pmap.parent_id.isin(p_va), "key_id"].astype(np.int64).values
    return np.sort(tr_keys), np.sort(va_keys), np.sort(p_tr), np.sort(p_va)

# Build pools (augmented key_ids) per task
task_pools = {}
task_parent_splits = {}
for t in label_cols:
    tr_keys, va_keys, p_tr, p_va = task_parent_split_keys(t, test_size=0.2, seed=42)
    task_pools[t] = (tr_keys, va_keys)
    task_parent_splits[t] = (p_tr, p_va)

for t in label_cols:
    tr_keys, va_keys = task_pools[t]
    p_tr, p_va = task_parent_splits[t]
    print(f"{t:>7} → parents train={len(p_tr):5d} val={len(p_va):5d} | aug rows train={len(tr_keys):6d} val={len(va_keys):6d}")

     Tg → parents train=  408 val=  103 | aug rows train=  4080 val=  1030
    FFV → parents train= 5624 val= 1406 | aug rows train= 56240 val= 14060
     Tc → parents train=  589 val=  148 | aug rows train=  5890 val=  1480
Density → parents train=  490 val=  123 | aug rows train=  4900 val=  1230
     Rg → parents train=  491 val=  123 | aug rows train=  4910 val=  1230


In [None]:
# -------------------- KEY_SIZES (key_id -> parent n_atoms_2d) and simple finite samplers --------------------
parent_meta_path = TRAIN_LMDB + ".parent_meta.tsv"
if not os.path.exists(parent_meta_path):
    raise FileNotFoundError(f"Missing {parent_meta_path}. Rebuild LMDB with parent_meta.tsv enabled.")

parent_meta = pd.read_csv(parent_meta_path, sep="\t") # cols: parent_id, n_atoms_2d, star_count, replacement_Z
parent_meta["parent_id"] = parent_meta["parent_id"].astype(np.int64)
parent_meta = parent_meta.drop_duplicates(subset=["parent_id"])

key_sizes_df = pmap.merge(parent_meta[["parent_id","n_atoms_2d"]], on="parent_id", how="left")

if key_sizes_df["n_atoms_2d"].isna().any():
    med = int(key_sizes_df["n_atoms_2d"].median())
    key_sizes_df["n_atoms_2d"] = key_sizes_df["n_atoms_2d"].fillna(med)

KEY_SIZES: Dict[int,int] = dict(
    zip(key_sizes_df["key_id"].astype(np.int64).tolist(),
        key_sizes_df["n_atoms_2d"].astype(int).tolist())
)

class EmptyBatchSampler(Sampler):
    def __iter__(self):
        return iter(())           
    def __len__(self):
        return 0

class TokenBucketBatchSampler(Sampler):
    """
    Precompute a finite list of batches from (keys, sizes) under token/quadratic/max_batch constraints.
    __iter__ just yields those batches. __len__ returns exact count.
    """
    def __init__(self, keys, sizes_dict: Dict[int,int], *, max_tokens: int, max_quadratic: int, max_batch_size: int, shuffle: bool, seed: int, bins: int = 8):
        self.keys = np.asarray(keys, dtype=np.int64)
        self.sizes_dict = sizes_dict
        self.max_tokens = int(max_tokens)
        self.max_quadratic = int(max_quadratic)
        self.max_batch_size = int(max_batch_size)
        self.shuffle = bool(shuffle)
        self.seed = int(seed)
        self.bins = int(max(1, bins))
        self._batches = self._pack_once() # precompute finite list

    def __len__(self):
        return len(self._batches)

    def __iter__(self):
        return iter(self._batches)

    def _pack_once(self):
        if self.keys.size == 0:
            return []

        # materialize (key,size), guard size>=1
        pairs = [(int(k), max(1, int(self.sizes_dict.get(int(k), 1)))) for k in self.keys]

        # shuffle or deterministic order
        rng = np.random.default_rng(self.seed)
        if self.shuffle:
            rng.shuffle(pairs)
        else:
            pairs.sort(key=lambda t: (t[1], t[0]))

        sizes = np.array([s for _, s in pairs], dtype=np.int32)
        # Bins by size to reduce padding 
        B = int(min(self.bins, max(1, sizes.size)))
        try:
            qs = np.quantile(sizes, np.linspace(0, 1, B + 1)) if sizes.size > 1 else np.array([sizes[0], sizes[0]])
        except Exception:
            qs = np.array([sizes.min(), sizes.max()])

        bins = [[] for _ in range(B)]
        for i, (k, s) in enumerate(pairs):
            b = int(np.searchsorted(qs, s, side="right")) - 1
            b = max(0, min(B - 1, b))
            bins[b].append((k, s))

        if self.shuffle:
            for b in range(B):
                rng.shuffle(bins[b])

        batches = []
        # Round robin draw from bins so each item is used once...finite
        progress = True
        while progress:
            progress = False
            for b in range(B):
                if not bins[b]:
                    continue
                progress = True
                cur, cur_tokens, cur_quad = [], 0, 0
                while bins[b]:
                    k, s = bins[b][0]
                    next_len = len(cur) + 1
                    next_tokens = cur_tokens + s
                    next_quad = cur_quad + s * s
                    if (next_len <= self.max_batch_size
                        and next_tokens <= self.max_tokens
                        and next_quad  <= self.max_quadratic):
                        cur.append(k)
                        cur_tokens = next_tokens
                        cur_quad = next_quad
                        bins[b].pop(0)
                    else:
                        break
                if cur:
                    batches.append(cur)

        if not batches: # Safety 
            batches = [[int(k)] for (k, _) in pairs]
        return batches

class ChainBatchSamplers(Sampler):
    """Concatenate multiple precomputed batch samplers into one finite sequence."""
    def __init__(self, samplers, *, shuffle_order=False, seed=0):
        self.samplers = list(samplers)
        self.shuffle_order = bool(shuffle_order)
        self.seed = int(seed)
        # precompute concatenation so __len__ is exact
        order = list(range(len(self.samplers)))
        if self.shuffle_order:
            rnd = random.Random(self.seed)
            rnd.shuffle(order)
        self._batches = []
        for k in order:
            s = self.samplers[k]
            if len(s) == 0:
                continue
            self._batches.extend(list(s))

    def __len__(self):
        return len(self._batches)

    def __iter__(self):
        return iter(self._batches)

def build_bucket_sampler_for_keys(
    keys: np.ndarray, *,
    max_tokens: int,
    max_quadratic: Optional[int],
    max_batch_size: int,
    shuffle: bool,
    seed: int,
    ) -> Sampler:
    keys = np.asarray(keys, dtype=np.int64)

    if keys.size == 0:
        return EmptyBatchSampler()
    
    return TokenBucketBatchSampler(
        keys, 
        KEY_SIZES,
        max_tokens=max_tokens,
        max_quadratic=max_quadratic if max_quadratic is not None else 1_200_000,
        max_batch_size=max_batch_size,
        shuffle=shuffle, 
        seed=seed, 
        bins=8
    )

In [None]:
def _get_rdkit_feats_from_record(rec):
    arr = getattr(rec, "rdkit_feats", None)
    if arr is None:
        return torch.zeros(1, 15, dtype=torch.float32) # keep (1, D)
    v = torch.as_tensor(np.asarray(arr, np.float32).reshape(1, -1), dtype=torch.float32)
    return v # (1, D)

class LMDBtoPyGSingleTask(Dataset):
    def __init__(self, ids, lmdb_path, target_index=None, *, use_mixed_edges: bool = True, include_extra_atom_feats: bool = True):
        self.ids = np.asarray(ids, dtype=np.int64)
        self.base = LMDBDataset(self.ids, lmdb_path)
        self.t = target_index
        self.use_mixed_edges = use_mixed_edges
        self.include_extra_atom_feats = include_extra_atom_feats

    def __len__(self): return len(self.base)

    def __getitem__(self, idx):
        rec = self.base[idx]
        x = torch.as_tensor(rec.x, dtype=torch.long)
        ei = torch.as_tensor(rec.edge_index, dtype=torch.long)
        ea = torch.as_tensor(rec.edge_attr)

        # Mixed edges: 3 categorical + 32 RBF. Categorical-only if disabled
        edge_attr = ea.to(torch.float32) if self.use_mixed_edges else ea[:, :3].to(torch.long)

        d = Data(x=x, edge_index=ei, edge_attr=edge_attr, rdkit_feats=_get_rdkit_feats_from_record(rec)) # (1, D)

        if hasattr(rec, "pos"):
                d.pos  = torch.as_tensor(rec.pos, dtype=torch.float32)
        if self.include_extra_atom_feats and hasattr(rec, "extra_atom_feats"):
             d.extra_atom_feats = torch.as_tensor(rec.extra_atom_feats, dtype=torch.float32)
        if hasattr(rec, "has_xyz"):
             d.has_xyz = torch.as_tensor(rec.has_xyz, dtype=torch.float32) # (1,)
        if hasattr(rec, "dist"):
             d.hops = torch.as_tensor(rec.dist, dtype=torch.long).unsqueeze(0) # (1,L,L)

        if (self.t is not None) and hasattr(rec, "y"):
            yv = torch.as_tensor(rec.y, dtype=torch.float32).view(-1)
            if self.t < yv.numel(): 
                 d.y = yv[self.t:self.t+1] # (1,)
        return d

In [None]:
# -------------------- Token-bucket loaders with key->index mapping --------------------
WHALE_CUTOFF = 86 # p99 from EDA

class MapKeyBatchesToIndexBatches(Sampler):
    """Wrap a batch-sampler that yields key_ids; convert to dataset indices."""
    def __init__(self, key_batch_sampler, id2pos):
        self.key_batch_sampler = key_batch_sampler # yields lists of key_ids
        self.id2pos = id2pos # {key_id: index_in_dataset}
        # pre-map once for speed and safety
        self._batches = [[self.id2pos[int(k)] for k in batch] for batch in key_batch_sampler]

    def __len__(self):
        return len(self._batches)

    def __iter__(self):
        return iter(self._batches)

def make_loaders_for_task_from_pools(
        task, 
        task_pools,
        *,
        normal_max_tokens=9000,
        normal_max_quadratic=1080000,
        whale_max_tokens=2500,
        whale_max_quadratic=400000,
        max_batch_size=1024,
        use_mixed_edges=True,
        include_extra_atom_feats=True,
        num_workers=0,
        pin_memory=False,
        debug_single_process=True,
        whale_cutoff=WHALE_CUTOFF,
        ):
    assert 'KEY_SIZES' in globals(), "KEY_SIZES dict must be built first"

    t = task2idx[task]
    tr_keys, va_keys = task_pools[task]
    if len(tr_keys) == 0 or len(va_keys) == 0:
        raise ValueError(f"Empty pools for {task}. Check splits.")

    # Datasets stay in key-id order
    tr_ds = LMDBtoPyGSingleTask(tr_keys, TRAIN_LMDB, target_index=t, use_mixed_edges=use_mixed_edges, include_extra_atom_feats=include_extra_atom_feats)
    va_ds = LMDBtoPyGSingleTask(va_keys, TRAIN_LMDB, target_index=t, use_mixed_edges=use_mixed_edges, include_extra_atom_feats=include_extra_atom_feats)

    # Map key_id -> position within each dataset
    tr_id2pos = {int(k): i for i, k in enumerate(tr_keys)}
    va_id2pos = {int(k): i for i, k in enumerate(va_keys)}

    # Split ids by size using KEY_SIZES
    def split_keys(keys):
        small, whales = [], []
        for k in keys:
            if KEY_SIZES.get(int(k), 1) > whale_cutoff:
                whales.append(int(k))
            else:
                small.append(int(k))
        return np.array(small, dtype=np.int64), np.array(whales, dtype=np.int64)

    tr_small, tr_whales = split_keys(tr_keys)
    va_small, va_whales = split_keys(va_keys)

    # Build key-id samplers (finite)
    tr_small_s = build_bucket_sampler_for_keys(tr_small, max_tokens=normal_max_tokens, max_quadratic=normal_max_quadratic, max_batch_size=max_batch_size, shuffle=True, seed=42)
    tr_whale_s = build_bucket_sampler_for_keys(tr_whales, max_tokens=whale_max_tokens, max_quadratic=whale_max_quadratic, max_batch_size=max_batch_size, shuffle=True, seed=43)
    tr_keys_sampler = ChainBatchSamplers([tr_small_s, tr_whale_s], shuffle_order=True, seed=123)

    va_small_s = build_bucket_sampler_for_keys(va_small, max_tokens=normal_max_tokens, max_quadratic=normal_max_quadratic, max_batch_size=max_batch_size, shuffle=False, seed=123)
    va_whale_s = build_bucket_sampler_for_keys(va_whales, max_tokens=whale_max_tokens, max_quadratic=whale_max_quadratic, max_batch_size=max_batch_size, shuffle=False, seed=123)
    va_keys_sampler = ChainBatchSamplers([va_small_s, va_whale_s], shuffle_order=False, seed=0)

    # Wrap: key-id batches -> dataset index batches
    tr_sampler = MapKeyBatchesToIndexBatches(tr_keys_sampler, tr_id2pos)
    va_sampler = MapKeyBatchesToIndexBatches(va_keys_sampler, va_id2pos)

    # DataLoaders
    if debug_single_process:
        tr_loader = GeoDataLoader(tr_ds, batch_sampler=tr_sampler, num_workers=0, pin_memory=False)
        va_loader = GeoDataLoader(va_ds, batch_sampler=va_sampler, num_workers=0, pin_memory=False)
    else:
        tr_loader = GeoDataLoader(
            tr_ds, 
            batch_sampler=tr_sampler,
            num_workers=num_workers, 
            pin_memory=pin_memory,
            persistent_workers=(num_workers > 0),
            **({} if num_workers == 0 else dict(prefetch_factor=2))
        )
        va_loader = GeoDataLoader(
            va_ds, 
            batch_sampler=va_sampler,
            num_workers=num_workers, 
            pin_memory=pin_memory,
            persistent_workers=(num_workers > 0),
            **({} if num_workers == 0 else dict(prefetch_factor=2))
        )
    return tr_loader, va_loader

train_loader_tg,  val_loader_tg  = make_loaders_for_task_from_pools("Tg", task_pools, debug_single_process=True)
train_loader_den, val_loader_den = make_loaders_for_task_from_pools("Density", task_pools, debug_single_process=True)
train_loader_rg,  val_loader_rg  = make_loaders_for_task_from_pools("Rg", task_pools, debug_single_process=True)
train_loader_ffv, val_loader_ffv = make_loaders_for_task_from_pools("FFV", task_pools, debug_single_process=True)
train_loader_tc,  val_loader_tc  = make_loaders_for_task_from_pools("Tc", task_pools, debug_single_process=True)

## Step 5: Define the Model


In [None]:
def _batch_len_stats(b):
    # counts nodes per-graph from PyG's batch vector
    sizes = torch.bincount(b.batch, minlength=b.num_graphs)
    s = sizes.to(torch.int32).cpu().numpy()
    if s.size == 0:
        return "L: n=0"
    tokens = int(s.sum())
    quad = int((sizes.to(torch.int64)**2).sum().item())
    q50, q90, q95, q99 = np.quantile(s, [0.5, 0.9, 0.95, 0.99])
    return (f"L: n={len(s)} tokens={tokens} sumL2={quad}"
            f"p50={int(q50)} p90={int(q90)} p95={int(q95)} p99={int(q99)} max={int(s.max())}")


def free_cuda_memory(tag: str = ""):
    try:
        torch.cuda.synchronize()
    except Exception:
        pass
    torch.cuda.empty_cache()
    gc.collect()
    if tag:
        try:
            alloc = torch.cuda.memory_allocated() / (1024**2)
            reserv = torch.cuda.memory_reserved() / (1024**2)
            print(f"[mem:{tag}] allocated={alloc:.1f}MB reserved={reserv:.1f}MB")
        except Exception:
            pass

def reset_cuda_stats():
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

def train_hybrid_gnn_sota(
    model: nn.Module,
    train_loader,
    val_loader,
    *,
    lr: float = 5e-4,
    optimizer: str = "AdamW",
    weight_decay: float = 1e-5,
    epochs: int = 120,
    warmup_epochs: int = 5,
    patience: int = 15,
    clip_norm: float = 1.0,
    amp: bool = True,
    loss_name: str = "mse",   # "mse" or "huber"
    save_dir: str = "saved_models/gnn",
    tag: str = "model_sota",
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
    os.makedirs(save_dir, exist_ok=True)
    model = model.to(device)

    # optimizer
    opt_name = optimizer.lower()
    if opt_name == "rmsprop":
        opt = RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.0)
    else:
        opt = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # cosine schedule w/ warmup
    def lr_factor(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / max(1, warmup_epochs)
        t = (epoch - warmup_epochs) / max(1, (epochs - warmup_epochs))
        return 0.5 * (1 + math.cos(math.pi * t))
    scaler = GradScaler("cuda", enabled=amp)

    def loss_fn(pred, target):
        if loss_name.lower() == "huber":
            return F.huber_loss(pred, target, delta=1.0)
        return F.mse_loss(pred, target)

    @torch.no_grad()
    def eval_once(loader):
        model.eval()
        preds, trues = [], []
        with torch.inference_mode():
            for i, b in enumerate(loader, 1):
                # print(_batch_len_stats(b))  # keep if you want
                b = b.to(device, non_blocking=True)
                p = model(b)
                preds.append(p.cpu())
                trues.append(b.y.view(-1,1).cpu())
                del p, b
        preds = torch.cat(preds).numpy(); trues = torch.cat(trues).numpy()
        mae = np.mean(np.abs(preds - trues))
        rmse = float(np.sqrt(np.mean((preds - trues)**2)))
        r2 = float(1 - np.sum((preds - trues)**2) / np.sum((trues - trues.mean())**2))
        return mae, rmse, r2

    best_mae = float("inf")
    best = None
    best_path = os.path.join(save_dir, f"{tag}.pt")
    bad = 0 

    for ep in range(1, epochs+1):
        for g in opt.param_groups:
            g["lr"] = lr * lr_factor(ep-1)

        model.train()
        total, count = 0.0, 0
        for step, b in enumerate(train_loader, start=1):
            b = b.to(device, non_blocking=True)
            if ep == 1 and step % 50 == 1:
            # if (ep <= 2) or (step % 50 == 1):
                print(_batch_len_stats(b))
            with autocast("cuda", enabled=amp):
                pred = model(b)
                loss = loss_fn(pred, b.y.view(-1,1))

            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            if clip_norm is not None:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
            scaler.step(opt); scaler.update()

            total += loss.item() * b.num_graphs
            count += b.num_graphs
            del pred, loss, b  

        free_cuda_memory(tag=f"after_epoch_{ep}")

        tr_mse = total / max(1, count)
        mae, rmse, r2 = eval_once(val_loader)
        print(f"Epoch {ep:03d} | tr_MSE {tr_mse:.5f} | val_MAE {mae:.5f} | val_RMSE {rmse:.5f} | R2 {r2:.4f}")

        if mae < best_mae - 1e-6:
            best_mae = mae
            best = deepcopy(model.state_dict())
            torch.save(best, best_path)
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                print("Early stopping.")
                break


    if best is not None:
        model.load_state_dict(best)
    else:
        model.load_state_dict(torch.load(best_path, map_location=device))

    final_mae, final_rmse, final_r2 = eval_once(val_loader)
    print(f"[{tag}] Best Val — MAE {final_mae:.6f} | RMSE {final_rmse:.6f} | R2 {final_r2:.4f}")
    return model, best_path, {"MAE": final_mae, "RMSE": final_rmse, "R2": final_r2}

In [None]:
def _act(name: str):
    name = (name or "relu").lower()
    if name == "gelu": 
        return nn.GELU()
    if name in ("silu", "swish"): 
        return nn.SiLU()
    return nn.ReLU()


class AttnBiasFull(nn.Module):
    """
    Produces additive per-head attention bias of shape (B, H, L0, L0)
    from geometry (xyz), adjacency, SPD buckets, and categorical edge types.

    Accepts both old arg names (use_geo/use_adj_const/spd_max/rbf_K) and
    new ones (use_geo_bias/use_adj_bias/spd_buckets/rbf_k/edge_cats).
    """
    def __init__(
        self,
        n_heads: int,
        *,
        # older names I used
        use_geo: bool = None, 
        use_adj_const: bool = None, 
        use_spd: bool = True,
        spd_max: int = None, 
        rbf_K: int = None,
        # new names
        use_geo_bias: bool = None, 
        use_adj_bias: bool = None,
        spd_buckets: int = None, 
        rbf_k: int = None,
        edge_cats: tuple = (5, 6, 2),
        use_edge_bias: bool = True,
        # shared
        rbf_beta: float = 5.0, 
        activation: str = "relu",
        edge_cont_dim: int = 32,  # (kept for compatibility but not currently being used here)
        use_headnorm: bool = True,
        bound_scale: float = 0.1, # tanh scale for gentle bounding
    ):
        super().__init__()
        self.n_heads = int(n_heads)
        self.bound_scale = float(bound_scale)
        self.use_headnorm = bool(use_headnorm)

        def pick(*vals, default):
            for v in vals:
                if v is not None:
                    return v
            return default

        self.use_geo = bool(pick(use_geo, use_geo_bias, default=True))
        self.use_adj_const = bool(pick(use_adj_const, use_adj_bias, default=True))

        # SPD: if spd_buckets given then we will use exactly that; otherwise, we will use spd_max + 2 (0..spd_max + catch-all)
        if spd_buckets is not None:
            self.spd_buckets = int(spd_buckets)
        else:
            smax = 5 if spd_max is None else int(spd_max)
            self.spd_buckets = smax + 2 # 0..smax + 1(>=)

        K = int(pick(rbf_K, rbf_k, default=16))
        self.rbf_beta = float(rbf_beta)

        #  geometry -> per-head bias 
        if self.use_geo:
            centers = torch.linspace(0.0, 10.0, K)
            self.register_buffer("centers", centers, persistent=False)
            self.geo_mlp = nn.Sequential(
                nn.Linear(K, self.n_heads), # simple per-head projection
            )

        if self.use_adj_const:
            self.adj_bias = nn.Parameter(torch.zeros(self.n_heads))

        # SPD buckets -> per-head bias 
        self.use_spd = bool(use_spd)
        if self.use_spd:
            self.spd_emb = nn.Embedding(self.spd_buckets, self.n_heads)

        # edge categorical bias (configurable widths)
        t, s, c = edge_cats
        self.use_edge_bias = bool(use_edge_bias)
        if self.use_edge_bias:
            self.edge_emb0 = nn.Embedding(int(t), self.n_heads)
            self.edge_emb1 = nn.Embedding(int(s), self.n_heads)
            self.edge_emb2 = nn.Embedding(int(c), self.n_heads)
        else:
            self.edge_emb0 = self.edge_emb1 = self.edge_emb2 = None

        # per-component learnable scalers 
        self.alpha_geo = nn.Parameter(torch.tensor(0.2))
        self.alpha_spd = nn.Parameter(torch.tensor(0.2))
        self.alpha_adj = nn.Parameter(torch.tensor(0.2))
        self.alpha_edge = nn.Parameter(torch.tensor(0.2))

        # simple head-wise LayerNorms (normalize across H)
        if self.use_headnorm:
            self.ln_geo = nn.LayerNorm(self.n_heads)
            self.ln_spd = nn.LayerNorm(self.n_heads)
            self.ln_edge = nn.LayerNorm(self.n_heads)

    # ------------------------------------------ helpers ------------------------------------------
    def _apply_ln_heads(self, t: torch.Tensor, ln: nn.LayerNorm) -> torch.Tensor:
        """Apply LayerNorm across heads for a (B,H,L,L) tensor."""
        # (B,H,L,L) -> (B,L,L,H) -> LN(H) -> (B,H,L,L)
        t = t.permute(0, 2, 3, 1)
        t = ln(t)
        t = t.permute(0, 3, 1, 2).contiguous()
        return t

    def _bound(self, t: torch.Tensor) -> torch.Tensor:
        """Bound magnitudes to avoid dominating softmax; keeps gradients smooth."""
        return self.bound_scale * torch.tanh(t)

    @torch.no_grad()
    def _spd_bias(self, hops: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
        """
        hops: (B, MAX_NODES, MAX_NODES) or (B, L0, L0) shortest-path distances (uint8/long)
        valid_mask: (B, L0, L0) bool, True where both tokens are real (not PAD)
        returns: (B, H, L0, L0) additive per-head bias
        """
        if hops.dim() == 2: # (L,L) -> (1,L,L)
            hops = hops.unsqueeze(0)

        B, L0, _ = valid_mask.shape

        # align SPD to current L0 (top-left block)
        if hops.size(1) != L0 or hops.size(2) != L0:
            hops = hops[:, :L0, :L0]

        # bucketize SPD: last bucket = catch-all (>= last)
        last = self.spd_buckets - 1
        raw = hops.to(valid_mask.device).long().clamp_min_(0)
        catch_all = raw >= last
        raw = raw.clamp_max(last - 1)
        bucket = torch.where(catch_all, raw.new_full(raw.shape, last), raw)

        # wipe invalid pairs
        bucket = torch.where(valid_mask, bucket, torch.zeros_like(bucket))

        emb = self.spd_emb(bucket) # (B, L0, L0, H)
        return emb.permute(0, 3, 1, 2).contiguous() # (B, H, L0, L0)

    def _edge_bias(self, edge_index, edge_attr, batch, L0, ptr=None) -> torch.Tensor:
        """
        Per-head additive bias from categorical bond attributes.
        Returns: (B, H, L0, L0)
        """
        u, v = edge_index
        be   = batch[u] # graph id per edge

        if ptr is None:
            B = int(batch.max().item()) + 1
            counts = torch.bincount(batch, minlength=B)
            ptr = torch.zeros(B + 1, dtype=torch.long, device=batch.device)
            ptr[1:] = torch.cumsum(counts, dim=0)
        B = int(ptr.numel() - 1)

        start = ptr[be]
        u_loc = (u - start).long()
        v_loc = (v - start).long()

        cat = edge_attr[:, :3].long()
        eh  = ( self.edge_emb0(cat[:, 0])
              + self.edge_emb1(cat[:, 1])
              + self.edge_emb2(cat[:, 2]) )  # (E,H)

        H = self.n_heads
        eb = torch.zeros((B, H, L0, L0), device=edge_attr.device, dtype=torch.float32)
        for b in range(B):
            m = (be == b)
            if not torch.any(m):
                continue
            eb[b, :, u_loc[m], v_loc[m]] += eh[m].T
        return eb

    # ------------------------------------------ forward ------------------------------------------
    def forward(self, pos, edge_index, edge_attr, batch, key_padding_mask, hops=None, ptr=None):
        """
        Returns (B, H, L0, L0) additive bias. PAD rows/cols are filled with large negative.
        """
        A = to_dense_adj(edge_index, batch=batch).squeeze(1) # (B,L0,L0)
        B, L0, _ = A.shape
        H = self.n_heads
        device = A.device

        valid = ~key_padding_mask# (B,L0)
        valid2d = valid.unsqueeze(2) & valid.unsqueeze(1) # (B,L0,L0)

        # geometry
        if self.use_geo and (pos is not None):
            pad_pos, _ = to_dense_batch(pos, batch) # (B,L0,3)
            diff = pad_pos.unsqueeze(2) - pad_pos.unsqueeze(1) # (B,L0,L0,3)
            dist = torch.sqrt(torch.clamp((diff**2).sum(-1), min=0.0)) # (B,L0,L0)
            centers = self.centers.to(dist.device)
            rbf = torch.exp(-self.rbf_beta * (dist.unsqueeze(-1) - centers)**2)
            geo = self.geo_mlp(rbf).permute(0, 3, 1, 2).contiguous() # (B,H,L0,L0)
        else:
            geo = torch.zeros((B, H, L0, L0), device=device)

        # adjacency constant per head
        if self.use_adj_const:
            adj = A.unsqueeze(1) * self.adj_bias.view(1, H, 1, 1) # (B,H,L0,L0)
        else:
            adj = torch.zeros_like(geo)

        # SPD
        if self.use_spd and (hops is not None):
            spd = self._spd_bias(hops, valid2d) # (B,H,L0,L0)
        else:
            spd = torch.zeros_like(geo)

        # edge categorical
        if self.use_edge_bias and (edge_attr is not None):
            edg = self._edge_bias(edge_index, edge_attr, batch, L0, ptr) # (B,H,L0,L0)
        else:
            edg = torch.zeros_like(geo)

        # normalize and bound each component, then scale 
        if self.use_headnorm:
            if self.use_geo:
                geo = self._apply_ln_heads(geo,  self.ln_geo)
            if self.use_spd:  
                spd = self._apply_ln_heads(spd,  self.ln_spd)
            if self.use_edge_bias: 
                edg = self._apply_ln_heads(edg, self.ln_edge)

        # gently bound to keep attention stable
        if self.use_geo:       
            geo = self._bound(geo)
        if self.use_spd:       
            spd = self._bound(spd)
        if self.use_edge_bias: 
            edg = self._bound(edg)
        # typically don't bound adj because it’s already a small learned scalar per head

        bias = (self.alpha_geo  * geo + self.alpha_spd  * spd + self.alpha_adj  * adj + self.alpha_edge * edg)

        # mask PAD rows/cols. keep diagonal 0 for valid tokens
        pad = key_padding_mask
        big_neg = torch.tensor(-1e4, device=bias.device, dtype=bias.dtype)
        bias = bias.masked_fill(pad.view(B, 1, L0, 1), big_neg)
        bias = bias.masked_fill(pad.view(B, 1, 1, L0), big_neg)
        I = torch.eye(L0, device=device, dtype=torch.bool).view(1, 1, L0, L0)
        bias = torch.where(I, bias.new_zeros(()), bias)

        return bias

In [None]:
class GINEBlock(nn.Module):
    def __init__(self, dim, activation="silu", dropout=0.1):
        super().__init__()
        act = _act(activation)
        self.norm1 = nn.LayerNorm(dim)
        self.conv = GINEConv(
            nn.Sequential(
                nn.Linear(dim, dim), act, 
                nn.Linear(dim, dim)
                ))
        self.drop1 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, 2*dim), act, 
            nn.Dropout(dropout), 
            nn.Linear(2*dim, dim)
            )
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x, edge_index, edge_emb):
        h = self.conv(self.norm1(x), edge_index, edge_emb)
        x = x + self.drop1(h)
        x = x + self.drop2(self.ffn(self.norm2(x)))
        return x

class EdgeEncoderMixed(nn.Module):
    def __init__(self, emb_dim: int, cont_dim: int = 32, activation="silu"):
        super().__init__()
        act = _act(activation)
        self.emb0 = nn.Embedding(5, emb_dim)
        self.emb1 = nn.Embedding(6, emb_dim)
        self.emb2 = nn.Embedding(2, emb_dim)
        self.mlp_cont = nn.Sequential(
            nn.Linear(cont_dim, emb_dim), act,
            nn.Linear(emb_dim, emb_dim),
            nn.LayerNorm(emb_dim),      
        )

    def forward(self, edge_attr):
        cat  = edge_attr[:, :3].long()
        cont = edge_attr[:, 3:].float()
        e_cat  = self.emb0(cat[:,0]) + self.emb1(cat[:,1]) + self.emb2(cat[:,2])
        e_cont = self.mlp_cont(cont)
        return e_cat + 0.5 * e_cont # gentle scale on cont branch

In [None]:
class GraphTransformerGPS(nn.Module):
    def __init__(
        self,
        *,
        d_model: int = 256,
        nhead: int = 8,
        nlayers: int = 6,
        dropout: float = 0.2,
        drop_path: float = 0.0, # (kept for future work)
        activation: str = "silu",
        rdkit_dim: int = 15,
        use_extra_atom_feats: bool = True,
        extra_atom_dim: int = 5,
        # local GNN (GPS) settings
        local_layers: int = 2,
        use_mixed_edges: bool = True,
        cont_dim: int = 32,
        # bias knobs
        use_geo_bias: bool = True,
        use_spd_bias: bool = True,
        spd_max: int = 5,
        use_adj_const: bool = True,
        use_edge_bias: bool = True,
        # readout
        use_cls: bool = True,
        use_has_xyz: bool = True,
        head_hidden: int = 512,
    ):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.use_cls = use_cls
        self.use_has_xyz = use_has_xyz
        self.use_extra_atom_feats = use_extra_atom_feats
        self.bias_builder = AttnBiasFull(
            n_heads=nhead,
            rbf_k=32,
            rbf_beta=5.0,
            use_geo_bias=use_geo_bias,          
            use_adj_bias=use_adj_const,        
            use_spd=use_spd_bias,             
            spd_buckets=(spd_max + 1), # was spd_max; +1 gives the ">= spd_max" bucket
            use_edge_bias=use_edge_bias,
            edge_cats=(5, 6, 2),
            activation=activation,
        )


        act = _act(activation)

        # encoders
        self.atom_enc = AtomEncoder(emb_dim=d_model)
        if use_extra_atom_feats:
            self.extra_proj = nn.Sequential(
                nn.Linear(extra_atom_dim, d_model), act, 
                nn.Linear(d_model, d_model)
                )
            self.extra_gate = nn.Sequential(
                nn.Linear(2*d_model, d_model), act
                )

        # local GNN stack
        self.use_mixed_edges = use_mixed_edges
        if use_mixed_edges:
            self.edge_enc = EdgeEncoderMixed(d_model, cont_dim=cont_dim, activation=activation)
        else:
            self.edge_enc = BondEncoder(emb_dim=d_model)
        self.local_blocks = nn.ModuleList([GINEBlock(d_model, activation=activation, dropout=dropout) for _ in range(local_layers)])

        # transformer stack (PyTorch encoder)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=4*d_model,
            dropout=dropout, 
            activation=activation, 
            batch_first=True, 
            norm_first=True
            )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=nlayers, enable_nested_tensor=False)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls_token, std=0.02)

        # readout: concat mean + max + (optional) CLS + attention pool
        self.gate_pool = nn.Sequential(
            nn.Linear(d_model, d_model//2), act, 
            nn.Linear(d_model//2, 1)
            )
        # features: mean(d), max(d), attn(d) = 3d, (+cls d) optional, + rdkit, + has_xyz
        pooled_dim = 3*d_model + (d_model if use_cls else 0)
        head_in = pooled_dim + rdkit_dim + (1 if use_has_xyz else 0)

        self.head = nn.Sequential(
            nn.LayerNorm(head_in),
            nn.Linear(head_in, head_hidden), act, 
            nn.Dropout(dropout),
            nn.Linear(head_hidden, head_hidden//2), act, 
            nn.Dropout(dropout),
            nn.Linear(head_hidden//2, 1),
        )

    def forward(self, data):
        # 1. atom encoding  and optional per-atom extras
        x = self.atom_enc(data.x) # (N,D)
        if getattr(self, "use_extra_atom_feats", False) and hasattr(data, "extra_atom_feats"):
            xa = self.extra_proj(data.extra_atom_feats.float()) # (N,D)
            x  = self.extra_gate(torch.cat([x, xa], dim=1)) # (N,D)

        # 2. local GNN over sparse graph
        e = self.edge_enc(data.edge_attr)
        for blk in self.local_blocks:
            x = blk(x, data.edge_index, e) # (N,D)

        # 3. pack to dense (no CLS yet)
        x_pad, valid = to_dense_batch(x, data.batch) # (B,L0,D)
        B, L0, D = x_pad.shape
        key_padding = ~valid # (B,L0) True == PAD

        # 4, head-wise attention bias on L0 tokens (B,H,L0,L0), pre-CLS
        # AttnBiasFull supports SPD, geometry, adjacency, edges
        hops = getattr(data, "hops", None) # (B,MAX_NODES,MAX_NODES) or None
        ptr = getattr(data, "ptr", None)
        attn_bias = self.bias_builder(
            pos=(data.pos if hasattr(data, "pos") else None),
            edge_index=data.edge_index,
            edge_attr=(data.edge_attr if hasattr(data, "edge_attr") else None),
            batch=data.batch,
            key_padding_mask=key_padding, # (B,L0), True=PAD
            hops=getattr(data, "hops", None),
            ptr=ptr
        )  # (B,H,L0,L0)                                              

        # 5. finalize bias (mask PAD rows/cols, keep diagonal 0), then optionally append CLS
        B, H, L = attn_bias.shape[0], attn_bias.shape[1], attn_bias.shape[-1]
        pad = key_padding # (B,L)
        huge = attn_bias.new_tensor(-1e4)

        # rows FROM PAD, cols TO PAD
        attn_bias = attn_bias.masked_fill(pad.view(B, 1, L, 1), huge)
        attn_bias = attn_bias.masked_fill(pad.view(B, 1, 1, L), huge)

        # keep diagonal = 0 on valid tokens
        I = torch.eye(L, device=attn_bias.device, dtype=torch.bool).view(1, 1, L, L)
        attn_bias = torch.where(I, attn_bias.new_zeros(()), attn_bias)

        # (optional) append CLS token at the end
        if getattr(self, "use_cls", False):
            # append CLS embedding
            cls = self.cls_token.expand(B, 1, D)# (B,1,D)
            x_pad = torch.cat([x_pad, cls], dim=1) # (B,L+1,D)

            # extend key_padding: CLS is always valid (False)
            key_padding = torch.cat(
                [key_padding, torch.zeros(B, 1, dtype=torch.bool, device=x_pad.device)],
                dim=1
            ) # (B,L+1)

            # pad bias by one row/col with zeros for CLS -> (B,H,L+1,L+1)
            attn_bias = F.pad(attn_bias, (0, 1, 0, 1), value=0.0)
            L = L + 1

        # 6. transformer encoder with 3D additive mask (B*H,L,L)
        attn_mask_3d = attn_bias.reshape(B * H, L, L).to(x_pad.dtype)
        h = self.encoder( # returns (B,L,D) when batch_first=True
            x_pad,
            mask=attn_mask_3d, # additive float mask 
        )

        # 7. pooling (mean + max + gated attention), plus optional CLS and then RDKit/has_xyz and head 
        # exclude CLS from token pools
        h_tok = h[:, :L0, :] # (B,L0,D)
        mask_f = valid.float()# (B,L0)

        mean = (h_tok * mask_f.unsqueeze(-1)).sum(1) / (mask_f.sum(1, keepdim=True) + 1e-8)  # (B,D)
        mmax, _ = (h_tok + (1.0 - mask_f.unsqueeze(-1)) * (-1e4)).max(dim=1) # (B,D)

        gate_logits = self.gate_pool(h_tok).squeeze(-1)# (B,L0)
        gate = torch.softmax(gate_logits.masked_fill(~valid, -1e4), dim=1)
        attn_pool = (h_tok * gate.unsqueeze(-1)).sum(1) # (B,D)

        parts = [mean, mmax, attn_pool]

        if getattr(self, "use_cls", False):
            parts.append(h[:, L-1, :]) # CLS vector (B,D)

        # RDKit globals
        rd = data.rdkit_feats.view(B, -1).float() # (B, rdkit_dim)
        parts.append(rd)

        # optional has_xyz scalar if present
        if getattr(self, "use_has_xyz", False) and hasattr(data, "has_xyz"):
            parts.append(data.has_xyz.view(B, 1).float())

        out = torch.cat(parts, dim=1)
        return self.head(out) # (B,1)

In [None]:
b = next(iter(train_loader_tg))
rd_dim = int(b.rdkit_feats.shape[-1])

model_tg = GraphTransformerGPS(
    d_model=256, nhead=8, nlayers=5, dropout=0.15,
    rdkit_dim=rd_dim, activation="gelu",
    use_extra_atom_feats=True, extra_atom_dim=5,
    local_layers=2, use_mixed_edges=True, cont_dim=32,
    use_geo_bias=False, use_spd_bias=False, spd_max=5,
    use_adj_const=False, use_edge_bias=False,
    use_cls=True, use_has_xyz=True, head_hidden=512
).to(b.x.device)

model_tg, ckpt_tg, met_tg = train_hybrid_gnn_sota(
    model_tg, train_loader_tg, val_loader_tg,
    lr=5e-4, optimizer="AdamW", weight_decay=1e-5,
    epochs=200, warmup_epochs=10, patience=20,
    clip_norm=1.0, amp=True, loss_name="mse",
    save_dir="saved_models/gt_tg_spd", tag="graphtransformer_tg_spd"
)
del model_tg, train_loader_tg, val_loader_tg
free_cuda_memory(tag="after_Tg")
reset_cuda_stats()

model_den = GraphTransformerGPS(
    d_model=256, nhead=8, nlayers=5, dropout=0.15,
    rdkit_dim=rd_dim, activation="gelu",
    use_extra_atom_feats=True, extra_atom_dim=5,
    local_layers=2, use_mixed_edges=True, cont_dim=32,
    use_geo_bias=False, use_spd_bias=False, spd_max=5,
    use_adj_const=False, use_edge_bias=False,
    use_cls=False, use_has_xyz=False, head_hidden=512
).to(b.x.device)

model_den, ckpt_den, met_den = train_hybrid_gnn_sota(
    model_den, train_loader_den, val_loader_den,
    lr=5e-4, optimizer="AdamW", weight_decay=1e-5,
    epochs=200, warmup_epochs=10, patience=20,
    clip_norm=1.0, amp=True, loss_name="mse",
    save_dir="saved_models/gt_den_spd", tag="graphtransformer_den_spd"
)
del model_den, train_loader_den, val_loader_den
free_cuda_memory(tag="after_Density")
reset_cuda_stats()

# Rg
model_rg = GraphTransformerGPS(
    d_model=256, nhead=8, nlayers=5, dropout=0.15,
    rdkit_dim=rd_dim, activation="gelu",
    use_extra_atom_feats=True, extra_atom_dim=5,
    local_layers=2, use_mixed_edges=True, cont_dim=32,
    use_geo_bias=True, use_spd_bias=True, spd_max=5,
    use_adj_const=False, use_edge_bias=False,
    use_cls=False, use_has_xyz=True, head_hidden=512
).to(b.x.device)

model_rg, ckpt_rg, met_rg = train_hybrid_gnn_sota(
    model_rg, train_loader_rg, val_loader_rg,
    lr=5e-4, optimizer="AdamW", weight_decay=1e-5,
    epochs=200, warmup_epochs=10, patience=20,
    clip_norm=1.0, amp=True, loss_name="mse",
    save_dir="saved_models/gt_rg_spd", tag="graphtransformer_rg_spd"
)
del model_rg, train_loader_rg, val_loader_rg
free_cuda_memory(tag="after_Rg")
reset_cuda_stats()
# Tc
model_tc = GraphTransformerGPS(
    d_model=256, nhead=8, nlayers=5, dropout=0.15,
    rdkit_dim=rd_dim, activation="gelu",
    use_extra_atom_feats=True, extra_atom_dim=5,
    local_layers=2, use_mixed_edges=True, cont_dim=32,
    use_geo_bias=True, use_spd_bias=False, spd_max=5,
    use_adj_const=False, use_edge_bias=False,
    use_cls=True, use_has_xyz=False, head_hidden=512
).to(b.x.device)

model_tc, ckpt_tc, met_tc = train_hybrid_gnn_sota(
    model_tc, train_loader_tc, val_loader_tc,
    lr=5e-4, optimizer="AdamW", weight_decay=1e-5,
    epochs=200, warmup_epochs=10, patience=20,
    clip_norm=1.0, amp=True, loss_name="mse",
    save_dir="saved_models/gt_tc_spd", tag="graphtransformer_tc_spd"
)
del model_tc, train_loader_tc, val_loader_tc
free_cuda_memory(tag="after_Tc")
reset_cuda_stats()

L: n=10 tokens=960 sumL2=92160 p50=96 p90=96 p95=96 p99=96 max=96


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


[mem:after_epoch_1] allocated=105.3MB reserved=174.0MB
Epoch 001 | tr_MSE 22771.91854 | val_MAE 100.23416 | val_RMSE 130.70927 | R2 -0.7897
[mem:after_epoch_2] allocated=127.2MB reserved=168.0MB
Epoch 002 | tr_MSE 22484.83189 | val_MAE 98.10220 | val_RMSE 128.20306 | R2 -0.7217
[mem:after_epoch_3] allocated=126.7MB reserved=184.0MB
Epoch 003 | tr_MSE 20913.60447 | val_MAE 90.39704 | val_RMSE 118.47663 | R2 -0.4704
[mem:after_epoch_4] allocated=127.4MB reserved=172.0MB
Epoch 004 | tr_MSE 16876.24435 | val_MAE 79.91265 | val_RMSE 100.89502 | R2 -0.0664
[mem:after_epoch_5] allocated=126.7MB reserved=186.0MB
Epoch 005 | tr_MSE 13248.30115 | val_MAE 83.10323 | val_RMSE 99.83572 | R2 -0.0441
[mem:after_epoch_6] allocated=126.7MB reserved=190.0MB
Epoch 006 | tr_MSE 13417.76134 | val_MAE 87.05917 | val_RMSE 103.53501 | R2 -0.1229
[mem:after_epoch_7] allocated=126.8MB reserved=188.0MB
Epoch 007 | tr_MSE 13460.40146 | val_MAE 68.70322 | val_RMSE 86.08450 | R2 0.2237
[mem:after_epoch_8] allocated

[graphtransformer_tg_spd] Best Val — MAE 48.938103 | RMSE 60.409924 | R2 0.6177
[graphtransformer_den_spd] Best Val — MAE 0.041639 | RMSE 0.080110 | R2 0.6345
[graphtransformer_rg_spd] Best Val — MAE 1.897893 | RMSE 2.820058 | R2 0.6387
[graphtransformer_tc_spd] Best Val — MAE 0.029127 | RMSE 0.043332 | R2 0.7801
[graphtransformer_ffv_spd] Best Val — MAE 0.007025 | RMSE 0.011341 | R2 0.8756

# Original LMDB
Minimal baseline:
[graphtransformer_tg_spd] Best Val — MAE 52.682880 | RMSE 66.310356 | R2 0.5394
[graphtransformer_den_spd] Best Val — MAE 0.033281 | RMSE 0.069495 | R2 0.7250
[graphtransformer_rg_spd] Best Val — MAE 2.274377 | RMSE 3.533569 | R2 0.4327
[graphtransformer_tc_spd] Best Val — MAE 0.027580 | RMSE 0.044943 | R2 0.7635
[graphtransformer_ffv_spd] Best Val — MAE 0.005713 | RMSE 0.008959 | R2 0.9223
use_edge_bias=True:
[graphtransformer_tg_spd] Best Val — MAE 59.157528 | RMSE 78.540993 | R2 0.3538
[graphtransformer_den_spd] Best Val — MAE 0.040733 | RMSE 0.075703 | R2 0.6736
[graphtransformer_rg_spd] Best Val — MAE 2.184664 | RMSE 3.215525 | R2 0.5302
[graphtransformer_tc_spd] Best Val — MAE 0.029891 | RMSE 0.045550 | R2 0.7570
[graphtransformer_ffv_spd] Best Val — MAE 0.014560 | RMSE 0.024105 | R2 0.4378
use_adj_const=True
[graphtransformer_tg_spd] Best Val — MAE 55.685497 | RMSE 71.508026 | R2 0.4644
[graphtransformer_den_spd] Best Val — MAE 0.043538 | RMSE 0.076304 | R2 0.6684
[graphtransformer_rg_spd] Best Val — MAE 2.194414 | RMSE 3.418880 | R2 0.4689
[graphtransformer_tc_spd] Best Val — MAE 0.030531 | RMSE 0.047150 | R2 0.7397
[graphtransformer_ffv_spd] Best Val — MAE 0.015711 | RMSE 0.025118 | R2 0.3895
use_spd_bias=True
[graphtransformer_tg_spd] Best Val — MAE 56.409588 | RMSE 74.155319 | R2 0.4240
[graphtransformer_den_spd] Best Val — MAE 0.051897 | RMSE 0.086738 | R2 0.5716
[graphtransformer_rg_spd] Best Val — MAE 2.043078 | RMSE 3.004847 | R2 0.5897
[graphtransformer_tc_spd] Best Val — MAE 0.030650 | RMSE 0.046136 | R2 0.7507
[graphtransformer_ffv_spd] Best Val — MAE 0.014980 | RMSE 0.024944 | R2 0.3980
use_geo_bias=True
[graphtransformer_tg_spd] Best Val — MAE 57.246323 | RMSE 71.635178 | R2 0.4624
[graphtransformer_den_spd] Best Val — MAE 0.040926 | RMSE 0.073173 | R2 0.6951
[graphtransformer_rg_spd] Best Val — MAE 2.186664 | RMSE 3.221086 | R2 0.5286
[graphtransformer_tc_spd] Best Val — MAE 0.029472 | RMSE 0.045144 | R2 0.7613
[graphtransformer_ffv_spd] Best Val — MAE 0.010738 | RMSE 0.018613 | R2 0.6648
use_cls=True
[graphtransformer_tg_spd] Best Val — MAE 51.327293 | RMSE 65.371567 | R2 0.5523
[graphtransformer_den_spd] Best Val — MAE 0.048267 | RMSE 0.078447 | R2 0.6495
[graphtransformer_rg_spd] Best Val — MAE 2.215581 | RMSE 3.335439 | R2 0.4945
[graphtransformer_tc_spd] Best Val — MAE 0.031363 | RMSE 0.046542 | R2 0.7463
[graphtransformer_ffv_spd] Best Val — MAE 0.011230 | RMSE 0.019117 | R2 0.6464
use_has_xyz=True

# New LMDB with canonicalize_psmiles
Minimal baseline:
[graphtransformer_tg_spd] Best Val — MAE 55.100464 | RMSE 71.059502 | R2 0.4711
[graphtransformer_den_spd] Best Val — MAE 0.030763 | RMSE 0.068418 | R2 0.7334
[graphtransformer_rg_spd] Best Val — MAE 2.184328 | RMSE 3.163936 | R2 0.5452
[graphtransformer_tc_spd] Best Val — MAE 0.027480 | RMSE 0.044562 | R2 0.7675
[graphtransformer_ffv_spd] Best Val — MAE 0.006242 | RMSE 0.013769 | R2 0.8166
use_edge_bias=True:
[graphtransformer_tg_spd] Best Val — MAE 64.235008 | RMSE 84.792068 | R2 0.2469
[graphtransformer_den_spd] Best Val — MAE 0.041322 | RMSE 0.079782 | R2 0.6375
[graphtransformer_rg_spd] Best Val — MAE 2.271991 | RMSE 3.328248 | R2 0.4967
[graphtransformer_tc_spd] Best Val — MAE 0.030659 | RMSE 0.046286 | R2 0.7491
[graphtransformer_ffv_spd] Best Val — MAE 0.012568 | RMSE 0.021181 | R2 0.5659
use_adj_const=True
[graphtransformer_tg_spd] Best Val — MAE 56.366352 | RMSE 73.193703 | R2 0.4388
[graphtransformer_den_spd] Best Val — MAE 0.047595 | RMSE 0.079254 | R2 0.6423
[graphtransformer_rg_spd] Best Val — MAE 2.042789 | RMSE 2.847963 | R2 0.6315
[graphtransformer_tc_spd] Best Val — MAE 0.033112 | RMSE 0.048311 | R2 0.7267
[graphtransformer_ffv_spd] Best Val — MAE 0.014251 | RMSE 0.022048 | R2 0.5297
use_spd_bias=True
[graphtransformer_tg_spd] Best Val — MAE 52.129559 | RMSE 64.974182 | R2 0.5578
[graphtransformer_den_spd] Best Val — MAE 0.050838 | RMSE 0.088112 | R2 0.5579
[graphtransformer_rg_spd] Best Val — MAE 2.168722 | RMSE 3.208059 | R2 0.5324
[graphtransformer_tc_spd] Best Val — MAE 0.033525 | RMSE 0.048122 | R2 0.7288
[graphtransformer_ffv_spd] Best Val — MAE 0.018180 | RMSE 0.029006 | R2 0.1860
use_geo_bias=True
[graphtransformer_tg_spd] Best Val — MAE 49.387646 | RMSE 63.215427 | R2 0.5814
[graphtransformer_den_spd] Best Val — MAE 0.045089 | RMSE 0.074725 | R2 0.6820
[graphtransformer_rg_spd] Best Val — MAE 2.082205 | RMSE 2.920620 | R2 0.6124
[graphtransformer_tc_spd] Best Val — MAE 0.033126 | RMSE 0.045319 | R2 0.7595
[graphtransformer_ffv_spd] Best Val — MAE 0.012233 | RMSE 0.020871 | R2 0.5785
use_cls=True
[graphtransformer_tg_spd] Best Val — MAE 51.995239 | RMSE 68.995934 | R2 0.5013
[graphtransformer_den_spd] Best Val — MAE 0.042371 | RMSE 0.080913 | R2 0.6272
[graphtransformer_rg_spd] Best Val — MAE 2.145705 | RMSE 3.241805 | R2 0.5225
[graphtransformer_tc_spd] Best Val — MAE 0.031370 | RMSE 0.044972 | R2 0.7632
[graphtransformer_ffv_spd] Best Val — MAE 0.015359 | RMSE 0.023831 | R2 0.4505
use_has_xyz=True
[graphtransformer_tg_spd] Best Val — MAE 52.416542 | RMSE 67.310928 | R2 0.5254
[graphtransformer_den_spd] Best Val — MAE 0.049355 | RMSE 0.081608 | R2 0.6207
[graphtransformer_rg_spd] Best Val — MAE 2.096542 | RMSE 2.964315 | R2 0.6007
[graphtransformer_tc_spd] Best Val — MAE 0.031574 | RMSE 0.046309 | R2 0.7489
[graphtransformer_ffv_spd] Best Val — MAE 0.013870 | RMSE 0.020263 | R2 0.6028



# Conclusions

| Model Type | Feature | MAE | RMSE | R2 |
|---|---|---|---|---|
| RF3D | FFV | 0.007621 | 0.017553 | 0.6605 |
| RF3D_Aug | FFV | 0.007578 | 0.017404 | 0.6662 |
| GNN2 | FFV | 0.013817 | 0.023902 | 0.4473 |
| GNN2_Aug | FFV | 0.013092 | 0.022793 | 0.4974 |
| ET | FFV | 0.006651 | 0.016818 | 0.6883 |
| ET_Aug | FFV | 0.006635 | 0.016826 | 0.6880 |
| **GT** | **FFV** | **0.005713** | **0.008959** | **0.9223** |
| GT_Aug | FFV | 0.007025 | 0.011341 | 0.8756 |
| RF3D | Tg | 58.315801 | 74.296699 | 0.5846 |
| RF3D_Aug | Tg | 58.143107 | 74.521032 | 0.5821 |
| **GNN2** | **Tg** | **47.105114** | **61.480179** | **0.6040** |
| GNN2_Aug | Tg | 51.539692 | 70.575638 | 0.4782 |
| ET | Tg | 58.973811 | 74.658978 | 0.5806 |
| ET_Aug | Tg | 58.521052 | 74.475532 | 0.5826 |
| GT | Tg | 78.903389 | 98.401192 |-0.0143 |
| GT_Aug | Tg | 52.365578 | 67.529610 | 0.5223 |
| RF3D | Tc | 0.029937 | 0.045036 | 0.7313 |
| RF3D_Aug | Tc | 0.029675 | 0.044853 | 0.7335 |
| **GNN2** | **Tc** | **0.025115** | **0.041331** | **0.8000** |
| **GNN2_Aug** | **Tc** | **0.025252** | **0.039670** | **0.8157** |
| ET | Tc | 0.028888 | 0.043469 | 0.7497 |
| ET_Aug | Tc | 0.027990 | 0.042644 | 0.7591 |
| GT | Tc | 0.032644 | 0.046613 | 0.7456 |
| GT_Aug | Tc | 0.028590 | 0.043121 | 0.7822 |
| RF3D | Rg | 1.648818 | 2.493712 | 0.7299 |
| RF3D_Aug | Rg | 1.668425 | 2.517235 | 0.7248 |
| GNN2 | Rg | 2.115880 | 2.801481 | 0.6434 |
| GNN2_Aug | Rg | 1.532573 | 2.405382 | 0.7371 |
| ET | Rg | 1.619464 | 2.522478 | 0.7237 |
| **ET_Aug** | **Rg** | **1.609396** | **2.526705** | **0.7227** |
| GT | Rg | 2.579300 | 3.521387 | 0.4366 |
| GT_Aug | Rg | 2.134301 | 3.066199 | 0.5728 |
| RF3D | Density | 0.037793 | 0.070932 | 0.7847 |
| RF3D_Aug | Density | 0.037123 | 0.070212 | 0.7891 |
| GNN2 | Density | 0.031735 | 0.067845 | 0.7379 |
| GNN2_Aug | Density | 0.030458 | 0.070372 | 0.7180 |
| ET | Density | 0.028492 | 0.052839 | 0.8805 |
| **ET_Aug** | **Density** | **0.028135** | **0.051842** | **0.8850** |
| GT | Density | 0.104749 | 0.134771 | -0.0343 |
| GT_Aug | Density | 0.087159 | 0.126079 | 0.0948 |












