In [13]:
# get obgn-arxiv dataset
# https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv
from ogb.nodeproppred import PygNodePropPredDataset

dataset = PygNodePropPredDataset(name='ogbn-arxiv')

In [14]:
dataGraph = dataset[0]

print(f"number of graphs: {len(dataset)}")
print(f"number of features: {dataset.num_features}")
print(f"number of classes: {dataset.num_classes}")
print(f"number of nodes: {dataGraph.num_nodes}")
print(f"number of edges: {dataGraph.num_edges}")

number of graphs: 1
number of features: 128
number of classes: 40
number of nodes: 169343
number of edges: 1166243


In [15]:
# HDC stuff
from hdc.bsc import BSC as HDC
from hdc.hv import HDCRepresentation
from hdc.itemmem import ItemMem
from hdc.encoder import Encoder
from hdc.levelencoder import LevelEncoder
from tqdm import tqdm
from typing import Type

hdc = HDC
N = 10240
encoder = LevelEncoder(hdc, N, -1, 1, 0.05)

num_intervals: 40
num_bits: 256


In [16]:
class HighResItemMem(ItemMem):
    def __init__(self, hdc: type[HDCRepresentation], lr = 1) -> None:
        super().__init__(hdc)
        self.lr = lr
    
    def cache(self, key, hv):
        key = int(key)
        if key not in self.caches:
            self.caches[key] = [np.zeros(len(hv)), 0]
        if key not in self.mem:
            sim = 0
        else:
            sim = self.hdc.dist(self.mem[key], hv)
        self.caches[key][0] += hv * (1 - sim) * self.lr
        self.caches[key][1] +=(1 - sim) * self.lr
        
    def decache(self, key, hv):
        key = int(key)
        if key not in self.caches:
            self.caches[key] = [np.zeros(len(hv)), 0]
        if key not in self.mem:
            sim = 0
        else:
            sim = self.hdc.dist(self.mem[key], hv)
        self.caches[key][0] -= hv * (1 - sim) * self.lr
        self.caches[key][1] -= (1 - sim) * self.lr
    
    def build(self):
        for key, cache_line in self.caches.items():
            new_mem = self.hdc.normalize(cache_line[0] / cache_line[1])
            self.mem[key] = new_mem if cache_line[1] > 0 else np.logical_not(new_mem)


In [17]:
id_hvs = [hdc.random_hypervector(N) for _ in range(dataGraph.num_features)]

In [18]:
from torch_geometric.utils import k_hop_subgraph
import networkx as nx
import numpy as np
import torch

def extract_subgraph(dataGraph, node_idx, depth):
    # get the subgraph around the node
    # node_idx: index of the node
    # depth: depth of the subgraph
    # return: subgraph
    # the subgraph is all nodes and edges within depth of the node
    # depth = 0 means only the node itself
    # depth = 1 means the node and its neighbors
    # depth = 2 means the node and its neighbors and their neighbors
    # etc.
    subset, edge_index, _, _ = k_hop_subgraph(node_idx, depth, dataGraph.edge_index)
    return subset, edge_index

# get the subgraph around the node
node_idx = 0
depth = 1
subnodes, edge_index = extract_subgraph(dataGraph, node_idx, depth)
print(subnodes.shape)
print(edge_index.shape)

torch.Size([290])
torch.Size([2, 1672])


In [19]:
def encode_instance(features):
    feature_hvs = []
    for i in range(len(features)):
        f = hdc.bind([id_hvs[i], encoder.encode(features[i])])
        feature_hvs.append(f)
    hv = hdc.bundle(feature_hvs)
    return hv

def encode_raw_nodes(dataGraph):
    node_hvs = []
    for i in tqdm(range(dataGraph.num_nodes)):
        hv = encode_instance(dataGraph.x[i])
        node_hvs.append(hv)
    return node_hvs

In [20]:
def encode_edges(node_hvs, edge_index):
    edge_hvs = []
    for i in range(edge_index.shape[1]):
        src = edge_index[0][i]
        dst = edge_index[1][i]
        hv = hdc.bind([node_hvs[src], node_hvs[dst]])
        edge_hvs.append(hv)
    return edge_hvs

