<a href="https://colab.research.google.com/github/ashweta1/dronerl/blob/main/cs230_rectifying_facts_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Rectifying Factual knowledge through RAG (Retrieval Augmented Generation)

This colab uses RAG on wikipedia knowledge dataset, to prepend context to prompts.

RAG is used from the library I wrote: git+https://github.com/ashweta1/rag_wiki.git

## Prepare environment

In [1]:
%%bash

# check that colab exists
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit

# recreate the local home for this colab run
cd /content && rm -rf /content/home && mkdir home && cd home

# install the known facts dataset.
pip install git+https://github.com/kmeng01/rome.git/tree/main/dsets >> install.log 2>&1

# install hugging face datasets library
pip install datasets >> install.log 2>&1

pip install git+https://github.com/ashweta1/rag_wiki.git >> install.log 2>&1
pip list | grep rag_wiki

# install latest torch and faiss-gpu
pip uninstall -y torch faiss-cpu faiss-gpu >> install.log 2>&1
pip install torch faiss-gpu >> install.log 2>&1

# pip uninstall -y torch torchaudio torchvision torchtext torchdata faiss-gpu >> install.log 2>&1
# pip install torch torchaudio torchvision torchtext torchdata faiss-gpu >> install.log 2>&1

rag_wiki                           0.1.0


In [24]:
from ctypes import pythonapi
!python --version
!python -c "import torch; print(torch.__version__)"
!python -c "import faiss; print(faiss.__version__)"
!python -c "import numpy; print(numpy.__version__)"

Python 3.10.12
2.5.1+cu124
1.7.2
1.26.4


In [2]:
IS_COLAB = True
try:
    import google.colab, torch, os

    IS_COLAB = True
    device = "cpu"
    if torch.cuda.is_available():
      device = torch.device("cuda")
    elif torch.backends.mps.is_available():
      device = torch.device("mps")
    else:
      device = torch.device("cpu")
    print("Device = ", device)
        # raise Exception("Change runtime type to include a GPU.")

    os.chdir("/content/home")
    torch.set_grad_enabled(False)  # no model parameter updates

except ModuleNotFoundError as _:
    pass

Device =  cuda


In [3]:
%%bash

# # Install my own rag_wiki library to use RAG from wikipeda
# cd /content/home && rm -fr /content/home/rag_wiki

# git clone https://github.com/ashweta1/rag_wiki.git && cd /content/home/rag_wiki && pip install -e . && cd /content/home
# ls -R /content/home/rag_wiki/


In [7]:
# IPYTHON magic to automatically reload imported module if they change
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
!nvidia-smi

Sat Nov 23 03:22:14 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   52C    P8              11W /  70W |      3MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [9]:
# Get wikipedia embeddings loaded
import torch
from rag_wiki import rag

print("torch.cuda.is_available()", torch.cuda.is_available())
print(torch.__version__)

# Load dataset
print("Loading dataset...")
dataset = rag.load_wiki_dataset(num_examples=100000, debug=True)
print("Loading dataset...done")
print("")

# Preprocess the dataset
print("Preprocessing dataset...")
index, texts = rag.preprocess(dataset, batch_size=200, debug=True)
print("Preprocessing dataset...done")
print("")

torch.cuda.is_available() True
2.5.1+cu124
Loading dataset...
Loading dataset from kilt_wikipedia


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Encoding the dataset embeddings...
torch embeddings tensor shape:  torch.Size([200, 384])
numpy embeddings shape:  (200, 384)
Encoding the dataset embeddings... done.
#Dataset embeddings =  (200, 384) d =  384
Adding embeddings to the index (dimension 384)...
Adding embeddings to the index... done.
Batch 2
Got 200 texts
Getting batch texts... done.
Encoding the dataset embeddings...
torch embeddings tensor shape:  torch.Size([200, 384])
numpy embeddings shape:  (200, 384)
Encoding the dataset embeddings... done.
#Dataset embeddings =  (200, 384) d =  384
Adding embeddings to the index (dimension 384)...
Adding embeddings to the index... done.
Batch 3
Got 200 texts
Getting batch texts... done.
Encoding the dataset embeddings...
torch embeddings tensor shape:  torch.Size([200, 384])
numpy embeddings shape:  (200, 384)
Encoding the dataset embeddings... done.
#Dataset embeddings =  (200, 384) d =  384
Adding embeddings to th

In [10]:
# Query the index and retrieve relevant texts
TOP_K_TEXTS = 1
prompts = ["What is the capital of India?",
           "Who is the president of the United States?",
           "What is the population of China?",
           "The captial of France is ",
           "Where is the Eiffel Tower located?"]

print("Retrieving relevant texts...")
print("Index: ", index)
print("Length of texts = ", len(texts))
retrieved_texts = rag.retrieve(prompts, index, texts, top_k=TOP_K_TEXTS, debug=False)
print("Retrieving relevant texts...done")
print("")

