### Step 0: Have some drawing function to visualize what you're doing...

In [3]:
import matplotlib.pyplot as plt
import networkx as nx
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx


IDX2COLOR = {
    0: ('R', 'red'),
    1: ('G', 'green'),
    2: ('B', 'blue'),
    3: ('W', 'white'),
    4: ('X', 'gray'),
}w

COLOR2IDX = {color_word: idx for idx, (color_char, color_word) in IDX2COLOR.items()}


def draw_color_graph(graph: Data, is_retrieved=False):
    G = to_networkx(graph, node_attrs=['x'], to_undirected=True)

    # Try planar layout if the graph is planar
    if nx.check_planarity(G)[0]:
        pos = nx.planar_layout(G)
    else:
        pos = nx.spring_layout(G, seed=7)

    node_labels = []
    node_colors = []

    for i in range(graph.num_nodes):
        if is_retrieved:
            color = getattr(graph, "node_color")[i]
            label = color[0].capitalize()
        else:
            idx =  int(graph.x[i].item())
            label, color = IDX2COLOR[idx]
        node_labels.append(label)
        node_colors.append(color)

    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=800, edgecolors='black', linewidths=2.5)
    nx.draw_networkx_edges(G, pos, width=3.5)

    label_dict = {i: node_labels[i] for i in range(graph.num_nodes)}
    nx.draw_networkx_labels(G, pos, labels=label_dict, font_size=16, font_color='black')

    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=color, edgecolor='black', label=label) for idx, (label, color) in
                       IDX2COLOR.items()]
    plt.legend(handles=legend_elements, loc='upper left', frameon=True)

    plt.axis('off')
    plt.tight_layout()
    plt.show()

### Step 1: Now let's create a basic simple dataset using many data objects.

In [None]:
from torch_geometric.utils import to_undirected
import random

from torch_geometric.data import InMemoryDataset
import matplotlib.colors as mcolors
from torch_geometric.data import Data


class HashableData(Data):
    def __eq__(self, other):
        if not isinstance(other, Data):
            return False
        return (
                torch.equal(self.x, other.x) and
                torch.equal(self.edge_index, other.edge_index)
        )

    def __hash__(self):
        x_hash = hash(self.x.cpu().numpy().tobytes())
        ei_hash = hash(self.edge_index.cpu().numpy().tobytes())
        return hash((x_hash, ei_hash))


torch.serialization.add_safe_globals([HashableData])


