In [1]:
import sys, os, pickle
sys.path.append(os.path.abspath(".."))

import torch
from utils.graph_creation import create_hetero_graph
from utils.training_utils import train_val_test_split, train
from models.custom.custom import Custom

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [8]:
if os.path.exists("../data/hetero_data_with_holds_std.pkl"):
    with open("../data/hetero_data_with_holds_std.pkl", "rb") as f:
        hetero_data = pickle.load(f)
else:
    hetero_data = create_hetero_graph(holds_as_nodes=True, standardize=True)
    with open("../data/hetero_data_with_holds_std.pkl", "wb") as f:
        pickle.dump(hetero_data, f)

In [9]:
edge_type = ("user", "rates", "problem")
# 2) Split into message / train / val / test
message_data, train_data, val_data, test_data = train_val_test_split(
    hetero_data,
    edge_type=edge_type,
    message_p=0.7,
    train_p=0.1,
    val_p=0.1,
    by_user=True,
)

In [10]:
# 3) Instantiate the model
model = Custom(
    hetero_data=message_data,   # same metadata and dims
    hidden_channels=64,
    num_layers=2,
    dropout=0.1,
)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [12]:
# 4) Train using your existing train() function
train(
    model=model,
    message_data=message_data,
    train_data=train_data,
    val_data=val_data,
    edge_type=edge_type,
    optimizer=optimizer,
    hetero=True,
    features=True,
    device=device,
    num_epochs=20,
    batch_size=1024,
    hn_increase_rate=5,
)

Computing hard negative candidates
Starting training...
Epoch 1, average training loss: 0.1381
Validation Recall@20: 0.2003
Epoch 2, average training loss: 0.0905
Validation Recall@20: 0.1958
Epoch 3, average training loss: 0.0809
Validation Recall@20: 0.1924
Epoch 4, average training loss: 0.0749
Validation Recall@20: 0.1943
Epoch 5, average training loss: 0.0735
Validation Recall@20: 0.1950
Epoch 6, average training loss: 0.2695
Validation Recall@20: 0.1891
Epoch 7, average training loss: 0.2488
Validation Recall@20: 0.1975
Epoch 8, average training loss: 0.2390
Validation Recall@20: 0.2073
Epoch 9, average training loss: 0.2359
Validation Recall@20: 0.2120
Epoch 10, average training loss: 0.2327
Validation Recall@20: 0.2091
Epoch 11, average training loss: 0.2708
Validation Recall@20: 0.1958
Epoch 12, average training loss: 0.2680
Validation Recall@20: 0.1981
Epoch 13, average training loss: 0.2654
Validation Recall@20: 0.2012
Epoch 14, average training loss: 0.2636
Validation Recal