In [1]:
# Cell 1: Setup and Dataset Preparation
import torch
from transformers import (
    RagTokenForGeneration, 
    RagTokenizer, 
    RagRetriever, 
    RagConfig, 
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer
)
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from tqdm.notebook import tqdm
import logging
import os
from datasets import Dataset as HFDataset
import numpy as np
import faiss

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Configuration
CONFIG = {
    "model_name": "facebook/rag-sequence-nq",
    "question_encoder_name": "facebook/dpr-question_encoder-single-nq-base",
    "max_length": 512,
    "batch_size": 4,
    "num_epochs": 3,
    "learning_rate": 1e-5,
    "dataset_path": "custom_dataset",
    "index_path": "custom_index.faiss"
}

# Create directories
os.makedirs("custom_dataset", exist_ok=True)

# Load and prepare passages
articles_df = pd.read_csv("articles.csv", usecols=["articles"])

# Create HuggingFace dataset
custom_dataset = HFDataset.from_pandas(
    pd.DataFrame({
        'text': articles_df['articles'].tolist(),
        'title': [f"Article {i}" for i in range(len(articles_df))],
        'id': list(range(len(articles_df)))
    })
)

# Initialize question encoder and tokenizer
question_encoder = DPRQuestionEncoder.from_pretrained(CONFIG["question_encoder_name"])
question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(CONFIG["question_encoder_name"])

# Function to compute embeddings
def compute_embeddings(batch):
    # First tokenize the texts
    encodings = question_encoder_tokenizer(
        batch['text'],
        max_length=CONFIG["max_length"],
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    # Then get the embeddings
    with torch.no_grad():
        outputs = question_encoder(
            input_ids=encodings['input_ids'],
            attention_mask=encodings['attention_mask']
        )
        embeddings = outputs.pooler_output
    
    return {'embeddings': embeddings.numpy()}

# Add embeddings to dataset
custom_dataset = custom_dataset.map(
    compute_embeddings,
    batched=True,
    batch_size=CONFIG["batch_size"]
)

# Create FAISS index
dimension = 768  # DPR embedding dimension
index = faiss.IndexFlatL2(dimension)
embeddings_array = np.array(custom_dataset['embeddings'])
index.add(embeddings_array)

# Save dataset and index
custom_dataset.save_to_disk(CONFIG["dataset_path"])
faiss.write_index(index, CONFIG["index_path"])

# Cell 2: Dataset Class and Model Setup
class CQADataset(Dataset):
    def __init__(self, df, tokenizer, max_length=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.inputs = []
        self.targets = []
        
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Preparing dataset"):
            try:
                article = row["articles"]
                question = "What is the content of this article?"
                
                inputs = self.tokenizer.question_encoder(
                    question,
                    max_length=self.max_length,
                    truncation=True,
                    padding="max_length",
                    return_tensors="pt",
                )
                
                targets = self.tokenizer(
                    article,
                    max_length=self.max_length,
                    truncation=True,
                    padding="max_length",
                    return_tensors="pt",
                )
                
                self.inputs.append({
                    "input_ids": inputs["input_ids"].squeeze(),
                    "attention_mask": inputs["attention_mask"].squeeze(),
                })
                self.targets.append(targets["input_ids"].squeeze())
                
            except Exception as e:
                logging.error(f"Error processing row: {e}")
                continue

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return {
            "input_ids": self.inputs[idx]["input_ids"],
            "attention_mask": self.inputs[idx]["attention_mask"],
            "labels": self.targets[idx],
        }

# Initialize RAG components
tokenizer = RagTokenizer.from_pretrained(CONFIG["model_name"])
config = RagConfig.from_pretrained(CONFIG["model_name"])
config.index_name = "custom"
config.passages_path = CONFIG["dataset_path"]
config.index_path = CONFIG["index_path"]

# Initialize retriever with custom dataset and index
retriever = RagRetriever.from_pretrained(
    CONFIG["model_name"],
    index_name="custom",
    passages_path=CONFIG["dataset_path"],
    index_path=CONFIG["index_path"],
    config=config
)

# Initialize model
model = RagTokenForGeneration.from_pretrained(CONFIG["model_name"], config=config)
model.set_retriever(retriever)

# Prepare dataset for training
train_dataset = CQADataset(articles_df, tokenizer, max_length=CONFIG["max_length"])
train_dataloader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)

# Cell 3: Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])

for epoch in range(CONFIG["num_epochs"]):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{CONFIG['num_epochs']}")
    
    for batch in progress_bar:
        optimizer.zero_grad()
        
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    
    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")

# Save the final model
os.makedirs("models", exist_ok=True)
model.save_pretrained("models/rag_model_final")
tokenizer.save_pretrained("models/rag_model_final")

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Map:   0%|          | 0/15 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/15 [00:00<?, ? examples/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

Preparing dataset:   0%|          | 0/15 [00:00<?, ?it/s]

Epoch 1/3:   0%|          | 0/4 [00:00<?, ?it/s]

: 