In [1]:
import sys, os
sys.path.append(os.path.abspath(".."))

import torch
from torch.optim import Adam

from utils import create_hetero_graph
from utils import train, train_val_test_split
from utils.gformer import GFormerWrapper
from models.gformer.Params import args

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# 1) Build hetero graph (with holds as nodes, if that's how you normally do it)
hetero_data = create_hetero_graph(holds_as_nodes=True)

# 2) Split into message/train/val/test on userâ€“problem edges (unchanged)
edge_type = ("user", "rates", "problem")
message_data, train_data, val_data, test_data = train_val_test_split(
    hetero_data,
    edge_type=edge_type,
    message_p=0.7,
    train_p=0.1,
    val_p=0.1,
    by_user=True,
)

In [3]:
# 3) Instantiate GFormer wrapper
model = GFormerWrapper(message_data, edge_type, device=device)

# 4) Optimizer
optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.reg)

In [4]:
# 5) Train using your generic training loop (hetero=True, features=False)
train(
    model=model,
    message_data=message_data,
    train_data=train_data,
    val_data=val_data,
    edge_type=edge_type,
    optimizer=optimizer,
    hetero=True,      # embed is a dict with keys 'user', 'problem'
    features=False,   # call model(edge_index_dict), we ignore dict inside wrapper
    device=device,
    num_epochs=args.epoch,
    batch_size=args.batch,
)

Computing hard negative candidates
Starting training...
Epoch 1, average training loss: 0.6390
Validation Recall@20: 0.2113 (95% CI [0.2063, 0.2162])
Epoch 2, average training loss: 0.5033
Validation Recall@20: 0.2086 (95% CI [0.2036, 0.2136])
Epoch 3, average training loss: 0.4967
Validation Recall@20: 0.2052 (95% CI [0.2002, 0.2101])
Epoch 4, average training loss: 0.4972
Validation Recall@20: 0.2046 (95% CI [0.1997, 0.2096])
Epoch 5, average training loss: 0.4923
Validation Recall@20: 0.2042 (95% CI [0.1993, 0.2092])
Epoch 6, average training loss: 0.4864
Validation Recall@20: 0.2061 (95% CI [0.2011, 0.2110])
Epoch 7, average training loss: 0.4805
Validation Recall@20: 0.2044 (95% CI [0.1995, 0.2093])
Epoch 8, average training loss: 0.4757
Validation Recall@20: 0.2040 (95% CI [0.1991, 0.2090])
Epoch 9, average training loss: 0.4711
Validation Recall@20: 0.2041 (95% CI [0.1992, 0.2091])
Epoch 10, average training loss: 0.4671
Validation Recall@20: 0.2051 (95% CI [0.2002, 0.2100])
Epo

{'best_recall': 0.21125715301230105,
 'best_epoch': 1,
 'val_ci_low': 0.20630080339970538,
 'val_ci_high': 0.21621350262489672,
 'val_n_users': 20103}