<a href="https://colab.research.google.com/github/Abhijit85/FederatedRAG/blob/main/TransGPT_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Libraries

In [None]:
# whether you are using a GPU to run this Colab
use_gpu = True
# whether you are using a custom GCE env to run the Colab (uses different CUDA)
custom_GCE_env = False

In [None]:
%pip install openai
%pip install python-dotenv
%pip install torch-geometric
%pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
%pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html

from dotenv import load_dotenv
from openai import OpenAI
import os
import re
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import math
from torch_geometric.data import InMemoryDataset, DataLoader
import torch_geometric

Collecting python-dotenv
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.0.1
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1
Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h

# Data Preparation and Processing

In [None]:
class FB15kDataset(torch_geometric.data.InMemoryDataset):
    r"""FB15-237 dataset from Freebase.
    Follows similar structure to torch_geometric.datasets.rel_link_pred_dataset

    Args:
      root (string): Root directory where the dataset should be saved.
      transform (callable, optional): A function/transform that takes in an
          :obj:`torch_geometric.data.Data` object and returns a transformed
          version. The data object will be transformed before every access.
          (default: :obj:`None`)
      pre_transform (callable, optional): A function/transform that takes in
          an :obj:`torch_geometric.data.Data` object and returns a
          transformed version. The data object will be transformed before
          being saved to disk. (default: :obj:`None`)
    """
    data_path = 'https://raw.githubusercontent.com/DeepGraphLearning/' \
                'KnowledgeGraphEmbedding/master/data/FB15k-237'

    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['train.txt', 'valid.txt', 'test.txt',
                'entities.dict', 'relations.dict']

    @property
    def processed_file_names(self):
        return ['data.pt']

    @property
    def raw_dir(self):
        return os.path.join(self.root, 'raw')

    def download(self):
        for file_name in self.raw_file_names:
            torch_geometric.data.download_url(f'{self.data_path}/{file_name}',
                                              self.raw_dir)

    def process(self):
        with open(os.path.join(self.raw_dir, 'entities.dict'), 'r') as f:
            lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
            entities_dict = {key: int(value) for value, key in lines}

        with open(os.path.join(self.raw_dir, 'relations.dict'), 'r') as f:
            lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
            relations_dict = {key: int(value) for value, key in lines}

        kwargs = {}
        for split in ['train', 'valid', 'test']:
            with open(os.path.join(self.raw_dir, f'{split}.txt'), 'r') as f:
                lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
                heads = [entities_dict[row[0]] for row in lines]
                relations = [relations_dict[row[1]] for row in lines]
                tails = [entities_dict[row[2]] for row in lines]
                kwargs[f'{split}_edge_index'] = torch.tensor([heads, tails])
                kwargs[f'{split}_edge_type'] = torch.tensor(relations)

        _data = torch_geometric.data.Data(num_entities=len(entities_dict),
                                          num_relations=len(relations_dict),
                                          **kwargs)

        if self.pre_transform is not None:
            _data = self.pre_transform(_data)

        data, slices = self.collate([_data])

        torch.save((data, slices), self.processed_paths[0])

# Load dataset
FB15k_dset = FB15kDataset(root='FB15k')
data = FB15k_dset[0]

Downloading https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/train.txt
Downloading https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/valid.txt
Downloading https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/test.txt
Downloading https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/entities.dict
Downloading https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/relations.dict
Processing...
Done!
  self.data, self.slices = torch.load(self.processed_paths[0])


# TransGPT Model

