# Import dependencies

In [1]:
import os
import chromadb
import nltk
import pandas
import spacy
import transformers
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.vectorstores import Chroma
from nltk.stem import WordNetLemmatizer
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers import MistralConfig
from tqdm.notebook import tqdm

transformers.logging.set_verbosity_error()

# Define helper functions

In [2]:
def extract_keywords(string):
    # Extract keywords from the prompt
    lemmatizer = WordNetLemmatizer()
    nlp = spacy.load("en_core_web_sm")
    # Extract keywords from the prompt
    doc = nlp(string)
    keywords = set()
    for chunk in doc.noun_chunks:
        if not chunk.text.lower().strip() in nltk.corpus.stopwords.words('english'):
            text_doc = nlp(chunk.text)
            # Remove indirect articles and convert to lowercase
            text_words = [token.text for token in text_doc if not token.is_stop]
            text = ' '.join(text_words)
            # Keyword must be longer than 2 chars to be valid
            if len(text) > 2:
                keywords.add(text.lower())
    # Convert keywords to their singular forms
    keywords = list(keywords)
    keywords_singular = [lemmatizer.lemmatize(word) for word in keywords]
    return keywords_singular

def contains_keywords_filter(keywords, docs):
    # Filter data by keywords
    filtered_data = []
    if len(keywords) > 0:
        for doc in docs:
            el = doc[0].page_content.lower()
            if any(keyword in el for keyword in keywords):
                filtered_data.append(doc)
        return filtered_data
    else:
        return []

def format_docs_for_LLM(docs):
    formated_documents = ""
    for idx, doc in enumerate(docs):
        page_content = "ID {}:\n".format(idx)
        page_content += "Title: {}\n".format(doc[0].metadata['title'])
        page_content += doc[0].page_content.replace("search_document: ", '', 1)
        page_content += "\n\n"
        formated_documents += page_content
    return formated_documents


def test_output_soft(model_output, ids_max_count=4, max_id=7):
    # empty list is a valid output
    if len(model_output) == 0:
        return True
    # if the list is not empty, at least one element
    # has to be valid
    for idx, el in enumerate(model_output):
        # if an elemnt is an empty string, continue
        if len(el) == 0:
            continue
        # if an elemnt is an integer in the range [-1, max_id], then it is valid 
        elif el.isdigit() and int(el) >= -1 and int(el) <= max_id:
            return True
        # if an element doesn't meet requirements, continue
        else:
            continue
    return False
    
def test_output_strict(model_output, ids_max_count=4, max_id=7):
    valid_els = []
    # empty list is a valid output
    if len(model_output) == 0:
        return True
    # single -1 value is a valid output
    elif len(model_output) == 1 and model_output[0].isdigit() and int(model_output[0]) == -1:
        return True
    # returning more ids than ids_max_count constitutes an invalid output 
    elif len(model_output) > ids_max_count:
        return False
    # if the list is not empty, and doesn't contain a single -1 value perform further tests
    for idx, el in enumerate(model_output):
        # if any elemnt is an empty string, then the output is invalid
        if len(el) == 0:
            return False
        # if an elemnt is an integer in the range [0, max_id], then it is valid 
        if el.isdigit() and int(el) >= 0 and int(el) <= max_id:
            valid_els.append(el)
        # if any element is invalid, then the whole output is invalid
        else:
            return False
    # if there are duplicated IDs, then the output is invalid
    if len(set(valid_els)) != len(model_output):
        return False
    # if all elements are valid, then the output is valid
    return True
                

# Prepare vectorstore

In [3]:
embedding_model = HuggingFaceEmbeddings(
    model_name="nomic-ai/nomic-embed-text-v1",
    model_kwargs={
        'device': 'cuda',
        'trust_remote_code': True
    }
)

chroma_client = chromadb.PersistentClient(path='chroma_data')
langchain_vector_db = Chroma(client=chroma_client, embedding_function=embedding_model)

def search_vector_db(query, vector_db, k=512):
    query = 'search_query: ' + query
    most_similar_docs = vector_db.similarity_search_with_relevance_scores(query, k=k)
    return most_similar_docs

# Peform initial search to load everything into memory
search_vector_db("Sample query", langchain_vector_db, k=1);

<All keys matched successfully>


# Initiate the LLM pipeline

### Option 1 - benchmark the default model

In [4]:
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", config=MistralConfig, device_map='cuda')
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

# pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=8, device=0)
# LLM = HuggingFacePipeline(pipeline=pipe)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### Option 2 - benchmark the fine-tunned model

In [4]:
base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", config=MistralConfig, device_map='cuda')
model = PeftModel.from_pretrained(base_model, os.path.join('fine_tuning', 'fine_tuned_models'))
model.load_adapter(os.path.join('fine_tuning', 'fine_tuned_models'), 'document_extraction_adapter')
model.set_adapter('document_extraction_adapter')
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

# pipe = pipeline(task="text-generation", model=base_model, tokenizer=tokenizer, max_new_tokens=8, device=0)
# pipe.model = model
# LLM = HuggingFacePipeline(pipeline=pipe)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

# Define other variables

In [5]:
prompt_template = """<s>[INST] Below is a list of documents. Return up to 4 IDs of documents most useful for solving the user_prompt. If no documents are relevant, output -1. {format}. 

<documents>
{documents}
</documents>

user_prompt: {user_prompt}

[/INST]IDs: """

test_data = pandas.read_csv(os.path.join('data', 'benchmarks', 'RAG_test_data.csv'))
output_parser = CommaSeparatedListOutputParser()

# Run the benchmark

In [6]:
sample_count = len(test_data)
valid_outputs_standard = 0
valid_outputs_strict = 0
invalid_outputs_list = []
for idx, row in tqdm(test_data.iterrows(), total=sample_count):  
    user_prompt = row['text']
    docs_with_score = search_vector_db(user_prompt, langchain_vector_db)

    keywords = extract_keywords(user_prompt)
    filtered_docs = contains_keywords_filter(keywords, docs_with_score)
    filtered_docs = filtered_docs[:8]
    documents = format_docs_for_LLM(filtered_docs)

    # prompt = PromptTemplate.from_template(prompt_template)
    # chain = prompt | LLM
    # response = chain.invoke({'format':output_parser.get_format_instructions(), 'documents': documents, 'user_prompt': user_prompt})
    # converted_response = output_parser.parse(response)
    
    RAG_prompt = prompt_template.format(format=output_parser.get_format_instructions(), documents=documents, user_prompt=user_prompt)
    tokenized_context = tokenizer(RAG_prompt, return_tensors="pt").to('cuda')
    response = model.generate(tokenized_context.input_ids, attention_mask=tokenized_context.attention_mask, do_sample=False, max_new_tokens=8)
    response = response[0][tokenized_context.input_ids.shape[1]:] # Remove the input from the output
    output = tokenizer.decode(response, skip_special_tokens=True)
    document_IDs = output_parser.parse(output)
    if test_output_soft(document_IDs):
        valid_outputs_standard += 1
        if test_output_strict(document_IDs):
            valid_outputs_strict += 1
    else:
        invalid_outputs_list.append(document_IDs)

  0%|          | 0/225 [00:00<?, ?it/s]

# Display the results

In [7]:
print('Valid outputs: {}%'.format(valid_outputs_standard/sample_count * 100))
print('Valid outputs (strict): {}%'.format(valid_outputs_strict/sample_count * 100))

print('Invalid outpus:')
#print(invalid_outputs_list)

Valid outputs: 99.11111111111111%
Valid outputs (strict): 13.777777777777779%
Invalid outpus:
