# Walking through training and generating embeddings for a heterogeneous model using GiGL

This notebook will walk you through using GiGL to train a hetereogeneous model.
For this example, we will be using some very small "toy graph" as our dataset.

By the end of this notebook you will have:

1. Pre-processed the toy graph using GiGL data preprocessor
2. Done a forward and backward pass of the model, using GiGL dataloaders
3. Complete inference on the model

This file is intended to be a companion to our example [heterogeneous_training.py](./heterogenous_training.py).

In [None]:
%load_ext autoreload
%autoreload 2
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # Suppress TensorFlow logs


from gigl.common.utils.jupyter_magics import change_working_dir_to_gigl_root
change_working_dir_to_gigl_root()

## Visualize the dataset

First, let's visualize the toy graph :)

In [None]:
from torch_geometric.data import HeteroData

from gigl.common.utils.jupyter_magics import GraphVisualizer
from gigl.src.mocking.toy_asset_mocker import load_toy_graph


original_graph_heterodata: HeteroData = load_toy_graph(graph_config_path="examples/tutorial/KDD_2025/graph_config.yaml")
# Visualize the graph
GraphVisualizer.visualize_graph(original_graph_heterodata)

# Preprocessor
TODO(mkolodner)

In [None]:
# Do a simple forward/backward pass of the model

# TODO(mkolodner): Swap to the on-the-fly task config from pre-populator.
task_config_uri = "examples/tutorial/KDD_2025/toy_graph_task_config.yaml"
# First, we need to load the dataset
import torch

from gigl.distributed import (
    DistLinkPredictionDataset,
    build_dataset_from_task_config_uri,
)
# GiGL is meant to operate in a very large distributed setting, so we need to initialize the process group.
torch.distributed.init_process_group(
    backend="gloo",  # Use the Gloo backend for CPU training.
    init_method="tcp://localhost:29500",
    rank=0,
    world_size=1,
)

# `build_dataset_from_task_config_uri` is a utility function
# to build a dataset in a distributed manner.
# It will:
# 1. Read the serialized graph data whose located is specified in the task config.
# 2. Load the graph data in a distributed manner.
# 3. Partition the graph data into shards for distributed training.
# 4. Optional: If training, will generate splits for training.
dataset: DistLinkPredictionDataset = build_dataset_from_task_config_uri(
        task_config_uri=task_config_uri,
        is_inference=False,
        _tfrecord_uri_pattern=".*tfrecord", # Our example data uses a different tfrecord pattern.
)

# And instantiate a dataloader:
from gigl.distributed import DistABLPLoader

loader = DistABLPLoader(
            dataset=dataset,
            num_neighbors=[2, 2],  # Example neighbor sampling configuration.
            input_nodes=("user", torch.tensor([0])),  # Example input nodes, adjust as needed.
            batch_size=1,
            supervision_edge_type=("user", "to", "story"),  # Supervision edge type defined in the graph.
            pin_memory_device=torch.device(
                "cpu"
            ),  # Only CPU training for this example.
        )
data: HeteroData = next(iter(loader))

In [None]:
# Now let's look at the data we just loaded.
print(data)

# You might notice a few things about the data that is different from vanilla PyG:
# * num_sampled_nodes and num_sampled_edges are present,
# * representing the number of nodes and edges sampled per hop.
# * y_positive is added, and is a dict of anchor node -> target nodes.

GraphVisualizer.visualize_graph(data)

In [None]:
# Initialize a model and do a forward pass
# You can interop with any PyG model, but we will use HGTConv for this example.
from torch_geometric.nn import HGTConv

model = HGTConv(
    in_channels=data.num_node_features,
    out_channels=16,  # Example output dimension.
    metadata=data.metadata(),
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

# Do a forward pass
embeddings = model(data.x_dict, data.edge_index_dict)

print(f"Embeddings: {embeddings}")

In [None]:
# Now let's define a loss function for the link prediction task.
# TODO should we define this in some util file?

# Note that we really should wrap this

def compute_loss(model: torch.nn.Module, data: HeteroData) -> torch.Tensor:
    main_out: dict[str, torch.Tensor] = model(data.x_dict, data.edge_index_dict)
    # data.y_positive = {
    #   0: [1, 2],
    #   1: [3, 4, 5],
    # }
    anchor_nodes = torch.arange(data["user"].batch_size).repeat_interleave(
        torch.tensor([len(v) for v in data.y_positive.values()])
    )
    # anchor_nodes = [0, 0, 1, 1, 1]
    target_nodes = torch.cat([v for v in data.y_positive.values()])
    # target_nodes = [1, 2, 3, 4, 5]
    # Use MarginRankingLoss for link prediction
    loss_fn = torch.nn.MarginRankingLoss()
    query_embeddings = main_out["user"][anchor_nodes]
    target_embeddings = main_out["story"][target_nodes]
    loss = loss_fn(
        input1=query_embeddings,
        input2=target_embeddings,
        target=torch.ones_like(query_embeddings, dtype=torch.float32),
    )
    return loss

# Note that in practice you would want to wrap this in a training loop
# but for this example doing just one pass is sufficient.
# A training loop example can be found in:
# examples/tutorial/KDD_2025/heterogeneous_training.py
loss = compute_loss(model, data)
print(f"Loss: {loss.item()}")

# And we can do a backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()

In [None]:
# Now if we run the loss function again, we should see a different value.
loss = compute_loss(model, data)
print(f"Loss after backward pass: {loss.item()}")