In [1]:
import torch
from transformers import (
    RagTokenForGeneration, 
    RagTokenizer, 
    RagRetriever, 
    RagConfig, 
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer
)
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from tqdm import tqdm
import logging
import os
from datasets import Dataset as HFDataset
import numpy as np
import faiss
import gc

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# More aggressive memory optimization configuration
CONFIG = {
    "model_name": "facebook/rag-sequence-nq",
    "question_encoder_name": "facebook/dpr-question_encoder-single-nq-base",
    "generator_name": "facebook/bart-large",
    "max_length": 64,           # Further reduced
    "batch_size": 1,
    "num_epochs": 3,
    "learning_rate": 1e-5,
    "dataset_path": "custom_dataset",
    "index_path": "custom_index.faiss",
    "chunk_size": 25,          # Smaller chunks
    "gradient_accumulation_steps": 16,  # Increased
    "max_retrieved_passages": 1,  # Reduce number of retrieved passages
    "train_n_passages": 1       # Reduce training passages
}

# Add environment variable to control MPS memory usage
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.5"  # Limit MPS memory usage to 50%

# More aggressive memory management
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    if hasattr(torch.mps, 'empty_cache'):
        torch.mps.empty_cache()
    # Force Python garbage collection
    gc.collect(generation=2)

# Modified RagConfig initialization
def get_optimized_config():
    config = RagConfig.from_pretrained(CONFIG["model_name"])
    config.index_name = "custom"
    config.passages_path = CONFIG["dataset_path"]
    config.index_path = CONFIG["index_path"]
    # Memory optimization settings
    config.n_docs = CONFIG["max_retrieved_passages"]
    config.max_combined_length = CONFIG["max_length"]
    config.train_n_passages = CONFIG["train_n_passages"]
    return config

# Modified embedding computation with memory optimization
@torch.no_grad()
def compute_embeddings(batch):
    try:
        # Process one item at a time to reduce memory usage
        embeddings_list = []
        for text in batch['text']:
            encoding = question_encoder_tokenizer(
                text,
                max_length=CONFIG["max_length"],
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            output = question_encoder(
                input_ids=encoding['input_ids'],
                attention_mask=encoding['attention_mask']
            )
            embedding = output.pooler_output.cpu().numpy()
            embeddings_list.append(embedding[0])
            clear_memory()
        
        return {'embeddings': np.stack(embeddings_list)}
    except Exception as e:
        logging.error(f"Error computing embeddings: {e}")
        return {'embeddings': np.zeros((len(batch['text']), 768))}

# Modified Dataset class with memory optimization
class CQADataset(Dataset):
    def __init__(self, df, rag_tokenizer, max_length=64):
        self.rag_tokenizer = rag_tokenizer
        self.max_length = max_length
        self.data = []
        
        for chunk_start in range(0, len(df), CONFIG["chunk_size"]):
            chunk = df[chunk_start:chunk_start + CONFIG["chunk_size"]]
            self._process_chunk(chunk)
            clear_memory()
    
    def _process_chunk(self, chunk):
        for _, row in chunk.iterrows():
            try:
                article = row["articles"]
                question = "What is the content of this article?"
                
                # Process with smaller max length
                encoded = self.rag_tokenizer(
                    question,
                    article,
                    max_length=self.max_length,
                    truncation=True,
                    padding="max_length",
                    return_tensors="pt"
                )
                
                self.data.append({
                    "input_ids": encoded["input_ids"].squeeze(),
                    "attention_mask": encoded["attention_mask"].squeeze(),
                    "labels": encoded["labels"].squeeze()
                })
                
            except Exception as e:
                logging.error(f"Error processing row: {e}")
                continue
            
            # Clear memory after each item
            clear_memory()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Modified training loop with memory optimization
def train_epoch(model, dataloader, optimizer, scheduler, device, epoch):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{CONFIG['num_epochs']}")
    
    for i, batch in enumerate(progress_bar):
        try:
            # Move tensors to device one at a time
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass with memory optimization
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                n_docs=CONFIG["max_retrieved_passages"]
            )
            
            loss = outputs.loss
            if not isinstance(loss, torch.Tensor):
                loss = torch.tensor(loss, requires_grad=True, device=device)
            
            # Gradient accumulation
            scaled_loss = loss / CONFIG["gradient_accumulation_steps"]
            scaled_loss.backward()
            
            if (i + 1) % CONFIG["gradient_accumulation_steps"] == 0:
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)  # Reduced max norm
                optimizer.step()
                optimizer.zero_grad()
                clear_memory()

            # Update progress
            current_loss = loss.item()
            total_loss += current_loss
            progress_bar.set_postfix({'loss': current_loss})
            
            # Clear memory after each batch
            del input_ids, attention_mask, labels, outputs, loss, scaled_loss
            clear_memory()
            
        except RuntimeError as e:
            logging.error(f"Runtime error during training: {e}")
            optimizer.zero_grad()
            clear_memory()
            continue
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss

# Main training setup
def main():
    # Initialize components with memory optimization
    config = get_optimized_config()
    
    # Initialize model with memory optimization
    model = RagTokenForGeneration.from_pretrained(
        CONFIG["model_name"],
        config=config,
        torch_dtype=torch.float16
    )
    
    # Move model to device
    device = torch.device("mps")
    model.to(device)
    
    # Initialize optimizer with lower memory usage
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=CONFIG["learning_rate"],
        weight_decay=0.01,
        eps=1e-8
    )
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=2,
        verbose=True
    )
    
    # Training loop
    for epoch in range(CONFIG["num_epochs"]):
        try:
            train_dataset = CQADataset(
                articles_df,
                rag_tokenizer,
                max_length=CONFIG["max_length"]
            )
            train_dataloader = DataLoader(
                train_dataset,
                batch_size=CONFIG["batch_size"],
                shuffle=True
            )
            
            avg_loss = train_epoch(
                model,
                train_dataloader,
                optimizer,
                scheduler,
                device,
                epoch
            )
            
            print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")
            scheduler.step(avg_loss)
            
            # Save checkpoint with memory optimization
            if (epoch + 1) % 1 == 0:
                checkpoint_dir = f"models/checkpoint-epoch-{epoch + 1}"
                os.makedirs(checkpoint_dir, exist_ok=True)
                model.save_pretrained(checkpoint_dir)
                rag_tokenizer.save_pretrained(checkpoint_dir)
            
            clear_memory()
            
        except Exception as e:
            logging.error(f"Error during epoch {epoch + 1}: {e}")
            clear_memory()
            continue

if __name__ == "__main__":
    main()


RuntimeError: invalid low watermark ratio 1.4