# Import dependencies

In [1]:
import os
import chromadb
import nltk
import pandas
import spacy
import transformers
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.vectorstores import Chroma
from nltk.stem import WordNetLemmatizer
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers import MistralConfig
from tqdm.notebook import tqdm

transformers.logging.set_verbosity_error()

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


# Define helper functions

In [2]:
def extract_keywords(string):
    # Extract keywords from the prompt
    lemmatizer = WordNetLemmatizer()
    nlp = spacy.load("en_core_web_sm")
    doc = nlp(string)
    keywords = []
    for chunk in doc.noun_chunks:
        if not chunk.text.lower().strip() in nltk.corpus.stopwords.words('english'):
            text = chunk.text
            # Remove indirect articles
            text = text.replace('a ', '').replace('an ', '').strip()
            keywords.append(text)
    # Convert keywords to their singular forms
    keywords_singular = [lemmatizer.lemmatize(word) for word in keywords]
    return keywords_singular

def contains_keywords_filter(keywords, docs):
    # Filter data by keywords
    filtered_data = []
    if len(keywords) > 0:
        for doc in docs:
            el = doc[0].page_content
            if any(keyword in el for keyword in keywords):
                filtered_data.append(doc)
        return filtered_data
    else:
        return docs

def format_docs_for_LLM(docs):
    formated_documents = ""
    for idx, doc in enumerate(docs):
        page_content = "ID {}:\n".format(idx)
        page_content += "Title: {}\n".format(doc[0].metadata['title'])
        page_content += doc[0].page_content.replace("search_document: ", '', 1)
        page_content += "\n\n"
        formated_documents += page_content
    return formated_documents


def test_output_soft(model_output, ids_max_count=4, max_id=7):
    is_output_valid = False
    if len(model_output) == 0:
        is_output_valid = True
    for idx, el in enumerate(model_output):
        if idx == ids_max_count:
            break
        if type(el) == str:
            if el[0].isdigit() and int(el[0]) >= 0 and int(el[0]) <= max_id:
                is_output_valid = True
            else:
                is_output_valid = False
                break
    return is_output_valid
    
def test_output_strict(model_output, ids_max_count=4, max_id=7):
    is_output_valid = False
    valid_els = []
    if len(model_output) == 0:
        is_output_valid = True
    for idx, el in enumerate(model_output):
        if idx == ids_max_count:
            break
        if type(el) == str:
            if len(model_output) > ids_max_count:
                break
                
            if el.isdigit() and int(el) >= 0 and int(el) <= max_id:
                is_output_valid = True
                valid_els.append(el)
            else:
                is_output_valid = False
                break
                
    if len(set(valid_els)) != len(model_output):
        is_output_valid = False
    return is_output_valid
                

# Prepare vectorstore

In [3]:
embedding_model = HuggingFaceEmbeddings(
    model_name="nomic-ai/nomic-embed-text-v1",
    model_kwargs={
        'device': 'cuda',
        'trust_remote_code': True
    }
)

chroma_client = chromadb.PersistentClient(path='chroma_data')
langchain_vector_db = Chroma(client=chroma_client, embedding_function=embedding_model)

def search_vector_db(query, vector_db, k=512):
    query = 'search_query: ' + query
    most_similar_docs = vector_db.similarity_search_with_relevance_scores(query, k=k)
    return most_similar_docs

# Peform initial search to load everything into memory
search_vector_db("Sample query", langchain_vector_db, k=1);

You try to use a model that was created with version 2.4.0.dev0, however, your version is 2.3.1. This might cause unexpected behavior or errors. In that case, try to update to the latest version.



<All keys matched successfully>


# Initiate the LLM pipeline

### Option 1 - benchmark the default model

In [5]:
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", config=MistralConfig)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=16, device=0)
LLM = HuggingFacePipeline(pipeline=pipe)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### Option 2 - benchmark the fine-tunned model

In [4]:
base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", config=MistralConfig)
model = PeftModel.from_pretrained(base_model, os.path.join('fine_tuning', 'fine_tuned_models'))
model.load_adapter(os.path.join('fine_tuning', 'fine_tuned_models'), 'document_extraction_adapter')
model.set_adapter('document_extraction_adapter')
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

pipe = pipeline(task="text-generation", model=base_model, tokenizer=tokenizer, max_new_tokens=16, device=0)
pipe.model = model
LLM = HuggingFacePipeline(pipeline=pipe)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

# Define other variables

In [5]:
prompt_template = """<s>[INST] Below is a list of documents. Return up to 4 IDs of documents most useful for solving the user_prompt. If no documents are relevant, output -1. {format}. 

<documents>
{documents}
</documents>

user_prompt: {user_prompt}

[/INST]IDs: """

test_data = pandas.read_csv(os.path.join('data', 'benchmarks', 'RAG_test_data.csv'))
output_parser = CommaSeparatedListOutputParser()

# Run the benchmark

In [7]:
sample_count = len(test_data)
valid_outputs_standard = 0
valid_outputs_strict = 0
invalid_outputs_list = []
for idx, row in tqdm(test_data.iterrows(), total=sample_count):  
    user_prompt = row['text']
    docs_with_score = search_vector_db(user_prompt, langchain_vector_db)

    keywords = extract_keywords(user_prompt)
    filtered_docs = contains_keywords_filter(keywords, docs_with_score)
    filtered_docs = filtered_docs[:8]
    documents = format_docs_for_LLM(filtered_docs)

    prompt = PromptTemplate.from_template(prompt_template)
    chain = prompt | LLM

    response = chain.invoke({'format':output_parser.get_format_instructions(), 'documents': documents, 'user_prompt': user_prompt})
    converted_response = output_parser.parse(response)
    if test_output_soft(converted_response):
        valid_outputs_standard += 1
        if test_output_strict(converted_response):
            valid_outputs_strict += 1
    else:
        invalid_outputs_list.append(converted_response)

  0%|          | 0/225 [00:00<?, ?it/s]



# Display the results

In [8]:
print('Valid outputs: {}%'.format(valid_outputs_standard/sample_count * 100))
print('Valid outputs (strict): {}%'.format(valid_outputs_strict/sample_count * 100))

print('Invalid outpus:')
print(invalid_outputs_list)

Valid outputs: 98.66666666666667%
Valid outputs (strict): 23.11111111111111%
Invalid outpus:
[['0', '5', '6. The documents discuss themes of love', 'night'], ['0', '3', '6. The documents contain various French texts', 'including'], ['0', '-1\n\nExplanation:\n\nDocument 0 may']]