In [None]:
class TransE(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim, margin, distance_metric='L1', visualize=False):
        super(TransE, self).__init__()
        self.entity_embeddings = nn.Embedding(num_entities, embedding_dim)
        self.relation_embeddings = nn.Embedding(num_relations, embedding_dim)
        self.margin = margin
        self.distance_metric = distance_metric
        self.visualize = visualize

        # Initialize embeddings using TransE paper's method
        uniform_max = 6 / np.sqrt(embedding_dim)
        self.entity_embeddings.weight.data.uniform_(-uniform_max, uniform_max)
        self.relation_embeddings.weight.data.uniform_(-uniform_max, uniform_max)

    def forward(self, edge_index, negative_edge_index, edge_type):
        positive_distance = self.distance(edge_index, edge_type)
        negative_distance = self.distance(negative_edge_index, edge_type)
        return self.loss(positive_distance, negative_distance)

    def predict(self, edge_index, edge_type):
        return self.distance(edge_index, edge_type)

    def distance(self, edge_index, edge_type):
        heads = edge_index[0, :]
        tails = edge_index[1, :]
        return (self.entity_embeddings(heads) + self.relation_embeddings(edge_type) -
                self.entity_embeddings(tails)).norm(p=2., dim=1, keepdim=True)  # L2 norm of h + r - t

    # def loss(self, positive_distance, negative_distance):
    #     batch_size = positive_distance.size(0)
    #     y = torch.tensor([-1], dtype=torch.long, device=self.entity_embeddings.weight.device).repeat(batch_size)
    #     criterion = nn.MarginRankingLoss(margin=self.margin)
    #     return criterion(positive_distance, negative_distance, y)

    def loss(self, positive_distance, negative_distance):
      batch_size = positive_distance.size(0)
      y = torch.full((batch_size, 1), -1, dtype=torch.float, device=self.entity_embeddings.weight.device)
      criterion = nn.MarginRankingLoss(margin=self.margin)
      return criterion(positive_distance, negative_distance, y)


# Helper function to create corrupted edges
def create_corrupted_edge_index(edge_index, edge_type, num_entities):
    corrupt_head_or_tail = torch.randint(high=2, size=edge_type.size(),
                                         device=edge_index.device)
    random_entities = torch.randint(high=num_entities,
                                     size=edge_type.size(), device=edge_index.device)
    # corrupt when 1, otherwise regular head
    heads = torch.where(corrupt_head_or_tail == 1, random_entities,
                        edge_index[0, :])
    # corrupt when 0, otherwise regular tail
    tails = torch.where(corrupt_head_or_tail == 0, random_entities,
                        edge_index[1, :])
    return torch.stack([heads, tails], dim=0)

# Helper function to prompt predictions to LLM

def refine_predictions_with_llm(predictions, batch_edge_index, batch_edge_type, model, entity_dict, relation_dict):
    # Decode the indices into human-readable entity and relation names
    head_entities = [entity_dict[int(head)] for head in batch_edge_index[0].tolist()]
    relation_names = [relation_dict[int(rel)] for rel in batch_edge_type.tolist()]
    predicted_tails = [entity_dict[int(torch.argmin(pred))] for pred in predictions]

    # Format the data for LLM
    input_data = []
    for h, r, t in zip(head_entities, relation_names, predicted_tails):
        input_data.append(f"Head: {h}, Relation: {r}, Predicted Tail: {t}")

    # Prompt the LLM (GPT-3.5-turbo/GPT-4 )
    api_key = "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"
    if not api_key:
        raise ValueError("OpenAI API key not found in environment variables")

    client = OpenAI(api_key=api_key)
    # openai.api_key = "your-openai-api-key"
    response = client.chat.completions.create(
        model="gpt-4", # gpt-3.5-turbo
        messages=[
            {"role": "system", "content": "You are an expert in knowledge graphs and TransE embeddings."},
            {"role": "user", "content": "Please evaluate and improve the following predictions:\n" + "\n".join(input_data)}
        ]
    )

    # Parse LLM response
    # refined_predictions = response['choices'][0]['message']['content']
    refined_predictions = response.choices[0].message.content
    print("Refined Predictions from LLM:")
    print(refined_predictions)

# Train Function

## Model: TransE
**Embeddings:**
Each entity and relation is represented as a vector in a high-dimensional space.
The embeddings are initialized randomly and updated during training.
**Distance Metric:**
TransE predicts relationships by minimizing the distance between embeddings of head + relation - tail.
A lower distance indicates a more likely relationship.

## How Is the LLM Used?
1. Prediction Refinement
After the TransE model predicts relationships (e.g., a tail entity for a given head and relation), these predictions are passed to the LLM.
The LLM evaluates the predictions, identifies errors, and suggests corrections or more plausible results.

