# Retrieval Augmented Generation (RAG) model

[link to documentation](https://huggingface.co/docs/transformers/main/en/model_doc/rag#transformers.RagModel)

A different type of model, using a separate retriever and seq2seq model.

[Paper](https://arxiv.org/pdf/2005.11401.pdf)

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()
!conda install -c pytorch faiss-gpu
%pip install datasets transformers

⏬ Downloading https://github.com/jaimergp/miniforge/releases/latest/download/Mambaforge-colab-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:18
🔁 Restarting kernel...
Collecting package metadata (current_repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ done
Solving environment: / - \ | / - \ | / - \ | / - failed with initial frozen solve. Retrying with flexible solve.
Collecting package metadata (repodata.json): | / - \ | / - \ | / - 

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
device = 'cuda:0'

In [None]:
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional

import torch
from datasets import Features, Sequence, Value, load_dataset

import faiss
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizerFast,
    HfArgumentParser,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenizer,
)

In [None]:

logger = logging.getLogger(__name__)
torch.set_grad_enabled(False)


def split_text(text: str, n=100, character=" ") -> List[str]:
    """Split the text every ``n``-th occurrence of ``character``"""
    text = text.split(character)
    return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]


def split_documents(documents: dict) -> dict:
    """Split documents into passages"""
    titles, texts = [], []
    for title, text in zip(documents["title"], documents["text"]):
        if text is not None:
            for passage in split_text(text):
                titles.append(title if title is not None else "")
                texts.append(passage)
    return {"title": titles, "text": texts}


def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
    """Compute the DPR embeddings of document passages"""
    input_ids = ctx_tokenizer(
        documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
    )["input_ids"]
    embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
    return {"embeddings": embeddings.detach().cpu().numpy()}




In [None]:
######################################
logger.info("Step 1 - Create the dataset")
######################################

# The dataset needed for RAG must have three columns:
# - title (string): title of the document
# - text (string): text of a passage of the document
# - embeddings (array of dimension d): DPR representation of the passage

csv_path = "/content/drive/MyDrive/data/data-wiki.csv"

#    # You can load a Dataset object this way
dataset = load_dataset(
    "csv", data_files=[csv_path], split="train", delimiter=",", column_names=["title", "text"]
)

# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files





Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-ca00b307380d8bf0/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-ca00b307380d8bf0/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.


In [None]:

# Then split the documents into passages of 100 words
dataset = dataset.map(split_documents, batched=True, num_proc=1, batch_size=50)  ### <- CHANGED THIS

dpr_ctx_encoder_model_name = "facebook/dpr-ctx_encoder-multiset-base"

# And compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained(dpr_ctx_encoder_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(dpr_ctx_encoder_model_name)
new_features = Features(
    {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
)  # optional, save as float32 instead of float64 to save space
dataset = dataset.map(
    partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
    batched=True,
    batch_size=16,
    features=new_features,
)

# And finally save your dataset
passages_path = os.path.join("/content/drive/MyDrive/data/", "reg_wiki_knowledge_dataset")
dataset.save_to_disk(passages_path)
# from datasets import load_from_disk
# dataset = load_from_disk(passages_path)  # to reload the dataset


  0%|          | 0/5 [00:00<?, ?ba/s]

Downloading:   0%|          | 0.00/492 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-multiset-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.weight', 'ctx_encoder.bert_model.pooler.dense.bias']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizerFast'.


  0%|          | 0/14 [00:00<?, ?ba/s]

In [None]:
#=================================================
#=================================================
#=================================================
#=================================================
passages_path = os.path.join("/content/drive/MyDrive/data/", "reg_wiki_knowledge_dataset")

from datasets import load_from_disk
dataset = load_from_disk(passages_path)  # to reload the dataset

In [None]:
######################################
logger.info("Step 2 - Index the dataset")
######################################

# HNSW arguments for FAISS
# - dimensionality of the embedding
d = 768   
# - number of bi-directional links for every new element during index construction
m = 128

# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
index = faiss.IndexHNSWFlat(d, m, faiss.METRIC_INNER_PRODUCT)
dataset.add_faiss_index("embeddings", custom_index=index)

# And save the index
index_path = os.path.join('/content/drive/MyDrive/data/', "reg_wiki_knowledge_dataset_hnsw_index.faiss")
dataset.get_index("embeddings").save(index_path)
# dataset.load_faiss_index("embeddings", index_path)  # to reload the index

  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
index_path = os.path.join('/content/drive/MyDrive/data/', "reg_wiki_knowledge_dataset_hnsw_index.faiss")
dataset.load_faiss_index("embeddings", index_path)  # to reload the index

In [None]:
######################################
logger.info("Step 3 - Load RAG")
######################################

rag_model_name = "facebook/rag-token-nq"

# Easy way to load the model
retriever = RagRetriever.from_pretrained(
    rag_model_name, index_name="custom", indexed_dataset=dataset
)
model = RagSequenceForGeneration.from_pretrained(rag_model_name, retriever=retriever)
tokenizer = RagTokenizer.from_pretrained(rag_model_name)

# For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
# retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)



  f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The t

In [None]:
######################################
logger.info("Step 4 - Have fun")
######################################

#question = "What is the new starter checklist? "
#question = "What are regular events at the Turing?"
question = "The regular events organised by REG are"
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
generated = model.generate(input_ids)
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
print("Q: " + question)
print("A: " + generated_string)



Q: The regular events organised by REG are
A:  hack week


In [None]:
print(tokenizer.batch_decode(generated, skip_special_tokens=True))

[' town hall']