Retrieving relevant texts...
Index:  <faiss.swigfaiss.GpuIndexFlat; proxy of <Swig Object of type 'faiss::gpu::GpuIndexFlat *' at 0x78522c303750> >
Length of texts =  100000
Retrieving relevant texts...done



In [37]:
for p, ts in zip(prompts, retrieved_texts):
    print(f"Prompt: {p}")
    print(f"Retrieved texts: {len(ts)}")
    print("")

prepend = "Based on the information provided, answer the question concisely. Do not repeat the prompt."
section1 = "[INFO]"
section2 = "[QUESTION]"
section3 = "[ANSWER]"
prompt_with_context = [prepend + "\n" + section1 + "\n" + "\n".join(ts)[:200] + "\n" + section2 + "\n" + p + "\n" + section3 + "\n" for p, ts in zip(prompts, retrieved_texts)]
# contexts = [f"{' '.join(ts)[:200]} {p}" for p, ts in zip(prompts, retrieved_texts)]
print("Prompts with contexts: ", prompt_with_context)

Prompt: What is the capital of India?
Retrieved texts: 1

Prompt: Who is the president of the United States?
Retrieved texts: 1

Prompt: What is the population of China?
Retrieved texts: 1

Prompt: The captial of France is 
Retrieved texts: 1

Prompt: Where is the Eiffel Tower located?
Retrieved texts: 1

Prompts with contexts:  ['Based on the information provided, answer the question concisely. Do not repeat the prompt.\n[INFO]\n{\'paragraph\': [\'Bharat Ek Khoj\\n\', \'Bharat Ek Khoj ("The Discovery of India") is a 53-episode Indian historical drama based on the book "The Discovery of India" (1946) by Jawaharlal Nehru that covers \n[QUESTION]\nWhat is the capital of India?\n[ANSWER]\n', "Based on the information provided, answer the question concisely. Do not repeat the prompt.\n[INFO]\n{'paragraph': ['Presidential system\\n', 'A presidential system is a democratic and republican system of government where a head of government leads an executive branch that is separate from the legis

In [38]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

def get_gpt2_model():
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    model = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)
    #.to(device)
    return model, tokenizer
model, tokenizer = get_gpt2_model()

In [44]:
def generate_text(model, tokenizer, prompt, max_length=50, device=device):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    outputs = model.generate(input_ids,
                             #.to(device),
                             max_length=10000,
                             do_sample=True,
                             num_beams=2,
                             temperature=0.1,
                             no_repeat_ngram_size=2,
                             early_stopping=True,
                             eos_token_id=tokenizer.encode(".")[0])

    # Decode the generated sequence back to text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return generated_text

In [40]:
generate_text(model, tokenizer, "The capital of France is")

'The capital of France is in a state of war with Russia.'

In [41]:
generate_text(model, tokenizer, "The capital of India is")

'The capital of India is home to more than 1.'

In [42]:
for p in prompts:
    print("Prompt: ", p)
    print("----")
    print(generate_text(model, tokenizer, p))
    print("=====")
    print("")

Prompt:  What is the capital of India?
----
What is the capital of India?

India is a country with a population of over 1.
=====

Prompt:  Who is the president of the United States?
----
Who is the president of the United States?

The president is elected by the people.
=====

Prompt:  What is the population of China?
----
What is the population of China?

The population is estimated to be about 1.
=====

Prompt:  The captial of France is 
----
The captial of France is vernacular, but it is not the only language spoken in the Middle East.
=====

Prompt:  Where is the Eiffel Tower located?
----
Where is the Eiffel Tower located?

Eiffels Tower is located in the heart of London.
=====



In [46]:
for p in prompt_with_context:
    print("Prompt: ", p)
    print("----")
    print(generate_text(model, tokenizer, p))
    print("=====")
    print("")

Prompt:  Based on the information provided, answer the question concisely. Do not repeat the prompt.
[INFO]
{'paragraph': ['Bharat Ek Khoj\n', 'Bharat Ek Khoj ("The Discovery of India") is a 53-episode Indian historical drama based on the book "The Discovery of India" (1946) by Jawaharlal Nehru that covers 
[QUESTION]
What is the capital of India?
[ANSWER]

----
Based on the information provided, answer the question concisely. Do not repeat the prompt.
[INFO]
{'paragraph': ['Bharat Ek Khoj\n', 'Bharat Ek Khoj ("The Discovery of India") is a 53-episode Indian historical drama based on the book "The Discovery of India" (1946) by Jawaharlal Nehru that covers 
[QUESTION]
What is the capital of India?
[ANSWER]
India is an independent nation.
=====

Prompt:  Based on the information provided, answer the question concisely. Do not repeat the prompt.
[INFO]
{'paragraph': ['Presidential system\n', 'A presidential system is a democratic and republican system of government where a head of governm