<a href="https://colab.research.google.com/github/Abhijit85/FederatedRAG/blob/main/TransEE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Libraries

In [None]:
# whether you are using a GPU to run this Colab
use_gpu = True
# whether you are using a custom GCE env to run the Colab (uses different CUDA)
custom_GCE_env = False

In [None]:
%pip install openai
%pip install python-dotenv
# %pip install torch-geometric
# %pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
# %pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html

from dotenv import load_dotenv
from openai import OpenAI
import os
import re
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import numpy as np
import math
# from torch_geometric.data import InMemoryDataset, DataLoader
# import torch_geometric

Collecting python-dotenv
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.0.1


# Data Preparation and Processing

In [None]:
class CustomDataset:
    def __init__(self, data_path: str):
        """
        Custom Dataset class for loading and processing data without PyTorch Geometric.

        Args:
            data_path (str): Path to the dataset directory.
        """

        #data_path = '/Users/abhi/GitHUB/FederatedRAG1/DataSets/FB15k-237'
        # Paths to files
        self.entity_dict_path = os.path.join(data_path, 'entities.dict')
        self.relation_dict_path = os.path.join(data_path, 'relations.dict')
        self.train_data_path = os.path.join(data_path, 'train.txt')
        self.valid_data_path = os.path.join(data_path, 'valid.txt')
        self.test_data_path = os.path.join(data_path, 'test.txt')

        # Load dictionaries and datasets
        self.entity_dict = self._read_dict(self.entity_dict_path)
        self.relation_dict = self._read_dict(self.relation_dict_path)

        self.train_data = self._read_data(self.train_data_path)
        self.valid_data = self._read_data(self.valid_data_path)
        self.test_data = self._read_data(self.test_data_path)

        self.num_entities = len(self.entity_dict)
        self.num_relations = len(self.relation_dict)

    # def _read_dict(self, file_path):
    #     """Read a dictionary file mapping strings to integers."""
    #     with open(file_path, 'r') as f:
    #         lines = f.readlines()
    #     return {line.split('\t')[0]: int(line.split('\t')[1]) for line in lines}

    def _read_dict(self, file_path: str):
        """
        Read entity / relation dict.
        Format: dict({id: entity / relation})
        """

        element_dict = {}
        with open(file_path, 'r') as f:
            for line in f:
                id_, element = line.strip().split('\t')
                element_dict[element] = int(id_)

        return element_dict

    def _read_data(self, file_path):
        """Read triples data and map to indices."""
        with open(file_path, 'r') as f:
            lines = f.readlines()
        triples = [line.strip().split('\t') for line in lines]
        return [(self.entity_dict[h], self.relation_dict[r], self.entity_dict[t]) for h, r, t in triples]

    def get_edge_indices_and_types(self, data):
        """Convert triples into edge indices and types for PyTorch tensors."""
        heads, relations, tails = zip(*data)
        edge_index = torch.tensor([heads, tails], dtype=torch.long)  # Shape: (2, num_edges)
        edge_type = torch.tensor(relations, dtype=torch.long)  # Shape: (num_edges,)
        return edge_index, edge_type


# TransGPT Model

