In [1]:
!pip install -U -q torch-geometric==2.5.3

In [2]:
!pip install -U -q ogb==1.3.6 # graph benchmark

In [4]:
import os
import math
import pickle
import torch
import pandas as pd
import networkx as nx
from tqdm import tqdm
from torch_geometric.seed import seed_everything
import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18, WordNet18RR
from torch_geometric.utils import train_test_split_edges, k_hop_subgraph, negative_sampling, to_undirected, is_undirected, to_networkx
from ogb.linkproppred import PygLinkPropPredDataset

# Define the utils functions

In [3]:
import numpy as np
import torch
import networkx as nx


def get_node_edge(graph):
    degree_sorted_ascend = sorted(graph.degree, key=lambda x: x[1])

    return degree_sorted_ascend[-1][0]

def h_hop_neighbor(G, node, h):
    path_lengths = nx.single_source_dijkstra_path_length(G, node)
    return [node for node, length in path_lengths.items() if length == h]
                    
def get_enclosing_subgraph(graph, edge_to_delete):
    subgraph = {0: [edge_to_delete]}
    s, t = edge_to_delete
    
    neighbor_s = []
    neighbor_t = []
    for h in range(1, 2+1):
        neighbor_s += h_hop_neighbor(graph, s, h)
        neighbor_t += h_hop_neighbor(graph, t, h)
        
        nodes = neighbor_s + neighbor_t + [s, t]
        
        subgraph[h] = list(graph.subgraph(nodes).edges())
        
    return subgraph

@torch.no_grad()
def get_link_labels(pos_edge_index, neg_edge_index):
    E = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(E, dtype=torch.float, device=pos_edge_index.device)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels

@torch.no_grad()
def get_link_labels_kg(pos_edge_index, neg_edge_index):
    E = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(E, dtype=torch.float, device=pos_edge_index.device)
    link_labels[:pos_edge_index.size(1)] = 1.

    return link_labels

@torch.no_grad()
def negative_sampling_kg(edge_index, edge_type):
    '''Generate negative samples but keep the node type the same'''

    edge_index_copy = edge_index.clone()
    for et in edge_type.unique():
        mask = (edge_type == et)
        old_source = edge_index_copy[0, mask]
        new_index = torch.randperm(old_source.shape[0])
        new_source = old_source[new_index]
        edge_index_copy[0, mask] = new_source
    
    return edge_index_copy

# Define split data to training and testing

In [5]:
data_dir = './data'

df_size = [i / 100 for i in range(10)] + [i / 10 for i in range(10)] + [i for i in range(10)]       # Df_size in percentage
seeds = [42, 21, 13, 87, 100]
graph_datasets = ['Cora', 'PubMed', 'DBLP', 'CS', 'ogbl-citation2', 'ogbl-collab'][4:]
kg_datasets = ['FB15k-237', 'WordNet18', 'WordNet18RR', 'ogbl-biokg'][-1:]
os.makedirs(data_dir, exist_ok=True)


num_edge_type_mapping = {
    'FB15k-237': 237,
    'WordNet18': 18,
    'WordNet18RR': 11
}

