In [1]:
import os

# Use the eager mode
os.environ['PT_HPU_LAZY_MODE'] = '0'

# Verify the environment variable is set
print(f"PT_HPU_LAZY_MODE: {os.environ['PT_HPU_LAZY_MODE']}")

import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

import habana_frameworks.torch.core as htcore

# use rich traceback

from rich import traceback
traceback.install()

device = torch.device("hpu")

PT_HPU_LAZY_MODE: 0




2.4.0a0+git74cd574


  return isinstance(object, types.FunctionType)


# Link Prediction on MovieLens

This colab notebook shows how to load a set of `*.csv` files as input and construct a heterogeneous graph from it.
We will then use this dataset as input into a [heterogeneous graph model](https://pytorch-geometric.readthedocs.io/en/latest/notes/heterogeneous.html#hgtutorial), and use it for the task of link prediction.
A few code cells require user input to let the code run through successfully.
If you are stuck on cells that require input, take a look at the fully filled out tutorial [here](https://medium.com/@pytorch_geometric/link-prediction-on-heterogeneous-graphs-with-pyg-6d5c29677c70).

We are going to use the [MovieLens dataset](https://grouplens.org/datasets/movielens/) collected by the GroupLens research group.
This toy dataset describes ratings and tagging activity from MovieLens.
The dataset contains approximately 100k ratings across more than 9k movies from more than 600 users.
We are going to use this dataset to generate two node types holding data for movies and users, respectively, and one edge type connecting users and movies, representing the relation of whether a user has rated a specific movie.

The link prediction task then tries to predict missing ratings, and can, for example, be used to recommend users new movies.

## Heterogeneous Graph Creation

First, we download the dataset to an arbitrary folder (in this case, the current directory):

In [2]:
from torch_geometric.data import download_url, extract_zip

url = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'
extract_zip(download_url(url, '.'), '.')

movies_path = './ml-latest-small/movies.csv'
ratings_path = './ml-latest-small/ratings.csv'

Using existing file ml-latest-small.zip
Extracting ./ml-latest-small.zip


Before we create the heterogeneous graph, let’s take a look at the data.

In [3]:
import pandas as pd

print('movies.csv:')
print('===========')
print(pd.read_csv(movies_path)[["movieId", "genres"]].head())
print()
print('ratings.csv:')
print('============')
print(pd.read_csv(ratings_path)[["userId", "movieId"]].head())

movies.csv:
   movieId                                       genres
0        1  Adventure|Animation|Children|Comedy|Fantasy
1        2                   Adventure|Children|Fantasy
2        3                               Comedy|Romance
3        4                         Comedy|Drama|Romance
4        5                                       Comedy

ratings.csv:
   userId  movieId
0       1        1
1       1        3
2       1        6
3       1       47
4       1       50


We see that the `movies.csv` file provides two useful columns: `movieId` assigns a unique identifier to each movie, while the `genres` column represent genres of the given movie.
We can make use of this column to define a feature representation that can be easily interpreted by machine learning models.

In [4]:
# Load the entire movie data frame into memory:
movies_df = pd.read_csv(movies_path, index_col='movieId')

# Split genres and convert into indicator variables:
genres = movies_df['genres'].str.get_dummies('|')
print(genres[["Action", "Adventure", "Drama", "Horror"]].head())

# Use genres as movie input features:
movie_feat = torch.from_numpy(genres.values).to(torch.float)
assert movie_feat.size() == (9742, 20)  # 20 genres in total.

         Action  Adventure  Drama  Horror
movieId                                  
1             0          1      0       0
2             0          1      0       0
3             0          0      0       0
4             0          0      1       0
5             0          0      0       0


The `ratings.csv` data connects users (as given by `userId`) and movies (as given by `movieId`).
Due to simplicity, we do not make use of the additional `timestamp` and `rating` information.
Here, we first read the `*.csv` file from disk, and create a mapping that maps entry IDs to a consecutive value in the range `{ 0, ..., num_rows - 1 }`.
This is needed as we want our final data representation to be as compact as possible, *e.g.*, the representation of a movie in the first row should be accessible via `x[0]`.

Afterwards, we obtain the final `edge_index` representation of shape `[2, num_ratings]` from `ratings.csv` by merging mapped user and movie indices with the raw indices given by the original data frame.

In [5]:
# Load the entire ratings data frame into memory:
ratings_df = pd.read_csv(ratings_path)

# Create a mapping from unique user indices to range [0, num_user_nodes):
unique_user_id = ratings_df['userId'].unique()
unique_user_id = pd.DataFrame(data={
    'userId': unique_user_id,
    'mappedID': pd.RangeIndex(len(unique_user_id)),
})
print("Mapping of user IDs to consecutive values:")
print("==========================================")
print(unique_user_id.head())
print()
# Create a mapping from unique movie indices to range [0, num_movie_nodes):
unique_movie_id = pd.DataFrame(data={
    'movieId': movies_df.index,
    'mappedID': pd.RangeIndex(len(movies_df)),
})
print("Mapping of movie IDs to consecutive values:")
print("===========================================")
print(unique_movie_id.head())

# Perform merge to obtain the edges from users and movies:
ratings_user_id = pd.merge(ratings_df['userId'], unique_user_id,
                            left_on='userId', right_on='userId', how='left')
ratings_user_id = torch.from_numpy(ratings_user_id['mappedID'].values)
ratings_movie_id = pd.merge(ratings_df['movieId'], unique_movie_id,
                            left_on='movieId', right_on='movieId', how='left')
ratings_movie_id = torch.from_numpy(ratings_movie_id['mappedID'].values)

# With this, we are ready to construct our `edge_index` in COO format
# following PyG semantics:
edge_index_user_to_movie = torch.stack([ratings_user_id, ratings_movie_id], dim=0)
assert edge_index_user_to_movie.size() == (2, 100836)

print()
print("Final edge indices pointing from users to movies:")
print("=================================================")
print(edge_index_user_to_movie)

Mapping of user IDs to consecutive values:
   userId  mappedID
0       1         0
1       2         1
2       3         2
3       4         3
4       5         4

Mapping of movie IDs to consecutive values:
   movieId  mappedID
0        1         0
1        2         1
2        3         2
3        4         3
4        5         4

Final edge indices pointing from users to movies:
tensor([[   0,    0,    0,  ...,  609,  609,  609],
        [   0,    2,    5,  ..., 9462, 9463, 9503]])


With this, we are ready to initialize our `HeteroData` object and pass the necessary information to it.
Note that we also pass in a `node_id` vector to each node type in order to reconstruct the original node indices from sampled subgraphs.
We also take care of adding reverse edges to the `HeteroData` object.
This allows our GNN model to use both directions of the edge for message passing:

In [6]:
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T

data = HeteroData()

# Save node indices:
data["user"].node_id = torch.arange(len(unique_user_id))
data["movie"].node_id = torch.arange(len(movies_df))

# Add the node features and edge indices:
data["movie"].x = movie_feat
data["user", "rates", "movie"].edge_index = edge_index_user_to_movie

# We also need to make sure to add the reverse edges from movies to users
# in order to let a GNN be able to pass messages in both directions.
# We can leverage the `T.ToUndirected()` transform for this from PyG:

data = T.ToUndirected()(data)

print(data)

assert data.node_types == ["user", "movie"]
assert data.edge_types == [("user", "rates", "movie"),
                           ("movie", "rev_rates", "user")]
assert data["user"].num_nodes == 610
assert data["user"].num_features == 0
assert data["movie"].num_nodes == 9742
assert data["movie"].num_features == 20
assert data["user", "rates", "movie"].num_edges == 100836
assert data["movie", "rev_rates", "user"].num_edges == 100836

HeteroData(
  user={ node_id=[610] },
  movie={
    node_id=[9742],
    x=[9742, 20],
  },
  (user, rates, movie)={ edge_index=[2, 100836] },
  (movie, rev_rates, user)={ edge_index=[2, 100836] }
)


## Defining Edge-level Training Splits

Since our data is now ready-to-be-used, we can split the ratings of users into training, validation, and test splits.
This is needed in order to ensure that we leak no information about edges used during evaluation into the training phase.
For this, we make use of the [`transforms.RandomLinkSplit`](https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.RandomLinkSplit) transformation from PyG.
This transforms randomly divides the edges in the `("user", "rates", "movie")` into training, validation and test edges.
The `disjoint_train_ratio` parameter further separates edges in the training split into edges used for message passing (`edge_index`) and edges used for supervision (`edge_label_index`).
Note that we also need to specify the reverse edge type `("movie", "rev_rates", "user")`.
This allows the `RandomLinkSplit` transform to drop reverse edges accordingly to not leak any information into the training phase.

In [7]:
# For this, we first split the set of edges into
# training (80%), validation (10%), and testing edges (10%).
# Across the training edges, we use 70% of edges for message passing,
# and 30% of edges for supervision.
# We further want to generate fixed negative edges for evaluation with a ratio of 2:1.
# Negative edges during training will be generated on-the-fly, so we don't want to
# add them to the graph right away.
# Overall, we can leverage the `RandomLinkSplit()` transform for this from PyG:
transform = T.RandomLinkSplit(
    num_val=0.1,  # 10% for validation
    num_test=0.1,  # 10% for testing
    disjoint_train_ratio=0.3,  # 30% of training edges for supervision
    neg_sampling_ratio=2.0,  # Fixed negative edges with a ratio of 2:1
    add_negative_train_samples=False,  # Negative edges during training generated on-the-fly
    edge_types=("user", "rates", "movie"),
    rev_edge_types=("movie", "rev_rates", "user"),
)

train_data, val_data, test_data = transform(data)
print("Training data:")
print("==============")
print(train_data)
print()
print("Validation data:")
print("================")
print(val_data)

assert train_data["user", "rates", "movie"].num_edges == 56469
assert train_data["user", "rates", "movie"].edge_label_index.size(1) == 24201
assert train_data["movie", "rev_rates", "user"].num_edges == 56469
# No negative edges added:
assert train_data["user", "rates", "movie"].edge_label.min() == 1
assert train_data["user", "rates", "movie"].edge_label.max() == 1

assert val_data["user", "rates", "movie"].num_edges == 80670
assert val_data["user", "rates", "movie"].edge_label_index.size(1) == 30249
assert val_data["movie", "rev_rates", "user"].num_edges == 80670
# Negative edges with ratio 2:1:
assert val_data["user", "rates", "movie"].edge_label.long().bincount().tolist() == [20166, 10083]

Training data:
HeteroData(
  user={ node_id=[610] },
  movie={
    node_id=[9742],
    x=[9742, 20],
  },
  (user, rates, movie)={
    edge_index=[2, 56469],
    edge_label=[24201],
    edge_label_index=[2, 24201],
  },
  (movie, rev_rates, user)={ edge_index=[2, 56469] }
)

Validation data:
HeteroData(
  user={ node_id=[610] },
  movie={
    node_id=[9742],
    x=[9742, 20],
  },
  (user, rates, movie)={
    edge_index=[2, 80670],
    edge_label=[30249],
    edge_label_index=[2, 30249],
  },
  (movie, rev_rates, user)={ edge_index=[2, 80670] }
)


## Defining Mini-batch Loaders

We are now ready to create a mini-batch loader that will generate subgraphs that can be used as input into our GNN.
While this step is not strictly necessary for small-scale graphs, it is absolutely necessary to apply GNNs on larger graphs that do not fit onto GPU memory otherwise.
Here, we make use of the [`loader.LinkNeighborLoader`](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.LinkNeighborLoader) which samples multiple hops from both ends of a link and creates a subgraph from it.
Here, `edge_label_index` serves as the "seed links" to start sampling from.

In [8]:
# torch_geometric/loader/neighbor_sampler.py

from typing import Callable, List, NamedTuple, Optional, Tuple, Union

import torch
from torch import Tensor

# from torch_geometric.typing import SparseTensor
import random


class EdgeIndex(NamedTuple):
    edge_index: Tensor
    e_id: Optional[Tensor]
    size: Tuple[int, int]

    def to(self, *args, **kwargs):
        edge_index = self.edge_index.to(*args, **kwargs)
        e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
        return EdgeIndex(edge_index, e_id, self.size)


# class Adj(NamedTuple):
#     adj_t: SparseTensor
#     e_id: Optional[Tensor]
#     size: Tuple[int, int]

#     def to(self, *args, **kwargs):
#         adj_t = self.adj_t.to(*args, **kwargs)
#         e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
#         return Adj(adj_t, e_id, self.size)


class DenseNeighborSampler(torch.utils.data.DataLoader):
    r"""The neighbor sampler from the `"Inductive Representation Learning on
    Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper, which allows
    for mini-batch training of GNNs on large-scale graphs where full-batch
    training is not feasible.

    Given a GNN with :math:`L` layers and a specific mini-batch of nodes
    :obj:`node_idx` for which we want to compute embeddings, this module
    iteratively samples neighbors and constructs bipartite graphs that simulate
    the actual computation flow of GNNs.

    More specifically, :obj:`sizes` denotes how much neighbors we want to
    sample for each node in each layer.
    This module then takes in these :obj:`sizes` and iteratively samples
    :obj:`sizes[l]` for each node involved in layer :obj:`l`.
    In the next layer, sampling is repeated for the union of nodes that were
    already encountered.
    The actual computation graphs are then returned in reverse-mode, meaning
    that we pass messages from a larger set of nodes to a smaller one, until we
    reach the nodes for which we originally wanted to compute embeddings.

    Hence, an item returned by :class:`NeighborSampler` holds the current
    :obj:`batch_size`, the IDs :obj:`n_id` of all nodes involved in the
    computation, and a list of bipartite graph objects via the tuple
    :obj:`(edge_index, e_id, size)`, where :obj:`edge_index` represents the
    bipartite edges between source and target nodes, :obj:`e_id` denotes the
    IDs of original edges in the full graph, and :obj:`size` holds the shape
    of the bipartite graph.
    For each bipartite graph, target nodes are also included at the beginning
    of the list of source nodes so that one can easily apply skip-connections
    or add self-loops.

    .. warning::

        :class:`~torch_geometric.loader.NeighborSampler` is deprecated and will
        be removed in a future release.
        Use :class:`torch_geometric.loader.NeighborLoader` instead.

    .. note::

        For an example of using :obj:`NeighborSampler`, see
        `examples/reddit.py
        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
        reddit.py>`_ or
        `examples/ogbn_products_sage.py
        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
        ogbn_products_sage.py>`_.

    Args:
        edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
            :class:`torch_sparse.SparseTensor` that defines the underlying
            graph connectivity/message passing flow.
            :obj:`edge_index` holds the indices of a (sparse) symmetric
            adjacency matrix.
            If :obj:`edge_index` is of type :obj:`torch.LongTensor`, its shape
            must be defined as :obj:`[2, num_edges]`, where messages from nodes
            :obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]`
            (in case :obj:`flow="source_to_target"`).
            If :obj:`edge_index` is of type :class:`torch_sparse.SparseTensor`,
            its sparse indices :obj:`(row, col)` should relate to
            :obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`.
            The major difference between both formats is that we need to input
            the *transposed* sparse adjacency matrix.
        sizes ([int]): The number of neighbors to sample for each node in each
            layer. If set to :obj:`sizes[l] = -1`, all neighbors are included
            in layer :obj:`l`.
        node_idx (LongTensor, optional): The nodes that should be considered
            for creating mini-batches. If set to :obj:`None`, all nodes will be
            considered.
        num_nodes (int, optional): The number of nodes in the graph.
            (default: :obj:`None`)
        return_e_id (bool, optional): If set to :obj:`False`, will not return
            original edge indices of sampled edges. This is only useful in case
            when operating on graphs without edge features to save memory.
            (default: :obj:`True`)
        transform (callable, optional): A function/transform that takes in
            a sampled mini-batch and returns a transformed version.
            (default: :obj:`None`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
    """
    def __init__(self, edge_index: Tensor,
                 sizes: List[int], node_idx: Optional[Tensor] = None,
                 num_nodes: Optional[int] = None, return_e_id: bool = True,
                 transform: Callable = None, **kwargs):

        edge_index = edge_index.to('cpu')

        # Remove for PyTorch Lightning:
        kwargs.pop('dataset', None)
        kwargs.pop('collate_fn', None)

        # Save for Pytorch Lightning < 1.6:
        self.edge_index = edge_index
        self.node_idx = node_idx
        self.num_nodes = num_nodes

        self.sizes = sizes
        self.return_e_id = return_e_id
        self.transform = transform
        # self.is_sparse_tensor = isinstance(edge_index, SparseTensor)
        self.__val__ = None

        # Obtain a *transposed* `SparseTensor` instance.        
        if (num_nodes is None and node_idx is not None
                and node_idx.dtype == torch.bool):
            num_nodes = node_idx.size(0)
        if (num_nodes is None and node_idx is not None
                and node_idx.dtype == torch.long):
            num_nodes = max(int(edge_index.max()), int(node_idx.max())) + 1
        if num_nodes is None:
            num_nodes = int(edge_index.max()) + 1
        
        self.num_nodes = num_nodes
        self.value = torch.arange(edge_index.size(1)) if return_e_id else None
        self.row = edge_index[0]
        self.col = edge_index[1]
        self.sparse_sizes = (num_nodes, num_nodes)               
        
        self.adj_list = [[] for _ in range(num_nodes)]
        self.edge_ids = [[] for _ in range(num_nodes)]
        
        for e_id, (src, dst) in enumerate(edge_index.t()):
            self.adj_list[src.item()].append(dst.item())
            if return_e_id:
                self.edge_ids[src.item()].append(e_id)
        
        if node_idx is None:
            node_idx = torch.arange(num_nodes)
        elif node_idx.dtype == torch.bool:
            node_idx = node_idx.nonzero(as_tuple=False).view(-1)

        super().__init__(
            node_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs)

    def sample_neighbors(self, nodes: Tensor, size: int) -> Tuple[Tensor, Tensor, Tensor]:
        """Sample neighbors for given nodes"""
        rows, cols, e_ids = [], [], []
        
        for node in nodes.tolist():
            neighbors = self.adj_list[node]
            if len(neighbors) > 0:
                if size > 0:
                    num_samples = min(size, len(neighbors))
                    sampled_idx = random.sample(range(len(neighbors)), num_samples)
                    sampled_neighbors = [neighbors[i] for i in sampled_idx]
                    if self.return_e_id:
                        sampled_e_ids = [self.edge_ids[node][i] for i in sampled_idx]
                else:
                    sampled_neighbors = neighbors
                    sampled_e_ids = self.edge_ids[node] if self.return_e_id else []
                
                rows.extend([node] * len(sampled_neighbors))
                cols.extend(sampled_neighbors)
                if self.return_e_id:
                    e_ids.extend(sampled_e_ids)
        
        edge_index = torch.tensor([cols, rows], dtype=torch.long)
        e_id = torch.tensor(e_ids) if self.return_e_id else None
        return edge_index, e_id


    def sample(self, batch):
        if not isinstance(batch, Tensor):
            batch = torch.tensor(batch)

        batch_size: int = len(batch)

        adjs = []
        n_id = batch
        for size in self.sizes:
            # adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False)
            edge_index, e_id = self.sample_neighbors(n_id, size)
            
            # Get new nodes involved
            new_nodes = torch.unique(edge_index[0])
            n_id = torch.unique(torch.cat([n_id, new_nodes]))
            
            size = (n_id.size(0), len(batch))
            adjs.append(EdgeIndex(edge_index, e_id, size))

        adjs = adjs[0] if len(adjs) == 1 else adjs[::-1]
        out = (batch_size, n_id, adjs)
        out = self.transform(*out) if self.transform is not None else out
        return out

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(sizes={self.sizes})'

In [10]:
torch.ops.load_library("/root/raw_torch_for_scatter/neighbor_sample/csrc/build/libneighbor_sample.so")

In [9]:
# torch_geometric/sampler/neighbor_sampler.py

import copy
import math
import sys
import warnings
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.data import (
    Data,
    FeatureStore,
    GraphStore,
    HeteroData,
    remote_backend_utils,
)
from torch_geometric.data.graph_store import EdgeLayout
from torch_geometric.sampler import (
    BaseSampler,
    EdgeSamplerInput,
    HeteroSamplerOutput,
    NegativeSampling,
    NodeSamplerInput,
    SamplerOutput,
)
from torch_geometric.sampler.base import DataType, NumNeighbors, SubgraphType
from torch_geometric.sampler.utils import remap_keys, to_csc, to_hetero_csc
from torch_geometric.typing import EdgeType, NodeType, OptTensor

NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]


class NeighborSampler(BaseSampler):
    r"""An implementation of an in-memory (heterogeneous) neighbor sampler used
    by :class:`~torch_geometric.loader.NeighborLoader`.
    """
    def __init__(
        self,
        data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
        num_neighbors: NumNeighborsType,
        subgraph_type: Union[SubgraphType, str] = 'directional',
        replace: bool = False,
        disjoint: bool = False,
        temporal_strategy: str = 'uniform',
        time_attr: Optional[str] = None,
        weight_attr: Optional[str] = None,
        is_sorted: bool = False,
        share_memory: bool = False,
        # Deprecated:
        directed: bool = True,
    ):

        self.data_type = DataType.from_data(data)

        if self.data_type == DataType.homogeneous:
            self.num_nodes = data.num_nodes

            self.node_time: Optional[Tensor] = None
            self.edge_time: Optional[Tensor] = None

            if time_attr is not None:
                if data.is_node_attr(time_attr):
                    self.node_time = data[time_attr]
                elif data.is_edge_attr(time_attr):
                    self.edge_time = data[time_attr]
                else:
                    raise ValueError(
                        f"The time attribute '{time_attr}' is neither a "
                        f"node-level or edge-level attribute")

            # Convert the graph data into CSC format for sampling:
            self.colptr, self.row, self.perm = to_csc(
                data, device='cpu', share_memory=share_memory,
                is_sorted=is_sorted, src_node_time=self.node_time,
                edge_time=self.edge_time)

            if self.edge_time is not None and self.perm is not None:
                self.edge_time = self.edge_time[self.perm]

            self.edge_weight: Optional[Tensor] = None
            if weight_attr is not None:
                self.edge_weight = data[weight_attr]
                if self.perm is not None:
                    self.edge_weight = self.edge_weight[self.perm]

        elif self.data_type == DataType.heterogeneous:
            self.node_types, self.edge_types = data.metadata()

            self.num_nodes = {k: data[k].num_nodes for k in self.node_types}

            self.node_time: Optional[Dict[NodeType, Tensor]] = None
            self.edge_time: Optional[Dict[EdgeType, Tensor]] = None

            if time_attr is not None:
                is_node_level_time = is_edge_level_time = False

                for store in data.node_stores:
                    if time_attr in store:
                        is_node_level_time = True
                for store in data.edge_stores:
                    if time_attr in store:
                        is_edge_level_time = True

                if is_node_level_time and is_edge_level_time:
                    raise ValueError(
                        f"The time attribute '{time_attr}' holds both "
                        f"node-level and edge-level information")

                if not is_node_level_time and not is_edge_level_time:
                    raise ValueError(
                        f"The time attribute '{time_attr}' is neither a "
                        f"node-level or edge-level attribute")

                if is_node_level_time:
                    self.node_time = data.collect(time_attr)
                else:
                    self.edge_time = data.collect(time_attr)

            # Conversion to/from C++ string type: Since C++ cannot take
            # dictionaries with tuples as key as input, edge type triplets need
            # to be converted into single strings.
            self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}
            self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}

            # Convert the graph data into CSC format for sampling:
            colptr_dict, row_dict, self.perm = to_hetero_csc(
                data, device='cpu', share_memory=share_memory,
                is_sorted=is_sorted, node_time_dict=self.node_time,
                edge_time_dict=self.edge_time)

            self.row_dict = remap_keys(row_dict, self.to_rel_type)
            self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)

            if self.edge_time is not None:
                for edge_type, edge_time in self.edge_time.items():
                    if self.perm.get(edge_type, None) is not None:
                        edge_time = edge_time[self.perm[edge_type]]
                        self.edge_time[edge_type] = edge_time
                self.edge_time = remap_keys(self.edge_time, self.to_rel_type)

            self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None
            if weight_attr is not None:
                self.edge_weight = data.collect(weight_attr)
                for edge_type, edge_weight in self.edge_weight.items():
                    if self.perm.get(edge_type, None) is not None:
                        edge_weight = edge_weight[self.perm[edge_type]]
                        self.edge_weight[edge_type] = edge_weight
                self.edge_weight = remap_keys(self.edge_weight,
                                              self.to_rel_type)

        else:  # self.data_type == DataType.remote
            feature_store, graph_store = data

            # Obtain graph metadata:
            attrs = [attr for attr in feature_store.get_all_tensor_attrs()]

            edge_attrs = graph_store.get_all_edge_attrs()
            self.edge_types = list({attr.edge_type for attr in edge_attrs})

            if weight_attr is not None:
                raise NotImplementedError(
                    f"'weight_attr' argument not yet supported within "
                    f"'{self.__class__.__name__}' for "
                    f"'(FeatureStore, GraphStore)' inputs")

            if time_attr is not None:
                # If the `time_attr` is present, we expect that `GraphStore`
                # holds all edges sorted by destination, and within local
                # neighborhoods, node indices should be sorted by time.
                # TODO (matthias, manan) Find an alternative way to ensure.
                for edge_attr in edge_attrs:
                    if edge_attr.layout == EdgeLayout.CSR:
                        raise ValueError(
                            "Temporal sampling requires that edges are stored "
                            "in either COO or CSC layout")
                    if not edge_attr.is_sorted:
                        raise ValueError(
                            "Temporal sampling requires that edges are "
                            "sorted by destination, and by source time "
                            "within local neighborhoods")

                # We obtain all features with `node_attr.name=time_attr`:
                time_attrs = [
                    copy.copy(attr) for attr in attrs
                    if attr.attr_name == time_attr
                ]

            if not self.is_hetero:
                self.node_types = [None]
                self.num_nodes = max(edge_attrs[0].size)
                self.edge_weight: Optional[Tensor] = None

                self.node_time: Optional[Tensor] = None
                self.edge_time: Optional[Tensor] = None

                if time_attr is not None:
                    if len(time_attrs) != 1:
                        raise ValueError("Temporal sampling specified but did "
                                         "not find any temporal data")
                    time_attrs[0].index = None  # Reset index for full data.
                    time_tensor = feature_store.get_tensor(time_attrs[0])
                    # Currently, we determine whether to use node-level or
                    # edge-level temporal sampling based on the attribute name.
                    if time_attr == 'time':
                        self.node_time = time_tensor
                    else:
                        self.edge_time = time_tensor

                self.row, self.colptr, self.perm = graph_store.csc()

            else:
                node_types = [
                    attr.group_name for attr in attrs
                    if isinstance(attr.group_name, str)
                ]
                self.node_types = list(set(node_types))
                self.num_nodes = {
                    node_type: remote_backend_utils.size(*data, node_type)
                    for node_type in self.node_types
                }
                self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None

                self.node_time: Optional[Dict[NodeType, Tensor]] = None
                self.edge_time: Optional[Dict[EdgeType, Tensor]] = None

                if time_attr is not None:
                    for attr in time_attrs:  # Reset index for full data.
                        attr.index = None

                    time_tensors = feature_store.multi_get_tensor(time_attrs)
                    time = {
                        attr.group_name: time_tensor
                        for attr, time_tensor in zip(time_attrs, time_tensors)
                    }

                    group_names = [attr.group_name for attr in time_attrs]
                    if all([isinstance(g, str) for g in group_names]):
                        self.node_time = time
                    elif all([isinstance(g, tuple) for g in group_names]):
                        self.edge_time = time
                    else:
                        raise ValueError(
                            f"Found time attribute '{time_attr}' for both "
                            f"node-level and edge-level types")

                # Conversion to/from C++ string type (see above):
                self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}
                self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}
                # Convert the graph data into CSC format for sampling:
                row_dict, colptr_dict, self.perm = graph_store.csc()
                self.row_dict = remap_keys(row_dict, self.to_rel_type)
                self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)

        self.num_neighbors = num_neighbors
        self.replace = replace
        self.subgraph_type = SubgraphType(subgraph_type)
        self.disjoint = disjoint
        self.temporal_strategy = temporal_strategy

    @property
    def num_neighbors(self) -> NumNeighbors:
        return self._num_neighbors

    @num_neighbors.setter
    def num_neighbors(self, num_neighbors: NumNeighborsType):
        if isinstance(num_neighbors, NumNeighbors):
            self._num_neighbors = num_neighbors
        else:
            self._num_neighbors = NumNeighbors(num_neighbors)

    @property
    def is_hetero(self) -> bool:
        if self.data_type == DataType.homogeneous:
            return False
        if self.data_type == DataType.heterogeneous:
            return True

        # self.data_type == DataType.remote
        return self.edge_types != [None]

    @property
    def is_temporal(self) -> bool:
        return self.node_time is not None or self.edge_time is not None

    @property
    def disjoint(self) -> bool:
        return self._disjoint or self.is_temporal

    @disjoint.setter
    def disjoint(self, disjoint: bool):
        self._disjoint = disjoint

    # Node-based sampling #####################################################

    def sample_from_nodes(
        self,
        inputs: NodeSamplerInput,
    ) -> Union[SamplerOutput, HeteroSamplerOutput]:
        out = node_sample(inputs, self._sample)
        if self.subgraph_type == SubgraphType.bidirectional:
            out = out.to_bidirectional()
        return out

    # Edge-based sampling #####################################################

    def sample_from_edges(
        self,
        inputs: EdgeSamplerInput,
        neg_sampling: Optional[NegativeSampling] = None,
    ) -> Union[SamplerOutput, HeteroSamplerOutput]:
        out = edge_sample(inputs, self._sample, self.num_nodes, self.disjoint,
                          self.node_time, neg_sampling)
        if self.subgraph_type == SubgraphType.bidirectional:
            out = out.to_bidirectional()
        return out

    # Other Utilities #########################################################

    @property
    def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
        return self.perm

    # Helper functions ########################################################

    def _sample(
        self,
        seed: Union[Tensor, Dict[NodeType, Tensor]],
        seed_time: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None,
        **kwargs,
    ) -> Union[SamplerOutput, HeteroSamplerOutput]:
        r"""Implements neighbor sampling by calling either :obj:`pyg-lib` (if
        installed) or :obj:`torch-sparse` (if installed) sampling routines.
        """
        if isinstance(seed, dict):  # Heterogeneous sampling:
            if self.disjoint:
                if self.subgraph_type == SubgraphType.induced:
                    raise ValueError("'disjoint' sampling not supported "
                                        "for neighbor sampling with "
                                        "`subgraph_type='induced'`")
                else:
                    raise ValueError("'disjoint' sampling not supported "
                                        "for neighbor sampling via "
                                        "'torch-sparse'. Please install "
                                        "'pyg-lib' for improved and "
                                        "optimized sampling routines.")

            out = torch.ops.torch_sparse.hetero_neighbor_sample(
                self.node_types,
                self.edge_types,
                self.colptr_dict,
                self.row_dict,
                seed,  # seed_dict
                self.num_neighbors.get_mapped_values(self.edge_types),
                self.num_neighbors.num_hops,
                self.replace,
                self.subgraph_type != SubgraphType.induced,
            )
            node, row, col, edge, batch = out + (None, )
            num_sampled_nodes = num_sampled_edges = None

            if num_sampled_edges is not None:
                num_sampled_edges = remap_keys(
                    num_sampled_edges,
                    self.to_edge_type,
                )

            return HeteroSamplerOutput(
                node=node,
                row=remap_keys(row, self.to_edge_type),
                col=remap_keys(col, self.to_edge_type),
                edge=remap_keys(edge, self.to_edge_type),
                batch=batch,
                num_sampled_nodes=num_sampled_nodes,
                num_sampled_edges=num_sampled_edges,
            )

        else:  # Homogeneous sampling:
            if self.disjoint:
                raise ValueError("'disjoint' sampling not supported for "
                                    "neighbor sampling via 'torch-sparse'. "
                                    "Please install 'pyg-lib' for improved "
                                    "and optimized sampling routines.")

            out = torch.ops.torch_sparse.neighbor_sample(
                self.colptr,
                self.row,
                seed,  # seed
                self.num_neighbors.get_mapped_values(),
                self.replace,
                self.subgraph_type != SubgraphType.induced,
            )
            node, row, col, edge, batch = out + (None, )
            num_sampled_nodes = num_sampled_edges = None

            return SamplerOutput(
                node=node,
                row=row,
                col=col,
                edge=edge,
                batch=batch,
                num_sampled_nodes=num_sampled_nodes,
                num_sampled_edges=num_sampled_edges,
            )