In [None]:
class TransEEnhanced(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim, margin, distance_metric="L1",
                 gamma=12.0, phase_weight=0.5, modulus_weight=1.0, epsilon=2.0):
        super(TransEEnhanced, self).__init__()

        # Basic TransE embeddings
        self.entity_modulus = nn.Embedding(num_entities, embedding_dim)
        self.entity_phase = nn.Embedding(num_entities, embedding_dim)
        self.relation_modulus = nn.Embedding(num_relations, embedding_dim)
        self.relation_phase = nn.Embedding(num_relations, embedding_dim)

        # Margin and distance settings
        self.margin = margin
        self.distance_metric = distance_metric

        # Hyperbolic and scaling settings
        self.gamma = nn.Parameter(torch.Tensor([gamma]), requires_grad=False)
        self.embedding_range = nn.Parameter(
            torch.Tensor([(self.gamma.item() + epsilon) / embedding_dim]), requires_grad=False
        )

        # Weights for phase and modulus
        self.phase_weight = phase_weight
        self.modulus_weight = modulus_weight

        # Initialization
        nn.init.uniform_(self.entity_modulus.weight, a=-self.embedding_range.item(), b=self.embedding_range.item())
        nn.init.uniform_(self.entity_phase.weight, a=-np.pi, b=np.pi)
        nn.init.uniform_(self.relation_modulus.weight, a=-self.embedding_range.item(), b=self.embedding_range.item())
        nn.init.uniform_(self.relation_phase.weight, a=-np.pi, b=np.pi)

    def forward(self, head, relation, tail):
        h_mod = self.entity_modulus(head)
        h_phase = self.entity_phase(head)
        r_mod = self.relation_modulus(relation)
        r_phase = self.relation_phase(relation)
        t_mod = self.entity_modulus(tail)
        t_phase = self.entity_phase(tail)

        # Modulus scoring: hyperbolic-inspired adjustment
        modulus_score = torch.norm(h_mod * r_mod - t_mod, p=2, dim=-1)

        # Phase scoring: advanced angular consistency
        phase_diff = torch.abs(torch.sin((h_phase + r_phase - t_phase) / 2))
        phase_score = torch.sum(phase_diff, dim=-1)

        # Weighted combined score
        score = self.modulus_weight * modulus_score + self.phase_weight * phase_score
        return score

    def compute_loss(self, positive_score, negative_score):
        # Margin-based ranking loss
        base_loss = F.relu(self.margin + positive_score - negative_score)

        # Regularization terms for modulus and phase
        modulus_regularization = torch.sum(torch.norm(self.entity_modulus.weight, p=2, dim=-1))
        phase_regularization = torch.sum(torch.norm(self.entity_phase.weight, p=2, dim=-1))

        # Total loss with regularization
        total_loss = base_loss.mean() + 1e-4 * (modulus_regularization + phase_regularization)
        return total_loss

# Helper function to create corrupted edges
def create_corrupted_edge_index(edge_index, edge_type, num_entities):
    corrupt_head_or_tail = torch.randint(high=2, size=edge_type.size(),
                                         device=edge_index.device)
    random_entities = torch.randint(high=num_entities,
                                     size=edge_type.size(), device=edge_index.device)
    # corrupt when 1, otherwise regular head
    heads = torch.where(corrupt_head_or_tail == 1, random_entities,
                        edge_index[0, :])
    # corrupt when 0, otherwise regular tail
    tails = torch.where(corrupt_head_or_tail == 0, random_entities,
                        edge_index[1, :])
    return torch.stack([heads, tails], dim=0)

# # Helper function to prompt predictions to LLM

# def refine_predictions_with_llm(predictions, batch_edge_index, batch_edge_type, model, entity_dict, relation_dict):
#     # Decode the indices into human-readable entity and relation names
#     head_entities = [entity_dict[int(head)] for head in batch_edge_index[0].tolist()]
#     relation_names = [relation_dict[int(rel)] for rel in batch_edge_type.tolist()]
#     predicted_tails = [entity_dict[int(torch.argmin(pred))] for pred in predictions]

#     # Format the data for LLM
#     input_data = []
#     for h, r, t in zip(head_entities, relation_names, predicted_tails):
#         input_data.append(f"Head: {h}, Relation: {r}, Predicted Tail: {t}")

    # # Prompt the LLM (GPT-3.5-turbo/GPT-4 )
    # api_key = "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"
    # if not api_key:
    #     raise ValueError("OpenAI API key not found in environment variables")

    # client = OpenAI(api_key=api_key)
    # # openai.api_key = "your-openai-api-key"
    # response = client.chat.completions.create(
    #     model="gpt-4", # gpt-3.5-turbo
    #     messages=[
    #         {"role": "system", "content": "You are an expert in knowledge graphs and TransE embeddings."},
    #         {"role": "user", "content": "Please evaluate and improve the following predictions:\n" + "\n".join(input_data)}
    #     ]
    # )

    # # Parse LLM response
    # # refined_predictions = response['choices'][0]['message']['content']
    # refined_predictions = response.choices[0].message.content
    # print("Refined Predictions from LLM:")
    # print(refined_predictions)

# Train Function

## Model: TransE
**Embeddings:**
Each entity and relation is represented as a vector in a high-dimensional space.
The embeddings are initialized randomly and updated during training.
**Distance Metric:**
TransE predicts relationships by minimizing the distance between embeddings of head + relation - tail.
A lower distance indicates a more likely relationship.

## How Is the LLM Used?
1. Prediction Refinement
After the TransE model predicts relationships (e.g., a tail entity for a given head and relation), these predictions are passed to the LLM.
The LLM evaluates the predictions, identifies errors, and suggests corrections or more plausible results.

