In [1]:
import torch
print("PyTorch version:", torch.__version__)

PyTorch version: 2.2.2


# Importing datasets and preprocessing them into triples

## via torch_geometric

Importing `torch_geometric` is just to get the datasets, and will not be used further.

In [2]:
import torch_geometric
print("PyTorch Geometric version:", torch_geometric.__version__)

PyTorch Geometric version: 2.6.1


In [3]:
def get_knowledge_graph_summary(data, dataset_name):
    print(f"--- {dataset_name} ---")
    print(f"Number of nodes: {data.num_nodes}")
    print(f"Number of edges: {data.edge_index.shape[1]}")
    print(f"Edge index shape: {data.edge_index.shape}")
    print(f"Edge types shape: {data.edge_type.shape}")

    unique, counts = torch.unique(data.edge_type, return_counts=True)
    min_count = torch.min(counts).item()
    max_count = torch.max(counts).item()
    mean_count = torch.mean(counts.float()).item()
    median_count = torch.median(counts.float()).item()

    print(f"Number of unique relation types: {unique.numel()}")
    print(f"- Minimum relation occurrences: {min_count}")
    print(f"- Maximum relation occurrences: {max_count}")
    print(f"- Mean relation occurrences: {mean_count:.2f}")
    print(f"- Median relation occurrences: {median_count:.2f}")

In [4]:
def preprocess_dataset_to_triples(edge_index, edge_type):
    """
    Convert PyTorch Geometric tensors to a list of (head, relation, tail) triples.

    Args:
        edge_index (torch.Tensor): Tensor of shape [2, num_edges] representing (head, tail).
        edge_type (torch.Tensor): Tensor of shape [num_edges] representing relation types.

    Returns:
        list: A list of triples (head, relation, tail).
    """
    triples = []
    for i in range(edge_index.shape[1]):
        head = edge_index[0, i].item()
        tail = edge_index[1, i].item()
        relation = edge_type[i].item()
        triples.append((head, relation, tail))
    return triples

### WN18 & WN18RR

In [5]:
from torch_geometric.datasets import WordNet18, WordNet18RR

WN18 = WordNet18(root='data/WN18')[0]
get_knowledge_graph_summary(WN18, "WN18")

WN18RR = WordNet18RR(root='data/WN18RR')[0]
get_knowledge_graph_summary(WN18RR, "WN18RR")

--- WN18 ---
Number of nodes: 40943
Number of edges: 151442
Edge index shape: torch.Size([2, 151442])
Edge types shape: torch.Size([151442])
Number of unique relation types: 18
- Minimum relation occurrences: 86
- Maximum relation occurrences: 37221
- Mean relation occurrences: 8413.44
- Median relation occurrences: 3150.00
--- WN18RR ---
Number of nodes: 40943
Number of edges: 93003
Edge index shape: torch.Size([2, 93003])
Edge types shape: torch.Size([93003])
Number of unique relation types: 11
- Minimum relation occurrences: 86
- Maximum relation occurrences: 37221
- Mean relation occurrences: 8454.82
- Median relation occurrences: 3150.00


In [6]:
##### WN18 #####

# Accessing predefined splits
WN18_mask_train = WN18.train_mask
WN18_mask_val = WN18.val_mask
WN18_mask_test = WN18.test_mask

# Getting train, val, and test splits
WN18_edges_train = WN18.edge_index[:, WN18_mask_train]
WN18_types_train = WN18.edge_type[WN18_mask_train]

WN18_edges_val = WN18.edge_index[:, WN18_mask_val]
WN18_types_val = WN18.edge_type[WN18_mask_val]

WN18_edges_test = WN18.edge_index[:, WN18_mask_test]
WN18_types_test = WN18.edge_type[WN18_mask_test]

# Converting to list of triples
WN18_triples_train = preprocess_dataset_to_triples(WN18_edges_train, WN18_types_train)
WN18_triples_val = preprocess_dataset_to_triples(WN18_edges_val, WN18_types_val)
WN18_triples_test = preprocess_dataset_to_triples(WN18_edges_test, WN18_types_test)

# Checking split sizes and triples
print("===== WN18 =====")
print(f"- Train: {len(WN18_triples_train)} triples")
print(f"- Validation: {len(WN18_triples_val)} triples")
print(f"- Test: {len(WN18_triples_test)} triples")
print("\nSome (train) triples:")
print("\n".join(f"> {triple}" for triple in WN18_triples_train[:5]))

===== WN18 =====
- Train: 141442 triples
- Validation: 5000 triples
- Test: 5000 triples

Some (train) triples:
> (0, 5, 9534)
> (0, 15, 12878)
> (0, 10, 14747)
> (1, 10, 39788)
> (1, 2, 40217)


