In [15]:
# get obgn-arxiv dataset
# https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv
from ogb.nodeproppred import PygNodePropPredDataset

dataset = PygNodePropPredDataset(name='ogbn-arxiv')

In [16]:
dataGraph = dataset[0]

print(f"number of graphs: {len(dataset)}")
print(f"number of features: {dataset.num_features}")
print(f"number of classes: {dataset.num_classes}")
print(f"number of nodes: {dataGraph.num_nodes}")
print(f"number of edges: {dataGraph.num_edges}")

number of graphs: 1
number of features: 128
number of classes: 40
number of nodes: 169343
number of edges: 1166243


In [80]:
# HDC stuff
from hdc.bsc import BSC as HDC
from hdc.hv import HDCRepresentation
from hdc.itemmem import ItemMem
from hdc.encoder import Encoder
from hdc.levelencoder import LevelEncoder
from tqdm import tqdm
from typing import Type
from hdc.itemmem import HighResItemMem

hdc = HDC
N = 10240
encoder = LevelEncoder(hdc, N, -1, 1, 0.05)

num_intervals: 40
num_bits: 256


In [81]:
id_hvs = [hdc.random_hypervector(N) for _ in range(dataGraph.num_features)]

In [82]:
from torch_geometric.utils import k_hop_subgraph
import networkx as nx
import numpy as np
import torch

def extract_subgraph(dataGraph, node_idx, depth):
    # get the subgraph around the node
    # node_idx: index of the node
    # depth: depth of the subgraph
    # return: subgraph
    # the subgraph is all nodes and edges within depth of the node
    # depth = 0 means only the node itself
    # depth = 1 means the node and its neighbors
    # depth = 2 means the node and its neighbors and their neighbors
    # etc.
    subset, edge_index, _, _ = k_hop_subgraph(node_idx, depth, dataGraph.edge_index)
    return subset, edge_index

# return node indices of neighbors where there is an edge from the node to the neighbor
def get_neighbors_from_src(dataGraph, node_idx):
    # get the neighbors of the node
    # node_idx: index of the node
    # return: neighbors
    # neighbors are all nodes adjacent to the node
    neighbors = dataGraph.edge_index[1][dataGraph.edge_index[0] == node_idx]
    return neighbors

def get_neighbors_from_dst(dataGraph, node_idx):
    # get the neighbors of the node
    # node_idx: index of the node
    # return: neighbors
    # neighbors are all nodes adjacent to the node
    neighbors = dataGraph.edge_index[0][dataGraph.edge_index[1] == node_idx]
    return neighbors

def get_neighbors(dataGraph, node_idx, depth, pr, sample_size=100):
    # get the neighbors of the node
    # node_idx: index of the node
    # depth: depth of the subgraph
    # return: neighbors
    # neighbors are all nodes within depth of the node
    # depth = 0 means no neighbors
    # depth = 1 means only the neighbors of the node
    # depth = 2 means the neighbors of the node and their neighbors
    # etc.
    # BFS to get the neighbors
    # pr: the PageRank vector
    # sample_size: the number of neighbors to sample
    neighbors = [] # list of neighbors in tuples (node indices, distance, pr_val); distance is 1 for adjacent nodes -1 for nodes pointing to the node
    visited = set()
    queue = []
    min_pr = -float('inf')
    queue.append((node_idx, 0))
    visited.add(node_idx)
    while queue:
        node_idx, distance = queue.pop(0)
        neighbors.append((node_idx, distance))
        if pr[node_idx] > min_pr:
            min_pr = pr[node_idx]
        if distance >= 0 and distance < depth:
            # get neighbors of node_idx
            neighbors_of_node = get_neighbors_from_src(dataGraph, node_idx)
            for neighbor in neighbors_of_node:
                if neighbor not in visited:
                    if pr[neighbor] > min_pr:
                        queue.append((neighbor.item(), distance+1))
                    visited.add(neighbor)
        if distance <= 0 and distance > -depth:
            # get neighbors of node_idx
            neighbors_of_node = get_neighbors_from_dst(dataGraph, node_idx)
            for neighbor in neighbors_of_node:
                if neighbor not in visited:
                    if pr[neighbor] > min_pr:
                        queue.append((neighbor.item(), distance-1))
                    visited.add(neighbor)
    # sort by page rank
    neighbors.sort(key=lambda x: pr[x[0]], reverse=True)
    # sample neighbors
    if len(neighbors) > sample_size:
        neighbors = neighbors[:sample_size]
    return neighbors

# test get_neighbors
# idx0_neighbors = get_neighbors(dataGraph, 0, 1, )
# print(idx0_neighbors)
# print(len(idx0_neighbors))

In [83]:
def encode_instance(features):
    feature_hvs = []
    for i in range(len(features)):
        f = hdc.bind([id_hvs[i], encoder.encode(features[i])])
        feature_hvs.append(f)
    hv = hdc.bundle(feature_hvs)
    return hv

def encode_raw_nodes(dataGraph):
    node_hvs = []
    for i in tqdm(range(dataGraph.num_nodes)):
        hv = encode_instance(dataGraph.x[i])
        node_hvs.append(hv)
    return node_hvs