In [None]:
def train(model, data, optimizer, device, entity_dict, relation_dict, epochs=50, batch_size=128, valid_freq=5):
    train_edge_index = data.train_edge_index.to(device)
    train_edge_type = data.train_edge_type.to(device)
    valid_edge_index = data.valid_edge_index.to(device)
    valid_edge_type = data.valid_edge_type.to(device)

    best_valid_score = 0
    valid_scores = None
    test_scores = None

    for epoch in range(epochs):
        model.train()

        # Normalize entity embeddings (modulus only)
        entities_modulus_norm = torch.norm(model.entity_modulus.weight.data, dim=1, keepdim=True)
        model.entity_modulus.weight.data = model.entity_modulus.weight.data / entities_modulus_norm

        # Shuffle the training data
        num_triples = train_edge_type.size(0)
        shuffled_indices = torch.randperm(num_triples)
        shuffled_edge_index = train_edge_index[:, shuffled_indices]
        shuffled_edge_type = train_edge_type[shuffled_indices]

        negative_edge_index = create_corrupted_edge_index(shuffled_edge_index, shuffled_edge_type, data.num_entities)

        total_loss = 0
        total_size = 0

        for batch_start in range(0, num_triples, batch_size):
            batch_end = min(batch_start + batch_size, num_triples)
            batch_edge_index = shuffled_edge_index[:, batch_start:batch_end]
            batch_negative_edge_index = negative_edge_index[:, batch_start:batch_end]
            batch_edge_type = shuffled_edge_type[batch_start:batch_end]

            # Compute positive and negative scores for TransEEnhanced
            positive_score = model(batch_edge_index[0], batch_edge_type, batch_edge_index[1])
            negative_score = model(batch_negative_edge_index[0], batch_edge_type, batch_negative_edge_index[1])

            # Compute loss using TransEEnhanced's loss function
            loss = model.compute_loss(positive_score, negative_score)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * (batch_end - batch_start)
            total_size += batch_end - batch_start

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / total_size:.4f}")

        # Validation at regular intervals
        if (epoch + 1) % valid_freq == 0:
            mrr_score, mr_score, hits_at_10 = evaluate_model(
                model, valid_edge_index, valid_edge_type, data.num_entities, device
            )
            print(f"Validation score: MRR = {mrr_score:.4f}, MR = {mr_score:.4f}, Hits@10 = {hits_at_10:.4f}")

            # Track best validation score
            if mrr_score > best_valid_score:
                best_valid_score = mrr_score
                test_mrr, test_mr, test_hits_at_10 = evaluate_model(
                    model, data.test_edge_index.to(device), data.test_edge_type.to(device), data.num_entities, device
                )
                test_scores = (test_mrr, test_mr, test_hits_at_10)

    print(f"Test scores from the best model (MMR, MR, Hits@10): {test_scores}")


# Evaluate Model

## Prediction:
After training, the model can predict missing relationships by ranking possible tail entities for a given (head, relation, ?).
Example Query:
Input: (Steve Jobs, FounderOf, ?)
Output: Apple (highest-ranked entity).

In [None]:
def evaluate_model(model, edge_index, edge_type, num_entities, device, eval_batch_size=64):
    model.eval()
    num_triples = edge_type.size(0)
    mrr_score = 0
    mr_score = 0
    hits_at_10 = 0
    num_predictions = 0

    with torch.no_grad():
        for batch_idx in range(math.ceil(num_triples / eval_batch_size)):
            batch_start = batch_idx * eval_batch_size
            batch_end = min((batch_idx + 1) * eval_batch_size, num_triples)
            batch_edge_index = edge_index[:, batch_start:batch_end]
            batch_edge_type = edge_type[batch_start:batch_end]
            batch_size = batch_edge_type.size(0)

            all_entities = torch.arange(num_entities, device=device).unsqueeze(0).repeat(batch_size, 1)
            head_repeated = batch_edge_index[0, :].reshape(-1, 1).repeat(1, num_entities)
            relation_repeated = batch_edge_type.reshape(-1, 1).repeat(1, num_entities)

            head_squeezed = head_repeated.reshape(-1)
            relation_squeezed = relation_repeated.reshape(-1)
            all_entities_squeezed = all_entities.reshape(-1)

            entity_index_replaced_tail = torch.stack((head_squeezed, all_entities_squeezed))
            predictions = model(entity_index_replaced_tail[0], relation_squeezed, entity_index_replaced_tail[1])
            predictions = predictions.reshape(batch_size, -1)
            gt = batch_edge_index[1, :].reshape(-1, 1)

            mrr_score += mrr(predictions, gt)
            mr_score += mr(predictions, gt)
            hits_at_10 += hit_at_k(predictions, gt, device=device, k=10)
            num_predictions += batch_size

    mrr_score = mrr_score / num_predictions
    mr_score = mr_score / num_predictions
    hits_at_10 = hits_at_10 / num_predictions
    return mrr_score, mr_score, hits_at_10

