### Install Required Libraries

In [None]:
# %%capture

# !pip install llama-index llama-index-llms-huggingface llama-index-embeddings-huggingface transformers accelerate bitsandbytes llama-index-readers-web matplotlib flash-attn

# !pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# !pip install --no-deps xformers peft trl

### Load the Important Libraries

In [None]:
from llama_index.llms.huggingface import HuggingFaceLLM
from unsloth import FastLanguageModel
import torch
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import logging
import sys
from llama_index.core import SimpleDirectoryReader, KnowledgeGraphIndex,StorageContext
from llama_index.core.graph_stores import SimpleGraphStore
from llama_index.core import Document
from llama_index.core.node_parser import SentenceSplitter
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from IPython.display import Markdown, display
logging.basicConfig(stream=sys.stdout, level=logging.INFO)


### Document Load
Loading text file having multiple paragraph

In [None]:
data_path = "/content/data/"

In [None]:
node_parser = SentenceSplitter(chunk_size=128, chunk_overlap=20)
documents = SimpleDirectoryReader(data_path,file_extractor=node_parser).load_data()
nodes = node_parser.get_nodes_from_documents(documents)

In [None]:
len(nodes)

In [None]:
nodes[0].text


## LLM & Embedding Settings

In [None]:
model_name ='unsloth/Phi-3-mini-4k-instruct'
embd_model_name = "BAAI/bge-small-en-v1.5"
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model_kwargs={"trust_remote_code": True}
generate_seq_args = {"do_sample": True, "temperature": 0.1}
graph_store = SimpleGraphStore()
storage_context = StorageContext.from_defaults(graph_store=graph_store)

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
                                    model_name = model_name,
                                    max_seq_length = max_seq_length,
                                    dtype = dtype,
                                    load_in_4bit = load_in_4bit)
Settings.embed_model = HuggingFaceEmbedding(model_name=embd_model_name)


In [None]:

def messages_to_prompt(messages):
  prompt = ""
  system_found = False
  for message in messages:
      if message.role == "system":
          prompt += f"<|system|>\n{message.content}<|end|>\n"
          system_found = True
      elif message.role == "user":
          prompt += f"<|user|>\n{message.content}<|end|>\n"
      elif message.role == "assistant":
          prompt += f"<|assistant|>\n{message.content}<|end|>\n"
      else:
          prompt += f"<|user|>\n{message.content}<|end|>\n"

  # trailing prompt
  prompt += "<|assistant|>\n"

  if not system_found:
      prompt = (
          "<|system|>\nYou are a helpful AI assistant.<|end|>\n" + prompt
      )

  return prompt

query_wrapper_prompt= (
        "<|system|>\n"
        "You are a helpful AI assistant, who is going to understand given knowledge graph. Your job is to understand thequery and write detailed answer<|end|>\n"
        "<|user|>\n"
        "{query_str}<|end|>\n"
        "<|assistant|>\n"
    )

Settings.llm = HuggingFaceLLM(
                          model=model,
                          tokenizer=tokenizer,
                          model_kwargs=model_kwargs,
                          generate_kwargs= generate_seq_args,
                          query_wrapper_prompt=query_wrapper_prompt,
                          messages_to_prompt=messages_to_prompt,
                          is_chat_model=True,
                            )

### Extract Triplets by using seq-to-seq Model

In [None]:
# Load model and tokenizer
bbl_tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
bbl_model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
gen_kwargs = {
    "max_length": 256,
    "length_penalty": 0,
    "num_beams": 3,
    "num_return_sequences": 3,
}

In [None]:
bbl_model = bbl_model.to('cuda')


In [None]:
def seq_to_seq_prediction(text):
  model_inputs = bbl_tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt')

  pred_tokens = bbl_model.generate(
      model_inputs["input_ids"].to(bbl_model.device),
      attention_mask=model_inputs["attention_mask"].to(bbl_model.device),
      **gen_kwargs,
  )

  pred = bbl_tokenizer.batch_decode(pred_tokens, skip_special_tokens=False)
  return(pred)


In [None]:

def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    return triplets

In [None]:
bbl_index = KnowledgeGraphIndex.from_documents([],include_embeddings=True)

