### Knowledge graph embedding model 

- ComplEx from torch_geometric.nn.kge for the Knowledge Graph Embedding model.


In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch_geometric
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.nn as nn
from torch.optim import Optimizer
from torch_geometric.nn.kge import ComplEx


  from .autonotebook import tqdm as notebook_tqdm


### Load the data

In [2]:
# Assuming your dataset is a DataFrame `df` with 'head', 'relation', 'tail' columns
# Load data, Reads the data from a CSV file containing knowledge graph facts for books.
# The dataset should contain 'head', 'relation', and 'tail' columns representing entities and relationships in the graph.
df = pd.read_csv('data/books_graph_facts.csv')

### Setting the training parameters

- Model - Complex model from torch.geometric. Complex models relations as complex-valued bilinear mappings between head and tail entities using the Hermetian dot product.  
    <img src="./resources/ComplexModel.png" alt="Complex" width="300"/>  
- Optimizer - To update the model parameters based on the computed gradients  

In [16]:
def train(
    dataloader: DataLoader,
    model: ComplEx,
    optimizer: Optimizer, 
    device: str,
) -> None:
    """
    General training loop for PyTorch models.

    Args:
        dataloader (DataLoader): DataLoader object for batching the training data.
        model (ComplEx): The ComplEx model instance used for link prediction.
        optimizer (Optimizer): Optimizer for updating model weights.
        device (str): Device to run the training on, either 'cpu' or 'cuda'.

    Returns:
        None

    This function iterates over the DataLoader batches and performs the following:
    - Moves the data to the specified device (CPU or GPU).
    - Calculates the loss using the ComplEx model's loss function.
    - Performs backpropagation and optimization to minimize the loss.
    - Outputs the loss at specified intervals for monitoring.
    """
    size = len(dataloader.dataset)
    print('###############Training Starts here#########################', size)
    
    model.train() 
    total_loss = total_examples = 0
    for batch, (head_index, rel_type, tail_index) in enumerate(dataloader):
        # Move head, relation, and tail tensors to device
        head_index, rel_type, tail_index = head_index.to(device), rel_type.to(device), tail_index.to(device)

        # Calculate the loss using ComplEx's loss function
        loss = model.loss(head_index, rel_type, tail_index)
        
        # Backpropagation and optimization
        optimizer.zero_grad() # to clears previous gradients and avoids the gradient accumulation
        loss.backward() # calculate gradients for each model parameters based on the loss
        optimizer.step() # adjust the model weights
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
        if batch % 1000 == 0:
            loss_val = loss.item()
            current = batch * len(head_index)
            print(f"loss: {loss_val:>7f}  [{current:>5d}/{size:>5d}]")
    return total_loss / total_examples

### Setting the testing parameters

Performs the evaluation on the model with the test data and calculates the below metrics,    
- Rank: A measure of how close the predictions are to the correct entity in sorted order.
- MRR (Mean Reciprocal Rank): The average of the reciprocal ranks of the predicted results, providing a single-score summary.
- Hits@k: Checks if the correct tail entity is within the top k predictions (here k=10). Hits@10 evaluates the number of correct predictions within the top 10 results.

In [37]:
@torch.no_grad()
def test(
    dataloader: DataLoader,
    model: ComplEx,
    device: str,
    k: int = 10 
) -> None:
    """
    Test method for ComplEx model.
    Evaluates the model using Hits@k and MRR metrics.
    """
    model.eval()  # Set model to evaluation mode

    # Initialize lists to collect the full test set
    all_head_index = []
    all_rel_type = []
    all_tail_index = []

    # Loop over the DataLoader and collect all test data
    for batch, (head_index, rel_type, tail_index) in enumerate(dataloader):
        # Move head, relation, and tail tensors to device
        head_index, rel_type, tail_index = head_index.to(device), rel_type.to(device), tail_index.to(device)
        
        # Append each batch to the list
        all_head_index.append(head_index)
        all_rel_type.append(rel_type)
        all_tail_index.append(tail_index)

    # Concatenate all batches into a single tensor for each
    all_head_index = torch.cat(all_head_index, dim=0)
    all_rel_type = torch.cat(all_rel_type, dim=0)
    all_tail_index = torch.cat(all_tail_index, dim=0)

    # Call the built-in `test()` method from PyTorch Geometric's ComplEx model
    rank, mrr, hits_at_10 = model.test(
        head_index=all_head_index,
        rel_type=all_rel_type,
        tail_index=all_tail_index,
        batch_size=2000,
        k=10
    )


    print(f'Validation Mean Rank: {rank:.2f}, Validation MRR: {mrr:.4f}, '
      f'Validation Hits@10: {hits_at_10:.4f}')

    return rank, mrr, hits_at_10
       

### Transform the data - tokenization, train test and validation split

In [20]:
# Create unique mappings for entities (from head and tail) and relations
entities = pd.concat([df['head'], df['tail']]).unique()  # Combine head and tail to get all unique entities
relations = df['relation'].unique()  # Get unique relations

In [21]:
# Create mappings to convert entities and relations into integer indices
entity_to_id = {entity: idx for idx, entity in enumerate(entities)}
relation_to_id = {relation: idx for idx, relation in enumerate(relations)}

In [22]:
# Apply the mappings to create new 'head_id', 'relation_id', and 'tail_id' columns
df['head_id'] = df['head'].map(entity_to_id)
df['relation_id'] = df['relation'].map(relation_to_id)
df['tail_id'] = df['tail'].map(entity_to_id)

