In [1]:
import sys, os
sys.path.append(os.path.abspath(".."))

import torch
from torch.optim import Adam

from utils import create_hetero_graph
from utils import train, train_val_test_split
from utils.gformer import GFormerWrapper
from models.gformer.Params import args

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# 1) Build hetero graph (with holds as nodes, if that's how you normally do it)
hetero_data = create_hetero_graph(holds_as_nodes=True)

# 2) Split into message/train/val/test on userâ€“problem edges (unchanged)
edge_type = ("user", "rates", "problem")
message_data, train_data, val_data, test_data = train_val_test_split(
    hetero_data,
    edge_type=edge_type,
    message_p=0.7,
    train_p=0.1,
    val_p=0.1,
    by_user=True,
)

In [3]:
# 3) Instantiate GFormer wrapper
model = GFormerWrapper(message_data, edge_type, device=device)

# 4) Optimizer
optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.reg)

In [4]:
# 5) Train using your generic training loop (hetero=True, features=False)
train(
    model=model,
    message_data=message_data,
    train_data=train_data,
    val_data=val_data,
    edge_type=edge_type,
    optimizer=optimizer,
    hetero=True,      # embed is a dict with keys 'user', 'problem'
    features=False,   # call model(edge_index_dict), we ignore dict inside wrapper
    device=device,
    num_epochs=args.epoch,
    batch_size=args.batch,
)

Computing hard negative candidates
Starting training...
Epoch 1, average training loss: 0.6735
Validation Recall@20: 0.21136725677167476
Epoch 2, average training loss: 0.5170
Validation Recall@20: 0.2071394009896197
Epoch 3, average training loss: 0.4493
Validation Recall@20: 0.20680410858617923
Epoch 4, average training loss: 0.4348
Validation Recall@20: 0.20758076972444378
Epoch 5, average training loss: 0.4271
Validation Recall@20: 0.20690866546521933
Epoch 6, average training loss: 0.4195
Validation Recall@20: 0.20955185749031416
Epoch 7, average training loss: 0.4131
Validation Recall@20: 0.20972237625396462
Epoch 8, average training loss: 0.4041
Validation Recall@20: 0.20853493471114853
Epoch 9, average training loss: 0.3958
Validation Recall@20: 0.21036481280549538
Epoch 10, average training loss: 0.3862
Validation Recall@20: 0.2112049148818125
Epoch 11, average training loss: 0.3763
Validation Recall@20: 0.2118963060449858
Epoch 12, average training loss: 0.3667
Validation Rec