In [47]:
# learning
# for each node, encode the subgraph around it
def learn(node_hvs, dataGraph, depth, lr = 0.1, exisiting_mem = None):
    itemmem = HighResItemMem(hdc, lr) if exisiting_mem is None else exisiting_mem
    for i in tqdm(range(dataGraph.num_nodes)):
        label = dataGraph.y[i].item()
        _, edge_index = extract_subgraph(dataGraph, i, depth)
        if edge_index.shape[1] == 0:
            continue
        edge_hvs = encode_edges(node_hvs, edge_index)
        hv = hdc.bundle(edge_hvs)
        if len(hv) == N:
            # Test if the item is already in the mem
            itemmem.build()
            if label in itemmem.caches.keys():
                pred = itemmem.query(hv)
                if not pred or pred != label:
                    # If mispredicted, add to label cache
                    itemmem.cache(label, hv)
                    # subtract from the mispredicted label cache
                    itemmem.decache(pred, hv)
                # if correct, do nothing
            else:
                # never seen this label before, add to label cache
                itemmem.cache(label, hv)
    itemmem.build()
    return itemmem

In [22]:
def rand_test(node_hvs, dataGraph, itemmem, depth):
    rand_indices = np.random.choice(dataGraph.num_nodes, 1000, replace=False)
    correct = 0
    skipped = 0
    for i in tqdm(rand_indices):
        label = dataGraph.y[i].item()
        _, edge_index = extract_subgraph(dataGraph, int(i), depth)
        if edge_index.shape[1] == 0:
            skipped += 1
            continue
        edge_hvs = encode_edges(node_hvs, edge_index)
        hv = hdc.bundle(edge_hvs)
        pred = itemmem.query(hv)
        if pred == label:
            correct += 1
    return correct / (len(rand_indices) - skipped)

In [23]:
node_hvs = encode_raw_nodes(dataGraph)

  0%|          | 0/169343 [00:00<?, ?it/s]

100%|██████████| 169343/169343 [08:23<00:00, 336.05it/s]


In [28]:
mem1 = learn(node_hvs, dataGraph, 1)

100%|██████████| 169343/169343 [13:33<00:00, 208.28it/s]


In [34]:
mem2 = learn(node_hvs, dataGraph, 1, 0.01)

100%|██████████| 169343/169343 [13:52<00:00, 203.50it/s]


In [40]:
mem3 = learn(node_hvs, dataGraph, 1, 0.2)

100%|██████████| 169343/169343 [13:54<00:00, 202.81it/s]


In [48]:
mem4 = learn(node_hvs, dataGraph, 1, 0.01, mem2)

100%|██████████| 169343/169343 [13:45<00:00, 205.11it/s]


In [51]:
mem5 = learn(node_hvs, dataGraph, 1, 0.001)

100%|██████████| 169343/169343 [14:55<00:00, 189.10it/s]


In [41]:
acc1 = rand_test(node_hvs, dataGraph, mem1, 1)

100%|██████████| 1000/1000 [00:04<00:00, 223.39it/s]


In [42]:
acc2 = rand_test(node_hvs, dataGraph, mem2, 1)

100%|██████████| 1000/1000 [00:04<00:00, 236.87it/s]


In [43]:
acc3 = rand_test(node_hvs, dataGraph, mem3, 1)

100%|██████████| 1000/1000 [00:05<00:00, 194.76it/s]


In [49]:
acc4 = rand_test(node_hvs, dataGraph, mem4, 1)

100%|██████████| 1000/1000 [00:04<00:00, 237.48it/s]


In [54]:
acc5 = rand_test(node_hvs, dataGraph, mem5, 1)

100%|██████████| 1000/1000 [00:04<00:00, 220.35it/s]


In [56]:
print(acc1)
print(acc2)
print(acc3)
print(acc4)
print(acc5)

0.43381180223285487
0.4780564263322884
0.43010752688172044
0.4534351145038168
0.2229299363057325
