In [1]:
import torch
from transformers import (
    RagTokenForGeneration, 
    RagTokenizer, 
    RagRetriever, 
    RagConfig, 
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer
)
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from tqdm import tqdm
import logging
import os
from datasets import Dataset as HFDataset
import numpy as np
import faiss
import gc

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Optimized configuration for 8GB RAM
CONFIG = {
    "model_name": "facebook/rag-sequence-nq",
    "question_encoder_name": "facebook/dpr-question_encoder-single-nq-base",
    "max_length": 128,  # Further reduced
    "batch_size": 1,    # Minimum batch size
    "num_epochs": 3,
    "learning_rate": 1e-5,
    "dataset_path": "custom_dataset",
    "index_path": "custom_index.faiss",
    "chunk_size": 50    # Smaller chunks
}

# Memory management helper
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    if hasattr(torch.mps, 'empty_cache'):
        torch.mps.empty_cache()

# Create directories
os.makedirs("custom_dataset", exist_ok=True)

# Load data in chunks
def process_data_in_chunks(df, chunk_size=CONFIG["chunk_size"]):
    chunks = [df[i:i + chunk_size] for i in range(0, len(df), chunk_size)]
    processed_chunks = []
    
    for chunk in tqdm(chunks, desc="Processing chunks"):
        chunk_dataset = HFDataset.from_pandas(
            pd.DataFrame({
                'text': chunk['articles'].tolist(),
                'title': [f"Article {i}" for i in range(len(chunk))],
                'id': list(range(len(chunk)))
            })
        )
        processed_chunks.append(chunk_dataset)
        clear_memory()
    
    return processed_chunks

# Load and prepare passages in chunks
articles_df = pd.read_csv("articles.csv", usecols=["articles"])
dataset_chunks = process_data_in_chunks(articles_df)

# Initialize question encoder with memory optimization
question_encoder = DPRQuestionEncoder.from_pretrained(
    CONFIG["question_encoder_name"],
    torch_dtype=torch.float16  # Use half precision
)
question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(CONFIG["question_encoder_name"])