class ColorGraphDataset(InMemoryDataset):
    def __init__(self,
                 root,
                 num_graphs=100,
                 min_nodes=5,
                 max_nodes=10,
                 edge_p=0.15,
                 transform=None,
                 pre_transform=None,
                 use_rgb = True,
                 ):
        self.num_graphs, self.min_nodes, self.max_nodes = num_graphs, min_nodes, max_nodes
        self.edge_p = edge_p
        self.use_rgb = use_rgb
        super().__init__(root, transform, pre_transform)
        self.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        motif_centric_collections = []
        motif_centrics = set()
        motifs = [self.triangle_motif, self.ring_motif]
        for i, motif_func in enumerate(motifs):
            # Ensure Evenly distributed motifs
            while len(motif_centrics) < self.num_graphs // len(motifs):
                base_colors, base_edges = motif_func()
                extra_nodes = random.randint(1, 5)
                total_colors, total_edges = self.expand_graph_planar(base_colors, base_edges, extra_nodes, )

                # Add node features
                x = torch.tensor([COLOR2IDX[c] for c in total_colors], dtype=torch.float)

                # Add edges
                edge_index = torch.tensor(total_edges, dtype=torch.long).t().contiguous()
                edge_index = to_undirected(edge_index, num_nodes=len(total_colors))

                data = HashableData(x=x, edge_index=edge_index)
                data.validate(raise_on_error=True)
                motif_centrics.add(data)

            motif_centric_collections.append(motif_centrics)
            motif_centrics = set()

        motif_centric_collections = [g for mfs in motif_centric_collections for g in mfs]
        if self.pre_filter is not None:
            motif_centric_collections = [graph for part in motif_centric_collections for graph in part if
                                         self.pre_filter(graph)]
        if self.pre_transform is not None:
            motif_centric_collections = [self.pre_transform(d) for d in motif_centric_collections]

        self.save(motif_centric_collections, self.processed_paths[0])

    @staticmethod
    def ring_motif():
        colors = ['red', 'green', 'red', 'green', 'red']
        nodes = list(range(5))
        edges = [(i, (i + 1) % 5) for i in nodes]
        return colors, edges

    @staticmethod
    def triangle_motif():
        colors = ['red', 'green', 'blue']
        edges = [(0, 1), (1, 2), (2, 0)]
        return colors, edges

    @staticmethod
    def expand_graph_planar(base_colors, base_edges, num_extra_nodes, max_attempts=100):
        n0 = len(base_colors)
        G = nx.Graph()
        G.add_nodes_from(range(n0))
        G.add_edges_from(base_edges)

        colors = base_colors[:]
        edges = base_edges[:]

        for i in range(num_extra_nodes):
            new_node_id = n0 + i
            G.add_node(new_node_id)
            new_color = random.choice(list(set(COLOR2IDX.keys()).difference(set(base_colors))))
            colors.append(new_color)

            attempts = 0
            while attempts < max_attempts:
                # Connect to 1–3 existing nodes
                k = random.randint(1, min(3, n0 + i))
                targets = random.sample(sorted(G.nodes - {new_node_id}), k)
                trial_edges = [(new_node_id, t) for t in targets]

                G.add_edges_from(trial_edges)
                is_planar, _ = nx.check_planarity(G)
                if is_planar:
                    edges.extend(trial_edges)
                    break
                else:
                    G.remove_edges_from(trial_edges)
                    attempts += 1

        return colors, edges

Use the generated Dataset

In [None]:
from pathlib import Path

root = Path.cwd()
project_dir = root.resolve().parent
datasets = project_dir / "datasets"

In [None]:
# usage
root = datasets / 'test' / 'ColorPlanar05'
dataset = ColorGraphDataset(root=root, num_graphs=100)
print(len(dataset))
print(draw_color_graph(dataset[0]))
print(draw_color_graph(dataset[1]))
print(draw_color_graph(dataset[98]))
print(draw_color_graph(dataset[99]))

### Step 3: Test Encoding Retrieval Process

In [None]:

from graph_hdc.special.colors import NominalColorEncoder
from graph_hdc.models import HyperNet
from torch_geometric.loader import DataLoader

dim = 1000
hyper_net = HyperNet(
    hidden_dim=dim,
    depth=2,
    node_encoder_map={
        # 'node_color': ColorEncoder(dim, list(COLOR2IDX.keys()), seed=7),
        'node_color': NominalColorEncoder(dim, list(COLOR2IDX.keys()), seed=7),
    },
    seed=7
)

# The forward take DataBatch as input
data_batch = next(iter(DataLoader(dataset, batch_size=1)))
setattr(data_batch, 'node_color', data_batch.x)
print(draw_color_graph(data_batch))

# Convert the original graph dict to a PyG data object and compute the graph embedding
result = hyper_net.forward(data_batch)
graph_embedding = result['graph_embedding']


In [None]:
# Reconstruct graph dict from the graph hypervector
rec_dict = hyper_net.reconstruct(
    graph_embedding,
    learning_rate=1.0,
    num_iterations=10,
    batch_size=10,
    low=0.0,
    high=1.0,
)
rec_graph = Data(x=torch.tensor(rec_dict['node_indices'], dtype=torch.float),
                 edge_index=torch.tensor(rec_dict['edge_indices'].T, dtype=torch.long))
setattr(rec_graph, 'node_color', rec_dict['node_color'])
rec_graph

In [None]:
draw_color_graph(data_batch)
draw_color_graph(rec_graph, is_retrieved=True)