# Sampling Utilities ##########################################################


def node_sample(
    inputs: NodeSamplerInput,
    sample_fn: Callable,
) -> Union[SamplerOutput, HeteroSamplerOutput]:
    r"""Performs sampling from a :class:`NodeSamplerInput`, leveraging a
    sampling function that accepts a seed and (optionally) a seed time as
    input. Returns the output of this sampling procedure.
    """
    if inputs.input_type is not None:  # Heterogeneous sampling:
        seed = {inputs.input_type: inputs.node}
        seed_time = None
        if inputs.time is not None:
            seed_time = {inputs.input_type: inputs.time}
    else:  # Homogeneous sampling:
        seed = inputs.node
        seed_time = inputs.time

    out = sample_fn(seed, seed_time)
    out.metadata = (inputs.input_id, inputs.time)

    return out


def edge_sample(
    inputs: EdgeSamplerInput,
    sample_fn: Callable,
    num_nodes: Union[int, Dict[NodeType, int]],
    disjoint: bool,
    node_time: Optional[Union[Tensor, Dict[str, Tensor]]] = None,
    neg_sampling: Optional[NegativeSampling] = None,
) -> Union[SamplerOutput, HeteroSamplerOutput]:
    r"""Performs sampling from an edge sampler input, leveraging a sampling
    function of the same signature as `node_sample`.
    """
    input_id = inputs.input_id
    src = inputs.row
    dst = inputs.col
    edge_label = inputs.label
    edge_label_time = inputs.time
    input_type = inputs.input_type

    src_time = dst_time = edge_label_time
    assert edge_label_time is None or disjoint

    assert isinstance(num_nodes, (dict, int))
    if not isinstance(num_nodes, dict):
        num_src_nodes = num_dst_nodes = num_nodes
    else:
        num_src_nodes = num_nodes[input_type[0]]
        num_dst_nodes = num_nodes[input_type[-1]]

    num_pos = src.numel()
    num_neg = 0

    # Negative Sampling #######################################################

    if neg_sampling is not None:
        # When we are doing negative sampling, we append negative information
        # of nodes/edges to `src`, `dst`, `src_time`, `dst_time`.
        # Later on, we can easily reconstruct what belongs to positive and
        # negative examples by slicing via `num_pos`.
        num_neg = math.ceil(num_pos * neg_sampling.amount)

        if neg_sampling.is_binary():
            # In the "binary" case, we randomly sample negative pairs of nodes.
            if isinstance(node_time, dict):
                src_node_time = node_time.get(input_type[0])
            else:
                src_node_time = node_time

            src_neg = neg_sample(src, neg_sampling, num_src_nodes, src_time,
                                 src_node_time, endpoint='src')
            src = torch.cat([src, src_neg], dim=0)

            if isinstance(node_time, dict):
                dst_node_time = node_time.get(input_type[-1])
            else:
                dst_node_time = node_time

            dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time,
                                 dst_node_time, endpoint='dst')
            dst = torch.cat([dst, dst_neg], dim=0)

            if edge_label is None:
                edge_label = torch.ones(num_pos)
            size = (num_neg, ) + edge_label.size()[1:]
            edge_neg_label = edge_label.new_zeros(size)
            edge_label = torch.cat([edge_label, edge_neg_label])

            if edge_label_time is not None:
                src_time = dst_time = edge_label_time.repeat(
                    1 + math.ceil(neg_sampling.amount))[:num_pos + num_neg]

        elif neg_sampling.is_triplet():
            # In the "triplet" case, we randomly sample negative destinations.
            if isinstance(node_time, dict):
                dst_node_time = node_time.get(input_type[-1])
            else:
                dst_node_time = node_time

            dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time,
                                 dst_node_time, endpoint='dst')
            dst = torch.cat([dst, dst_neg], dim=0)

            assert edge_label is None

            if edge_label_time is not None:
                dst_time = edge_label_time.repeat(1 + neg_sampling.amount)

    # Heterogeneous Neighborhood Sampling #####################################

    if input_type is not None:
        seed_time_dict = None
        if input_type[0] != input_type[-1]:  # Two distinct node types:

            if not disjoint:
                src, inverse_src = src.unique(return_inverse=True)
                dst, inverse_dst = dst.unique(return_inverse=True)

            seed_dict = {input_type[0]: src, input_type[-1]: dst}

            if edge_label_time is not None:  # Always disjoint.
                seed_time_dict = {
                    input_type[0]: src_time,
                    input_type[-1]: dst_time,
                }

        else:  # Only a single node type: Merge both source and destination.

            seed = torch.cat([src, dst], dim=0)

            if not disjoint:
                seed, inverse_seed = seed.unique(return_inverse=True)

            seed_dict = {input_type[0]: seed}

            if edge_label_time is not None:  # Always disjoint.
                seed_time_dict = {
                    input_type[0]: torch.cat([src_time, dst_time], dim=0),
                }

        out = sample_fn(seed_dict, seed_time_dict)

        # Enhance `out` by label information ##################################
        if disjoint:
            for key, batch in out.batch.items():
                out.batch[key] = batch % num_pos

        if neg_sampling is None or neg_sampling.is_binary():
            if disjoint:
                if input_type[0] != input_type[-1]:
                    edge_label_index = torch.arange(num_pos + num_neg)
                    edge_label_index = edge_label_index.repeat(2).view(2, -1)
                else:
                    edge_label_index = torch.arange(2 * (num_pos + num_neg))
                    edge_label_index = edge_label_index.view(2, -1)
            else:
                if input_type[0] != input_type[-1]:
                    edge_label_index = torch.stack([
                        inverse_src,
                        inverse_dst,
                    ], dim=0)
                else:
                    edge_label_index = inverse_seed.view(2, -1)

            out.metadata = (input_id, edge_label_index, edge_label, src_time)

        elif neg_sampling.is_triplet():
            if disjoint:
                src_index = torch.arange(num_pos)
                if input_type[0] != input_type[-1]:
                    dst_pos_index = torch.arange(num_pos)
                    # `dst_neg_index` needs to be offset such that indices with
                    # offset `num_pos` belong to the same triplet:
                    dst_neg_index = torch.arange(
                        num_pos, seed_dict[input_type[-1]].numel())
                    dst_neg_index = dst_neg_index.view(-1, num_pos).t()
                else:
                    dst_pos_index = torch.arange(num_pos, 2 * num_pos)
                    dst_neg_index = torch.arange(
                        2 * num_pos, seed_dict[input_type[-1]].numel())
                    dst_neg_index = dst_neg_index.view(-1, num_pos).t()
            else:
                if input_type[0] != input_type[-1]:
                    src_index = inverse_src
                    dst_pos_index = inverse_dst[:num_pos]
                    dst_neg_index = inverse_dst[num_pos:]
                else:
                    src_index = inverse_seed[:num_pos]
                    dst_pos_index = inverse_seed[num_pos:2 * num_pos]
                    dst_neg_index = inverse_seed[2 * num_pos:]

            dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)

            out.metadata = (
                input_id,
                src_index,
                dst_pos_index,
                dst_neg_index,
                src_time,
            )

    # Homogeneous Neighborhood Sampling #######################################

    else:

        seed = torch.cat([src, dst], dim=0)
        seed_time = None

        if not disjoint:
            seed, inverse_seed = seed.unique(return_inverse=True)

        if edge_label_time is not None:  # Always disjoint.
            seed_time = torch.cat([src_time, dst_time])

        out = sample_fn(seed, seed_time)

        # Enhance `out` by label information ##################################
        if neg_sampling is None or neg_sampling.is_binary():
            if disjoint:
                out.batch = out.batch % num_pos
                edge_label_index = torch.arange(seed.numel()).view(2, -1)
            else:
                edge_label_index = inverse_seed.view(2, -1)

            out.metadata = (input_id, edge_label_index, edge_label, src_time)

        elif neg_sampling.is_triplet():
            if disjoint:
                out.batch = out.batch % num_pos
                src_index = torch.arange(num_pos)
                dst_pos_index = torch.arange(num_pos, 2 * num_pos)
                # `dst_neg_index` needs to be offset such that indices with
                # offset `num_pos` belong to the same triplet:
                dst_neg_index = torch.arange(2 * num_pos, seed.numel())
                dst_neg_index = dst_neg_index.view(-1, num_pos).t()
            else:
                src_index = inverse_seed[:num_pos]
                dst_pos_index = inverse_seed[num_pos:2 * num_pos]
                dst_neg_index = inverse_seed[2 * num_pos:]
            dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)

            out.metadata = (
                input_id,
                src_index,
                dst_pos_index,
                dst_neg_index,
                src_time,
            )

    return out