# Metric Functions
def mrr(predictions, gt):
    indices = predictions.argsort()
    return (1.0 / (indices == gt).nonzero()[:, 1].float().add(1.0)).sum().item()

def mr(predictions, gt):
    indices = predictions.argsort()
    return ((indices == gt).nonzero()[:, 1].float().add(1.0)).sum().item()

def hit_at_k(predictions, gt, device, k=10):
    zero_tensor = torch.tensor([0], device=device)
    one_tensor = torch.tensor([1], device=device)
    _, indices = predictions.topk(k=k, largest=False)
    return torch.where(indices == gt, one_tensor, zero_tensor).sum().item()


# Start Training

## Positive Triplets:
The dataset provides positive examples in the form of valid (head, relation, tail) triplets.
## Negative Sampling:
For each positive triplet, a corrupted version is generated by replacing either the head or tail with a random entity.
## Loss Function:
The model uses margin-based ranking loss:
Ensures valid triplets are closer in embedding space than invalid ones by at least a predefined margin.

In [None]:
lr = 0.1
use_gpu = torch.cuda.is_available()
if use_gpu:
    epochs = 10
    valid_freq = 5
else:
    epochs = 10
    valid_freq = 10

device = torch.device('cuda' if use_gpu else 'cpu')

# Load dataset using CustomDataset class
data_path = '/content/sample_data'
dataset = CustomDataset(data_path)
data = dataset

# Extract edge indices and types
train_edge_index, train_edge_type = dataset.get_edge_indices_and_types(dataset.train_data)
valid_edge_index, valid_edge_type = dataset.get_edge_indices_and_types(dataset.valid_data)
test_edge_index, test_edge_type = dataset.get_edge_indices_and_types(dataset.test_data)

model = TransEEnhanced(data.num_entities, data.num_relations, embedding_dim=400, margin=2.0).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

entity_dict = dataset.entity_dict  # Load entity dictionary from dataset
relation_dict = dataset.relation_dict  # Load relation dictionary from dataset

# Create a data object with the extracted edge indices and types
data.train_edge_index = train_edge_index
data.train_edge_type = train_edge_type
data.valid_edge_index = valid_edge_index
data.valid_edge_type = valid_edge_type
data.test_edge_index = test_edge_index
data.test_edge_type = test_edge_type

# Training
train(
    model=model,
    data=data,
    optimizer=optimizer,
    device=device,
    entity_dict=entity_dict,
    relation_dict=relation_dict,
    epochs=epochs,  # Use the epochs variable defined earlier
    batch_size=64,
    valid_freq=valid_freq  # Use the valid_freq variable defined earlier
)

Epoch 1/10, Loss: 66.7723
Epoch 2/10, Loss: 79.7920
Epoch 3/10, Loss: 80.0953
Epoch 4/10, Loss: 80.3803
Epoch 5/10, Loss: 80.4606
Validation score: MRR = 0.0064, MR = 7336.6015, Hits@10 = 0.0067
Epoch 6/10, Loss: 80.5395


# Example Workflow:
## Input:

**Dataset**: (Barack Obama, PresidentOf, United States), (Elon Musk, FounderOf, Tesla).
**Embedding Initialization**:

**Entities**: Barack Obama, United States, Elon Musk, Tesla.
**Relations**: PresidentOf, FounderOf.
**Training:**

**Positive Triplets**: (Barack Obama, PresidentOf, United States).
**Negative Sampling**: (Barack Obama, PresidentOf, RandomEntity).
Evaluation:

Metrics like **MRR, MR,** and **Hits@10** are computed during validation to measure the model’s performance.
Prediction:

**Query**: *(Elon Musk, FounderOf, ?)*
**Prediction**: Tesla (most likely tail).
