In [None]:
from transformers import AutoTokenizer


In [8]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", use_fast=True)

In [None]:
tokenizer("Hello, world!")

In [None]:
from langgfm.data.graph_generator import InputGraphGenerator
generator = InputGraphGenerator.create("oag_scholar_interest")
G, metadata = generator.generate_graph(sample_id=1)

In [None]:
G.nodes(data=True)

In [None]:
generator = InputGraphGenerator.create("oag_scholar_interest")

In [None]:
generator.hetero_graph["author", "writes", "paper"].edge_index

In [None]:
G, metadata = generator.generate_graph(sample_id=1)

In [None]:
G.nodes(data=True)

In [None]:
from langgfm.data.graph_generator.utils.sampling import generate_node_centric_k_hop_subgraph

In [None]:
_,nodes,_ = generate_node_centric_k_hop_subgraph(generator.graph, sample_id=1,num_hops=1,neighbor_size=10,random_seed=42,sampling=True)

In [None]:
_,nodes,_ = generate_node_centric_k_hop_subgraph(generator.graph, sample_id=nodes,num_hops=1,neighbor_size=10,random_seed=42,sampling=True)

In [None]:
generator.graph

In [None]:
generator.graph.edge_index.T[torch.logical_or(generator.graph.edge_index[0]==1,generator.graph.edge_index[1]==3180866)].T

In [None]:
generator.graph.edge_index.T[generator.graph.edge_index[1]==3180866]

In [51]:
from torch_geometric.utils import k_hop_subgraph
# Generate k-hop subgraph without sampling
src_to_tgt_subset, src_to_tgt_edge_index, _, src_to_tgt_edge_mask = k_hop_subgraph(
    node_idx=1, num_hops=2, edge_index=generator.graph.edge_index,
    relabel_nodes=False, flow='source_to_target', directed=False
)

In [None]:
import torch
# edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],
#                            [2, 2, 4, 4, 6, 6]])
edge_index = torch.tensor([[2, 2, 4, 4, 6, 6],
                            [0, 1, 2, 3, 4, 5]])
# subset, edge_index, mapping, edge_mask = k_hop_subgraph(
#     6, 2, edge_index, flow='source_to_target', num_nodes=None,)
subset, edge_index, mapping, edge_mask = k_hop_subgraph(
    6, 2, edge_index, flow='target_to_source', num_nodes=None,)
print(subset)

In [None]:
# edge_index[1]

In [None]:
import torch
import random
from typing import Optional, Union, List

def to_undirected(edge_index: torch.Tensor) -> torch.Tensor:
    """
    Converts a directed graph to an undirected one by adding reverse edges.
    For each edge (u, v), it also adds (v, u).
    If your graph is already undirected, you can return the original edge_index directly.
    """
    row, col = edge_index
    reversed_edges = torch.stack([col, row], dim=0)
    return torch.cat([edge_index, reversed_edges], dim=1)

def build_csr(edge_index: torch.Tensor, num_nodes: int):
    """
    Builds a CSR-like structure from edge_index (2, E):
      - row_ptr[i] gives the start position of node i's neighbors in col_ind
      - col_ind[row_ptr[i]: row_ptr[i+1]] contains all neighbors of node i
    The input edge_index is assumed to be undirected or will be converted by calling to_undirected().
    """
    # Convert to undirected (if your original graph is directed)
    edge_index_undirected = to_undirected(edge_index)

    row, col = edge_index_undirected
    # Sort edges by row so that all edges from the same source node are contiguous
    sorted_idx = row.argsort()
    row = row[sorted_idx]
    col = col[sorted_idx]

    # Build row_ptr: the cumulative count of edges for each node
    row_counts = torch.bincount(row, minlength=num_nodes)
    row_ptr = torch.zeros(num_nodes + 1, dtype=torch.long)
    row_ptr[1:] = torch.cumsum(row_counts, dim=0)

    return row_ptr, col

