In [2]:
import multiprocessing as mp
import os.path as osp
from multiprocessing import cpu_count
from pathlib import Path
from os import path
from zipfile import ZipFile

import kuzu
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from ogb.nodeproppred import PygNodePropPredDataset
from tqdm import tqdm
from torch_geometric.loader import NeighborLoader

from .src.modeling.graph_sage import GraphSAGE

In [4]:
DATASET_PATH = Path("dataset/ogbn_papers100M")
RAW_PATH = DATASET_PATH / "raw"
PROCESSED_PATH = DATASET_PATH / "processed"

In [5]:
with ZipFile(RAW_PATH / "data.npz", 'r') as data_zip:
    print('Extracting data.npz...')
    data_zip.extractall()

with ZipFile(RAW_PATH / "node-label.npz", 'r') as node_label_zip:
    print('Extracting node-label.npz...')
    node_label_zip.extractall()

Extracting data.npz...
Extracting node-label.npz...


In [None]:
edge_index = np.load(RAW_PATH / 'edge_index.npy', mmap_mode='r')
csvfile = open(RAW_PATH / 'edge_index.csv', 'w')
csvfile.write('src,dst\n')
for i in tqdm(range(edge_index.shape[1])):
    csvfile.write(str(edge_index[0, i]) + ',' + str(edge_index[1, i]) + '\n')
csvfile.close()

 41%|████▏     | 669821815/1615685872 [22:05<1:05:30, 240643.16it/s]

In [None]:
ids_path = RAW_PATH / 'ids.npy'
edge_index_path = RAW_PATH / 'edge_index.csv'
node_label_path = RAW_PATH / 'node_label.npy'
node_feature_path = RAW_PATH / 'node_feat.npy'
node_year_path = RAW_PATH / 'node_year.npy'

In [None]:
db = kuzu.Database('Papers100M')
conn = kuzu.Connection(db, num_threads=32)

In [None]:
conn.execute(
    "CREATE NODE TABLE paper(id INT64, x FLOAT[128], year INT64, y FLOAT, "
    "PRIMARY KEY (id));")
conn.execute("CREATE REL TABLE cites(FROM paper TO paper, MANY_MANY);")
conn.execute('COPY paper FROM ("%s",  "%s",  "%s", "%s") BY COLUMN;' %
             (ids_path, node_feature_path, node_year_path, node_label_path))
conn.execute('COPY cites FROM "%s";' % (edge_index_path))

In [None]:
NUM_EPOCHS = 1
LOADER_BATCH_SIZE = 1024

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the train set:
train_path = osp.join('.', 'papers100M-bin', 'split', 'time', 'train.csv.gz')
train_df = pd.read_csv(
    osp.abspath(train_path),
    compression='gzip',
    header=None,
)
input_nodes = torch.tensor(train_df[0].values, dtype=torch.long)

In [None]:
########################################################################
# The below code sets up the remote backend of Kùzu for PyG.
# Please refer to: https://kuzudb.com/docs/client-apis/python-api/overview.html
# for how to use the Python API of Kùzu.
########################################################################

# The buffer pool size of Kùzu is set to 40GB. You can change it to a smaller
# value if you have less memory.
KUZU_BM_SIZE = 40 * 1024**3

In [None]:
# Create Kùzu database:
db = kuzu.Database(osp.abspath(osp.join('.', 'papers100M')), KUZU_BM_SIZE)

# Get remote backend for PyG:
feature_store, graph_store = db.get_torch_geometric_remote_backend(
    mp.cpu_count())

In [None]:
# Plug the graph store and feature store into the `NeighborLoader`.
# Note that `filter_per_worker` is set to `False`. This is because the Kùzu
# database is already using multi-threading to scan the features in parallel
# and the database object is not fork-safe.
loader = NeighborLoader(
    data=(feature_store, graph_store),
    num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
    batch_size=LOADER_BATCH_SIZE,
    input_nodes=('paper', input_nodes),
    num_workers=4,
    filter_per_worker=False,
)

In [None]:
model = GraphSAGE(in_channels=128, hidden_channels=1024, out_channels=172,
                  num_layers=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [None]:
for epoch in range(1, NUM_EPOCHS + 1):
    total_loss = total_examples = 0
    for batch in tqdm(loader):
        batch = batch.to(device)
        batch_size = batch['paper'].batch_size

        optimizer.zero_grad()
        out = model(
            batch['paper'].x,
            batch['paper', 'cites', 'paper'].edge_index,
        )[:batch_size]
        y = batch['paper'].y[:batch_size].long().view(-1)
        loss = F.cross_entropy(out, y)

        loss.backward()
        optimizer.step()

        total_loss += float(loss) * y.numel()
        total_examples += y.numel()

    print(f'Epoch: {epoch:02d}, Loss: {total_loss / total_examples:.4f}')