# Importing Modules

In [22]:
from transformers import pipeline
import torch
import numpy as np
import random
import os
from llama_index.core.graph_stores import SimpleGraphStore
from pprint import pprint

# Config

In [5]:
class Config:
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    seed = 42

# Loading Rebel Model

In [6]:
triplet_extractor = pipeline(
    'text2text-generation', 
    model='Babelscape/rebel-large', 
    tokenizer='Babelscape/rebel-large',
)

# Helper functions

In [7]:
def extract_triplets(text):
    """
    
    Function to extract triplets from a text chunk
    
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    return triplets

def set_seed(cls, seed: int = Config.seed):
    """
    
    Function to set the seed for the entire notebook
    
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(Config.seed)

# Running a sample example

In [8]:
text = "Punta Cana is a resort town in the municipality of Higuey, in La Altagracia Province, the eastern most province of the Dominican Republic"
extracted_text = triplet_extractor.tokenizer.batch_decode(
    [
        triplet_extractor(
            text, 
            return_tensors=True, 
            return_text=False
        )[0]["generated_token_ids"]
    ]
)

In [9]:
extracted_text[0]

'<s><triplet> Punta Cana <subj> La Altagracia Province <obj> located in the administrative territorial entity <subj> Dominican Republic <obj> country <triplet> Higuey <subj> La Altagracia Province <obj> located in the administrative territorial entity <subj> Dominican Republic <obj> country <triplet> La Altagracia Province <subj> Dominican Republic <obj> country <triplet> Dominican Republic <subj> La Altagracia Province <obj> contains administrative territorial entity</s>'

In [10]:
extracted_triplets = extract_triplets(extracted_text[0])
print(extracted_triplets)

[{'head': 'Punta Cana', 'type': 'located in the administrative territorial entity', 'tail': 'La Altagracia Province'}, {'head': 'Punta Cana', 'type': 'country', 'tail': 'Dominican Republic'}, {'head': 'Higuey', 'type': 'located in the administrative territorial entity', 'tail': 'La Altagracia Province'}, {'head': 'Higuey', 'type': 'country', 'tail': 'Dominican Republic'}, {'head': 'La Altagracia Province', 'type': 'country', 'tail': 'Dominican Republic'}, {'head': 'Dominican Republic', 'type': 'contains administrative territorial entity', 'tail': 'La Altagracia Province'}]


In [11]:
for et in extracted_triplets:
    print(et)

{'head': 'Punta Cana', 'type': 'located in the administrative territorial entity', 'tail': 'La Altagracia Province'}
{'head': 'Punta Cana', 'type': 'country', 'tail': 'Dominican Republic'}
{'head': 'Higuey', 'type': 'located in the administrative territorial entity', 'tail': 'La Altagracia Province'}
{'head': 'Higuey', 'type': 'country', 'tail': 'Dominican Republic'}
{'head': 'La Altagracia Province', 'type': 'country', 'tail': 'Dominican Republic'}
{'head': 'Dominican Republic', 'type': 'contains administrative territorial entity', 'tail': 'La Altagracia Province'}


# Loading Embedding Model

In [12]:
from sentence_transformers import SentenceTransformer

embedder = SentenceTransformer('BAAI/bge-large-zh-v1.5')

In [18]:
triplets = []
for et in extracted_triplets:
    triplets.append(str(tuple(et.values())))
triplets

["('Punta Cana', 'located in the administrative territorial entity', 'La Altagracia Province')",
 "('Punta Cana', 'country', 'Dominican Republic')",
 "('Higuey', 'located in the administrative territorial entity', 'La Altagracia Province')",
 "('Higuey', 'country', 'Dominican Republic')",
 "('La Altagracia Province', 'country', 'Dominican Republic')",
 "('Dominican Republic', 'contains administrative territorial entity', 'La Altagracia Province')"]

In [19]:
queries = ['What is Punta Cana?', 'Where is Punta Cana located?']
instruction = "Generate a representation for this sentence for use in retrieving related articles."

q_embeddings = embedder.encode([instruction+q for q in queries], normalize_embeddings=True)
p_embeddings = embedder.encode(triplets, normalize_embeddings=True)
scores = q_embeddings @ p_embeddings.T
scores

array([[0.5565389 , 0.5961919 , 0.44976604, 0.4719923 , 0.45241246,
        0.4475163 ],
       [0.62759954, 0.5839532 , 0.5231231 , 0.45175192, 0.45139718,
        0.44331726]], dtype=float32)

In [26]:
for idx, query in enumerate(queries):
    scores_mapper = dict(zip(triplets, scores[idx].tolist()))
    print(query)
    print("="*50)
    pprint(scores_mapper)

What is Punta Cana?
{"('Dominican Republic', 'contains administrative territorial entity', 'La Altagracia Province')": 0.4475162923336029,
 "('Higuey', 'country', 'Dominican Republic')": 0.4719923138618469,
 "('Higuey', 'located in the administrative territorial entity', 'La Altagracia Province')": 0.44976603984832764,
 "('La Altagracia Province', 'country', 'Dominican Republic')": 0.4524124562740326,
 "('Punta Cana', 'country', 'Dominican Republic')": 0.5961918830871582,
 "('Punta Cana', 'located in the administrative territorial entity', 'La Altagracia Province')": 0.5565388798713684}
Where is Punta Cana located?
{"('Dominican Republic', 'contains administrative territorial entity', 'La Altagracia Province')": 0.4433172643184662,
 "('Higuey', 'country', 'Dominican Republic')": 0.4517519176006317,
 "('Higuey', 'located in the administrative territorial entity', 'La Altagracia Province')": 0.5231230854988098,
 "('La Altagracia Province', 'country', 'Dominican Republic')": 0.45139718055