# Simple GraphHD Encoder/Decoder

In [1]:
import torch
import torchhd
from torch_geometric.data import Data

class GraphHDEncoder:
    def __init__(self, dim=10000, levels=256):
        self.dim = dim
        self.levels = levels
        self.level_space = torchhd.level(levels, dim)  # for continuous features

    def encode(self, data: Data):
        num_nodes = data.num_nodes

        # Random hypervectors for each node (symbolic identity)
        self.node_ids = torchhd.random(num_nodes, self.dim)

        # Encode node features
        # Assumes features in [0, 1] and continuous
        node_values = (data.x * (self.levels - 1)).long()  # shape: (num_nodes, num_features)
        feature_vectors = self.level_space[node_values]     # shape: (num_nodes, num_features, dim)
        feature_encodings = feature_vectors.sum(dim=1)      # sum over features => (num_nodes, dim)

        # Bind each node ID with its encoded features
        node_encodings = torchhd.bind(self.node_ids, feature_encodings)

        # Sum to get graph hypervector
        graph_hv = node_encodings.sum(dim=0)  # (dim,)

        # Optionally include edge info (purely symbolic)
        if hasattr(data, 'edge_index'):
            src, dst = data.edge_index
            edge_encodings = torchhd.bind(self.node_ids[src], self.node_ids[dst])
            graph_hv += edge_encodings.sum(dim=0)

        return graph_hv

In [6]:
class GraphHDDecoder:
    def __init__(self, dim=10000, levels=256):
        self.dim = dim
        self.levels = levels
        self.level_space = torchhd.level(levels, dim)  # must match encoder

    def decode_node_features(self, graph_hv: torch.Tensor, node_ids: torch.Tensor):
        # Unbind each node ID from the graph hypervector
        estimates = torchhd.bind(graph_hv.unsqueeze(0), torchhd.inverse(node_ids))  # (num_nodes, dim)

        decoded_values = []
        for est in estimates:
            sim = torchhd.cosine_similarity(est.unsqueeze(0), self.level_space)
            decoded_val = sim.argmax().item() / (self.levels - 1)
            decoded_values.append(decoded_val)

        return torch.tensor(decoded_values).unsqueeze(1)  # shape: (num_nodes, 1)


In [7]:
    # Toy example
x = torch.tensor([[0.1], [0.5], [0.9]])  # node features normalized to [0, 1]
edge_index = torch.tensor([[0, 1], [1, 2]])
data = Data(x=x, edge_index=edge_index.t())

encoder = GraphHDEncoder(dim=10000)
decoder = GraphHDDecoder(dim=10000)

graph_hv = encoder.encode(data)
decoded = decoder.decode_node_features(graph_hv, encoder.node_ids)

print("Original x:")
print(x)
print("Decoded x:")
print(decoded)  # should be close to original x

Original x:
tensor([[0.1000],
        [0.5000],
        [0.9000]])
Decoded x:
tensor([[0.9569],
        [0.3412],
        [0.1373]])