def neg_sample(
    seed: Tensor,
    neg_sampling: NegativeSampling,
    num_nodes: int,
    seed_time: Optional[Tensor],
    node_time: Optional[Tensor],
    endpoint: Literal['str', 'dst'],
) -> Tensor:
    num_neg = math.ceil(seed.numel() * neg_sampling.amount)

    # TODO: Do not sample false negatives.
    if node_time is None:
        return neg_sampling.sample(num_neg, endpoint, num_nodes)

    # If we are in a temporal-sampling scenario, we need to respect the
    # timestamp of the given nodes we can use as negative examples.
    # That is, we can only sample nodes for which `node_time <= seed_time`.
    # For now, we use a greedy algorithm which randomly samples negative
    # nodes and discard any which do not respect the temporal constraint.
    # We iteratively repeat this process until we have sampled a valid node for
    # each seed.
    # TODO See if this greedy algorithm here can be improved.
    assert seed_time is not None
    num_samples = math.ceil(neg_sampling.amount)
    seed_time = seed_time.view(1, -1).expand(num_samples, -1)

    out = neg_sampling.sample(num_samples * seed.numel(), endpoint, num_nodes)
    out = out.view(num_samples, seed.numel())
    mask = node_time[out] > seed_time  # holds all invalid samples.
    neg_sampling_complete = False
    for i in range(5):  # pragma: no cover
        num_invalid = int(mask.sum())
        if num_invalid == 0:
            neg_sampling_complete = True
            break

        # Greedily search for alternative negatives.
        out[mask] = tmp = neg_sampling.sample(num_invalid, endpoint, num_nodes)
        mask[mask.clone()] = node_time[tmp] >= seed_time[mask]

    if not neg_sampling_complete:  # pragma: no cover
        # Not much options left. In that case, we set remaining negatives
        # to the node with minimum timestamp.
        out[mask] = node_time.argmin()

    return out.view(-1)[:num_neg]