In [84]:
def encode_edges(node_hvs, edge_index):
    edge_hvs = []
    for i in range(edge_index.shape[1]):
        src = edge_index[0][i]
        dst = edge_index[1][i]
        hv = hdc.bind([node_hvs[src], node_hvs[dst]])
        edge_hvs.append(hv)
    return edge_hvs

def encode_subgraph(dataGraph, node_hvs, node_idx, depth, pr):
    # get the subgraph around the node
    # node_idx: index of the node
    # depth: depth of the subgraph
    # return: subgraph
    # the subgraph is all nodes and edges within depth of the node
    # depth = 0 means only the node itself
    # depth = 1 means the node and its neighbors
    # depth = 2 means the node and its neighbors and their neighbors
    # etc.
    nodes = get_neighbors(dataGraph, node_idx, depth, pr)
    hvs_to_bundle = []
    for node, distance in nodes:
        hv = node_hvs[node]
        hvs_to_bundle.append(hdc.permute(hv, distance))
    hv = hdc.bundle(hvs_to_bundle)
    return hv

In [85]:
# learning
# for each node, encode the subgraph around it
def learn(node_hvs, dataGraph, depth, pr, lr = 0.1, exisiting_mem = None):
    itemmem = HighResItemMem(hdc, lr) if exisiting_mem is None else exisiting_mem
    correct = 0
    total = 0
    accuracy = 0
    with tqdm(total=dataGraph.num_nodes, desc="Learning", unit="node") as pbar:
        for i in range(dataGraph.num_nodes):
            label = dataGraph.y[i].item()
            hv = encode_subgraph(dataGraph, node_hvs, i, depth, pr)
            if len(hv) == N:
                # Test if the item is already in the mem
                itemmem.build()
                if label in itemmem.caches.keys():
                    pred = itemmem.query(hv)
                    if not pred or pred != label:
                        # If mispredicted, add to label cache
                        itemmem.cache(label, hv)
                        # subtract from the mispredicted label cache
                        itemmem.decache(pred, hv)
                    else:
                        correct += 1
                else:
                    # never seen this label before, add to label cache
                    itemmem.cache(label, hv)
                total += 1
                accuracy = correct / total
            pbar.set_postfix({"accuracy": accuracy})
            pbar.update(1)
    itemmem.build()
    return itemmem

In [86]:
def rand_test(node_hvs, dataGraph, itemmem, depth, pr):
    rand_indices = np.random.choice(dataGraph.num_nodes, 1000, replace=False)
    correct = 0
    skipped = 0
    for i in tqdm(rand_indices):
        label = dataGraph.y[i].item()
        hv = encode_subgraph(dataGraph, node_hvs, i, depth, pr)
        pred = itemmem.query(hv)
        if pred == label:
            correct += 1
    return correct / (len(rand_indices) - skipped)

In [87]:
node_hvs = encode_raw_nodes(dataGraph)

 44%|████▍     | 75187/169343 [03:47<04:36, 339.99it/s]

In [53]:
# calculate PageRank
from torch_geometric.utils import to_networkx
import networkx as nx

# convert to networkx graph
G = to_networkx(dataGraph)
# calculate PageRank
pr = nx.pagerank(G)

In [None]:
mem1 = learn(node_hvs, dataGraph, 1, pr, lr=0.1)

100%|██████████| 169343/169343 [13:33<00:00, 208.28it/s]


In [70]:
mem2 = learn(node_hvs, dataGraph, 1, pr, 0.01)

Learning: 100%|██████████| 169343/169343 [15:17<00:00, 184.49node/s, accuracy=0.402]


In [76]:
mem3 = learn(node_hvs, dataGraph, 2, pr, 0.1)

Learning:   1%|          | 1385/169343 [00:31<1:04:21, 43.50node/s, accuracy=0.108] 


KeyboardInterrupt: 

In [48]:
mem4 = learn(node_hvs, dataGraph, 1, 0.01, mem2)

100%|██████████| 169343/169343 [13:45<00:00, 205.11it/s]


In [51]:
mem5 = learn(node_hvs, dataGraph, 1, 0.001)

100%|██████████| 169343/169343 [14:55<00:00, 189.10it/s]


In [41]:
acc1 = rand_test(node_hvs, dataGraph, mem1, 1)

100%|██████████| 1000/1000 [00:04<00:00, 223.39it/s]


In [73]:
acc2 = rand_test(node_hvs, dataGraph, mem2, 1, pr)

100%|██████████| 1000/1000 [00:03<00:00, 308.99it/s]


In [42]:
acc3 = rand_test(node_hvs, dataGraph, mem3, 2)

100%|██████████| 1000/1000 [00:19<00:00, 51.50it/s]


In [49]:
acc4 = rand_test(node_hvs, dataGraph, mem4, 1)

100%|██████████| 1000/1000 [00:04<00:00, 237.48it/s]


In [54]:
acc5 = rand_test(node_hvs, dataGraph, mem5, 1)

100%|██████████| 1000/1000 [00:04<00:00, 220.35it/s]


In [74]:
print(acc2)

0.504


In [43]:
print(acc3)

0.314