def train_test_split_edges_no_neg_adj_mask(data, val_ratio: float = 0.05, test_ratio: float = 0.1, two_hop_degree=None, kg=False):
    '''Avoid adding neg_adj_mask'''

    num_nodes = data.num_nodes
    row, col = data.edge_index
    edge_attr = data.edge_attr
    if kg:
        edge_type = data.edge_type
    data.edge_index = data.edge_attr = data.edge_weight = data.edge_year = data.edge_type = None

    if not kg:
        # Return upper triangular portion.
        mask = row < col
        row, col = row[mask], col[mask]

        if edge_attr is not None:
            edge_attr = edge_attr[mask]

    n_v = int(math.floor(val_ratio * row.size(0)))
    n_t = int(math.floor(test_ratio * row.size(0)))

    if two_hop_degree is not None:          # Use low degree edges for test sets
        low_degree_mask = two_hop_degree < 50

        low = low_degree_mask.nonzero().squeeze()
        high = (~low_degree_mask).nonzero().squeeze()

        low = low[torch.randperm(low.size(0))]
        high = high[torch.randperm(high.size(0))]

        perm = torch.cat([low, high])

    else:
        perm = torch.randperm(row.size(0))

    row = row[perm]
    col = col[perm]

    # Train
    r, c = row[n_v + n_t:], col[n_v + n_t:]
    
    if kg:

        # data.edge_index and data.edge_type has reverse edges and edge types for message passing
        pos_edge_index = torch.stack([r, c], dim=0)
        # rev_pos_edge_index = torch.stack([r, c], dim=0)
        train_edge_type = edge_type[n_v + n_t:]
        # train_rev_edge_type = edge_type[n_v + n_t:] + edge_type.unique().shape[0]

        # data.edge_index = torch.cat((torch.stack([r, c], dim=0), torch.stack([r, c], dim=0)), dim=1)
        # data.edge_type = torch.cat([train_edge_type, train_rev_edge_type], dim=0)

        data.edge_index = pos_edge_index
        data.edge_type = train_edge_type
        
        # data.train_pos_edge_index and data.train_edge_type only has one direction edges and edge types for decoding
        data.train_pos_edge_index = torch.stack([r, c], dim=0)
        data.train_edge_type = train_edge_type
    
    else:
        data.train_pos_edge_index = torch.stack([r, c], dim=0)
        if edge_attr is not None:
            # out = to_undirected(data.train_pos_edge_index, edge_attr[n_v + n_t:])
            data.train_pos_edge_index, data.train_pos_edge_attr = out
        else:
            data.train_pos_edge_index = data.train_pos_edge_index
            # data.train_pos_edge_index = to_undirected(data.train_pos_edge_index)
        
        assert not is_undirected(data.train_pos_edge_index)

    
    # Test
    r, c = row[:n_t], col[:n_t]
    data.test_pos_edge_index = torch.stack([r, c], dim=0)

    if kg:
        data.test_edge_type = edge_type[:n_t]
        neg_edge_index = negative_sampling_kg(
            edge_index=data.test_pos_edge_index,
            edge_type=data.test_edge_type)
    else:
        neg_edge_index = negative_sampling(
            edge_index=data.test_pos_edge_index,
            num_nodes=data.num_nodes,
            num_neg_samples=data.test_pos_edge_index.shape[1])

    data.test_neg_edge_index = neg_edge_index

    # Valid
    r, c = row[n_t:n_t+n_v], col[n_t:n_t+n_v]
    data.val_pos_edge_index = torch.stack([r, c], dim=0)

    if kg:
        data.val_edge_type = edge_type[n_t:n_t+n_v]
        neg_edge_index = negative_sampling_kg(
            edge_index=data.val_pos_edge_index,
            edge_type=data.val_edge_type)
    else:
        neg_edge_index = negative_sampling(
            edge_index=data.val_pos_edge_index,
            num_nodes=data.num_nodes,
            num_neg_samples=data.val_pos_edge_index.shape[1])

    data.val_neg_edge_index = neg_edge_index

    return data

# Pre-process graph data

In [6]:
def process_graph():
    for d in graph_datasets:

        if d in ['Cora', 'PubMed', 'DBLP']:
            dataset = CitationFull(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures())
        elif d in ['CS', 'Physics']:
            dataset = Coauthor(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures())
        elif d in ['Flickr']:
            dataset = Flickr(os.path.join(data_dir, d), transform=T.NormalizeFeatures())
        elif 'ogbl' in d:
            dataset = PygLinkPropPredDataset(root=os.path.join(data_dir, d), name=d)
        else:
            raise NotImplementedError

        print('Processing:', d)
        print(dataset)
        data = dataset[0]
        data.train_mask = data.val_mask = data.test_mask = None
        graph = to_networkx(data)

        # Get two hop degree for all nodes
        node_to_neighbors = {}
        for n in tqdm(graph.nodes(), desc='Two hop neighbors'):
            neighbor_1 = set(graph.neighbors(n))
            neighbor_2 = sum([list(graph.neighbors(i)) for i in neighbor_1], [])
            neighbor_2 = set(neighbor_2)
            neighbor = neighbor_1 | neighbor_2
            
            node_to_neighbors[n] = neighbor

        two_hop_degree = []
        row, col = data.edge_index
        mask = row < col
        row, col = row[mask], col[mask]
        for r, c in tqdm(zip(row, col), total=len(row)):
            neighbor_row = node_to_neighbors[r.item()]
            neighbor_col = node_to_neighbors[c.item()]
            neighbor = neighbor_row | neighbor_col
            
            num = len(neighbor)
            
            two_hop_degree.append(num)

        two_hop_degree = torch.tensor(two_hop_degree)

        for s in seeds:
            seed_everything(s)

            # D
            data = dataset[0]
            if 'ogbl' in d:
                data = train_test_split_edges_no_neg_adj_mask(data, test_ratio=0.05, two_hop_degree=two_hop_degree)
            else:
                data = train_test_split_edges_no_neg_adj_mask(data, test_ratio=0.05)
            print(s, data)

            with open(os.path.join(data_dir, d, f'd_{s}.pkl'), 'wb') as f:
                pickle.dump((dataset, data), f)

            # Two ways to sample Df from the training set
            ## 1. Df is within 2 hop local enclosing subgraph of Dtest
            ## 2. Df is outside of 2 hop local enclosing subgraph of Dtest
            
            # All the candidate edges (train edges)
            # graph = to_networkx(Data(edge_index=data.train_pos_edge_index, x=data.x))

            # Get the 2 hop local enclosing subgraph for all test edges
            _, local_edges, _, mask = k_hop_subgraph(
                data.test_pos_edge_index.flatten().unique(), 
                2, 
                data.train_pos_edge_index, 
                num_nodes=dataset[0].num_nodes)
            distant_edges = data.train_pos_edge_index[:, ~mask]
            print('Number of edges. Local: ', local_edges.shape[1], 'Distant:', distant_edges.shape[1])

            in_mask = mask
            out_mask = ~mask

            torch.save(
                {'out': out_mask, 'in': in_mask},
                os.path.join(data_dir, d, f'df_{s}.pt')
            )

