# Installs
Restart after installing

In [None]:
%%time

from IPython.display import clear_output

! pip install -qq -U langchain-huggingface
! pip install -qq -U langchain-community
! pip install -qq -U langchain
! pip install -qq -U rouge_score
! pip install -qq -U bitsandbytes
! pip install -qq -U accelerate
! pip install -qq -U faiss-gpu
! pip install -qq -U peft
! pip install -qq -U torch

clear_output()

# Imports

In [1]:
%%time

from IPython.display import clear_output
import warnings
warnings.filterwarnings("ignore")

import os
import glob
import textwrap
import time
import pandas as pd
from datasets import Dataset, load_metric
from tqdm import tqdm
import re
from sklearn.model_selection import train_test_split
from peft import LoraConfig, get_peft_model, TaskType

import langchain

### loaders
from langchain.document_loaders import DirectoryLoader, TextLoader

### splits
from langchain.text_splitter import RecursiveCharacterTextSplitter

### prompts
from langchain import PromptTemplate, LLMChain

### vector stores
from langchain.vectorstores import FAISS

### models
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings

### retrievers
from langchain.chains import RetrievalQA

import torch
import transformers
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM,
    BitsAndBytesConfig, pipeline, GenerationConfig,
    Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
)

clear_output()

CPU times: user 9.36 s, sys: 1.52 s, total: 10.9 s
Wall time: 15.1 s


# CFG

In [51]:
class CFG:
    # LLMs
    # With bigger LLM like wizardlm or llama2-7b-chat fine-tuning can be skipped, while with Flan-T5, a smaller and easier to fine-tune model, fine-tuning is recommended.
    model_name = 'google/flan-t5-base' # TheBloke/wizardLM-7B-HF, llama2-7b-chat, mistral-7B, google/flan-t5-base
    fine_tune_with_LoRA = True
    
    temperature = 0
    top_p = 0.95
    repetition_penalty = 1.15

    # splitting
    split_chunk_size = 400
    split_overlap = 0

    # similar passages
    k = 2
    
    # Vector Database Embedding
    embedding_model = 'sentence-transformers/all-mpnet-base-v2'
    
    # paths
    DOCs_path = '/kaggle/input/questionanswer-dataset/text_data/text_data'
    Output_folder = './rag-vectordb'

# Preprocess Data

In [241]:
df_S08 = pd.read_csv('/kaggle/input/questionanswer-dataset/S08_question_answer_pairs.txt', sep='\t')
df_S08.head()

Unnamed: 0,ArticleTitle,Question,Answer,DifficultyFromQuestioner,DifficultyFromAnswerer,ArticleFile
0,Abraham_Lincoln,Was Abraham Lincoln the sixteenth President of...,yes,easy,easy,S08_set3_a4
1,Abraham_Lincoln,Was Abraham Lincoln the sixteenth President of...,Yes.,easy,easy,S08_set3_a4
2,Abraham_Lincoln,Did Lincoln sign the National Banking Act of 1...,yes,easy,medium,S08_set3_a4
3,Abraham_Lincoln,Did Lincoln sign the National Banking Act of 1...,Yes.,easy,easy,S08_set3_a4
4,Abraham_Lincoln,Did his mother die of pneumonia?,no,easy,medium,S08_set3_a4


In [242]:
print(f"Before removing NULL values: {df_S08.shape}")

df_S08 = df_S08.dropna()

print(f"After removing NULL values: {df_S08.shape}")

Before removing NULL values: (1715, 6)
After removing NULL values: (1150, 6)


In [243]:
df_S08 = df_S08.drop_duplicates(subset=['Question'])

print(f"After removing duplicates: {df_S08.shape}")

After removing duplicates: (602, 6)


In [244]:
df_S08['Question'] = df_S08.apply(
    lambda x: x['Question'] if all(word in x['Question'] for word in x['ArticleTitle'].replace('_', ' ').split()) 
    else x['ArticleTitle'].replace('_', ' ') + ". " + x['Question'], 
    axis=1
)

# Loader

In [8]:
loader = DirectoryLoader(
    CFG.DOCs_path,
    glob="S08*.txt.clean",
    loader_cls=TextLoader,
    show_progress=True,
    use_multithreading=True,
    loader_kwargs={"encoding": "ISO-8859-1"}
)

documents = loader.load()