In [23]:
# Assuming your dataset is stored in a pandas DataFrame 'df' with 'head', 'relation', 'tail' columns
# and has been mapped to integer indices

# Create the tensors for head, relation, and tail indices
head_index = torch.tensor(df['head_id'].values, dtype=torch.long)
rel_type = torch.tensor(df['relation_id'].values, dtype=torch.long)
tail_index = torch.tensor(df['tail_id'].values, dtype=torch.long)

In [25]:
# Step 2: Split the data into training and temp sets (80% train, 20% for validation and test)
train_head, temp_head, train_rel, temp_rel, train_tail, temp_tail = train_test_split(
    head_index, rel_type, tail_index, test_size=0.2, random_state=42
)

# Step 3: Split the temp_data into validation and test sets (50% validation, 50% test)
val_head, test_head, val_rel, test_rel, val_tail, test_tail = train_test_split(
    temp_head, temp_rel, temp_tail, test_size=0.5, random_state=42
)

### Creating the dataloader and setting the batch

In [26]:
#Create TensorDatasets for each split
train_dataset = TensorDataset(train_head, train_rel, train_tail)
val_dataset = TensorDataset(val_head, val_rel, val_tail)
test_dataset = TensorDataset(test_head, test_rel, test_tail)

In [27]:
# Step 5: Create DataLoaders for each split
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)
# Check the sizes of the splits
print(f"Training set size: {len(train_loader.dataset)}")
print(f"Validation set size: {len(val_loader.dataset)}")
print(f"Test set size: {len(test_loader.dataset)}")

Training set size: 480000
Validation set size: 60000
Test set size: 60000


In [28]:
# Set up the ComplEx model
num_nodes = len(pd.concat([df['head'], df['tail']]).unique())  # Total number of unique entities (nodes)
num_relations = len(df['relation'].unique())  # Total number of unique relations
hidden_channels = 100  # Embedding dimension

In [29]:
device = 'cuda'
# Initialize ComplEx model
model = ComplEx(num_nodes=num_nodes, num_relations=num_relations, hidden_channels=hidden_channels).to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

### Training the model, hyperparameter evaluation

In [39]:
for epoch in range(1, 3):
    loss = train(train_loader, model, optimizer, device='cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    if epoch % 2 == 0:
        rank, mrr, hits = test(val_loader, model, 'cuda')
        
rank, mrr, hits_at_10 = test(test_loader, model, 'cuda')
print(f'Test Mean Rank: {rank:.2f}, Test MRR: {mrr:.4f}, '
      f'Test Hits@10: {hits_at_10:.4f}')


###############Training Starts here######################### 480000
loss: 0.245957  [    0/480000]
loss: 0.154188  [64000/480000]
loss: 0.153189  [128000/480000]
loss: 0.286057  [192000/480000]
loss: 0.274347  [256000/480000]
loss: 0.337183  [320000/480000]
loss: 0.252536  [384000/480000]
loss: 0.237739  [448000/480000]
Epoch: 001, Loss: 0.2179
###############Training Starts here######################### 480000
loss: 0.258786  [    0/480000]
loss: 0.167268  [64000/480000]
loss: 0.230701  [128000/480000]
loss: 0.243550  [192000/480000]
loss: 0.298122  [256000/480000]
loss: 0.286704  [320000/480000]
loss: 0.184250  [384000/480000]
loss: 0.249438  [448000/480000]
Epoch: 002, Loss: 0.2054


100%|██████████| 60000/60000 [04:50<00:00, 206.69it/s]


Validation Mean Rank: 1894.23, Validation MRR: 0.0322, Validation Hits@10: 0.0803


100%|██████████| 60000/60000 [03:11<00:00, 313.95it/s]

Validation Mean Rank: 1886.08, Validation MRR: 0.0327, Validation Hits@10: 0.0818
Test Mean Rank: 1886.08, Test MRR: 0.0327, Test Hits@10: 0.0818





### Hyperparameter tuning:
Model needs the hyperparameter tuning to improve ranking performance.   
The model is making progress but might benefit from further optimization. Here are a few suggestions based on this training and evaluation output:

- Optimizer - trying the different optimizer like Adagrad, Adam, SGD will be helpful to tune the model for more performace.
- Early stopping - stops the model training if the model performance was not improving for some epochs.
- Learning Rate Adjustment: Lowering the learning rate or using a learning rate scheduler might help stabilize the training loss.
- Model Complexity: A more complex or fine-tuned model could potentially improve MRR and Hits@10.
- Data Augmentation: If possible, enhancing the dataset or applying regularization methods could help the model generalize better.

### Save the model for downstream tasks

In [40]:
import torch

def save_model(model, optimizer, epoch, file_path="model/complex_model.pth"):
    """
    Saves the model and optimizer state dictionaries to a file for later use.
    """
    # Save the model's state dict and optimizer state dict
    torch.save({
        'epoch': epoch,  # Save the current epoch number
        'model_state_dict': model.state_dict(),  # Model parameters
        'optimizer_state_dict': optimizer.state_dict(),  # Optimizer parameters
    }, file_path)

In [41]:
# Example usage after training the model for an epoch
save_model(model, optimizer, epoch=2)

Evaluating the model against downstream task like link predictions, classsification will help to see the model performance on potential applications.