In [7]:
##### WN18RR #####

# Accessing predefined splits
WN18RR_mask_train = WN18RR.train_mask
WN18RR_mask_val = WN18RR.val_mask
WN18RR_mask_test = WN18RR.test_mask

# Getting train, val, and test splits
WN18RR_edges_train = WN18RR.edge_index[:, WN18RR_mask_train]
WN18RR_types_train = WN18RR.edge_type[WN18RR_mask_train]

WN18RR_edges_val = WN18RR.edge_index[:, WN18RR_mask_val]
WN18RR_types_val = WN18RR.edge_type[WN18RR_mask_val]

WN18RR_edges_test = WN18RR.edge_index[:, WN18RR_mask_test]
WN18RR_types_test = WN18RR.edge_type[WN18RR_mask_test]

# Converting to list of triples
WN18RR_triples_train = preprocess_dataset_to_triples(WN18RR_edges_train, WN18RR_types_train)
WN18RR_triples_val = preprocess_dataset_to_triples(WN18RR_edges_val, WN18RR_types_val)
WN18RR_triples_test = preprocess_dataset_to_triples(WN18RR_edges_test, WN18RR_types_test)

# Checking split sizes and triples
print("===== WN18RR =====")
print(f"- Train: {len(WN18RR_triples_train)} triples")
print(f"- Validation: {len(WN18RR_triples_val)} triples")
print(f"- Test: {len(WN18RR_triples_test)} triples")
print("\nSome (train) triples:")
print("\n".join(f"> {triple}" for triple in WN18RR_triples_train[:5]))

===== WN18RR =====
- Train: 86835 triples
- Validation: 3034 triples
- Test: 3134 triples

Some (train) triples:
> (0, 3, 10211)
> (0, 9, 25525)
> (1, 10, 3891)
> (1, 1, 5070)
> (1, 1, 7723)


### FB15k237

In [8]:
from torch_geometric.datasets import FB15k_237

FB15k237 = FB15k_237(root='data/FB15k_237')[0]
get_knowledge_graph_summary(FB15k237, "FB15k_237")

--- FB15k_237 ---
Number of nodes: 14541
Number of edges: 272115
Edge index shape: torch.Size([2, 272115])
Edge types shape: torch.Size([272115])
Number of unique relation types: 237
- Minimum relation occurrences: 37
- Maximum relation occurrences: 15989
- Mean relation occurrences: 1148.16
- Median relation occurrences: 373.00


In [9]:
# Loading splits
print("===== FB15k_237 (Raw) =====")

FB15k237_train = FB15k_237(root='data/FB15k_237', split='train')[0]
print(f"- FB15k_237 Train: {FB15k237_train.edge_index.shape}, {FB15k237_train.edge_type.shape}")

FB15k237_val = FB15k_237(root='data/FB15k_237', split='val')[0]
print(f"- FB15k_237 Validation: {FB15k237_val.edge_index.shape}, {FB15k237_val.edge_type.shape}")

FB15k237_test = FB15k_237(root='data/FB15k_237', split='test')[0]
print(f"- FB15k_237 Test: {FB15k237_test.edge_index.shape}, {FB15k237_test.edge_type.shape}")

# Converting to list of triples
FB15k237_triples_train = preprocess_dataset_to_triples(FB15k237_train.edge_index, FB15k237_train.edge_type)
FB15k237_triples_val = preprocess_dataset_to_triples(FB15k237_val.edge_index, FB15k237_val.edge_type)
FB15k237_triples_test = preprocess_dataset_to_triples(FB15k237_test.edge_index, FB15k237_test.edge_type)

# Checking split sizes and triples
print("\n===== FB15k_237 (List) =====")
print(f"- Train: {len(FB15k237_triples_train)} triples")
print(f"- Validation: {len(FB15k237_triples_val)} triples")
print(f"- Test: {len(FB15k237_triples_test)} triples")
print("\nSome (train) triples:")
print("\n".join(f"> {triple}" for triple in FB15k237_triples_train[:5]))

===== FB15k_237 (Raw) =====
- FB15k_237 Train: torch.Size([2, 272115]), torch.Size([272115])
- FB15k_237 Validation: torch.Size([2, 17535]), torch.Size([17535])
- FB15k_237 Test: torch.Size([2, 20466]), torch.Size([20466])

===== FB15k_237 (List) =====
- Train: 272115 triples
- Validation: 17535 triples
- Test: 20466 triples

Some (train) triples:
> (0, 0, 1)
> (2, 1, 3)
> (4, 2, 5)
> (6, 3, 7)
> (8, 4, 9)


