In [1]:
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

from data.dataset import BCDataset, SubgraphDataset

import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
import networkx as nx
import random

In [2]:
idata = BCDataset(type='elliptic', path='datasets/elliptic')
print(idata.features.shape, idata.labels.shape, idata.edge_index.shape)

sub_data = SubgraphDataset(type='elliptic', path='datasets/elliptic')
print(sub_data.features.shape, sub_data.labels.shape, sub_data.edge_index.shape)

data = sub_data.to_torch_data()

data

torch.Size([203769, 166]) torch.Size([203769]) torch.Size([2, 468710])
Extracted 46564 subgraphs with max hops=6 and max nodes=None
46564 46564 46564 46564 46564
torch.Size([46564, 165]) torch.Size([46564]) torch.Size([2, 0])


Data(x=[46564, 164], edge_index=[2, 0], y=[46564], train_mask=[46564], val_mask=[46564], test_mask=[46564])

In [3]:
full = idata.to_torch_data()
full

Data(x=[203769, 165], edge_index=[2, 468710], y=[203769], train_mask=[203769], val_mask=[203769], test_mask=[203769])

In [None]:
def visualize_connected_subgraph(data, max_nodes=25, max_edges=40):
    """
    Extract a connected subgraph with up to max_nodes and max_edges using BFS,
    visualize it, and return a PyTorch Geometric Data object.

    Args:
        data (torch_geometric.data.Data): Input graph data.
        max_nodes (int): Maximum number of nodes in the subgraph (default: 25).
        max_edges (int): Maximum number of edges in the subgraph (default: 40).

    Returns:
        Data: A PyTorch Geometric Data object representing the subgraph.
    """
    # Convert data to NetworkX graph
    G = to_networkx(data, to_undirected=True) # Ensure the graph is undirected or directed as needed
    max_attempts = 100

    for _ in range(max_attempts):
        start_node = random.choice(list(G.nodes))
        visited = set([start_node])
        queue = [start_node]
        edges = set()

        while queue and len(visited) < max_nodes and len(edges) < max_edges:
            current = queue.pop(0)
            for neighbor in G.neighbors(current):
                if neighbor not in visited and len(visited) < max_nodes:
                    visited.add(neighbor)
                    queue.append(neighbor)
                if current in visited and neighbor in visited:
                    edge = (current, neighbor)
                    edges.add(edge)
                if len(edges) >= max_edges:
                    break

        if len(visited) >= 2 and len(edges) > 0:
            break

    final_nodes_ids = list(visited)

    if hasattr(data, 'node_id') and data.node_id is not None:
        node_id_to_index = {nid.item(): idx for idx, nid in enumerate(data.node_id)}
    else:
        unique_nodes = torch.unique(data.edge_index)
        node_id_to_index = {nid.item(): idx for idx, nid in enumerate(unique_nodes)}

    try:
        final_nodes_indices = [node_id_to_index[nid] for nid in final_nodes_ids]
    except KeyError as e:
        raise ValueError(f"Node ID {e} from NetworkX graph not found in data.node_id or data.edge_index. "
                         "Ensure node IDs match between data and graph.")

    node_map = {nid: i for i, nid in enumerate(final_nodes_ids)}
    filtered_edges = [
        [node_map[u], node_map[v]] for u, v in edges
        if u in node_map and v in node_map
    ]

    G_sub = nx.Graph()
    G_sub.add_nodes_from(range(len(final_nodes_indices)))
    G_sub.add_edges_from(filtered_edges)

    plt.figure(figsize=(10, 8))
    pos = nx.spring_layout(G_sub)
    nx.draw(G_sub, pos, with_labels=True, node_color='lightblue', node_size=500, 
            font_size=10, edge_color='gray', width=2)
    plt.title("Connected Subgraph Visualization")
    plt.show()

    x = (data.x[final_nodes_indices] if hasattr(data, 'x') and data.x is not None 
         else torch.eye(len(final_nodes_indices)))
    edge_index = torch.tensor(filtered_edges, dtype=torch.long).t().contiguous()

    edge_attr = None
    if hasattr(data, 'edge_attr') and data.edge_attr is not None:
        edge_attr_list = []
        for u, v in edges:
            mask = ((data.edge_index[0] == u) & (data.edge_index[1] == v)) | \
                   ((data.edge_index[0] == v) & (data.edge_index[1] == u))
            edge_idx = torch.where(mask)[0]
            if len(edge_idx) > 0:
                edge_attr_list.append(data.edge_attr[edge_idx[0]])
        if edge_attr_list:
            edge_attr = torch.stack(edge_attr_list)
            
    subgraph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    if hasattr(data, 'y') and data.y is not None:
        subgraph.y = data.y[final_nodes_indices]

    return subgraph

In [None]:
sub = visualize_connected_subgraph(data, max_nodes=10, max_edges=20)

In [None]:
sub