In [236]:
import dgl
import torch

In [265]:
"""Random walk routines
"""

from dgl import backend as F, ndarray as nd, utils
from dgl._ffi.function import _init_api
from dgl.base import DGLError

__all__ = ["random_walk", "pack_traces"]


def random_walk(
    g,
    nodes,
    *,
    metapath=None,
    length=None,
    prob=None,
    restart_prob=None,
    return_eids=False
):
    """Generate random walk traces from an array of starting nodes based on the given metapath.
    """
    n_etypes = len(g.canonical_etypes)
    n_ntypes = len(g.ntypes)

    if metapath is None:
        if n_etypes > 1 or n_ntypes > 1:
            raise DGLError(
                "metapath not specified and the graph is not homogeneous."
            )
        if length is None:
            raise ValueError(
                "Please specify either the metapath or the random walk length."
            )
        metapath = [0] * length
    else:
        metapath = [g.get_etype_id(etype) for etype in metapath]

    gidx = g._graph
    nodes = utils.prepare_tensor(g, nodes, "nodes")
    nodes = F.to_dgl_nd(nodes)
    # (Xin) Since metapath array is created by us, safe to skip the check
    #       and keep it on CPU to make max_nodes sanity check easier.
    metapath = F.to_dgl_nd(F.astype(F.tensor(metapath), g.idtype))
    # print(nodes)
    # Load the probability tensor from the edge frames
    ctx = utils.to_dgl_context(g.device)
    if prob is None:
        p_nd = [nd.array([], ctx=ctx) for _ in g.canonical_etypes]
    else:
        p_nd = []
        for etype in g.canonical_etypes:
            if prob in g.edges[etype].data:
                prob_nd = F.to_dgl_nd(g.edges[etype].data[prob])
            else:
                prob_nd = nd.array([], ctx=ctx)
            p_nd.append(prob_nd)
    # print(p_nd)
    # Actual random walk
    if restart_prob is None:
        print(gidx, nodes, metapath, p_nd)
        traces, eids, types = _CAPI_DGLSamplingRandomWalk(
            gidx, nodes, metapath, p_nd
        )
    elif F.is_tensor(restart_prob):
        restart_prob = F.to_dgl_nd(restart_prob)
        traces, eids, types = _CAPI_DGLSamplingRandomWalkWithStepwiseRestart(
            gidx, nodes, metapath, p_nd, restart_prob
        )
    elif isinstance(restart_prob, float):
        traces, eids, types = _CAPI_DGLSamplingRandomWalkWithRestart(
            gidx, nodes, metapath, p_nd, restart_prob
        )
    else:
        raise TypeError("restart_prob should be float or Tensor.")

    print(traces)
    traces = F.from_dgl_nd(traces)
    types = F.from_dgl_nd(types)
    eids = F.from_dgl_nd(eids)
    
    # print((traces, types))

    return (traces, eids, types) if return_eids else (traces, types)


_init_api("dgl.sampling.randomwalks", __name__)

In [None]:
g2 = dgl.heterograph({
    ('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),
    ('user', 'view', 'item'): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),
    ('item', 'viewed-by', 'user'): ([0, 1, 1, 2, 2, 1], [0, 0, 1, 2, 3, 3])})

In [266]:
g2 = dgl.heterograph({
    ('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),
    ('user', 'view', 'item'): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),
    ('item', 'viewed-by', 'user'): ([0, 1, 1, 2, 2, 1], [0, 0, 1, 2, 3, 3])}, device=torch.device('cpu'))
# g2.pin_memory_()
display(g2)
random_walk(g2, [0, 1, 2, 0], metapath=['follow', 'view', 'viewed-by'] * 2)

# dgl.sampling.random_walk(g2, [0, 1, 2, 0], metapath=['follow', 'view', 'viewed-by'] * 2, restart_prob=torch.FloatTensor([0, 0.5, 0, 0, 0.5, 0]))



Graph(num_nodes={'item': 3, 'user': 4},
      num_edges={('item', 'viewed-by', 'user'): 6, ('user', 'follow', 'user'): 5, ('user', 'view', 'item'): 6},
      metagraph=[('item', 'user', 'viewed-by'), ('user', 'user', 'follow'), ('user', 'item', 'view')])

(tensor([[             0,              1,              1,              3,
                       0,              0,              0],
         [           145,      228631568, 47001428110336,             16,
                       8,            144,             49],
         [     234414112,      228631616,             16,              8,
                      96,             48,      234489552],
         [             3,             24,              8,            240,
                      80,              0,              0]]),
 tensor([1, 1, 0, 1, 1, 0, 1]))