In [1]:
import torch
import torch.nn.functional as F
import yaml
import warnings

warnings.filterwarnings('ignore')

from torch_geometric.data import Data  # from train.py imports fileciteturn2file0
from torch_geometric.utils import subgraph  # for extracting subgraphs fileciteturn2file0
import model

In [2]:
model_config = yaml.safe_load(open('config/model.yaml', 'r'))
training_config = yaml.safe_load(open('config/training.yaml', 'r'))

In [3]:
model_dict = {name: getattr(model, name) for name in model.__all__}  # fileciteturn2file0

In [4]:
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda:3


In [5]:
data = torch.load(
    training_config['data']['data_path'],
    weights_only=training_config['data']['weights_only']
)
data = data.to(device)
print(f"Loaded data: {data.num_nodes} nodes, {data.num_node_features} features per node")

Loaded data: 203768 nodes, 166 features per node


In [6]:
nsubset_nodes = [i for i in range(25)]
subset = torch.tensor(nsubset_nodes, dtype=torch.long, device=device)

In [7]:
edge_index_sub, edge_attr_sub = subgraph(
    subset,
    data.edge_index,
    data.edge_attr if hasattr(data, 'edge_attr') else None,
    relabel_nodes=True
)

In [8]:
x_sub = data.x[subset]
data_sub = Data(x=x_sub, edge_index=edge_index_sub)

In [9]:
if hasattr(data, 'edge_attr') and edge_attr_sub is not None:
    data_sub.edge_attr = edge_attr_sub

In [10]:
fold = 1  # change to the fold number you want
model_cls = model_dict[model_config['model']['type']]
model_net = model_cls(
    in_channels=data.num_node_features,
    **model_config['model']['params']
).to(device)
ckpt_path = (
    f"{model_config['checkpoint']}/"
    f"{model_net.__class__.__name__}/fold_{fold}/"
    f"{model_net.__class__.__name__}_fold{fold}.pt"
)
checkpoint = torch.load(ckpt_path, map_location=device)
model_net.load_state_dict(checkpoint['model_state_dict'])
model_net.eval()
print(f"Loaded model from {ckpt_path}")

Loaded model from checkpoints/elliptic/SAGE/fold_1/SAGE_fold1.pt


In [11]:
with torch.no_grad():
    out = model_net(data_sub.x, data_sub.edge_index)
    probs = F.softmax(out, dim=1)
    preds = probs.argmax(dim=1)

# Map predictions back to original node IDs and print results
for i, node_id in enumerate(subset):
    label = preds[i].item()
    score = probs[i, label].item()
    print(f"Original node {node_id.item()}: predicted class {label} (confidence {score:.4f})")

Original node 0: predicted class 1 (confidence 0.8661)
Original node 1: predicted class 1 (confidence 0.9699)
Original node 2: predicted class 1 (confidence 1.0000)
Original node 3: predicted class 1 (confidence 0.9959)
Original node 4: predicted class 1 (confidence 0.9996)
Original node 5: predicted class 1 (confidence 0.9999)
Original node 6: predicted class 1 (confidence 0.9156)
Original node 7: predicted class 1 (confidence 0.9699)
Original node 8: predicted class 1 (confidence 1.0000)
Original node 9: predicted class 1 (confidence 0.8164)
Original node 10: predicted class 1 (confidence 0.7640)
Original node 11: predicted class 1 (confidence 0.9998)
Original node 12: predicted class 1 (confidence 0.9996)
Original node 13: predicted class 1 (confidence 0.9992)
Original node 14: predicted class 1 (confidence 0.9780)
Original node 15: predicted class 1 (confidence 1.0000)
Original node 16: predicted class 1 (confidence 1.0000)
Original node 17: predicted class 1 (confidence 0.8836)
Or