In [None]:
def train(model, data, optimizer, device, entity_dict, relation_dict, epochs=50, batch_size=128, valid_freq=5):
    train_edge_index = data.train_edge_index.to(device)
    train_edge_type = data.train_edge_type.to(device)
    valid_edge_index = data.valid_edge_index.to(device)
    valid_edge_type = data.valid_edge_type.to(device)

    best_valid_score = 0
    valid_scores = None
    test_scores = None

    for epoch in range(epochs):
        model.train()

        # Normalize entity embeddings
        entities_norm = torch.norm(model.entity_embeddings.weight.data, dim=1, keepdim=True)
        model.entity_embeddings.weight.data = model.entity_embeddings.weight.data / entities_norm

        # Shuffle the training data
        num_triples = train_edge_type.size(0)
        shuffled_indices = torch.randperm(num_triples)
        shuffled_edge_index = train_edge_index[:, shuffled_indices]
        shuffled_edge_type = train_edge_type[shuffled_indices]

        negative_edge_index = create_corrupted_edge_index(shuffled_edge_index, shuffled_edge_type, data.num_entities)

        total_loss = 0
        total_size = 0

        for batch_start in range(0, num_triples, batch_size):
            batch_end = min(batch_start + batch_size, num_triples)
            batch_edge_index = shuffled_edge_index[:, batch_start:batch_end]
            batch_negative_edge_index = negative_edge_index[:, batch_start:batch_end]
            batch_edge_type = shuffled_edge_type[batch_start:batch_end]

            loss = model(batch_edge_index, batch_negative_edge_index, batch_edge_type)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * (batch_end - batch_start)
            total_size += batch_end - batch_start

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / total_size:.4f}")

        # Validation at regular intervals
        if (epoch + 1) % valid_freq == 0:
            mrr_score, mr_score, hits_at_10 = evaluate_model(
                model, valid_edge_index, valid_edge_type, data.num_entities, device
            )
            print(f"Validation score: MRR = {mrr_score:.4f}, MR = {mr_score:.4f}, Hits@10 = {hits_at_10:.4f}")

            # Track best validation score
            if mrr_score > best_valid_score:
                best_valid_score = mrr_score
                test_mrr, test_mr, test_hits_at_10 = evaluate_model(
                    model, data.test_edge_index.to(device), data.test_edge_type.to(device), data.num_entities, device
                )
                test_scores = (test_mrr, test_mr, test_hits_at_10)

    print(f"Test scores from the best model (MMR, MR, Hits@10): {test_scores}")

# Evaluate Model

## Prediction:
After training, the model can predict missing relationships by ranking possible tail entities for a given (head, relation, ?).
Example Query:
Input: (Steve Jobs, FounderOf, ?)
Output: Apple (highest-ranked entity).

In [None]:
def evaluate_model(model, edge_index, edge_type, num_entities, device, eval_batch_size=64):
    model.eval()
    num_triples = edge_type.size(0)
    mrr_score = 0
    mr_score = 0
    hits_at_10 = 0
    num_predictions = 0

    with torch.no_grad():
        for batch_idx in range(math.ceil(num_triples / eval_batch_size)):
            batch_start = batch_idx * eval_batch_size
            batch_end = min((batch_idx + 1) * eval_batch_size, num_triples)
            batch_edge_index = edge_index[:, batch_start:batch_end]
            batch_edge_type = edge_type[batch_start:batch_end]
            batch_size = batch_edge_type.size(0)

            all_entities = torch.arange(num_entities, device=device).unsqueeze(0).repeat(batch_size, 1)
            head_repeated = batch_edge_index[0, :].reshape(-1, 1).repeat(1, num_entities)
            relation_repeated = batch_edge_type.reshape(-1, 1).repeat(1, num_entities)

            head_squeezed = head_repeated.reshape(-1)
            relation_squeezed = relation_repeated.reshape(-1)
            all_entities_squeezed = all_entities.reshape(-1)

            entity_index_replaced_tail = torch.stack((head_squeezed, all_entities_squeezed))
            predictions = model.predict(entity_index_replaced_tail, relation_squeezed)
            predictions = predictions.reshape(batch_size, -1)
            gt = batch_edge_index[1, :].reshape(-1, 1)

            mrr_score += mrr(predictions, gt)
            mr_score += mr(predictions, gt)
            hits_at_10 += hit_at_k(predictions, gt, device=device, k=10)
            num_predictions += batch_size

    mrr_score = mrr_score / num_predictions
    mr_score = mr_score / num_predictions
    hits_at_10 = hits_at_10 / num_predictions
    return mrr_score, mr_score, hits_at_10

