In [1]:
%pip install datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

In [2]:
%pip install faiss-cpu --no-cache

Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl (30.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m87.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.10.0


In [10]:
import os
import torch
from datasets import load_dataset, Dataset
from transformers import (
    RagTokenizer,
    RagRetriever,
    RagSequenceForGeneration,
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
)

# === Paths ===
dataset_path = "/content/pubmed_subset_dataset"
index_path = os.path.join(dataset_path, "faiss_index")

# === Step 1: Load & Shrink Dataset ===
print("Loading dataset...")
dataset = load_dataset("timaeus/dsir-pile-100k-filtered-for-pubmed-abstracts", split="train[:10]")  # Use fewer for faster indexing

# Prepare for RAG: only need 'text'
rag_dataset = Dataset.from_dict({"text": dataset["contents"]})
rag_dataset = rag_dataset.map(lambda x, i: {"title": f"PubMed Doc {i}"}, with_indices=True)

rag_dataset.save_to_disk(dataset_path)

# === Step 2: Embed with DPR Encoder ===
print("Embedding passages...")
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

def embed_texts(batch):
    inputs = ctx_tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt", max_length=256)
    with torch.no_grad():
        embeddings = ctx_encoder(**inputs).pooler_output
    return {"embeddings": embeddings.cpu().numpy()}

rag_dataset = rag_dataset.map(embed_texts, batched=True, batch_size=16)
rag_dataset.add_faiss_index(column="embeddings")

# === Step 3: Save Dataset + FAISS Index ===
print("Saving dataset and index...")

# Save FAISS index to disk
rag_dataset.get_index("embeddings").save(index_path)

# Drop the index from dataset before saving it
rag_dataset.drop_index("embeddings")

# Save dataset
rag_dataset.save_to_disk(dataset_path)

# === Step 4: Load RAG Components ===
print("Loading RAG model and retriever...")
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")

retriever = RagRetriever.from_pretrained(
    "facebook/rag-sequence-nq",
    index_name="custom",
    passages_path=dataset_path,
    index_path=index_path,
    use_dummy_dataset=False,
)

model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)

# === Step 5: Ask a Question ===
print("Generating answer...")
question = "What is the treatment for influenza?"
inputs = tokenizer.prepare_seq2seq_batch([question], return_tensors="pt")

with torch.no_grad():
    generated = model.generate(input_ids=inputs["input_ids"])

answer = tokenizer.batch_decode(generated, skip_special_tokens=True)
print("Answer:", answer[0])

Loading dataset...


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Embedding passages...


Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokeniz

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Saving dataset and index...


Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Loading RAG model and retriever...


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

pytorch_model.bin:   0%|          | 0.00/2.06G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.06G [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/rag-sequence-nq were not used when initializing RagSequenceForGeneration: ['rag.question_encoder.question_encoder.bert_model.pooler.dense.bias', 'rag.question_encoder.question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing RagSequenceForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RagSequenceForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Generating answer...
Answer:  rectilinear duct


In [11]:
# === Step 5: Ask a Question ===
print("Generating answer...")
question = "What disease does Borrelia miyamotoi cause?"
inputs = tokenizer.prepare_seq2seq_batch([question], return_tensors="pt")

with torch.no_grad():
    generated = model.generate(input_ids=inputs["input_ids"])

answer = tokenizer.batch_decode(generated, skip_special_tokens=True)
print("Answer:", answer[0])

Generating answer...
Answer:  delayed repair