In [11]:
from typing import Callable, Dict, List, Optional, Tuple, Union

from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
from torch_geometric.loader.link_loader import LinkLoader
from torch_geometric.sampler import NegativeSampling
# , NeighborSampler 
from torch_geometric.sampler.base import SubgraphType
from torch_geometric.typing import EdgeType, InputEdges, OptTensor


class LinkNeighborLoader(LinkLoader):
    r"""A link-based data loader derived as an extension of the node-based
    :class:`torch_geometric.loader.NeighborLoader`.
    This loader allows for mini-batch training of GNNs on large-scale graphs
    where full-batch training is not feasible.

    More specifically, this loader first selects a sample of edges from the
    set of input edges :obj:`edge_label_index` (which may or not be edges in
    the original graph) and then constructs a subgraph from all the nodes
    present in this list by sampling :obj:`num_neighbors` neighbors in each
    iteration.

    .. code-block:: python

        from torch_geometric.datasets import Planetoid
        from torch_geometric.loader import LinkNeighborLoader

        data = Planetoid(path, name='Cora')[0]

        loader = LinkNeighborLoader(
            data,
            # Sample 30 neighbors for each node for 2 iterations
            num_neighbors=[30] * 2,
            # Use a batch size of 128 for sampling training nodes
            batch_size=128,
            edge_label_index=data.edge_index,
        )

        sampled_data = next(iter(loader))
        print(sampled_data)
        >>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368],
                 train_mask=[1368], val_mask=[1368], test_mask=[1368],
                 edge_label_index=[2, 128])

    It is additionally possible to provide edge labels for sampled edges, which
    are then added to the batch:

    .. code-block:: python

        loader = LinkNeighborLoader(
            data,
            num_neighbors=[30] * 2,
            batch_size=128,
            edge_label_index=data.edge_index,
            edge_label=torch.ones(data.edge_index.size(1))
        )

        sampled_data = next(iter(loader))
        print(sampled_data)
        >>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368],
                 train_mask=[1368], val_mask=[1368], test_mask=[1368],
                 edge_label_index=[2, 128], edge_label=[128])

    The rest of the functionality mirrors that of
    :class:`~torch_geometric.loader.NeighborLoader`, including support for
    heterogeneous graphs.
    In particular, the data loader will add the following attributes to the
    returned mini-batch:

    * :obj:`n_id` The global node index for every sampled node
    * :obj:`e_id` The global edge index for every sampled edge
    * :obj:`input_id`: The global index of the :obj:`edge_label_index`
    * :obj:`num_sampled_nodes`: The number of sampled nodes in each hop
    * :obj:`num_sampled_edges`: The number of sampled edges in each hop

    .. note::
        Negative sampling is currently implemented in an approximate
        way, *i.e.* negative edges may contain false negatives.

    .. warning::
        Note that the sampling scheme is independent from the edge we are
        making a prediction for.
        That is, by default supervision edges in :obj:`edge_label_index`
        **will not** get masked out during sampling.
        In case there exists an overlap between message passing edges in
        :obj:`data.edge_index` and supervision edges in
        :obj:`edge_label_index`, you might end up sampling an edge you are
        making a prediction for.
        You can generally avoid this behavior (if desired) by making
        :obj:`data.edge_index` and :obj:`edge_label_index` two disjoint sets of
        edges, *e.g.*, via the
        :class:`~torch_geometric.transforms.RandomLinkSplit` transformation and
        its :obj:`disjoint_train_ratio` argument.

    Args:
        data (Any): A :class:`~torch_geometric.data.Data`,
            :class:`~torch_geometric.data.HeteroData`, or
            (:class:`~torch_geometric.data.FeatureStore`,
            :class:`~torch_geometric.data.GraphStore`) data object.
        num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The
            number of neighbors to sample for each node in each iteration.
            If an entry is set to :obj:`-1`, all neighbors will be included.
            In heterogeneous graphs, may also take in a dictionary denoting
            the amount of neighbors to sample for each individual edge type.
        edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
            The edge indices for which neighbors are sampled to create
            mini-batches.
            If set to :obj:`None`, all edges will be considered.
            In heterogeneous graphs, needs to be passed as a tuple that holds
            the edge type and corresponding edge indices.
            (default: :obj:`None`)
        edge_label (Tensor, optional): The labels of edge indices for
            which neighbors are sampled. Must be the same length as
            the :obj:`edge_label_index`. If set to :obj:`None` its set to
            `torch.zeros(...)` internally. (default: :obj:`None`)
        edge_label_time (Tensor, optional): The timestamps for edge indices
            for which neighbors are sampled. Must be the same length as
            :obj:`edge_label_index`. If set, temporal sampling will be
            used such that neighbors are guaranteed to fulfill temporal
            constraints, *i.e.*, neighbors have an earlier timestamp than
            the ouput edge. The :obj:`time_attr` needs to be set for this
            to work. (default: :obj:`None`)
        replace (bool, optional): If set to :obj:`True`, will sample with
            replacement. (default: :obj:`False`)
        subgraph_type (SubgraphType or str, optional): The type of the returned
            subgraph.
            If set to :obj:`"directional"`, the returned subgraph only holds
            the sampled (directed) edges which are necessary to compute
            representations for the sampled seed nodes.
            If set to :obj:`"bidirectional"`, sampled edges are converted to
            bidirectional edges.
            If set to :obj:`"induced"`, the returned subgraph contains the
            induced subgraph of all sampled nodes.
            (default: :obj:`"directional"`)
        disjoint (bool, optional): If set to :obj: `True`, each seed node will
            create its own disjoint subgraph.
            If set to :obj:`True`, mini-batch outputs will have a :obj:`batch`
            vector holding the mapping of nodes to their respective subgraph.
            Will get automatically set to :obj:`True` in case of temporal
            sampling. (default: :obj:`False`)
        temporal_strategy (str, optional): The sampling strategy when using
            temporal sampling (:obj:`"uniform"`, :obj:`"last"`).
            If set to :obj:`"uniform"`, will sample uniformly across neighbors
            that fulfill temporal constraints.
            If set to :obj:`"last"`, will sample the last `num_neighbors` that
            fulfill temporal constraints.
            (default: :obj:`"uniform"`)
        neg_sampling (NegativeSampling, optional): The negative sampling
            configuration.
            For negative sampling mode :obj:`"binary"`, samples can be accessed
            via the attributes :obj:`edge_label_index` and :obj:`edge_label` in
            the respective edge type of the returned mini-batch.
            In case :obj:`edge_label` does not exist, it will be automatically
            created and represents a binary classification task (:obj:`0` =
            negative edge, :obj:`1` = positive edge).
            In case :obj:`edge_label` does exist, it has to be a categorical
            label from :obj:`0` to :obj:`num_classes - 1`.
            After negative sampling, label :obj:`0` represents negative edges,
            and labels :obj:`1` to :obj:`num_classes` represent the labels of
            positive edges.
            Note that returned labels are of type :obj:`torch.float` for binary
            classification (to facilitate the ease-of-use of
            :meth:`F.binary_cross_entropy`) and of type
            :obj:`torch.long` for multi-class classification (to facilitate the
            ease-of-use of :meth:`F.cross_entropy`).
            For negative sampling mode :obj:`"triplet"`, samples can be
            accessed via the attributes :obj:`src_index`, :obj:`dst_pos_index`
            and :obj:`dst_neg_index` in the respective node types of the
            returned mini-batch.
            :obj:`edge_label` needs to be :obj:`None` for :obj:`"triplet"`
            negative sampling mode.
            If set to :obj:`None`, no negative sampling strategy is applied.
            (default: :obj:`None`)
        neg_sampling_ratio (int or float, optional): The ratio of sampled
            negative edges to the number of positive edges.
            Deprecated in favor of the :obj:`neg_sampling` argument.
            (default: :obj:`None`)
        time_attr (str, optional): The name of the attribute that denotes
            timestamps for either the nodes or edges in the graph.
            If set, temporal sampling will be used such that neighbors are
            guaranteed to fulfill temporal constraints, *i.e.* neighbors have
            an earlier or equal timestamp than the center node.
            Only used if :obj:`edge_label_time` is set. (default: :obj:`None`)
        weight_attr (str, optional): The name of the attribute that denotes
            edge weights in the graph.
            If set, weighted/biased sampling will be used such that neighbors
            are more likely to get sampled the higher their edge weights are.
            Edge weights do not need to sum to one, but must be non-negative,
            finite and have a non-zero sum within local neighborhoods.
            (default: :obj:`None`)
        transform (callable, optional): A function/transform that takes in
            a sampled mini-batch and returns a transformed version.
            (default: :obj:`None`)
        transform_sampler_output (callable, optional): A function/transform
            that takes in a :class:`torch_geometric.sampler.SamplerOutput` and
            returns a transformed version. (default: :obj:`None`)
        is_sorted (bool, optional): If set to :obj:`True`, assumes that
            :obj:`edge_index` is sorted by column.
            If :obj:`time_attr` is set, additionally requires that rows are
            sorted according to time within individual neighborhoods.
            This avoids internal re-sorting of the data and can improve
            runtime and memory efficiency. (default: :obj:`False`)
        filter_per_worker (bool, optional): If set to :obj:`True`, will filter
            the returned data in each worker's subprocess.
            If set to :obj:`False`, will filter the returned data in the main
            process.
            If set to :obj:`None`, will automatically infer the decision based
            on whether data partially lives on the GPU
            (:obj:`filter_per_worker=True`) or entirely on the CPU
            (:obj:`filter_per_worker=False`).
            There exists different trade-offs for setting this option.
            Specifically, setting this option to :obj:`True` for in-memory
            datasets will move all features to shared memory, which may result
            in too many open file handles. (default: :obj:`None`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
    """
    def __init__(
        self,
        data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
        num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],
        edge_label_index: InputEdges = None,
        edge_label: OptTensor = None,
        edge_label_time: OptTensor = None,
        replace: bool = False,
        subgraph_type: Union[SubgraphType, str] = 'directional',
        disjoint: bool = False,
        temporal_strategy: str = 'uniform',
        neg_sampling: Optional[NegativeSampling] = None,
        neg_sampling_ratio: Optional[Union[int, float]] = None,
        time_attr: Optional[str] = None,
        weight_attr: Optional[str] = None,
        transform: Optional[Callable] = None,
        transform_sampler_output: Optional[Callable] = None,
        is_sorted: bool = False,
        filter_per_worker: Optional[bool] = None,
        neighbor_sampler: Optional[NeighborSampler] = None,
        directed: bool = True,  # Deprecated.
        **kwargs,
    ):
        if (edge_label_time is not None) != (time_attr is not None):
            raise ValueError(
                f"Received conflicting 'edge_label_time' and 'time_attr' "
                f"arguments: 'edge_label_time' is "
                f"{'set' if edge_label_time is not None else 'not set'} "
                f"while 'time_attr' is "
                f"{'set' if time_attr is not None else 'not set'}. "
                f"Both arguments must be provided for temporal sampling.")

        if neighbor_sampler is None:
            neighbor_sampler = NeighborSampler(
                data,
                num_neighbors=num_neighbors,
                replace=replace,
                subgraph_type=subgraph_type,
                disjoint=disjoint,
                temporal_strategy=temporal_strategy,
                time_attr=time_attr,
                weight_attr=weight_attr,
                is_sorted=is_sorted,
                share_memory=kwargs.get('num_workers', 0) > 0,
                directed=directed,
            )

        super().__init__(
            data=data,
            link_sampler=neighbor_sampler,
            edge_label_index=edge_label_index,
            edge_label=edge_label,
            edge_label_time=edge_label_time,
            neg_sampling=neg_sampling,
            neg_sampling_ratio=neg_sampling_ratio,
            transform=transform,
            transform_sampler_output=transform_sampler_output,
            filter_per_worker=filter_per_worker,
            **kwargs,
        )