I was also going to import **FB15k** dataset (alongside FB15k_237), but it does not seem to be available on `torch_geometric`, and even though it is available in HuggingFace, no one seems to be using it after 2019.

> "*The original FB15k dataset suffers from major test leakage through inverse relations, where a large number of test triples could be obtained by inverting triples in the training set. In order to create a dataset without this characteristic, the FB15k_237 describes a subset of FB15k where inverse relations are removed.*"

Hence, I decided to skip the replication for that dataset even though it is present in the 2019 conference publication of AnyBURL.

## via HuggingFace

### YAGO3-10

In [10]:
from datasets import load_dataset
YAGO = load_dataset("VLyb/YAGO3-10")
print(YAGO)

DatasetDict({
    train: Dataset({
        features: ['head', 'relation', 'tail'],
        num_rows: 1079040
    })
    validation: Dataset({
        features: ['head', 'relation', 'tail'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['head', 'relation', 'tail'],
        num_rows: 5000
    })
})


In [11]:
# Doing the splits
YAGO_train = YAGO["train"]
YAGO_triples_train = [tuple(row.values()) for row in YAGO_train]

YAGO_val = YAGO["validation"]
YAGO_triples_val = [tuple(row.values()) for row in YAGO_val]

YAGO_test = YAGO["test"]
YAGO_triples_test = [tuple(row.values()) for row in YAGO_test]

# Checking split sizes and triples
print("===== YAGO3-10 =====")
print(f"- Train: {len(YAGO_triples_train)} triples")
print(f"- Validation: {len(YAGO_triples_val)} triples")
print(f"- Test: {len(YAGO_triples_test)} triples")
print("\nSome (train) triples:")
print("\n".join(f"> {triple}" for triple in YAGO_triples_train[:5]))

===== YAGO3-10 =====
- Train: 1079040 triples
- Validation: 5000 triples
- Test: 5000 triples

Some (train) triples:
> ('Chatou', 'isLocatedIn', 'France')
> ('Boo_Young-tae', 'playsFor', 'Yangju_Citizen_FC')
> ('Toni_Kuivasto', 'isAffiliatedTo', 'Helsingin_Jalkapalloklubi')
> ('Josh_Smith_(soccer)', 'playsFor', 'Trinity_University_(Texas)')
> ('Albrecht_Dürer', 'diedIn', 'Nuremberg')


# Saving

In [12]:
import os
import pickle

output_dir = "data/triples"
os.makedirs(output_dir, exist_ok=True)

def save_if_not_exists(data, filename):
    file_path = os.path.join(output_dir, filename)
    if not os.path.exists(file_path):
        with open(file_path, "wb") as f:
            pickle.dump(data, f)
        print(f"Saved: {file_path}")
    else:
        print(f"File already exists: {file_path}")

# WN18
save_if_not_exists(WN18_triples_train, "WN18_triples_train.pkl")
save_if_not_exists(WN18_triples_val, "WN18_triples_val.pkl")
save_if_not_exists(WN18_triples_test, "WN18_triples_test.pkl")

# WN18RR
save_if_not_exists(WN18RR_triples_train, "WN18RR_triples_train.pkl")
save_if_not_exists(WN18RR_triples_val, "WN18RR_triples_val.pkl")
save_if_not_exists(WN18RR_triples_test, "WN18RR_triples_test.pkl")

# FB15k-237
save_if_not_exists(FB15k237_triples_train, "FB15k237_triples_train.pkl")
save_if_not_exists(FB15k237_triples_val, "FB15k237_triples_val.pkl")
save_if_not_exists(FB15k237_triples_test, "FB15k237_triples_test.pkl")

# YAGO3-10
save_if_not_exists(YAGO_triples_train, "YAGO_triples_train.pkl")
save_if_not_exists(YAGO_triples_val, "YAGO_triples_val.pkl")
save_if_not_exists(YAGO_triples_test, "YAGO_triples_test.pkl")

Saved: data/triples/WN18_triples_train.pkl
Saved: data/triples/WN18_triples_val.pkl
Saved: data/triples/WN18_triples_test.pkl
Saved: data/triples/WN18RR_triples_train.pkl
Saved: data/triples/WN18RR_triples_val.pkl
Saved: data/triples/WN18RR_triples_test.pkl
Saved: data/triples/FB15k237_triples_train.pkl
Saved: data/triples/FB15k237_triples_val.pkl
Saved: data/triples/FB15k237_triples_test.pkl
Saved: data/triples/YAGO_triples_train.pkl
Saved: data/triples/YAGO_triples_val.pkl
Saved: data/triples/YAGO_triples_test.pkl