def get_khop_subgraph(
    edge_index: torch.Tensor,
    node_idx: Union[int, List[int], torch.Tensor],
    num_hops: int,
    max_neighbors_per_hop: Optional[Union[int, List[int]]] = None,
    sampling: bool = False,
    random_seed: Optional[int] = None
):
    """
    Performs a layer-wise BFS up to num_hops steps from the seed node(s) and optionally samples neighbors per node.

    Parameters
    ----------
    edge_index : torch.Tensor, shape (2, E)
        The edge list of the graph. If the graph is directed, it should be converted first or handled by to_undirected().
    node_idx : int, list[int], or 1D torch.Tensor[int]
        The starting node(s) for BFS.
    num_hops : int
        The maximum number of BFS layers.
    max_neighbors_per_hop : None or int or list[int], optional
        - When sampling=True, this parameter controls the maximum number of neighbors retained per node per hop.
          * If an int is given, the same limit is used for every hop.
          * If a list[int] is given, each hop uses a potentially different limit. For example, [15, 5] means:
            - For the 1st hop, each node keeps at most 15 neighbors.
            - For the 2nd hop, each node keeps at most 5 neighbors.
          * If None, no limit is applied, keeping all neighbors.
        - When sampling=False, this parameter is ignored and all neighbors are kept.
    sampling : bool
        Whether to perform neighbor sampling (True) or not (False).
    random_seed : int or None
        If set, a random seed is applied (both Python and PyTorch) for reproducible sampling.

    Returns
    -------
    sub_graph_nodes : set[int]
        The set of nodes visited by BFS, intersected with nodes from edge_index[0].
    sub_graph_edge_mask : torch.BoolTensor of shape (E,)
        A boolean mask for the original edges. True means the edge belongs to the k-hop subgraph (both endpoints visited).

    Notes
    -----
    This approach uses a BFS with a frontier set of nodes at each hop. When sampling=True, each node's neighbor list
    is randomly truncated (if longer than the specified limit). Otherwise, the full neighbor list is used.
    """
    # 1) Convert node_idx into a list of integers
    if isinstance(node_idx, int):
        start_nodes = [node_idx]
    elif isinstance(node_idx, list):
        start_nodes = node_idx
    elif isinstance(node_idx, torch.Tensor):
        start_nodes = node_idx.tolist()
    else:
        raise TypeError(f"Unsupported type for node_idx: {type(node_idx)}")

    # 2) Set random seed if specified
    if random_seed is not None:
        random.seed(random_seed)
        torch.manual_seed(random_seed)

    # 3) Determine the number of nodes in the graph
    all_nodes = torch.cat([edge_index[0], edge_index[1]], dim=0)
    num_nodes = int(all_nodes.max().item()) + 1

    # 4) Build the CSR-like structure
    row_ptr, col_ind = build_csr(edge_index, num_nodes)

    # 5) Initialize BFS
    visited = set(start_nodes)
    frontier = set(start_nodes)

    # Helper function: get the sampling limit for the current hop
    def get_limit_for_hop(hop_idx: int) -> Optional[int]:
        # If sampling=False, no limit
        if not sampling:
            return None

        # If sampling=True but max_neighbors_per_hop is None, keep all
        if max_neighbors_per_hop is None:
            return None

        # If max_neighbors_per_hop is a single int, apply it to all hops
        if isinstance(max_neighbors_per_hop, int):
            return max_neighbors_per_hop

        # If max_neighbors_per_hop is a list, index by hop
        if hop_idx < len(max_neighbors_per_hop):
            return max_neighbors_per_hop[hop_idx]
        else:
            # If out of range, we could return the last or None
            return max_neighbors_per_hop[-1]

    # 6) Layer-wise BFS
    for hop in range(num_hops):
        if not frontier:
            break
        next_frontier = set()
        limit_this_hop = get_limit_for_hop(hop)

        for node in frontier:
            u = int(node)
            start = row_ptr[u].item()
            end = row_ptr[u+1].item()
            neighbors = col_ind[start:end]

            # If we have a limit for neighbors, randomly sample
            if limit_this_hop is not None and len(neighbors) > limit_this_hop:
                perm = torch.randperm(len(neighbors))[:limit_this_hop]
                neighbors = neighbors[perm]

            for nbr in neighbors.tolist():
                if nbr not in visited:
                    visited.add(nbr)
                    next_frontier.add(nbr)

        frontier = next_frontier

    # 7) Build subgraph nodes and edge mask
    # Intersect visited nodes with the nodes in edge_index[0]
    row0_nodes = set(edge_index[0].tolist())
    sub_graph_nodes = visited.intersection(row0_nodes)

    # Create a boolean mask for edges whose both endpoints are visited
    visited_mask = torch.zeros(num_nodes, dtype=torch.bool)
    for n in visited:
        visited_mask[n] = True

    edge_u = edge_index[0]
    edge_v = edge_index[1]
    sub_graph_edge_mask = visited_mask[edge_u] & visited_mask[edge_v]

    return sub_graph_nodes, sub_graph_edge_mask



# Example usage
edge_index = torch.tensor([
    [0,       1,       2,       3,       4,       7127060,  5],
    [3180866, 3180866, 3180866, 3180866, 3180866, 3180866,  199811]
])
node_idx = [1,5]
num_hops = 2

print("==== sampling=False ====")
sub_nodes_no_sample, sub_mask_no_sample = get_khop_subgraph(
    edge_index, 
    node_idx, 
    num_hops,
    max_neighbors_per_hop=[15, 5],
    sampling=False,       # Disable sampling
    random_seed=42
)
print("sub_graph_nodes =", sub_nodes_no_sample)
print("sub_graph_edge_mask =", sub_mask_no_sample)

print("==== sampling=True ====")
sub_nodes_sample, sub_mask_sample = get_khop_subgraph(
    edge_index, 
    node_idx, 
    num_hops,
    max_neighbors_per_hop=[3, 2],
    sampling=True,        # Enable sampling
    random_seed=0
)
print("sub_graph_nodes =", sub_nodes_sample)
print("sub_graph_edge_mask =", sub_mask_sample)


# # edge_index = generator.graph.edge_index.T[torch.logical_or(generator.graph.edge_index[0]==1,generator.graph.edge_index[1]==3180866)].T
# # edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)
# # print(edge_index)
# edge_index = torch.tensor([[ 0,       1,       2,       3,       4, 7127060],
#                             [3180866, 3180866, 3180866, 3180866, 3180866, 3180866]])
# # subset, edge_index, mapping, edge_mask = k_hop_subgraph(
# #     1, 2, edge_index,flow="target_to_source")
# subset, _, _, edge_mask = k_hop_subgraph(
#     1, 3, edge_index,flow="target_to_source",directed=False)
# print(subset)
# # subset, _, mapping, edge_mask = k_hop_subgraph(
# #     [1], 2, edge_index,flow="source_to_target",directed=False)
# # print(subset)

In [None]:
src_to_tgt_subset

In [53]:
tgt_to_src_subset, tgt_to_src_edge_index, _, tgt_to_src_edge_mask = k_hop_subgraph(
            node_idx=1, num_hops=2, edge_index=generator.graph.edge_index,
            relabel_nodes=False, flow='target_to_source', directed=False
)

In [None]:
tgt_to_src_subset

In [None]:
generate_node_centric_k_hop_subgraph(graph=generator.graph, sample_id=1, num_hops=2)