for idx, node in enumerate(nodes):
  print(f'Processed Triplets for sentence {idx+1}')
  pred = seq_to_seq_prediction(node.text)
  for pr in pred:
    triplet = extract_triplets(text = pr)
    for trip in triplet:
      kg_val = tuple(trip.values())
      bbl_index.upsert_triplet_and_node(kg_val, nodes[idx])


In [None]:
def draw_network(indx):
  import networkx as nx
  graph = indx.get_networkx_graph()
  nx.draw(graph)


In [None]:
draw_network(indx=bbl_index)

In [None]:
bbl_query_engine = bbl_index.as_query_engine(include_text=False, response_mode="tree_summarize")

# Indexing with Default Knowledge Graph Template of Lamma Index

In [None]:
default_index = KnowledgeGraphIndex.from_documents(
                  documents,
                  max_triplets_per_chunk=2,
                  storage_context=storage_context, include_embeddings=True
                  )

default_query_engine = default_index.as_query_engine(include_text=False, response_mode="tree_summarize")

# Get custom KG from LLM

In [None]:
DEFAULT_KG_TRIPLET_EXTRACT_TMPL = (
    "You have been given the text below and you need to extract up to 3 "
    "knowledge graph triplets in the form of (subject, predicate, object). You should not extract any puctuations.\n"
    "---------------------\n"
    "Example:"
    "Text: ITMS (15 micrograms.kg-1) was injected via standard dural puncture."
    "Triplets:\n(ITMS, get injected, standard dural puncture)\n"
    "Text: Recurrent ulceration and mucosal tags are well-described oral manifestations of Crohn's disease.\n"
    "Triplets:\n"
    "(Recurrent ulceration, well-described, oral manifestations)\n"
    "(mucosal tags, well-described, oral manifestations)\n"
    "(oral manifestations, is related, Crohn's disease)\n"
    "---------------------\n"
    "Text: {text}\n"
    "Triplets:\n"
)
from llama_index.core.prompts.base import PromptTemplate
from llama_index.core.prompts.prompt_type import PromptType

DEFAULT_KG_TRIPLET_EXTRACT_PROMPT = PromptTemplate(
    DEFAULT_KG_TRIPLET_EXTRACT_TMPL,
    prompt_type=PromptType.KNOWLEDGE_TRIPLET_EXTRACT,
)

In [None]:
def extract_prompt_triplets(text):
  prompt = DEFAULT_KG_TRIPLET_EXTRACT_PROMPT.format(text=text)
  response = Settings.llm.complete(prompt)
  res = response.text.split('\n')
  # triplets = []
  # for ans in res:
  #   out = [an.replace('(','').replace(')','').strip() for an in ans.split(',')]
  #   triplets.append(out)
  return(res)

In [None]:
kg_val

In [None]:
custom_index = KnowledgeGraphIndex.from_documents([],include_embeddings=True)

for idx, node in enumerate(nodes):
  print(f'Processed Triplets for sentence {idx+1}')
  response = extract_prompt_triplets(text=node.text)
  for trip in response:
    kg_val = [one_entity.replace('(','').replace(')','').strip() for one_entity in trip.split(',')][:3]
    if len(kg_val)==3:
      custom_index.upsert_triplet_and_node(kg_val, nodes[idx])


In [None]:
custom_query_engine = custom_index.as_query_engine(include_text=False, response_mode="tree_summarize")

In [None]:
def format_response(response):
    return(display(Markdown(f"<b>{response}</b>")))

In [None]:
response = default_query_engine.query(
    "What are the causing factor for unilateral facial weakness?",
)
format_response(response)

In [None]:
response.source_nodes[0].metadata['kg_rel_texts']

In [None]:
response = bbl_query_engine.query(
    "What are the causing factor for unilateral facial weakness?",
)
format_response(response)

In [None]:
response.source_nodes[0].metadata['kg_rel_texts']

In [None]:
response = custom_query_engine.query(
    "What are the causing factor for unilateral facial weakness?",
)
format_response(response)

In [None]:
response.source_nodes[0].metadata['kg_rel_texts']