# KG generation with customized LM

In [1]:
import os
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
from llama_index.llms import HuggingFaceLLM
from llama_index.prompts import PromptTemplate
from transformers import BitsAndBytesConfig
from IPython.display import Markdown, display
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
    BitsAndBytesConfig
)
from typing import Optional, List, Mapping, Any, Tuple
from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings
from llama_index import (
    ServiceContext, 
    SimpleDirectoryReader, 
#     LangchainEmbedding, 
#     ListIndex,
    KnowledgeGraphIndex
)
from llama_index.callbacks import CallbackManager
from llama_index.llms import (
    CustomLLM, 
    CompletionResponse, 
    CompletionResponseGen,
    LLMMetadata,
)
from llama_index.storage.storage_context import StorageContext
from llama_index.graph_stores import NebulaGraphStore
from llama_index.llms.base import llm_completion_callback

### 4bit quantization

In [2]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

In [4]:
# set context window size
context_window = 2048
# set number of output tokens
num_output = 256

# model_name = "bofenghuang/vigostral-7b-chat"
# model_name = "Open-Orca/Mistral-7B-OpenOrca"
model_name = "HuggingFaceH4/zephyr-7b-beta"
model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto", quantization_config=quantization_config, trust_remote_code=True
)

Downloading shards:   0%|          | 0/8 [00:00<?, ?it/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

Downloading (…)of-00008.safetensors:   0%|          | 0.00/816M [00:00<?, ?B/s]


BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64
Loading CUDA version: BNB_CUDA_VERSION=116


  warn((f'\n\n{"="*80}\n'


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

### Config the model

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [12]:
DEFAULT_KG_TRIPLET_EXTRACT_TMPL = (
    "Some text is provided below. Given the text, extract up to "
    "{max_knowledge_triplets} "
    "knowledge triplets in the form of (subject, predicate, object). Avoid stopwords.\n"
    "---------------------\n"
    "Example:"
    "Text: Alice is Bob's mother."
    "Triplets:\n(Alice, is mother of, Bob)\n"
    "Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
    "Triplets:\n"
    "(Philz, is, coffee shop)\n"
    "(Philz, founded in, Berkeley)\n"
    "(Philz, founded in, 1982)\n"
    "---------------------\n"
    "Text: {text}\n"
    "Triplets:\n"
)
encoded = tokenizer.encode("Les chercheurs ont découvert une nouvelle espèce de papillon, qui vit exclusivement dans les forêts tropicales humides.")
len(encoded)

31

In [18]:
test = "La découverte de la pénicilline par Alexander Fleming en 1928 a révolutionné le domaine médical. Cette avancée majeure a ouvert la voie à la fabrication d'antibiotiques, qui ont depuis sauvé d'innombrables vies. La pénicilline, un antibiotique naturel produit par le champignon Penicillium, a démontré son efficacité dans le traitement des infections bactériennes. La capacité de la pénicilline à tuer les bactéries a été le sujet de nombreuses études scientifiques. Ces recherches ont permis de comprendre comment la pénicilline agit en inhibant la synthèse de la paroi cellulaire des bactéries. Cette découverte a jeté les bases de l'utilisation des antibiotiques dans le traitement des infections et a ouvert la porte à de futures avancées dans le domaine de la médecine."
encoded = tokenizer.encode(test)
print(len(test.split(" ")))
len(encoded)

119


236

In [None]:
class ZephyrLLM(CustomLLM):
    context_window: int = 2048
    num_output: int = 256
    model_name: str = "custom"
    dummy_response: str = "My response"

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=context_window,
            num_output=num_output,
            model_name=model_name
        )

    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        prompt_length = len(prompt)
        response = pipeline(prompt, max_new_tokens=num_output)[0]["generated_text"]

        # only return newly generated tokens
        text = response[prompt_length:]
        return CompletionResponse(text=text)
    
    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        raise NotImplementedError()

In [5]:
class FalconLLM(CustomLLM):

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=context_window,
            num_output=num_output,
            model_name=model_name
        )

    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        prompt_length = len(prompt)
        response = pipeline(prompt, max_new_tokens=num_output)[0]["generated_text"]

        # only return newly generated tokens
        text = response[prompt_length:]
        return CompletionResponse(text=text)
    
    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        raise NotImplementedError()


In [6]:
context_window = 2048
# set number of output tokens
num_output = 256
chunk_size = 512

# define our LLM
llm = FalconLLM()

embed_model = HuggingFaceBgeEmbeddings(model_name="dangvantuan/sentence-camembert-large")

service_context = ServiceContext.from_defaults(
    llm=llm, 
    embed_model=embed_model,
    context_window=context_window, 
    chunk_size=chunk_size,
    num_output=num_output
)

No sentence-transformers model found with name /home/xli/.cache/torch/sentence_transformers/dangvantuan_sentence-camembert-large. Creating a new one with MEAN pooling.
Loading the tokenizer from the `special_tokens_map.json` and the `added_tokens.json` will be removed in `transformers 5`,  it is kept for forward compatibility, but it is recommended to update your `tokenizer_config.json` by uploading it again. You will see the new `added_tokens_decoder` attribute that will store the relevant information.


### Create nebula space

In [7]:
%load_ext ngql
connection_string = f"--address 127.0.0.1 --port 9669 --user root --password nebula"
%ngql {connection_string}

Connection Pool Created


Unnamed: 0,Name
0,digital_safety
1,rag_workshop


rag_workshop (index used for demo)

In [None]:
%%ngql
CREATE SPACE IF NOT EXISTS rag_workshop(vid_type=FIXED_STRING(256), partition_num=1, replica_factor=1);
USE rag_workshop;
CREATE TAG IF NOT EXISTS entity(name string);
CREATE EDGE IF NOT EXISTS relationship(relationship string);
CREATE TAG INDEX IF NOT EXISTS entity_index ON entity(name(256));

digital_safety

In [13]:
%%ngql
CREATE SPACE IF NOT EXISTS digital_safety(vid_type=FIXED_STRING(256), partition_num=1, replica_factor=1);
USE digital_safety;
CREATE TAG IF NOT EXISTS entity(name string);
CREATE EDGE IF NOT EXISTS relationship(relationship string);
CREATE TAG INDEX IF NOT EXISTS entity_index ON entity(name(256));

In [8]:
os.environ['NEBULA_USER'] = "root"
os.environ['NEBULA_PASSWORD'] = "nebula"
os.environ['NEBULA_ADDRESS'] = "127.0.0.1:9669"

space_name = "digital_safety"
edge_types, rel_prop_names = ["relationship"], ["relationship"]
tags = ["entity"]

graph_store = NebulaGraphStore(
    space_name=space_name,
    edge_types=edge_types,
    rel_prop_names=rel_prop_names,
    tags=tags,
)
storage_context = StorageContext.from_defaults(graph_store=graph_store)

## Generate KG

### Falcon 7b

### Mistral 7b

In [9]:
def chat_completion(messages):
    prompt = ""
    for m in messages:
        if m["role"]=="user":
            prompt += "[INST]" +  m["content"] + "[/INST]"
        elif m["role"]=="assistant":
            prompt += "<s>" + m["content"] + "</s>"
    
    return prompt.strip()

In [11]:
test = "The company's CEO, who is known for his innovative ideas, announced a groundbreaking partnership with a global tech giant."
assistant_message = """
You're an NLP expert, your task is to extract the triplets (subject, predicate, object) presented in given sentences.
Here is an example:
Text: She loves ice cream.
Subject: She
Predicate: loves
Object: ice cream
Here is another example:
Text: Despite his initial reluctance, John successfully completed the challenging project.
Subject: John
Predicate: completed
Object: challenging project
"""
messages = [
    {"role": "assistant", "content": assistant_message},
    {"role": "user", "content": f"What is the subject, predicate and object in this sentence: {test} ?"}
]
completed_prompt = chat_completion(messages)
print(completed_prompt)

<s>
You're an NLP expert, your task is to extract the triplets (subject, predicate, object) presented in given sentences.
Here is an example:
Text: She loves ice cream.
Subject: She
Predicate: loves
Object: ice cream
Here is another example:
Text: Despite his initial reluctance, John successfully completed the challenging project.
Subject: John
Predicate: completed
Object: challenging project
</s>[INST]What is the subject, predicate and object in this sentence: The company's CEO, who is known for his innovative ideas, announced a groundbreaking partnership with a global tech giant. ?[/INST]


In [None]:
class MistralKGIndex(KnowledgeGraphIndex):

    @staticmethod
    def _parse_triplet_response(
        response: str, max_length: int = 128
    ) -> List[Tuple[str, str, str]]:
        knowledge_strs = response.strip().split("\n")
        results = []
        for text in knowledge_strs:
            if not text or text[0] != "(" or text[-1] != ")":
                # skip empty lines and non-triplets
                continue
            tokens = text[1:-1].split(",")
            if len(tokens) != 3:
                continue

            if any(len(s.encode("utf-8")) > max_length for s in tokens):
                # We count byte-length instead of len() for UTF-8 chars,
                # will skip if any of the tokens are too long.
                # This is normally due to a poorly formatted triplet
                # extraction, in more serious KG building cases
                # we'll need NLP models to better extract triplets.
                continue

            subj, pred, obj = map(str.strip, tokens)
            if not subj or not pred or not obj:
                # skip partial triplets
                continue
            results.append((subj, pred, obj))
        return results

In [15]:
kg_index = MistralKGIndex.from_documents(
    documents,
    storage_context=storage_context,
    service_context=service_context,
    kg_triple_extract_template=prompt,
    max_triplets_per_chunk=10,
    space_name=space_name,
    edge_types=edge_types,
    rel_prop_names=rel_prop_names,
    tags=tags,
)









































..............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

In [16]:
kg_index.storage_context.persist(persist_dir='/mnt/c/Users/xli.ASSYSTEM/Documents/Digital safety/data/fr_embed_storage_graph')

In [17]:
%ngql USE digital_safety;
%ngql MATCH ()-[e]->() RETURN e

Unnamed: 0,e
0,"(""stock"")-[:relationship@-1179187948981095342{..."
1,"(""new york"")-[:relationship@219232360648163354..."
2,"(""is"")-[:relationship@-8433750674388488407{rel..."
3,"(""is"")-[:relationship@-7462421433918528579{rel..."
4,"(""is"")-[:relationship@-7462421433918528579{rel..."
5,"(""the"")-[:relationship@-7462421433918528579{re..."
6,"(""installations"")-[:relationship@4754892476402..."
7,"(""exchange"")-[:relationship@392570128547222489..."
8,"(""torch relay"")-[:relationship@346655847438871..."
9,"(""The"")-[:relationship@-7810131783760243813{re..."


In [18]:
%ng_draw

<class 'pyvis.network.Network'> |N|=25 |E|=21

## Load previously generated KG

In [20]:
from llama_index import load_index_from_storage

storage_context = StorageContext.from_defaults(persist_dir='/mnt/c/Users/xli.ASSYSTEM/Documents/Digital safety/data/fr_embed_storage_graph', graph_store=graph_store)
kg_index = load_index_from_storage(
    storage_context=storage_context,
    service_context=service_context,
    max_triplets_per_chunk=10,
    space_name=space_name,
    edge_types=edge_types,
    rel_prop_names=rel_prop_names,
    tags=tags,
    verbose=True,
)

In [24]:
%ngql USE rag_workshop;
%ngql MATCH ()-[e]->() RETURN e

Unnamed: 0,e
0,"(""évents"")-[:relationship@-3431488967660501439..."
1,"(""zone_de_surpression"")-[:relationship@-738609..."
2,"(""zone_de_surpression"")-[:relationship@-738609..."
3,"(""zone_de_surpression"")-[:relationship@-738609..."
4,"(""zone_de_surpression"")-[:relationship@-738609..."
...,...
972,"(""fumées"")-[:relationship@-3640410747914980111..."
973,"(""fonctionnement normal"")-[:relationship@-7396..."
974,"(""Implantation"")-[:relationship@65928469527764..."
975,"(""Implantation"")-[:relationship@65928469527764..."


In [25]:
%ng_draw

<class 'pyvis.network.Network'> |N|=898 |E|=977

In [35]:
kg_index.graph_store

<llama_index.graph_stores.nebulagraph.NebulaGraphStore at 0x7f983175d7e0>

In [42]:
import networkx as nx
import matplotlib.pyplot as plt
g = kg_index.get_networkx_graph()
for n in g.nodes:
    print(n)

UP1
usine
objet
champ d’application
Usine
Marcoule
Usine UP1
Usine Marcoule
Philz
1982
Berkeley
coffee shop
PT
générale
risque
opération
atelier
section
libellé
Plutonium
limit
criticité
sûreté
H
PT spécifique
I
J
K
L
M
N
O
P
L.1
L.3
L.4
L.6
L.8
L.10
L.12
L.14
L.16
L.18
L.20
L.22
L.24
L.26
L.28
L.30
L.32
L.34
L.36
P.1
P.3
effluents
solutions actives
assainissement
masse
soluble
cumul
inférieure
bat 117
traitement
text: Philz
RDS
R0
Page
10
/ 10
013413
9
32
1
boîte à gants
procédé
démantelée
MAR 09 013413
site
bâtiment 100
bâtiment 117
température
température minimale
température maximale
température moyenne
température moyenne des mois d’hiver
température moyenne des mois d’été
température de 30°C
nombre de jours de gelée sous abri
humidité de l’air
précipitation
vent dominant
vitesse moyenne des vents
mistral de l’ordre de 70 à 80 km/h
barrière dynamique
vitesse de passage
barrière statique
taux de renouvellement
page_label
2
ventilation
e
t
l
o
,
<
[
s
-
 
n
a
i
m
]
g
f
>
û
é
r
u
q
d

In [None]:
kg_index_query_engine = kg_index.as_query_engine(
    retriever_mode="keyword",
    verbose=True,
    response_mode="tree_summarize",
)
response_graph_rag = kg_index_query_engine.query("Résume moi")

display(Markdown(f"<b>{response_graph_rag}</b>"))

In [None]:
response_graph_rag = kg_index_query_engine.query("Résume moi")

display(Markdown(f"<b>{response_graph_rag}</b>"))

In [None]:

index = ListIndex.from_documents(documents, service_context=service_context)

# Query and print response
query_engine = index.as_query_engine()
response = query_engine.query("Quel est le titre du chapitre 5 ?")
print(response)

In [None]:
index.save_to_disk("/mnt/c/Users/xli.ASSYSTEM/Documents/Digital safety/data")

In [None]:
index.index_id

In [None]:
index.storage_context.persist("/mnt/c/Users/xli.ASSYSTEM/Documents/Digital safety/data/index")

In [None]:
from llama_index import StorageContext, load_index_from_storage
storage_context = StorageContext.from_defaults(persist_dir="/mnt/c/Users/xli.ASSYSTEM/Documents/Digital safety/data/index")
index = ListIndex.from_documents(documents, service_context=service_context)

new_index = load_index_from_storage(storage_context, service_context=service_context)
new_query_engine = new_index.as_query_engine()
response = new_query_engine.query("Quel est le titre du chapitre 5 ?")
print(response)