100%|██████████| 40/40 [00:00<00:00, 566.26it/s]


In [9]:
print(f'We have {len(documents)} pages in total')

We have 40 pages in total


In [10]:
print(documents[0].page_content[:600])

otter



Otters are amphibious (or in one case aquatic) carnivorous mammals.  The otter subfamily Lutrinae forms part of the family Mustelidae, which also includes weasels, polecats, badgers, as well as others. With 13 species in 7 genera, otters have an almost worldwide distribution.

An otter's den is called a holt.  Male otters are dog-otters, females are bitches and babies are cubs or pups.  The collective noun romp is sometimes used for a group of otters, being descriptive of their often playful nature.




Otters have long, slim bodies and relatively short limbs, with webbed paws. Most h


In [11]:
# Function to clean text
def clean_text(text):
    # Replace multiple newlines with a single newline
    text = re.sub(r'\n+', '\n', text)
    # Replace multiple spaces with a single space
    text = re.sub(r'\s+', ' ', text)
    return text.strip()  # Optional: Strip leading/trailing whitespace

# Clean each document
for doc in documents:
    doc.page_content = clean_text(doc.page_content)

print(documents[0].page_content[:600])

otter Otters are amphibious (or in one case aquatic) carnivorous mammals. The otter subfamily Lutrinae forms part of the family Mustelidae, which also includes weasels, polecats, badgers, as well as others. With 13 species in 7 genera, otters have an almost worldwide distribution. An otter's den is called a holt. Male otters are dog-otters, females are bitches and babies are cubs or pups. The collective noun romp is sometimes used for a group of otters, being descriptive of their often playful nature. Otters have long, slim bodies and relatively short limbs, with webbed paws. Most have sharp c


# Splitter

In [12]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = CFG.split_chunk_size,
    chunk_overlap = CFG.split_overlap
)

texts = text_splitter.split_documents(documents)

print(f'We have created {len(texts)} chunks from {len(documents)} pages')

We have created 3162 chunks from 40 pages


# Create Embeddings

In [13]:
%%time

vectordb = FAISS.from_documents(
    texts,
    HuggingFaceEmbeddings(model_name=CFG.embedding_model)
)

### persist vector database
vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag")
#vectordb = FAISS.load_local(f"{CFG.Output_folder}/faiss_index_rag", HuggingFaceEmbeddings(model_name=CFG.embedding_model), allow_dangerous_deserialization=True)

clear_output()

CPU times: user 15.3 s, sys: 1.03 s, total: 16.3 s
Wall time: 17.8 s


In [245]:
retriever = vectordb.as_retriever(search_kwargs={"k": CFG.k, "search_type": "similarity"})

# Initialize an empty list to store contexts
contexts = []

# Loop through each question and fetch its context
for question in tqdm(df_S08['Question'], desc="Fetching contexts"):

    results = retriever.invoke(question)
    
    # Extract page contents from results and join them as a single string
    context = " ".join([doc.page_content for doc in results])
    
    # Append the context to the list
    contexts.append(context)

# Add the contexts list as a new column to the dataframe
df_S08['Context'] = contexts

# Display the dataframe to verify
df_S08[['Question', 'Context']].head()

Fetching contexts: 100%|██████████| 602/602 [00:10<00:00, 56.07it/s]


Unnamed: 0,Question,Context
0,Was Abraham Lincoln the sixteenth President of...,"Abraham Lincoln Abraham Lincoln (February 12, ..."
2,Abraham Lincoln. Did Lincoln sign the National...,"Transcontinental Railroad, which was completed..."
4,Abraham Lincoln. Did his mother die of pneumonia?,born. Theodore Roosevelt's mother Mittie died ...
6,Abraham Lincoln. How many long was Lincoln's f...,"a frequent visitor to Kentucky, he would have ..."
8,Abraham Lincoln. When did Lincoln begin his po...,"not like killing animals, even for food. Thoug..."


# Define model

In [246]:
%%time

model_repo = CFG.model_name
        
tokenizer = AutoTokenizer.from_pretrained(model_repo)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

if CFG.fine_tune_with_LoRA:
    
    def base_model_init():
        return AutoModelForSeq2SeqLM.from_pretrained(model_repo, torch_dtype=torch.bfloat16, device_map = 'auto',)
    
    base_model = base_model_init()
    
    max_len = base_model.config.n_positions
    
else:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )

    base_model = AutoModelForCausalLM.from_pretrained(
        model_repo,
        device_map = 'auto',
        quantization_config = bnb_config,
        low_cpu_mem_usage = True,
        trust_remote_code = True,
    )
    max_len = base_model.config.max_position_embeddings

clear_output()

CPU times: user 1.28 s, sys: 298 ms, total: 1.58 s
Wall time: 1.5 s


# Prompt

In [247]:
prompt_template = """
Use only the following pieces of context to answer the question.

{context}

Question: {question}
Answer:"""


PROMPT = PromptTemplate(
    template = prompt_template, 
    input_variables = ["context", "question"]
)

# Cross-validation

In [248]:
def tokenize_function(row):
    formatted_prompt = PROMPT.format(question=row['Question'], context=row['Context'])
    inputs = tokenizer(formatted_prompt, max_length=max_len, truncation=True)
    labels = tokenizer(row["Answer"], max_length=max_len, truncation=True)

    return pd.Series({
        'input_ids': inputs.input_ids,
        'attention_mask': inputs.attention_mask,
        'labels': labels.input_ids
    })

if CFG.fine_tune_with_LoRA:
    tqdm.pandas(desc="Tokenizing rows")
    df_S08[['input_ids', 'attention_mask', 'labels']] = df_S08.progress_apply(tokenize_function, axis=1)

Tokenizing rows: 100%|██████████| 602/602 [00:00<00:00, 726.40it/s]


In [249]:
hf_dataset = Dataset.from_pandas(df_S08)

train_test_split = hf_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']

print(f"Shapes of the datasets:")
print(f"Training: {len(train_dataset)} samples")
print(f"Validation: {len(eval_dataset)} samples")

Shapes of the datasets:
Training: 481 samples
Validation: 121 samples


# Perform PEFT with LoRA

In [250]:
%%time

peft_model = None