# Metric Functions
def mrr(predictions, gt):
    indices = predictions.argsort()
    return (1.0 / (indices == gt).nonzero()[:, 1].float().add(1.0)).sum().item()

def mr(predictions, gt):
    indices = predictions.argsort()
    return ((indices == gt).nonzero()[:, 1].float().add(1.0)).sum().item()

def hit_at_k(predictions, gt, device, k=10):
    zero_tensor = torch.tensor([0], device=device)
    one_tensor = torch.tensor([1], device=device)
    _, indices = predictions.topk(k=k, largest=False)
    return torch.where(indices == gt, one_tensor, zero_tensor).sum().item()


# Start Training

## Positive Triplets:
The dataset provides positive examples in the form of valid (head, relation, tail) triplets.
## Negative Sampling:
For each positive triplet, a corrupted version is generated by replacing either the head or tail with a random entity.
## Loss Function:
The model uses margin-based ranking loss:
Ensures valid triplets are closer in embedding space than invalid ones by at least a predefined margin.

In [8]:
lr = 0.1
use_gpu = torch.cuda.is_available()
if use_gpu:
    epochs = 80
    valid_freq = 5
else:
    epochs = 10
    valid_freq = 10

device = torch.device('cuda' if use_gpu else 'cpu')
model = TransE(data.num_entities, data.num_relations, embedding_dim=400, margin=2.0).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

entity_dict = {}  # Load or define entity dictionary here
relation_dict = {}  # Load or define relation dictionary here

train(model, data, optimizer, device, entity_dict, relation_dict, epochs=epochs, valid_freq=valid_freq)

Epoch 1/80, Loss: 1.9707
Epoch 2/80, Loss: 1.9274
Epoch 3/80, Loss: 1.9024
Epoch 4/80, Loss: 1.8830
Epoch 5/80, Loss: 1.8662
Validation score: MRR = 0.2022, MR = 3387.3837, Hits@10 = 0.2926
Epoch 6/80, Loss: 1.8517
Epoch 7/80, Loss: 1.8387
Epoch 8/80, Loss: 1.8252
Epoch 9/80, Loss: 1.8134
Epoch 10/80, Loss: 1.8013
Validation score: MRR = 0.2211, MR = 2816.6613, Hits@10 = 0.3306
Epoch 11/80, Loss: 1.7910
Epoch 12/80, Loss: 1.7801
Epoch 13/80, Loss: 1.7705
Epoch 14/80, Loss: 1.7601
Epoch 15/80, Loss: 1.7524
Validation score: MRR = 0.2296, MR = 2441.1472, Hits@10 = 0.3510
Epoch 16/80, Loss: 1.7437
Epoch 17/80, Loss: 1.7352
Epoch 18/80, Loss: 1.7278
Epoch 19/80, Loss: 1.7207
Epoch 20/80, Loss: 1.7131
Validation score: MRR = 0.2344, MR = 2138.7070, Hits@10 = 0.3612
Epoch 21/80, Loss: 1.7051
Epoch 22/80, Loss: 1.6991
Epoch 23/80, Loss: 1.6921
Epoch 24/80, Loss: 1.6858
Epoch 25/80, Loss: 1.6781
Validation score: MRR = 0.2383, MR = 1888.2295, Hits@10 = 0.3694
Epoch 26/80, Loss: 1.6719
Epoch 27

# Example Workflow:
## Input:

**Dataset**: (Barack Obama, PresidentOf, United States), (Elon Musk, FounderOf, Tesla).
**Embedding Initialization**:

**Entities**: Barack Obama, United States, Elon Musk, Tesla.
**Relations**: PresidentOf, FounderOf.
**Training:**

**Positive Triplets**: (Barack Obama, PresidentOf, United States).
**Negative Sampling**: (Barack Obama, PresidentOf, RandomEntity).
Evaluation:

Metrics like **MRR, MR,** and **Hits@10** are computed during validation to measure the model’s performance.
Prediction:

**Query**: *(Elon Musk, FounderOf, ?)*
**Prediction**: Tesla (most likely tail).
