<a href="https://colab.research.google.com/github/MahdiTheGreat/Intro-to-language-modeling/blob/main/Retrieval_augmented_text_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
generative_model="meta-llama/Llama-3.2-1B-Instruct"
sentence_model = "dmis-lab/biobert-base-cased-v1.1"

In [22]:
from google.colab import userdata

In [161]:
!pip install langchain_huggingface
!pip install -qU langchain-text-splitters
!pip install -qU "langchain-chroma>=0.1.2"
!pip install -U langchain-community
!pip install nltk



In [162]:
!wget https://raw.githubusercontent.com/pubmedqa/pubmedqa/refs/heads/master/data/ori_pqal.json

--2024-12-11 13:28:50--  https://raw.githubusercontent.com/pubmedqa/pubmedqa/refs/heads/master/data/ori_pqal.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2584787 (2.5M) [text/plain]
Saving to: ‘ori_pqal.json.2’


2024-12-11 13:28:51 (43.7 MB/s) - ‘ori_pqal.json.2’ saved [2584787/2584787]



We collect two datasets:

    ‘questions’: the questions with corresponding gold long answer, gold document ID, and year.
    ‘documents’: the abstracts (contexts+long_answer concatenated), and year.


In [163]:
import pandas as pd
tmp_data = pd.read_json("/content/ori_pqal.json").T
# some labels have been defined as "maybe", only keep the yes/no answers
tmp_data = tmp_data[tmp_data.final_decision.isin(["yes", "no"])]

documents = pd.DataFrame({"abstract": tmp_data.apply(lambda row: (" ").join(row.CONTEXTS+[row.LONG_ANSWER]), axis=1),
             "year": tmp_data.YEAR})
questions = pd.DataFrame({"question": tmp_data.QUESTION,
             "year": tmp_data.YEAR,
             "gold_label": tmp_data.final_decision,
             "gold_context": tmp_data.LONG_ANSWER,
             "gold_document_id": documents.index})

For an example of a query:

In [164]:
questions.iloc[0].question

'Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?'

For an example of a document to leverage for the queries:

In [165]:
text = documents.iloc[0].abstract

In [166]:
from langchain_huggingface import HuggingFaceEmbeddings

model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
hf = HuggingFaceEmbeddings(
    model_name=sentence_model,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs
)



In [175]:
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [192]:
from langchain.text_splitter import NLTKTextSplitter
import re

# Initialize the NLTKTextSplitter
text_splitter = NLTKTextSplitter(chunk_size=1000, chunk_overlap=200)

# Split the text
chunks = text_splitter.split_text(text)

chunks = [re.sub(r'\s+', ' ', chunk) for chunk in chunks]

# Output the chunks
for i, chunk in enumerate(chunks):
    print(f"Chunk {i+1}: {chunk}\n")

Chunk 1: Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants. The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD).

Chunk 2: Window stage leaves were stained with the mitochondrial dye MitoTr

In [193]:
print(chunks[0])
print(chunks[1])

Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants. The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD).
Window stage leaves were stained with the mitochondrial dye MitoTracker Red CMXRos an

In [194]:
# Generate unique ids for the chunks
chunk_ids = [i for i in range(len(chunks))]

If you want to get best in-class automated tracing of your model calls you can also set your LangSmith API key by uncommenting below:

In [182]:
# os.environ["LANGSMITH_API_KEY"] = getpass.getpass("Enter your LangSmith API key: ")
# os.environ["LANGSMITH_TRACING"] = "true"

In [195]:
from uuid import uuid4
from langchain_core.documents import Document

using_chunks = True
docs=[]
split_idx = 10

if using_chunks:
  ids = chunk_ids

  for i, doc in enumerate(chunks[:split_idx]):
    temp_doc = Document(
        page_content=doc,
        metadata={"source": "tweet"},
        id=ids[i],
    )
    docs.append(temp_doc)

else:

  ids = documents.iloc[:split_idx].index

  for i, doc in enumerate(documents.iloc[:split_idx].values):
    temp_doc = Document(
        page_content=doc[0],
        metadata={"source": "tweet"},
        id=ids[i],
    )
    docs.append(temp_doc)

In [196]:
from langchain.vectorstores import Chroma

try:
  vector_store.delete_collection() # Delete vector_store collection if it already exists
except:
  pass # Ignore if it doesnt exist

