In [2]:
%pip install datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.w

In [3]:
%pip install faiss-cpu --no-cache

Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl (30.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m196.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.10.0


In [4]:
import os
import torch
from datasets import load_dataset, Dataset
from transformers import (
    RagTokenizer,
    RagRetriever,
    RagSequenceForGeneration,
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
)

# === Paths ===
dataset_path = "/content/pubmed_subset_dataset"
index_path = os.path.join(dataset_path, "faiss_index")

# === Step 1: Load & Shrink Dataset ===
print("Loading dataset...")
dataset = load_dataset("timaeus/dsir-pile-100k-filtered-for-pubmed-abstracts", split="train[:10]")  # Use fewer for faster indexing

# Prepare for RAG: only need 'text'
rag_dataset = Dataset.from_dict({"text": dataset["contents"]})
rag_dataset = rag_dataset.map(lambda x, i: {"title": f"PubMed Doc {i}"}, with_indices=True)

rag_dataset.save_to_disk(dataset_path)

# === Step 2: Embed with DPR Encoder ===
print("Embedding passages...")
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

def embed_texts(batch):
    inputs = ctx_tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt", max_length=256)
    with torch.no_grad():
        embeddings = ctx_encoder(**inputs).pooler_output
    return {"embeddings": embeddings.cpu().numpy()}

rag_dataset = rag_dataset.map(embed_texts, batched=True, batch_size=16)
rag_dataset.add_faiss_index(column="embeddings")

# === Step 3: Save Dataset + FAISS Index ===
print("Saving dataset and index...")

# Save FAISS index to disk
rag_dataset.get_index("embeddings").save(index_path)

# Drop the index from dataset before saving it
rag_dataset.drop_index("embeddings")

# Save dataset
rag_dataset.save_to_disk(dataset_path)

# === Step 4: Load RAG Components ===
print("Loading RAG model and retriever...")
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")

retriever = RagRetriever.from_pretrained(
    "facebook/rag-sequence-nq",
    index_name="custom",
    passages_path=dataset_path,
    index_path=index_path,
    use_dummy_dataset=False,
)

model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)


Loading dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/400 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.29M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2272 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Embedding passages...


config.json:   0%|          | 0.00/492 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Saving dataset and index...


Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Loading RAG model and retriever...


config.json:   0%|          | 0.00/4.60k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may res

pytorch_model.bin:   0%|          | 0.00/2.06G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.06G [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/rag-sequence-nq were not used when initializing RagSequenceForGeneration: ['rag.question_encoder.question_encoder.bert_model.pooler.dense.bias', 'rag.question_encoder.question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing RagSequenceForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RagSequenceForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:

# === Step 5: Ask a Question ===
print("Generating answer...")
question = "What is the treatment for influenza?"
inputs = tokenizer.prepare_seq2seq_batch([question], return_tensors="pt")

with torch.no_grad():
    generated = model.generate(input_ids=inputs["input_ids"])

answer = tokenizer.batch_decode(generated, skip_special_tokens=True)
print("Answer:", answer[0])

Generating answer...
Answer:  rectilinear duct


In [15]:
# === Step 6: Show Retrieved Documents ===
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer

# Load DPR question encoder (used internally by RAG)
question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

# Encode the same question
question_inputs = question_tokenizer(question, return_tensors="pt")
with torch.no_grad():
    question_hidden_states = question_encoder(**question_inputs).pooler_output.cpu().numpy()

# Use retriever to get top-k doc indices
retrieval_output = retriever(
    question_input_ids=question_inputs["input_ids"],
    question_hidden_states=question_hidden_states
)
k = 5
retrieved_doc_ids = retrieval_output["doc_ids"][0][:k].tolist()

# Print retrieved docs
print(f"\nTop {k} Retrieved Documents for: '{question}'\n" + "="*60)
for idx in retrieved_doc_ids:
    print(f"[Doc {idx}]\n{dataset[idx]['contents']}\n" + "-"*60)


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).



Top 5 Retrieved Documents for: 'What is the treatment for influenza?'
[Doc 6]
Trends in classification usage in the mental retardation literature.
Trends in classification usage were analyzed through examination of all issues of Mental Retardation, American Journal of Mental Deficiency, and American Journal on Mental Retardation from 1980 through 1989. The research was undertaken to determine whether the recommendations by Taylor (1980) and MacMillan, Meyers, and Morrison (1980) regarding subject description had been implemented. Results indicated that the system of the American Association on Mental Retardation (previously the American Association on Mental Deficiency) was used in over 50% of the articles, whereas the American Educators system was used in only 10%. A further analysis regarding the use of various classification systems as a function of the age of the subjects was also conducted. Implications of these results were discussed.The girl, Chrissiya Berry, was listed as stab

In [4]:
# === Step 5: Ask a Question ===
print("Generating answer...")
question = "What disease does Borrelia miyamotoi cause?"
inputs = tokenizer.prepare_seq2seq_batch([question], return_tensors="pt")

with torch.no_grad():
    generated = model.generate(input_ids=inputs["input_ids"])

answer = tokenizer.batch_decode(generated, skip_special_tokens=True)
print("Answer:", answer[0])

Generating answer...
Answer:  delayed repair