# Optimized embedding computation
@torch.no_grad()
def compute_embeddings(batch):
    try:
        encodings = question_encoder_tokenizer(
            batch['text'],
            max_length=CONFIG["max_length"],
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        outputs = question_encoder(
            input_ids=encodings['input_ids'],
            attention_mask=encodings['attention_mask']
        )
        embeddings = outputs.pooler_output.cpu().numpy()
        clear_memory()
        
        return {'embeddings': embeddings}
    except Exception as e:
        logging.error(f"Error computing embeddings: {e}")
        return {'embeddings': np.zeros((len(batch['text']), 768))}

# Process embeddings in chunks
all_embeddings = []
for chunk in dataset_chunks:
    chunk_with_embeddings = chunk.map(
        compute_embeddings,
        batched=True,
        batch_size=CONFIG["batch_size"]
    )
    all_embeddings.extend(chunk_with_embeddings['embeddings'])
    clear_memory()

# Create and save FAISS index
dimension = 768
index = faiss.IndexFlatL2(dimension)
embeddings_array = np.array(all_embeddings, dtype=np.float32)
index.add(embeddings_array)
faiss.write_index(index, CONFIG["index_path"])
clear_memory()

# Save processed dataset
combined_dataset = HFDataset.from_pandas(
    pd.DataFrame({
        'text': articles_df['articles'].tolist(),
        'title': [f"Article {i}" for i in range(len(articles_df))],
        'id': list(range(len(articles_df))),
        'embeddings': all_embeddings
    })
)
combined_dataset.save_to_disk(CONFIG["dataset_path"])
clear_memory()

# Optimized Dataset class
class CQADataset(Dataset):
    def __init__(self, df, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.inputs = []
        self.targets = []
        
        for chunk_start in range(0, len(df), CONFIG["chunk_size"]):
            chunk = df[chunk_start:chunk_start + CONFIG["chunk_size"]]
            self._process_chunk(chunk)
            clear_memory()
    
    def _process_chunk(self, chunk):
        for _, row in chunk.iterrows():
            try:
                article = row["articles"]
                question = "What is the content of this article?"
                
                inputs = self.tokenizer.question_encoder(
                    question,
                    max_length=self.max_length,
                    truncation=True,
                    padding="max_length",
                    return_tensors="pt"
                )
                
                targets = self.tokenizer(
                    article,
                    max_length=self.max_length,
                    truncation=True,
                    padding="max_length",
                    return_tensors="pt"
                )
                
                self.inputs.append({
                    "input_ids": inputs["input_ids"].squeeze(),
                    "attention_mask": inputs["attention_mask"].squeeze(),
                })
                self.targets.append(targets["input_ids"].squeeze())
                
            except Exception as e:
                logging.error(f"Error processing row: {e}")
                continue

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return {
            "input_ids": self.inputs[idx]["input_ids"],
            "attention_mask": self.inputs[idx]["attention_mask"],
            "labels": self.targets[idx],
        }

# Initialize model components with memory optimization
tokenizer = RagTokenizer.from_pretrained(CONFIG["model_name"])
config = RagConfig.from_pretrained(CONFIG["model_name"])
config.index_name = "custom"
config.passages_path = CONFIG["dataset_path"]
config.index_path = CONFIG["index_path"]

retriever = RagRetriever.from_pretrained(
    CONFIG["model_name"],
    index_name="custom",
    passages_path=CONFIG["dataset_path"],
    index_path=CONFIG["index_path"],
    config=config
)

# Initialize model with memory optimization
model = RagTokenForGeneration.from_pretrained(
    CONFIG["model_name"],
    config=config,
    torch_dtype=torch.float16  # Use half precision
)
model.set_retriever(retriever)

# Training setup
device = torch.device("mps")  # Use Metal Performance Shaders for M2 Mac
model.to(device)

# Initialize optimizer with gradient clipping
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

# Gradient accumulation steps
gradient_accumulation_steps = 8  # Increased for better memory handling

# Training loop with proper loss handling and error recovery
for epoch in range(CONFIG["num_epochs"]):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    train_dataset = CQADataset(articles_df, tokenizer, max_length=CONFIG["max_length"])
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=CONFIG["batch_size"],
        shuffle=True
    )
    
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{CONFIG['num_epochs']}")
    
    for i, batch in enumerate(progress_bar):
        try:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass with gradient computation
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            # Ensure loss is a scalar tensor
            loss = outputs.loss
            if not isinstance(loss, torch.Tensor):
                loss = torch.tensor(loss, requires_grad=True, device=device)
            
            # Scale loss for gradient accumulation
            scaled_loss = loss / gradient_accumulation_steps
            
            # Backward pass
            scaled_loss.backward()
            
            if (i + 1) % gradient_accumulation_steps == 0:
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                # Optimizer step
                optimizer.step()
                optimizer.zero_grad()
                clear_memory()

            current_loss = loss.item()
            total_loss += current_loss
            progress_bar.set_postfix({'loss': current_loss})
            
        except RuntimeError as e:
            logging.error(f"Runtime error during training: {e}")
            optimizer.zero_grad()
            clear_memory()
            continue
            
        except Exception as e:
            logging.error(f"Unexpected error during training: {e}")
            optimizer.zero_grad()
            clear_memory()
            continue
    
    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")
    
    # Update learning rate
    scheduler.step(avg_loss)
    
    # Clear memory between epochs
    clear_memory()
    
    # Save checkpoint
    if (epoch + 1) % 1 == 0:
        checkpoint_dir = f"models/checkpoint-epoch-{epoch + 1}"
        os.makedirs(checkpoint_dir, exist_ok=True)
        model.save_pretrained(checkpoint_dir)
        tokenizer.save_pretrained(checkpoint_dir)

# Save the final model
os.makedirs("models", exist_ok=True)
model.save_pretrained("models/rag_model_final")
tokenizer.save_pretrained("models/rag_model_final")

Processing chunks: 100%|██████████| 1/1 [00:00<00:00,  5.89it/s]
Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Map:   0%|          | 0/15 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/15 [00:00<?, ? examples/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

Epoch 1, Average Loss: nan


Epoch 2/3:   0%|          | 0/15 [00:00<?, ?it/s]2025-01-03 00:02:45,520 - ERROR - Unexpected error during training: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 768
Epoch 2/3:   7%|▋         | 1/15 [00:00<00:05,  2.57it/s]2025-01-03 00:02:45,630 - ERROR - Unexpected error during training: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 768
Epoch 2/3:  13%|█▎        | 2/15 [00:00<00:02,  4.46it/s]2025-01-03 00:02:45,745 - ERROR - Unexpected error during training: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 768
Epoch 2/3:  20%|██        | 3/15 [00:00<00:02,  5.77it/s]2025-01-03 00:02:45,869 - ERROR - Unexpect

Epoch 2, Average Loss: 0.0000


Epoch 3/3:   0%|          | 0/15 [00:00<?, ?it/s]2025-01-03 00:02:52,405 - ERROR - Unexpected error during training: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 768
Epoch 3/3:   7%|▋         | 1/15 [00:00<00:12,  1.14it/s]2025-01-03 00:02:52,717 - ERROR - Unexpected error during training: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 768
Epoch 3/3:  13%|█▎        | 2/15 [00:01<00:05,  2.24it/s]2025-01-03 00:02:52,844 - ERROR - Unexpected error during training: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 768
Epoch 3/3:  20%|██        | 3/15 [00:01<00:03,  3.34it/s]2025-01-03 00:02:52,982 - ERROR - Unexpect

Epoch 3, Average Loss: 0.0000