In [12]:
# In the first hop, we sample at most 20 neighbors.
# In the second hop, we sample at most 10 neighbors.
# In addition, during training, we want to sample negative edges on-the-fly with
# a ratio of 2:1.
# We can make use of the `loader.LinkNeighborLoader` from PyG:
# from torch_geometric.loader import LinkNeighborLoader

# Define seed edges:
edge_label_index = train_data["user", "rates", "movie"].edge_label_index
edge_label = train_data["user", "rates", "movie"].edge_label

train_loader = LinkNeighborLoader(
    data=train_data,  # Use the training data
    num_neighbors=[20, 10],  # Sample at most 20 neighbors in the first hop and 10 in the second hop
    neg_sampling_ratio=2.0,  # Sample negative edges on-the-fly with a ratio of 2:1
    edge_label_index=(("user", "rates", "movie"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

# Inspect a sample:
sampled_data = next(iter(train_loader))

print("Sampled mini-batch:")
print("===================")
print(sampled_data)

assert sampled_data["user", "rates", "movie"].edge_label_index.size(1) == 3 * 128
assert sampled_data["user", "rates", "movie"].edge_label.min() == 0
assert sampled_data["user", "rates", "movie"].edge_label.max() == 1

Sampled mini-batch:
HeteroData(
  user={
    node_id=[609],
    n_id=[609],
  },
  movie={
    node_id=[2765],
    x=[2765, 20],
    n_id=[2765],
  },
  (user, rates, movie)={
    edge_index=[2, 17239],
    edge_label=[384],
    edge_label_index=[2, 384],
    e_id=[17239],
    input_id=[128],
  },
  (movie, rev_rates, user)={
    edge_index=[2, 7779],
    e_id=[7779],
  }
)


## Creating a Heterogeneous Link-level GNN

We are now ready to create our heterogeneous GNN.
The GNN is responsible for learning enriched node representations from the surrounding subgraphs, which can be then used to derive edge-level predictions.
For defining our heterogenous GNN, we make use of [`nn.SAGEConv`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.SAGEConv) and the [`nn.to_hetero()`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.to_hetero_transformer.to_hetero) function, which transforms a GNN defined on homogeneous graphs to be applied on heterogeneous ones.

In addition, we define a final link-level classifier, which simply takes both node embeddings of the link we are trying to predict, and applies a dot-product on them.

As users do not have any node-level information, we choose to learn their features jointly via a `torch.nn.Embedding` layer. In order to improve the expressiveness of movie features, we do the same for movie nodes, and simply add their shallow embeddings to the pre-defined genre features.

In [14]:
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()

        self.conv1 = SAGEConv(hidden_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        # Define a 2-layer GNN computation graph.
        # Use a *single* `ReLU` non-linearity in-between.        
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        return x
                
        

# Our final classifier applies the dot-product between source and destination
# node embeddings to derive edge-level predictions:
class Classifier(torch.nn.Module):
    def forward(self, x_user: Tensor, x_movie: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_movie = x_movie[edge_label_index[1]]

        # Apply dot-product to get a prediction per supervision edge:
        return (edge_feat_user * edge_feat_movie).sum(dim=-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        # Since the dataset does not come with rich features, we also learn two
        # embedding matrices for users and movies:
        self.movie_lin = torch.nn.Linear(20, hidden_channels)
        self.user_emb = torch.nn.Embedding(data["user"].num_nodes, hidden_channels)
        self.movie_emb = torch.nn.Embedding(data["movie"].num_nodes, hidden_channels)

        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)

        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())

        self.classifier = Classifier()

    def forward(self, data: HeteroData) -> Tensor:
        x_dict = {
          "user": self.user_emb(data["user"].node_id),
          "movie": self.movie_lin(data["movie"].x) + self.movie_emb(data["movie"].node_id),
        }

        # `x_dict` holds feature matrices of all node types
        # `edge_index_dict` holds all edge indices of all edge types
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        pred = self.classifier(
            x_dict["user"],
            x_dict["movie"],
            data["user", "rates", "movie"].edge_label_index,
        )

        return pred


model = Model(hidden_channels=64)

print(model)

Model(
  (movie_lin): Linear(in_features=20, out_features=64, bias=True)
  (user_emb): Embedding(610, 64)
  (movie_emb): Embedding(9742, 64)
  (gnn): GraphModule(
    (conv1): ModuleDict(
      (user__rates__movie): SAGEConv(64, 64, aggr=mean)
      (movie__rev_rates__user): SAGEConv(64, 64, aggr=mean)
    )
    (conv2): ModuleDict(
      (user__rates__movie): SAGEConv(64, 64, aggr=mean)
      (movie__rev_rates__user): SAGEConv(64, 64, aggr=mean)
    )
  )
  (classifier): Classifier()
)


## Training a Heterogeneous Link-level GNN

Training our GNN is then similar to training any PyTorch model.
We move the model to the desired device, and initialize an optimizer that takes care of adjusting model parameters via stochastic gradient descent.

The training loop then iterates over our mini-batches, applies the forward computation of the model, computes the loss from ground-truth labels and obtained predictions (here we make use of binary cross entropy), and adjusts model parameters via back-propagation and stochastic gradient descent.

In [None]:
import tqdm
import torch.nn.functional as F

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
print(f"Device: '{device}'")

model = Model(hidden_channels=64)

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1, 6):
    total_loss = total_examples = 0
    for sampled_data in tqdm.tqdm(train_loader):
        optimizer.zero_grad()

        # TODO: Move `sampled_data` to the respective `device`
        sampled_data = sampled_data.to(device)
        
        # TODO: Run `forward` pass of the model
        pred = model(sampled_data)
        
        # TODO: Apply binary cross entropy via        
        # `F.binary_cross_entropy_with_logits(pred, ground_truth)`
        
        loss = F.binary_cross_entropy_with_logits(pred, sampled_data["user", "rates", "movie"].edge_label)
                

        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

Device: 'cpu'


100%|██████████| 190/190 [00:16<00:00, 11.33it/s]


Epoch: 001, Loss: 0.4505


100%|██████████| 190/190 [00:16<00:00, 11.34it/s]


Epoch: 002, Loss: 0.3532


100%|██████████| 190/190 [00:16<00:00, 11.33it/s]


Epoch: 003, Loss: 0.3303


100%|██████████| 190/190 [00:16<00:00, 11.33it/s]


Epoch: 004, Loss: 0.3143


100%|██████████| 190/190 [00:16<00:00, 11.38it/s]

Epoch: 005, Loss: 0.2990





In [19]:
import tqdm
import torch.nn.functional as F

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
device = torch.device('hpu')
print(f"Device: '{device}'")

model = Model(hidden_channels=64)

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()
# model = torch.compile(model,backend="hpu_backend")

for epoch in range(1, 6):
    total_loss = total_examples = 0
    for sampled_data in tqdm.tqdm(train_loader):
        optimizer.zero_grad()

        # TODO: Move `sampled_data` to the respective `device`
        sampled_data = sampled_data.to(device)
        
        # TODO: Run `forward` pass of the model
        pred = model(sampled_data)
        
        # TODO: Apply binary cross entropy via        
        # `F.binary_cross_entropy_with_logits(pred, ground_truth)`
        
        loss = F.binary_cross_entropy_with_logits(pred, sampled_data["user", "rates", "movie"].edge_label)

        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

Device: 'hpu'


  1%|          | 1/190 [00:00<00:22,  8.50it/s]

100%|██████████| 190/190 [00:19<00:00,  9.69it/s]


Epoch: 001, Loss: 0.4758


100%|██████████| 190/190 [00:20<00:00,  9.26it/s]


Epoch: 002, Loss: 0.3742


100%|██████████| 190/190 [00:19<00:00,  9.68it/s]


Epoch: 003, Loss: 0.3461


100%|██████████| 190/190 [00:19<00:00,  9.65it/s]


Epoch: 004, Loss: 0.3343


100%|██████████| 190/190 [00:20<00:00,  9.25it/s]

Epoch: 005, Loss: 0.3230





## Evaluating a Heterogeneous Link-level GNN

After training, we evaluate our model on useen data coming from the validation set.
For this, we define a new `LinkNeighborLoader` (which now iterates over the edges in the validation set), obtain the predictions on validation edges by running the model, and finally evaluate the performance of the model by computing the AUC score over the set of predictions and their corresponding ground-truth edges (including both positive and negative edges).

In [20]:
# Define the validation seed edges:
edge_label_index = val_data["user", "rates", "movie"].edge_label_index
edge_label = val_data["user", "rates", "movie"].edge_label

val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[20, 10],
    edge_label_index=(("user", "rates", "movie"), edge_label_index),
    edge_label=edge_label,
    batch_size=3 * 128,
    shuffle=False,
)

sampled_data = next(iter(val_loader))

print("Sampled mini-batch:")
print("===================")
print(sampled_data)

assert sampled_data["user", "rates", "movie"].edge_label_index.size(1) == 3 * 128
assert sampled_data["user", "rates", "movie"].edge_label.min() >= 0
assert sampled_data["user", "rates", "movie"].edge_label.max() <= 1

Sampled mini-batch:
HeteroData(
  user={
    node_id=[609],
    n_id=[609],
  },
  movie={
    node_id=[2652],
    x=[2652, 20],
    n_id=[2652],
  },
  (user, rates, movie)={
    edge_index=[2, 18960],
    edge_label=[384],
    edge_label_index=[2, 384],
    e_id=[18960],
    input_id=[384],
  },
  (movie, rev_rates, user)={
    edge_index=[2, 7814],
    e_id=[7814],
  }
)


In [21]:
from sklearn.metrics import roc_auc_score

preds = []
ground_truths = []
for sampled_data in tqdm.tqdm(val_loader):
    with torch.no_grad():
        # TODO: Collect predictions and ground-truths and write them into
        # `preds` and `ground_truths`.
        
        sampled_data = sampled_data.to(device)
        pred = model(sampled_data)
        preds.append(pred.cpu())
        ground_truths.append(sampled_data["user", "rates", "movie"].edge_label.cpu())

pred = torch.cat(preds, dim=0).cpu().numpy()
ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
auc = roc_auc_score(ground_truth, pred)
print()
print(f"Validation AUC: {auc:.4f}")

100%|██████████| 79/79 [00:07<00:00, 10.73it/s]


Validation AUC: 0.9233