if CFG.fine_tune_with_LoRA:
    base_model = base_model_init()
    
    lora_config = LoraConfig(
        r=32, # Rank
        lora_alpha=32,
        target_modules=["q", "v"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.SEQ_2_SEQ_LM
    )

    peft_model = get_peft_model(base_model, lora_config)
    
clear_output()

CPU times: user 1.58 s, sys: 217 ms, total: 1.8 s
Wall time: 1.52 s


In [251]:
if CFG.fine_tune_with_LoRA:
    output_dir = f'./peft-qa-training-{str(int(time.time()))}'
    batch_size = 8

    peft_training_args = Seq2SeqTrainingArguments(
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        
        num_train_epochs=1,
        learning_rate=1e-3,
        lr_scheduler_type="cosine",
        warmup_ratio=0.01,
        
        evaluation_strategy="steps",
        logging_steps=15,
        
        output_dir=output_dir,
        report_to="none"
    )
    
    data_collator = DataCollatorForSeq2Seq(tokenizer)

    peft_trainer = Seq2SeqTrainer(
        model=peft_model,
        args=peft_training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer
    )

In [252]:
if CFG.fine_tune_with_LoRA:
    peft_trainer.train()

Step,Training Loss,Validation Loss
15,1.8404,1.392334
30,1.5883,1.328857
45,1.699,1.338013
60,1.7971,1.327759


# 🤗 Pipeline & Generation

In [253]:
generation_config = GenerationConfig(
    max_length=max_len,
    temperature=CFG.temperature,
    top_p=CFG.top_p,
    repetition_penalty=CFG.repetition_penalty,
)

if CFG.fine_tune_with_LoRA:
    peft_model.eval()
    task = "text2text-generation"

else:
    base_model.eval()
    task="text-generation"
    
pipe = pipeline(
    task=task,
    model=peft_model if CFG.fine_tune_with_LoRA else base_model,
    tokenizer=tokenizer,
    device_map="auto",
    truncation=True,
    generation_config=generation_config
)

llm = HuggingFacePipeline(pipeline = pipe)

clear_output()

In [254]:
def check_llm_response(dataset, indx):
    query = dataset[indx]['Question']
    context = dataset[indx]['Context']

    # Format the prompt using the question
    formatted_prompt = PROMPT.format(question=query, context=context)

    # Use the formatted prompt with the LLM
    llm_response = llm.invoke(formatted_prompt)
    
    if CFG.fine_tune_with_LoRA:
        llm_response = formatted_prompt + " " + llm_response
        
    print(llm_response)
    print(f"\nCorrect Answer: {dataset[indx]['Answer']}")

#check_llm_response(wrong_ans, 0) # use wrong_ans you find later to check better pipeline args
check_llm_response(eval_dataset, 0)


Use only the following pieces of context to answer the question.

tail reaches 60 to 110cm. Shoulder height is 45 to 80 cm. Males are considerably larger than females and weigh 37 to 90 kg compared to 28 to 60 kg for females. Ronald M. Nowak: Walker's Mammals of the World. Johns Hopkins University Press, 1999 ISBN 0-8018-5789-9 One of many spotted cats, a leopard may be mistaken for a cheetah or a jaguar. The leopard has rosettes rather than cheetah's simple puá¹á¸Ã¡rÄ«ka ("tiger", among other things), then borrowed into Greek. The leopard is an agile and graceful predator. Although smaller than the other members of Panthera, the leopard is still able to take large prey given a massive skull that well utilizes powerful jaw muscles. Its body is comparatively long for a cat and its legs are short. Head and body length is between 90 and 190 cm, the

Question: How long is a leopard's tail?
Answer: 60 to 110cm

Correct Answer: 60 to 110cm


# Retriever chain

In [255]:
retriever = vectordb.as_retriever(search_kwargs = {"k": CFG.k, "search_type" : "similarity"})

qa_chain = RetrievalQA.from_chain_type(
    llm = llm,
    chain_type = "stuff", # map_reduce, map_rerank, stuff, refine
    retriever = retriever, 
    chain_type_kwargs = {"prompt": PROMPT},
    return_source_documents = True,
    verbose = False
)

# Post-process outputs

In [256]:
def wrap_text_preserve_newlines(text, width=700):
    # Split the input text into lines based on newline characters
    lines = text.split('\n')

    # Wrap each line individually
    wrapped_lines = [textwrap.fill(line, width=width) for line in lines]

    # Join the wrapped lines back together using newline characters
    wrapped_text = '\n'.join(wrapped_lines)

    return wrapped_text


def process_llm_response(llm_response):
    ans = wrap_text_preserve_newlines(llm_response['result'])
    
    sources_used = ' \n'.join(
        [
            source.metadata['source'].split('/')[-1][:-4]
            + (' - page: ' + str(source.metadata['page']) if 'page' in source.metadata else '')
            + (f'\nContent: {source.page_content}' if CFG.fine_tune_with_LoRA else '')
            for source in llm_response['source_documents']
        ]
    )
    
    ans = ans + '\n\nSources: \n' + sources_used
    return ans

In [257]:
def llm_ans(query):
    llm_response = qa_chain.invoke(query)
    ans = process_llm_response(llm_response)
    if CFG.fine_tune_with_LoRA:
        ans = f"Question: {query}\nLLM Answer: " + ans
    return ans

# Evaluations
- Check model on a single sample
- Calculate average recall score on validation dataset
- Check wrong answers

In [258]:
# Load the ROUGE metric
metric = load_metric("rouge")

In [259]:
def extract_prediction(llm_output):
    return llm_output.split("Answer:")[1].split("Sources:")[0].strip()

def evaluate_answer(dataset, indx):
    # Get the question and correct answer from the DataFrame
    query = dataset[indx]['Question']
    correct_answer = dataset[indx]['Answer']

    # Get the predicted answer from the language model
    pred_ans = llm_ans(query)

    print(pred_ans)
    print(f"\nCorrect Answer: {correct_answer}")

    # Compute ROUGE scores
    rouge_score = metric.compute(
        predictions=[extract_prediction(pred_ans)],
        references=[correct_answer],
        use_stemmer=True
    )

    # Extract recall scores for different ROUGE metrics
    fmeasures = {
        'rouge1': rouge_score['rouge1'].mid.recall,
        'rouge2': rouge_score['rouge2'].mid.recall,
        'rougeL': rouge_score['rougeL'].mid.recall,
        'rougeLsum': rouge_score['rougeLsum'].mid.recall
    }

    print(f"\nROUGE Recall Scores: {fmeasures}")

evaluate_answer(eval_dataset, 4)

Question: John Adams. With what party did Adams run for presidency?
LLM Answer: Federalist

Sources: 
S08_set3_a1.txt.c
Content: Hamilton. Because of Adams's seniority and the need for a northern president, he was elected as the Federalist nominee for president in 1796, over Thomas Jefferson, the leader of the opposition Democratic-Republican Party. His success was due to peace and prosperity; Washington and Hamilton had averted war with Britain by the Jay Treaty of 1795. Ferling (1992) pp 316-32 Adams' two terms as Vice 
S08_set3_a1.txt.c
Content: Quincy rather than actively campaign for the Presidency. He wanted to stay out of what he called the silly and wicked game. His party, however, campaigned for him, while the Republicans campaigned for Jefferson. It was expected that Adams would dominate the votes in New England, while Jefferson was expected to win in the Southern states. In the end, Adams won the election by a narrow margin of 71

Correct Answer: The Federalist Party

ROUGE 

In [260]:
# Define a function to make predictions
def predict(batch):
    queries = batch['Question']
    
    # Get the predicted answers from the model for the entire batch
    pred_ans = [llm_ans(query) for query in queries]
    
    # Extract predictions from LLM output
    extracted_preds = [extract_prediction(ans) for ans in pred_ans]
    
    # Initialize lists to store ROUGE scores
    recalls = []
    
    # Calculate ROUGE score for each prediction and store recalls
    for pred, ref in zip(extracted_preds, batch['Answer']):
        result = metric.compute(predictions=[pred], references=[ref])
        recalls.append(result['rouge1'].mid.recall)
    
    # Return predictions, references, and low recall indices
    return {
        'prediction': extracted_preds, 
        'reference': batch['Answer'],
        'recalls': recalls
    }

# Apply the function to all rows in the dataset
predicted_dataset = eval_dataset.map(
    predict,
    batched=True,
    batch_size=16,
    desc="Processing predictions"
)

Processing predictions:   0%|          | 0/8 [00:00<?, ?ba/s]

In [262]:
# Compute the ROUGE score for the entire dataset
rouge_score = metric.compute(
    predictions=predicted_dataset['prediction'],
    references=predicted_dataset['reference'],
    use_aggregator=True,
    use_stemmer=True
)

# Extract recall scores into a dictionary
fmeasures = {key: rouge_score[key].mid.recall for key in ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']}

print(f"ROUGE Recall Scores: {fmeasures}")

ROUGE Recall Scores: {'rouge1': 0.6733852916203682, 'rouge2': 0.20900857515690052, 'rougeL': 0.6699366198633716, 'rougeLsum': 0.6717573666004761}


In [263]:
# Collect the low recall indices based on recalls < 0.5
low_recall_indices = [i for i, recall in enumerate(predicted_dataset['recalls']) if recall < 0.5]
wrong_ans = eval_dataset.select(low_recall_indices)

# Display the filtered dataframe
percentage_wrong = round((len(wrong_ans) / len(eval_dataset)) * 100)
print(f"Number of wrong answers: {len(wrong_ans)} ({percentage_wrong}%)")

Number of wrong answers: 41 (34%)


In [265]:
evaluate_answer(wrong_ans, 0)

Question: Are otters herbivores?
LLM Answer: Yes

Sources: 
S08_set1_a7.txt.c
Content: species hunt for 3 to 5 hours a day, and nursing mothers up to 8 hours a day. Most otters have fish as the primary item in their diet, supplemented by frogs, crayfish and crabs. Some are expert at opening shellfish, and others will take any available small mammals or birds. This prey-dependence leaves otters very vulnerable to prey depletion. Otters are very active, chasing prey in the water or 
S08_set1_a7.txt.c
Content: humans hunted them almost to extinction. By the time the 1911 Fur Seal Treaty gave them protection, so few sea otters remained that the fur trade had become unprofitable. Sea otters eat shellfish and other invertebrates (especially clams, abalone, and sea urchins ), frequently using rocks as crude tools to smash open shells. They grow to 1 to 1.5 m (2.5 to 5 feet) in length and weigh 30 kg (about

Correct Answer: No

ROUGE Recall Scores: {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0,

# Conclusions

- Things I found had the most impact on models output quality in my experiments:
    - Splitting: chunk size, overlap
    - Search: k
    - Pipeline parameters (temperature, top_p, penalty)
    - Embeddings function
    - Question with or without title