For our `ogbn-papers100M` experiments, we have 256GB memory. If you have enough memory, you can simply run `run_ogbn.py` for `ogbn-papers100M`. However, due to the hugh size of the dataset, we have to preprocess some data. It is recommended to run this jupyter notebook first before you run `run_100m.py`.

In [None]:
import sys
sys.path.append('/home/chendi/ibmb')

In [None]:
import numpy as np
import torch
import pickle

## load original data

In [None]:
from ogb.nodeproppred import PygNodePropPredDataset

In [None]:
dataset = PygNodePropPredDataset(name="ogbn-papers100M", root='/nfs/students/qian')  # use your /path/to/data

splits = dataset.get_idx_split()
train_indices = splits['train'].numpy()
val_indices = splits['valid'].numpy()
test_indices = splits['test'].numpy()

with open('splits.pkl', 'wb') as handle:
    pickle.dump((train_indices, val_indices, test_indices), handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open('splits.pkl', 'rb') as handle:
    (train_indices, val_indices, test_indices) = pickle.load(handle)

In [None]:
graph = dataset[0]

## remove some currently unneeded data to free memory

In [None]:
row, col = graph.edge_index
num_nodes = graph.num_nodes

del dataset, graph

## process adj

In [None]:
from torch_sparse import SparseTensor
from dataloaders.BaseLoader import BaseLoader

data = torch.ones_like(row, dtype=torch.bool)
adj = SparseTensor(row=row, col=col, value=data, sparse_sizes=(num_nodes, num_nodes))

In [None]:
adj = adj + adj.t() + SparseTensor.eye(num_nodes, dtype=torch.bool)
adj = BaseLoader.normalize_adjmat(adj, 'sym')

## save adj

In [None]:
from torch_sparse import SparseTensor

torch.save(adj, '/nfs/students/qian/adj.pt')  # use your /path/to/data

In [None]:
adj_t = torch.load('/nfs/students/qian/adj.pt')

scipy_adj = adj_t.to_scipy('csr')

## calculate ppr matrices for train, val and test split

See https://github.com/TUM-DAML/pprgo_pytorch/blob/master/pprgo/ppr.py for method reference

In [None]:
from dataloaders.utils import get_partitions, topk_ppr_matrix

In [None]:
topk = 96
alpha = 0.05
eps = 2e-5

### val

In [None]:
edge_index = torch.vstack((adj_t.storage.row(), adj_t.storage.col()))

In [None]:
num_nodes = adj_t.sizes()[0]

In [None]:
val_ppr_mat, val_neighbors = topk_ppr_matrix(edge_index,
                    num_nodes,
                    alpha,
                    eps,
                    val_indices,
                    topk)

with open('papers100m_val_ppr.pkl', 'wb') as handle:
    pickle.dump((val_ppr_mat, val_neighbors), handle, protocol=pickle.HIGHEST_PROTOCOL)

### train

In [None]:
train_ppr_mat, train_neighbors = topk_ppr_matrix(edge_index,
                    num_nodes,
                    alpha,
                    eps,
                    train_indices,
                    topk)

with open('papers100m_train_ppr.pkl', 'wb') as handle:
    pickle.dump((train_ppr_mat, train_neighbors), handle, protocol=pickle.HIGHEST_PROTOCOL)

### test

In [None]:
test_ppr_mat, test_neighbors = topk_ppr_matrix(edge_index,
                    num_nodes,
                    alpha,
                    eps,
                    test_indices,
                    topk)

with open('papers100m_test_ppr.pkl', 'wb') as handle:
    pickle.dump((test_ppr_mat, test_neighbors), handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open('papers100m_val_ppr.pkl', 'rb') as handle:
    val_ppr_mat, val_neighbor = pickle.load(handle)

with open('papers100m_train_ppr.pkl', 'rb') as handle:
    train_ppr_mat, train_neighbors = pickle.load(handle)

with open('papers100m_test_ppr.pkl', 'rb') as handle:
    test_ppr_mat, test_neighbors = pickle.load(handle)

## Node-wise IBMB batches

In [None]:
num_output_node_per_batch = 5000

In [None]:
from dataloaders.IBMBNodeLoader import get_pairs, prime_orient_merge, prime_post_process

def prime_ppr_loader(ppr_matrix, output_indices, neighbors, num_aux_per_node):
    ppr_matrix = ppr_matrix[:, output_indices]
    ppr_pairs = get_pairs(ppr_matrix)

    output_list = prime_orient_merge(ppr_pairs, num_aux_per_node, len(output_indices))
    output_list = prime_post_process(output_list, num_aux_per_node)
    node_wise_out_aux_pairs = []

    if isinstance(neighbors, list):
        neighbors = np.array(neighbors, dtype=object)

    _union = lambda inputs: np.unique(np.concatenate(inputs))
    for p in output_list:
        node_wise_out_aux_pairs.append((output_indices[p], _union(neighbors[p]).astype(np.int64)))
    return node_wise_out_aux_pairs

In [None]:
val_loader = prime_ppr_loader(val_ppr_mat, 
                                     val_indices, 
                                     val_neighbor, 
                                     num_output_node_per_batch * 2)

In [None]:
test_loader = prime_ppr_loader(test_ppr_mat, 
                                     test_indices, 
                                     test_neighbors, 
                                     num_output_node_per_batch * 2)

In [None]:
train_loader = prime_ppr_loader(train_ppr_mat, 
                                     train_indices, 
                                     train_neighbors, 
                                     num_output_node_per_batch)

In [None]:
with open('papers100m_train_ppr_batches.pkl', 'wb') as handle:
    pickle.dump(train_loader, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('papers100m_val_ppr_batches.pkl', 'wb') as handle:
    pickle.dump(val_loader, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('papers100m_test_ppr_batches.pkl', 'wb') as handle:
    pickle.dump(test_loader, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Batch-wise IBMB

This is tricky for `ogbn-papers100M` dataset. 

Because the dataset is quite large, METIS partitioning cannot be directly applied. 

For each split, e.g. train split, we obtain some neighborhood of each node, and take the induced subgraph. 

Then we do partitioning on the subgraph.

Finally, we merge the primary nodes in each partition, and auxiliary nodes are obtained from the topk PPR scores. 

In [None]:
from scipy.sparse import find

In [None]:
def partition_ppr_loader(partitions, prime_indices, neighbor_list):
    n = len(partitions)
    batches = []
    if isinstance(neighbor_list, list):
        neighbor_list = np.array(neighbor_list, dtype=object)
    for i in range(n):
        intersect = np.intersect1d(partitions[i], prime_indices)
        ind = np.in1d(prime_indices, intersect)
        lst = list(neighbor_list[ind])
        seconds = np.unique(np.concatenate(lst))
        batches.append((intersect, seconds,))
    
    return batches

In [None]:
thresh = 5e-4
train_parts, val_parts, test_parts = [256, 32, 48]

In [None]:
for indices, mat, neighbor, num_parts, naming in zip([train_indices, val_indices, test_indices], 
                                             [train_ppr_mat, val_ppr_mat, test_ppr_mat], 
                                             [train_neighbors, val_neighbor, test_neighbors],
                                             [train_parts, val_parts, test_parts],
                                             ['train', 'val', 'test']):
    row, col, val = find(mat)
    
    mask = val > thresh
    mask = np.unique(col[mask])
    torch_mask = torch.from_numpy(mask).long()
    
    temp_adj_t = adj_t[torch_mask, :][:, torch_mask]
    print(f'processed {naming} adj')
    
    _, partptr, perm = temp_adj_t.partition(num_parts=num_parts, recursive=False, weighted=False)
    print(f'partitioned {naming} adj')
    
    partitions = []
    for i in range(len(partptr) - 1):
        partitions.append(mask[perm[partptr[i] : partptr[i + 1]].numpy()])
    print(f'obtained {naming} partitions')
    
    batches = partition_ppr_loader(partitions, indices, neighbor)
    print(f'obtained {naming} batches')
    
    with open(f'papers100m_{naming}_part_batches.pkl', 'wb') as handle:
        pickle.dump(batches, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print(f'saved {naming} batches')

### visualize weight distribution

In [None]:
row, col, val = find(train_ppr_mat)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

f, ax = plt.subplots(figsize=(7, 7))
ax.set(xscale="log")
sns.histplot(val, ax=ax, bins=50)