# Importing Modules

In [7]:
from transformers import pipeline
import torch
import numpy as np
import random
import os
from llama_index.core.graph_stores import SimpleGraphStore

  from .autonotebook import tqdm as notebook_tqdm


# Config

In [8]:
class Config:
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    seed = 42

# Loading Rebel Model

In [9]:
triplet_extractor = pipeline(
    'text2text-generation', 
    model='Babelscape/rebel-large', 
    tokenizer='Babelscape/rebel-large',
)

# Helper functions

In [10]:
def extract_triplets(text):
    """
    
    Function to extract triplets from a text chunk
    
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    return triplets

def set_seed(cls, seed: int = Config.seed):
    """
    
    Function to set the seed for the entire notebook
    
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(Config.seed)

# Running a sample example

In [11]:
text = "Punta Cana is a resort town in the municipality of Higuey, in La Altagracia Province, the eastern most province of the Dominican Republic"
extracted_text = triplet_extractor.tokenizer.batch_decode(
    [
        triplet_extractor(
            text, 
            return_tensors=True, 
            return_text=False
        )[0]["generated_token_ids"]
    ]
)

In [12]:
extracted_text[0]

'<s><triplet> Punta Cana <subj> La Altagracia Province <obj> located in the administrative territorial entity <subj> Dominican Republic <obj> country <triplet> Higuey <subj> La Altagracia Province <obj> located in the administrative territorial entity <subj> Dominican Republic <obj> country <triplet> La Altagracia Province <subj> Dominican Republic <obj> country <triplet> Dominican Republic <subj> La Altagracia Province <obj> contains administrative territorial entity</s>'

In [13]:
extracted_triplets = extract_triplets(extracted_text[0])
print(extracted_triplets)

[{'head': 'Punta Cana', 'type': 'located in the administrative territorial entity', 'tail': 'La Altagracia Province'}, {'head': 'Punta Cana', 'type': 'country', 'tail': 'Dominican Republic'}, {'head': 'Higuey', 'type': 'located in the administrative territorial entity', 'tail': 'La Altagracia Province'}, {'head': 'Higuey', 'type': 'country', 'tail': 'Dominican Republic'}, {'head': 'La Altagracia Province', 'type': 'country', 'tail': 'Dominican Republic'}, {'head': 'Dominican Republic', 'type': 'contains administrative territorial entity', 'tail': 'La Altagracia Province'}]


In [15]:
for et in extracted_triplets:
    print(et)

{'head': 'Punta Cana', 'type': 'located in the administrative territorial entity', 'tail': 'La Altagracia Province'}
{'head': 'Punta Cana', 'type': 'country', 'tail': 'Dominican Republic'}
{'head': 'Higuey', 'type': 'located in the administrative territorial entity', 'tail': 'La Altagracia Province'}
{'head': 'Higuey', 'type': 'country', 'tail': 'Dominican Republic'}
{'head': 'La Altagracia Province', 'type': 'country', 'tail': 'Dominican Republic'}
{'head': 'Dominican Republic', 'type': 'contains administrative territorial entity', 'tail': 'La Altagracia Province'}


# Loading Embedding Model

In [16]:
from sentence_transformers import SentenceTransformer

embedder = SentenceTransformer('BAAI/bge-large-zh-v1.5')

modules.json: 100%|██████████| 349/349 [00:00<00:00, 751kB/s]
config_sentence_transformers.json: 100%|██████████| 124/124 [00:00<00:00, 647kB/s]
README.md: 100%|██████████| 30.3k/30.3k [00:00<00:00, 40.2MB/s]
sentence_bert_config.json: 100%|██████████| 52.0/52.0 [00:00<00:00, 36.8kB/s]
config.json: 100%|██████████| 1.00k/1.00k [00:00<00:00, 5.43MB/s]
model.safetensors: 100%|██████████| 1.30G/1.30G [01:30<00:00, 14.3MB/s]
tokenizer_config.json: 100%|██████████| 394/394 [00:00<00:00, 673kB/s]
vocab.txt: 100%|██████████| 110k/110k [00:00<00:00, 478kB/s]
tokenizer.json: 100%|██████████| 439k/439k [00:00<00:00, 637kB/s]
special_tokens_map.json: 100%|██████████| 125/125 [00:00<00:00, 572kB/s]
1_Pooling/config.json: 100%|██████████| 191/191 [00:00<00:00, 638kB/s]