process_graph()

This will download 2.14GB. Will you proceed? (y/N)
 y


Downloading http://snap.stanford.edu/ogb/data/linkproppred/citation-v2.zip


Downloaded 2.14 GB: 100%|██████████| 2189/2189 [02:10<00:00, 16.72it/s]


Extracting ./data/ogbl-citation2/citation-v2.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 14926.35it/s]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 5714.31it/s]

Saving...



Done!


Processing: ogbl-citation2
PygLinkPropPredDataset()


Two hop neighbors: 100%|██████████| 2927963/2927963 [03:04<00:00, 15851.68it/s]
100%|██████████| 15228622/15228622 [08:11<00:00, 31001.82it/s] 


42 Data(num_nodes=2927963, x=[2927963, 128], node_year=[2927963, 1], train_pos_edge_index=[2, 13705760], test_pos_edge_index=[2, 761431], test_neg_edge_index=[2, 761431], val_pos_edge_index=[2, 761431], val_neg_edge_index=[2, 761431])
Number of edges. Local:  10304632 Distant: 3401128
21 Data(num_nodes=2927963, x=[2927963, 128], node_year=[2927963, 1], train_pos_edge_index=[2, 13705760], test_pos_edge_index=[2, 761431], test_neg_edge_index=[2, 761431], val_pos_edge_index=[2, 761431], val_neg_edge_index=[2, 761431])
Number of edges. Local:  10304406 Distant: 3401354
13 Data(num_nodes=2927963, x=[2927963, 128], node_year=[2927963, 1], train_pos_edge_index=[2, 13705760], test_pos_edge_index=[2, 761431], test_neg_edge_index=[2, 761431], val_pos_edge_index=[2, 761431], val_neg_edge_index=[2, 761431])
Number of edges. Local:  10319833 Distant: 3385927
87 Data(num_nodes=2927963, x=[2927963, 128], node_year=[2927963, 1], train_pos_edge_index=[2, 13705760], test_pos_edge_index=[2, 761431], test

Downloaded 0.11 GB: 100%|██████████| 117/117 [00:08<00:00, 14.04it/s]


Extracting ./data/ogbl-collab/collab.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 55.09it/s]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 7073.03it/s]

Saving...



Done!


Processing: ogbl-collab
PygLinkPropPredDataset()


Two hop neighbors: 100%|██████████| 235868/235868 [00:28<00:00, 8255.64it/s] 
100%|██████████| 1179052/1179052 [01:24<00:00, 13966.88it/s]


42 Data(num_nodes=235868, x=[235868, 128], train_pos_edge_index=[2, 1061148], test_pos_edge_index=[2, 58952], test_neg_edge_index=[2, 58952], val_pos_edge_index=[2, 58952], val_neg_edge_index=[2, 58952])
Number of edges. Local:  299852 Distant: 761296
21 Data(num_nodes=235868, x=[235868, 128], train_pos_edge_index=[2, 1061148], test_pos_edge_index=[2, 58952], test_neg_edge_index=[2, 58952], val_pos_edge_index=[2, 58952], val_neg_edge_index=[2, 58952])
Number of edges. Local:  302507 Distant: 758641
13 Data(num_nodes=235868, x=[235868, 128], train_pos_edge_index=[2, 1061148], test_pos_edge_index=[2, 58952], test_neg_edge_index=[2, 58952], val_pos_edge_index=[2, 58952], val_neg_edge_index=[2, 58952])
Number of edges. Local:  302993 Distant: 758155
87 Data(num_nodes=235868, x=[235868, 128], train_pos_edge_index=[2, 1061148], test_pos_edge_index=[2, 58952], test_neg_edge_index=[2, 58952], val_pos_edge_index=[2, 58952], val_neg_edge_index=[2, 58952])
Number of edges. Local:  304415 Distant: