In [19]:
# get obgn-arxiv dataset
# https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv
from ogb.nodeproppred import PygNodePropPredDataset

dataset = PygNodePropPredDataset(name='ogbn-arxiv')

In [20]:
dataGraph = dataset[0]

print(f"number of graphs: {len(dataset)}")
print(f"number of features: {dataset.num_features}")
print(f"number of classes: {dataset.num_classes}")
print(f"number of nodes: {dataGraph.num_nodes}")
print(f"number of edges: {dataGraph.num_edges}")

number of graphs: 1
number of features: 128
number of classes: 40
number of nodes: 169343
number of edges: 1166243


In [21]:
# get the first node and print its features and label
node_idx = 0
print(f"node feature size: {dataGraph.x[node_idx].size()}")
print(f"node feature: {dataGraph.x[node_idx]}")
print(f"node label: {dataGraph.y[node_idx]}")

node feature size: torch.Size([128])
node feature: tensor([-0.0579, -0.0525, -0.0726, -0.0266,  0.1304, -0.2414, -0.4492, -0.0184,
        -0.0872,  0.1123, -0.0921, -0.2896, -0.0810,  0.0745, -0.1562, -0.0974,
         0.1194,  0.6458,  0.0774, -0.0939, -0.4004,  0.3114, -0.5418,  0.0805,
        -0.0069,  0.5423, -0.0122, -0.1808,  0.0165,  0.0508, -0.2083, -0.0870,
         0.0124,  0.2817,  0.1004, -0.1643,  0.0269,  0.0782,  0.0795, -0.0134,
         0.2915,  0.0416, -0.1414, -0.1345,  0.0162,  0.2810, -0.0919, -0.2403,
         0.4618,  0.1873,  0.1533,  0.0331,  0.0108,  0.0124, -0.1589,  0.0980,
         0.0305,  0.0162, -0.0957,  0.0521,  0.3218, -0.1057,  0.2229, -0.1206,
        -0.1723,  0.3954,  0.0883, -0.2219,  0.2310, -0.2096, -0.1125, -0.0644,
         0.0697, -0.1574,  0.0223, -0.4190,  0.1344,  0.2605,  0.0417, -0.0935,
        -0.0516, -0.0255,  0.7744,  0.0581,  0.0452,  0.0571, -0.5482, -0.0464,
         0.8728,  0.0119,  0.3891, -0.0859,  0.1116,  0.0618,  0.0015

In [22]:
# get a node and its neighbors
def get_neighbors(node_idx):
    return dataGraph.edge_index[1][dataGraph.edge_index[0] == node_idx]

def get_to_neighbors(node_idx):
    return dataGraph.edge_index[0][dataGraph.edge_index[1] == node_idx]

def get_from_neighbors(node_idx):
    return dataGraph.edge_index[1][dataGraph.edge_index[0] == node_idx]

node_idx = 0
print(f"node {node_idx} has {get_neighbors(node_idx).size(0)} neighbors")
print(f"node {node_idx} neighbors: {get_neighbors(node_idx)}")



node 0 has 2 neighbors
node 0 neighbors: tensor([93487, 52893])


In [24]:
# import HDC stuff
from hdc.bsc import BSC as HDC
from hdc.hv import HDCRepresentation
from hdc.itemmem import ItemMem
from hdc.encoder import Encoder
from hdc.levelencoder import LevelEncoder
from tqdm import tqdm
from typing import Type

hdc = HDC
N = 10240
encoder = LevelEncoder(hdc, N, -1, 1, 0.1)

num_intervals: 20
num_bits: 512


In [25]:
def encode_node(node_idx):
    feature_hvs = []
    for feature in dataGraph.x[node_idx]:
        feature_hvs.append(encoder.encode(feature))
    hv = hdc.sequence(feature_hvs)
    return hv

def encode_nodes():
    node_hvs = []
    for node_idx in tqdm(range(dataGraph.num_nodes), desc="Encoding nodes"):
        node_hvs.append(encode_node(node_idx))
    return node_hvs

In [26]:
node_hvs = encode_nodes()

Encoding nodes: 100%|██████████| 169343/169343 [09:12<00:00, 306.65it/s]


In [68]:
def encode_instance(node_idx):
    hv = node_hvs[node_idx]
    label = dataGraph.y[node_idx]
    neighbors_hvs = []
    for neighbor_idx in get_neighbors(node_idx):
        neighbors_hvs.append(node_hvs[neighbor_idx])
    for neighbor_idx in get_to_neighbors(node_idx):
        neighbors_hvs.append(node_hvs[neighbor_idx])
    if len(neighbors_hvs) == 0:
        return hv
    neighbor_hv = hdc.bundle(neighbors_hvs)
    return hdc.sequence([hv, neighbor_hv]) # [hv, neighbor_hv, neightbor_hv]

In [53]:
def learn_class():
    class_hvs = []
    for class_idx in tqdm(range(dataset.num_classes), desc="Learning classes"):
        all_instance_hvs_in_class = []
        all_node_idx_in_class = (dataGraph.y == class_idx).nonzero(as_tuple=True)[0]
        for node_idx in all_node_idx_in_class:
            all_instance_hvs_in_class.append(encode_instance(node_idx))
        class_hvs.append(hdc.bundle(all_instance_hvs_in_class))
    return class_hvs

In [54]:
def query(hv, class_hvs):
    min_dist = float("inf")
    pred = None
    for class_idx in range(len(class_hvs)):
        dist = hdc.dist(hv, class_hvs[class_idx])
        if dist < min_dist:
            min_dist = dist
            pred = class_idx
    return pred

In [69]:
class_hvs = learn_class()

Learning classes: 100%|██████████| 40/40 [05:40<00:00,  8.50s/it]


In [65]:
import torch

def rand_test(K):
    # randomly select K nodes
    rand_node_idx = torch.randint(0, dataGraph.num_nodes, (K,))
    # test model performance
    correct = 0
    for node_idx in rand_node_idx:
        hv = encode_instance(node_idx)
        pred = query(hv, class_hvs)
        if pred == dataGraph.y[node_idx].item():
            correct += 1
        else:
            print(f"node {node_idx} prediction: {pred}, label: {dataGraph.y[node_idx]}")
    return correct / K

In [70]:
acc = rand_test(1000)

print("len class_hvs: ", len(class_hvs))
print("len node_hvs: ", len(node_hvs))
print("acc: ", acc)

node 63017 prediction: 31, label: tensor([27])
node 2610 prediction: 5, label: tensor([34])
node 25226 prediction: 15, label: tensor([24])
node 162865 prediction: 21, label: tensor([36])
node 164848 prediction: 21, label: tensor([5])
node 65013 prediction: 12, label: tensor([9])
node 102745 prediction: 24, label: tensor([10])
node 110844 prediction: 38, label: tensor([26])
node 82473 prediction: 12, label: tensor([10])
node 47286 prediction: 12, label: tensor([28])
node 158584 prediction: 12, label: tensor([10])
node 58665 prediction: 21, label: tensor([34])
node 105943 prediction: 12, label: tensor([21])
node 52079 prediction: 12, label: tensor([2])
node 86256 prediction: 27, label: tensor([24])
node 17225 prediction: 21, label: tensor([24])
node 116970 prediction: 12, label: tensor([27])
node 111932 prediction: 12, label: tensor([24])
node 5059 prediction: 12, label: tensor([30])
node 122296 prediction: 21, label: tensor([26])
node 169080 prediction: 21, label: tensor([30])
node 1136