vector_store = Chroma.from_documents(docs, hf)
#retriever = vector_store.as_retriever(...)

In [197]:
print(vector_store._collection.count())

3


In [198]:
results = vector_store.similarity_search_with_score(
    "What is programmed cell death?", k=10
)
for res, score in results:
    print(f"* [SIM={score:3f}] {res.page_content} [{res.metadata}]")



* [SIM=40.248814] Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants. The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD). [{'source': 'tweet'}]
* [SIM=44.876183] Window stage leaves were st

In [199]:
from huggingface_hub import login
login(userdata.get('huggingface_token'))

In [None]:
! huggingface-cli download $generative_model --local-dir ./$generative_model

Fetching 13 files:   0% 0/13 [00:00<?, ?it/s]Downloading 'USE_POLICY.md' to 'meta-llama/Llama-3.2-1B-Instruct/.cache/huggingface/download/USE_POLICY.md.ac3c5f21b9779e3da0677d6d3c587778fe3a331e.incomplete'
Downloading '.gitattributes' to 'meta-llama/Llama-3.2-1B-Instruct/.cache/huggingface/download/.gitattributes.a6344aac8c09253b3b630fb776ae94478aa0275b.incomplete'
Downloading 'config.json' to 'meta-llama/Llama-3.2-1B-Instruct/.cache/huggingface/download/config.json.3e3aaf51a035cb5092d9f6827a0dc074657ba88c.incomplete'
Downloading 'original/params.json' to 'meta-llama/Llama-3.2-1B-Instruct/.cache/huggingface/download/original/params.json.9cd8dbdf2dc6f4d8abb60bdb5ce64f4bec2fdfd9.incomplete'
Downloading 'generation_config.json' to 'meta-llama/Llama-3.2-1B-Instruct/.cache/huggingface/download/generation_config.json.75ae08310d6d23df373ee2644b497192b3cce6d8.incomplete'
Downloading 'model.safetensors' to 'meta-llama/Llama-3.2-1B-Instruct/.cache/huggingface/download/model.safetensors.1ff795ff6

In [38]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("./"+generative_model)
model_hf = AutoModelForCausalLM.from_pretrained("./"+generative_model)

OSError: Can't load tokenizer for './meta-llama/Llama-3.2-1B-Instruct'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure './meta-llama/Llama-3.2-1B-Instruct' is the correct path to a directory containing all relevant files for a LlamaTokenizerFast tokenizer.

In [None]:
retriever = vector_store.as_retriever()
template = """Answer the question based only on the following context:
{context}

Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)

# Create a pipeline for text generation
pipe = pipeline(
    "text-generation",
    model=model_hf,
    tokenizer=tokenizer,
    max_new_tokens=512,
    device="cuda"
    )

# Wrap the pipeline in a LangChain HuggingFacePipeline object
llm = HuggingFacePipeline(pipeline=pipe)

retrieval_chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

response = retrieval_chain.invoke("Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?")

In [None]:

# Assuming the output is a dictionary or a similar structure:
if isinstance(response, dict):
    for key, value in response.items():
        print(f"{key}: {value}")
elif isinstance(response, list):
    for idx, item in enumerate(response):
        print(f"Item {idx + 1}: {item}")
else:
    print(response)


In [None]:
from tqdm import tqdm
template = """Answer the question, only with yes or no, based only on the following context:
{context}
Do not write anything else. Do not explain your answer.
Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)

# Create a pipeline for text generation
pipe = pipeline(
    "text-generation",
    model=model_hf,
    tokenizer=tokenizer,
    max_new_tokens=10,
    device="cuda"
    )

# Wrap the pipeline in a LangChain HuggingFacePipeline object
llm = HuggingFacePipeline(pipeline=pipe)

retrieval_chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

# Initialize a counter for correct answers
correct_count = 0
q_a = zip(questions.question, questions.gold_label)
print(len(questions.question))
for question, gold_label in q_a:
    answer = retrieval_chain.invoke(question)
    print(type(answer))
    y_n = answer.splitlines()[-1].split()[0]
    # Check if the answer is "yes" or "no" and matches the gold label
    if gold_label.lower() in answer[-1].lower():
        print("Correct")
        correct_count += 1
    else:
        print("Incorrect")

# Print the total number of correct answers
print(f"Total correct answers: {correct_count}")


