# GraphLoader demo

In [1]:
import sys

# Necessary to import from sibling directory
sys.path.append("..")

from typing import TYPE_CHECKING

from pymdb import (
    MDBClient,
    EvalGraphLoader,
    SamplingGraphLoader,
    Sampler,
)


## SamplingGraphLoader (with seed node resampling)

In [3]:
with MDBClient() as client:
    tgl = SamplingGraphLoader(
        client=client,
        batch_size=16,
        num_neighbors=[5, 5],
        node_feature_prop="",
        edge_feature_prop="",
        num_seeds=64
    )

    print("First sample")
    for graph in tgl:
        print(graph)
    print("Second sample")
    for graph in tgl:
        print(graph)

First sample
Graph(num_seeds=16, node_ids=[84], edge_ids=[94], edge_index=[2, 94])
Graph(num_seeds=16, node_ids=[96], edge_ids=[111], edge_index=[2, 111])
Graph(num_seeds=16, node_ids=[89], edge_ids=[102], edge_index=[2, 102])
Graph(num_seeds=16, node_ids=[92], edge_ids=[109], edge_index=[2, 109])
Second sample
Graph(num_seeds=16, node_ids=[94], edge_ids=[111], edge_index=[2, 111])
Graph(num_seeds=16, node_ids=[106], edge_ids=[132], edge_index=[2, 132])
Graph(num_seeds=16, node_ids=[111], edge_ids=[137], edge_index=[2, 137])
Graph(num_seeds=16, node_ids=[86], edge_ids=[100], edge_index=[2, 100])


## EvalGraphLoader (for sampling the entire graph)

In [4]:
with MDBClient() as client:
    tgl = EvalGraphLoader(
        client=client,
        batch_size=16,
        num_neighbors=[5, 5],
        node_feature_prop="",
        edge_feature_prop="",
    )

    k = 10
    print(f"First {k} batches from the entire graph")
    for i, graph in enumerate(tgl):
        if i == k:
            break
        print(graph)
    

First 10 batches from the entire graph
Graph(num_seeds=16, node_ids=[55], edge_ids=[55], edge_index=[2, 55])
Graph(num_seeds=16, node_ids=[58], edge_ids=[58], edge_index=[2, 58])
Graph(num_seeds=16, node_ids=[53], edge_ids=[53], edge_index=[2, 53])
Graph(num_seeds=16, node_ids=[53], edge_ids=[53], edge_index=[2, 53])
Graph(num_seeds=16, node_ids=[56], edge_ids=[56], edge_index=[2, 56])
Graph(num_seeds=16, node_ids=[52], edge_ids=[52], edge_index=[2, 52])
Graph(num_seeds=16, node_ids=[56], edge_ids=[56], edge_index=[2, 56])
Graph(num_seeds=16, node_ids=[51], edge_ids=[51], edge_index=[2, 51])
Graph(num_seeds=16, node_ids=[53], edge_ids=[53], edge_index=[2, 53])
Graph(num_seeds=16, node_ids=[50], edge_ids=[50], edge_index=[2, 50])
