<a href="https://colab.research.google.com/github/ashweta1/interp/blob/main/cs230_rectifying_facts_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Rectifying Factual knowledge through RAG (Retrieval Augmented Generation)

This colab uses RAG on wikipedia knowledge dataset, to prepend context to prompts.

RAG is used from the library I wrote: git+https://github.com/ashweta1/rag_wiki.git

Dataset used for evaluation:  https://github.com/kmeng01/rome/tree/main/dsets


## Prepare environment

In [None]:
%%bash

# check that colab exists
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit

# recreate the local home for this colab run
cd /content && rm -rf /content/home && mkdir home && cd home

# install the known facts dataset.
pip install git+https://github.com/kmeng01/rome.git/tree/main/dsets >> install.log 2>&1

# install hugging face datasets library
pip install datasets >> install.log 2>&1

pip install git+https://github.com/ashweta1/rag_wiki.git >> install.log 2>&1
pip list | grep rag_wiki

# install latest torch and faiss-gpu
pip uninstall -y torch faiss-cpu faiss-gpu >> install.log 2>&1
pip install torch faiss-gpu >> install.log 2>&1

# pip uninstall -y torch torchaudio torchvision torchtext torchdata faiss-gpu >> install.log 2>&1
# pip install torch torchaudio torchvision torchtext torchdata faiss-gpu >> install.log 2>&1

In [None]:
IS_COLAB = True
try:
    import google.colab, torch, os

    IS_COLAB = True
    device = "cpu"
    if torch.cuda.is_available():
      device = torch.device("cuda")
    elif torch.backends.mps.is_available():
      device = torch.device("mps")
    else:
      device = torch.device("cpu")
    print("Device = ", device)
        # raise Exception("Change runtime type to include a GPU.")

    os.chdir("/content/home")
    torch.set_grad_enabled(False)  # no model parameter updates

except ModuleNotFoundError as _:
    pass

In [None]:
%%bash

# # Install my own rag_wiki library to use RAG from wikipeda
# cd /content/home && rm -fr /content/home/rag_wiki

# git clone https://github.com/ashweta1/rag_wiki.git && cd /content/home/rag_wiki && pip install -e . && cd /content/home
# ls -R /content/home/rag_wiki/


In [None]:
# IPYTHON magic to automatically reload imported module if they change
%load_ext autoreload
%autoreload 2


In [None]:
!nvidia-smi

In [None]:
# Get wikipedia embeddings loaded
import torch
from rag_wiki import rag

print("torch.cuda.is_available()", torch.cuda.is_available())
print(torch.__version__)

# Load dataset
print("Loading dataset...")
dataset = rag.load_wiki_dataset(num_examples=100, debug=True)
print("Loading dataset...done")
print("")

# Preprocess the dataset
print("Preprocessing dataset...")
index, texts = rag.preprocess(dataset, batch_size=200, debug=True)
print("Preprocessing dataset...done")
print("")

In [None]:
# Query the index and retrieve relevant texts
TOP_K_TEXTS = 1
prompts = ["What is the capital of India?",
           "Who is the president of the United States?",
           "What is the population of China?",
           "The captial of France is ",
           "Where is the Eiffel Tower located?"]

print("Retrieving relevant texts...")
print("Index: ", index)
print("Length of texts = ", len(texts))
retrieved_texts = rag.retrieve(prompts, index, texts, top_k=TOP_K_TEXTS, debug=False)
print("Retrieving relevant texts...done")
print("")

for p, ts in zip(prompts, retrieved_texts):
    print(f"Prompt: {p}")
    print(f"Retrieved texts: {ts}")
    print("")

contexts = [f"{' '.join(ts)} {p}" for p, ts in zip(prompts, retrieved_texts)]
print("Prompts with contexts: ", contexts)

In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

def get_gpt2_model():
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    model = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id).to(device)
    return model, tokenizer
model, tokenizer = get_gpt2_model()

In [None]:
def generate_text(model, tokenizer, prompt, max_length=50, device=device):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    outputs = model.generate(input_ids.to(device),
                             max_length=50,
                             do_sample=True,
                             num_beams=2,
                             temperature=0.001,
                             no_repeat_ngram_size=2,
                             early_stopping=True,
                             eos_token_id=tokenizer.encode(".")[0])

    # Decode the generated sequence back to text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return generated_text

In [None]:
generate_text(model, tokenizer, "The capital of France is")

In [None]:
generate_text(model, tokenizer, "The capital of India is")

In [None]:
for p in prompts:
    print(p)
    generate_text(model, tokenizer, p)

In [None]:
for p in contexts:
    print(generate_text(model, tokenizer, p))