In [None]:
import pandas as pd

df = pd.read_csv('filtered_arxiv_papers_metadata.csv')

In [None]:
df['categories']

### First Approach 

In [3]:
embeds = np.load('embeddings.npy')

In [None]:
embeds.shape

In [None]:
# Install required libraries (uncomment if needed)
# !pip install transformers datasets faiss-cpu sentence-transformers torch tqdm pandas numpy

import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from tqdm import tqdm
import json
import logging
from pathlib import Path

# ==============================
# Configuration Variables
# ==============================

# Path to the filtered arXiv metadata JSON file
INPUT_JSON = 'arxiv-metadata-oai-snapshot.json'  # Replace with your actual file path

# Path to save the FAISS index
FAISS_INDEX_PATH = 'faiss_index.index'

# Path to save the metadata mapping
METADATA_MAPPING_PATH = 'metadata_mapping.json'

# Number of top similar documents to retrieve
TOP_K = 5  # You can adjust this number as needed

# Directory to store logs
LOG_FILE = 'rag_system_macos.log'

# Path to embeddings file
EMBEDDINGS_PATH = 'embeddings.npy'

# Embedding model name
EMBEDDING_MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'

# Generation model name (using an instruction-tuned model)
GENERATION_MODEL_NAME = 'google/flan-t5-small'  # Use 'google/flan-t5-small' if resource constraints

# Maximum input length for the generation model
MAX_INPUT_LENGTH = 512

# Reserve tokens for the answer
RESERVED_TOKENS = 100  # Adjust based on expected answer length

# ==============================
# Setup Logging
# ==============================

logging.basicConfig(
    filename=LOG_FILE,
    filemode='a',
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)

# ==============================
# Function Definitions
# ==============================

def load_metadata(json_path, categories):
    """
    Loads and filters the arXiv metadata from a JSON file based on specified categories.
    
    Parameters:
        json_path (str): Path to the arXiv JSON dataset.
        categories (list): List of categories to filter by.
    
    Returns:
        pd.DataFrame: DataFrame containing filtered metadata.
    """
    papers = []
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc="Loading and Filtering Papers"):
                try:
                    paper = json.loads(line)
                except json.JSONDecodeError:
                    continue  # Skip malformed lines
                
                paper_categories = paper.get('categories', '').split()
                if not set(paper_categories).intersection(set(categories)):
                    continue  # Skip papers not in the desired categories
                
                arxiv_id = paper.get('id', '')
                title = paper.get('title', '').replace('\n', ' ').strip()
                authors = paper.get('authors', '').strip()  # Corrected authors field
                abstract = paper.get('abstract', '').replace('\n', ' ').strip()
                published_date = paper.get('published', '')
                categories_str = ', '.join(paper_categories)
                arxiv_url = paper.get('link', '')
                journal_ref = paper.get('journal-ref', '')
                comment = paper.get('comment', '')
                doi = paper.get('doi', '')
                
                papers.append({
                    'arxiv_id': arxiv_id,
                    'title': title,
                    'authors': authors,
                    'abstract': abstract,
                    'published_date': published_date,
                    'categories': categories_str,
                    'arxiv_url': arxiv_url,
                    'journal_ref': journal_ref,
                    'comment': comment,
                    'doi': doi
                })
        df = pd.DataFrame(papers)
        print(f"Loaded and filtered metadata for {len(df)} papers.")
        logging.info(f"Loaded and filtered metadata for {len(df)} papers.")
        return df
    except Exception as e:
        logging.error(f"Error loading metadata: {e}")
        print(f"Error loading metadata: {e}")
        return pd.DataFrame()

def preprocess_text(row):
    """
    Preprocesses the text by combining relevant fields.
    
    Parameters:
        row (pd.Series): A row from the dataframe.
    
    Returns:
        str: Combined text.
    """
    title = row['title'] if pd.notna(row['title']) else ''
    abstract = row['abstract'] if pd.notna(row['abstract']) else ''
    categories = row['categories'] if pd.notna(row['categories']) else ''
    return f"{title}:{abstract}"

def generate_embeddings(texts, model_name='sentence-transformers/all-MiniLM-L6-v2'):
    """
    Generates embeddings for a list of texts using SentenceTransformer.
    
    Parameters:
        texts (list of str): List of text strings.
        model_name (str): SentenceTransformer model name.
    
    Returns:
        np.ndarray: Array of embeddings.
    """
    model = SentenceTransformer(model_name)
    embeddings = model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
    return embeddings

def build_faiss_index(embeddings, index_path):
    """
    Builds and saves a FAISS index from embeddings.
    
    Parameters:
        embeddings (np.ndarray): Array of embeddings.
        index_path (str): Path to save the FAISS index.
    """
    try:
        # Normalize embeddings
        faiss.normalize_L2(embeddings)
        
        # Initialize FAISS index
        index = faiss.IndexFlatIP(embeddings.shape[1])  # Inner Product for cosine similarity
        index.add(embeddings)
        faiss.write_index(index, index_path)
        print(f"FAISS index built and saved to {index_path}")
        logging.info(f"FAISS index built and saved to {index_path}")
    except Exception as e:
        logging.error(f"Error building FAISS index: {e}")
        print(f"Error building FAISS index: {e}")

def save_metadata_mapping(df, mapping_path):
    """
    Saves the metadata mapping to a JSON file.
    
    Parameters:
        df (pd.DataFrame): DataFrame containing the metadata.
        mapping_path (str): Path to save the mapping.
    """
    try:
        mapping = df.to_dict(orient='records')
        with open(mapping_path, 'w', encoding='utf-8') as f:
            json.dump(mapping, f, ensure_ascii=False, indent=4)
        print(f"Metadata mapping saved to {mapping_path}")
        logging.info(f"Metadata mapping saved to {mapping_path}")
    except Exception as e:
        logging.error(f"Error saving metadata mapping: {e}")
        print(f"Error saving metadata mapping: {e}")

def load_faiss_index(index_path):
    """
    Loads a FAISS index from a file.
    
    Parameters:
        index_path (str): Path to the FAISS index file.
    
    Returns:
        faiss.Index: Loaded FAISS index.
    """
    try:
        index = faiss.read_index(index_path)
        print(f"FAISS index loaded from {index_path}")
        logging.info(f"FAISS index loaded from {index_path}")
        return index
    except Exception as e:
        logging.error(f"Error loading FAISS index: {e}")
        print(f"Error loading FAISS index: {e}")
        return None

def load_metadata_mapping(mapping_path):
    """
    Loads metadata mapping from a JSON file.
    
    Parameters:
        mapping_path (str): Path to the metadata mapping JSON file.
    
    Returns:
        list of dict: Metadata mapping.
    """
    try:
        with open(mapping_path, 'r', encoding='utf-8') as f:
            metadata_mapping = json.load(f)
        print(f"Metadata mapping loaded from {mapping_path}")
        logging.info(f"Metadata mapping loaded from {mapping_path}")
        return metadata_mapping
    except Exception as e:
        logging.error(f"Error loading metadata mapping: {e}")
        print(f"Error loading metadata mapping: {e}")
        return []

def load_generation_model(model_name='google/flan-t5-base'):
    """
    Loads the generation model and tokenizer.
    
    Parameters:
        model_name (str): Hugging Face model name.
    
    Returns:
        tuple: (tokenizer, model, device)
    """
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        
        # Move model to GPU (MPS on macOS if available)
        if torch.backends.mps.is_available():
            device = torch.device("mps")
            model.to(device)
            print("Generation model loaded on macOS GPU (MPS).")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
            model.to(device)
            print("Generation model loaded on CUDA GPU.")
        else:
            device = torch.device("cpu")
            model.to(device)
            print("Generation model loaded on CPU.")
        
        logging.info(f"Generation model '{model_name}' loaded successfully on {device}.")
        return tokenizer, model, device
    except Exception as e:
        logging.error(f"Error loading generation model: {e}")
        print(f"Error loading generation model: {e}")
        return None, None, None

def query_faiss(index, query_embedding, top_k=5):
    """
    Queries the FAISS index to retrieve top_k similar embeddings.
    
    Parameters:
        index (faiss.Index): FAISS index.
        query_embedding (np.ndarray): Embedding of the query.
        top_k (int): Number of top results to retrieve.
    
    Returns:
        list of int: Indices of the top_k similar embeddings.
    """
    try:
        # Normalize the query embedding
        faiss.normalize_L2(query_embedding)
        
        # Perform the search
        distances, indices = index.search(query_embedding, top_k)
        return indices[0]
    except Exception as e:
        logging.error(f"Error querying FAISS index: {e}")
        print(f"Error querying FAISS index: {e}")
        return []

def generate_answer(tokenizer, model, device, prompt):
    """
    Generates an answer based on the prompt using the generation model.
    
    Parameters:
        tokenizer: Tokenizer of the generation model.
        model: Generation model.
        device: Device where the model is loaded.
        prompt (str): The input prompt.
    
    Returns:
        str: Generated answer.
    """
    try:
        inputs = tokenizer(prompt, return_tensors='pt', max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=MAX_INPUT_LENGTH,
            num_beams=5,
            early_stopping=True
        )
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return answer
    except Exception as e:
        logging.error(f"Error generating answer: {e}")
        print(f"Error generating answer: {e}")
        return "I'm sorry, I couldn't generate an answer for your query."

# ==============================
# Main Execution
# ==============================

def main():
    # Step 1: Load and filter the metadata
    categories = ['cs.AI', 'cs.LG', 'cs.CV', 'cs.CL', 'stat.ML']
    df = load_metadata(INPUT_JSON, categories)
    if df.empty:
        print("No data to process. Exiting.")
        return
    
    # Step 2: Preprocess text for embeddings
    print("Preprocessing text for embeddings...")
    df['combined_text'] = df.apply(preprocess_text, axis=1)
    combined_texts = df['combined_text'].tolist()
    
    # Step 3: Generate embeddings
    print("Generating embeddings...")
    embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
    # Check if embeddings file exists
    if Path(EMBEDDINGS_PATH).exists():
        print(f"Embeddings file '{EMBEDDINGS_PATH}' found. Loading embeddings...")
        embeddings = np.load(EMBEDDINGS_PATH)
    else:
        embeddings = generate_embeddings(combined_texts, model_name=EMBEDDING_MODEL_NAME)
        np.save(EMBEDDINGS_PATH, embeddings)
        print(f"Embeddings saved to {EMBEDDINGS_PATH}")
    
    # Step 4: Build and save FAISS index
    print("Building FAISS index...")
    faiss.normalize_L2(embeddings)
    index = faiss.IndexFlatIP(embeddings.shape[1])  # Inner Product for cosine similarity
    index.add(embeddings)
    faiss.write_index(index, FAISS_INDEX_PATH)
    print(f"FAISS index built and saved to {FAISS_INDEX_PATH}")
    logging.info(f"FAISS index built and saved to {FAISS_INDEX_PATH}")
    
    # Step 5: Save metadata mapping
    print("Saving metadata mapping...")
    mapping = df.to_dict(orient='records')
    with open(METADATA_MAPPING_PATH, 'w', encoding='utf-8') as f:
        json.dump(mapping, f, ensure_ascii=False, indent=4)
    print(f"Metadata mapping saved to {METADATA_MAPPING_PATH}")
    logging.info(f"Metadata mapping saved to {METADATA_MAPPING_PATH}")
    
    # Step 6: Load FAISS index and metadata mapping
    index = load_faiss_index(FAISS_INDEX_PATH)
    if index is None:
        print("Failed to load FAISS index. Exiting.")
        return
    metadata_mapping = load_metadata_mapping(METADATA_MAPPING_PATH)
    if not metadata_mapping:
        print("Failed to load metadata mapping. Exiting.")
        return
    
    # Step 7: Load generation model
    print("Loading generation model...")
    tokenizer, model, device = load_generation_model(model_name=GENERATION_MODEL_NAME)
    if tokenizer is None or model is None:
        print("Failed to load generation model. Exiting.")
        return
    
    # Step 8: Interactive Querying
    print("\nRAG System is ready. You can now ask questions related to AI papers.")
    print("Type 'exit' to quit.\n")
    
    while True:
        user_query = input("Your Question: ")
        if user_query.lower() in ['exit', 'quit']:
            print("Exiting RAG System. Goodbye!")
            break
        
        # Step 8a: Generate embedding for the query
        query_embedding = embedding_model.encode([user_query], convert_to_numpy=True).astype('float32')
        faiss.normalize_L2(query_embedding)
        
        # Step 8b: Retrieve top_k similar papers
        top_indices = query_faiss(index, query_embedding, top_k=TOP_K)
        retrieved_papers = [metadata_mapping[idx] for idx in top_indices]
        
        # Step 8c: Prepare context for generation
        # Limit context length to fit within model's maximum input length
        max_context_length = MAX_INPUT_LENGTH - RESERVED_TOKENS
        context_texts = []
        total_tokens = 0
        
        for paper in retrieved_papers:
            text = f"{paper['title']}: {paper['abstract']}"
            encoded_text = tokenizer.encode(text, truncation=True, max_length=max_context_length, add_special_tokens=False)
            text_length = len(encoded_text)
            # if total_tokens + text_length > max_context_length:
                # break
            context_texts.append(text)
            total_tokens += text_length
        
        context = "\n\n".join(context_texts)
        
        # Optional: Print the context for debugging
        # print(f'Context:\n{context}')
        
        # Step 8d: Create prompt for generation
        prompt = (
            f"Please provide a concise and informative answer to the question based on the context below.\n\n"
            f"Context:\n{context}\n\n"
            f"Question: {user_query}\n"
            f"Answer:"
        )
        
        # Optional: Print the prompt for debugging
        print(f'Prompt:\n{prompt}')
        
        # Step 8e: Generate answer
        answer = generate_answer(tokenizer, model, device, prompt)
        print(f"\nAnswer:\n{answer}\n")

if __name__ == "__main__":
    main()

### Second Approach

In [None]:
# Install required libraries (uncomment if needed)
# !pip install transformers datasets faiss-cpu sentence-transformers torch tqdm pandas numpy

import json
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import logging
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, Trainer, TrainingArguments
from datasets import load_dataset, load_from_disk, Dataset
import torch
from pathlib import Path  # Added import for Path

In [2]:
# ==============================
# Configuration Variables
# ==============================

INPUT_JSON = 'arxiv-metadata-oai-snapshot.json'  # Path to your arXiv metadata JSON
CATEGORIES = ['cs.AI', 'cs.LG', 'cs.CV', 'cs.CL', 'stat.ML']  # Categories of interest
EMBEDDING_MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'  # Embedding model
FAISS_INDEX_PATH = 'faiss_index.index'
METADATA_MAPPING_PATH = 'metadata_mapping.json'
EMBEDDINGS_PATH = 'embeddings.npy'
LOG_FILE = 'knowledge_base_preparation.log'
QA_DATASET_NAME = 'qasper'  # Hugging Face dataset name

FINE_TUNED_MODEL_PATH = './rag-finetuned-qasper'

# ==============================
# Setup Logging
# ==============================

logging.basicConfig(
    filename=LOG_FILE,
    filemode='a',
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)


In [3]:

# ==============================
# Function Definitions
# ==============================

def load_and_filter_arxiv(json_path, categories):
    papers = []
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc="Loading and Filtering Papers"):
                try:
                    paper = json.loads(line)
                except json.JSONDecodeError:
                    continue
                paper_categories = paper.get('categories', '').split()
                if not set(paper_categories).intersection(set(categories)):
                    continue
                arxiv_id = paper.get('id', '')
                title = paper.get('title', '').replace('\n', ' ').strip()
                authors = ', '.join(paper.get('authors', []))
                abstract = paper.get('abstract', '').replace('\n', ' ').strip()
                published_date = paper.get('published', '')
                categories_str = ', '.join(paper_categories)
                arxiv_url = paper.get('link', '')
                journal_ref = paper.get('journal-ref', '')
                comment = paper.get('comment', '')
                doi = paper.get('doi', '')
                papers.append({
                    'arxiv_id': arxiv_id,
                    'title': title,
                    'authors': authors,
                    'abstract': abstract,
                    'published_date': published_date,
                    'categories': categories_str,
                    'arxiv_url': arxiv_url,
                    'journal_ref': journal_ref,
                    'comment': comment,
                    'doi': doi
                })
        df = pd.DataFrame(papers)
        logging.info(f"Loaded and filtered metadata for {len(df)} papers.")
        print(f"Loaded and filtered metadata for {len(df)} papers.")
        return df
    except Exception as e:
        logging.error(f"Error loading metadata: {e}")
        print(f"Error loading metadata: {e}")
        return pd.DataFrame()

def preprocess_documents(df):
    combined_texts = []
    for _, row in df.iterrows():
        title = row['title'] if pd.notna(row['title']) else ''
        abstract = row['abstract'] if pd.notna(row['abstract']) else ''
        categories = row['categories'] if pd.notna(row['categories']) else ''
        combined = f"Title: {title}\nAbstract: {abstract}\nCategories: {categories}"
        combined_texts.append(combined)
    return combined_texts

def generate_embeddings(texts, model_name):
    model = SentenceTransformer(model_name)
    embeddings = model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
    return embeddings

def build_faiss_index(embeddings, index_path):
    try:
        faiss.normalize_L2(embeddings)
        index = faiss.IndexFlatIP(embeddings.shape[1])
        index.add(embeddings)
        faiss.write_index(index, index_path)
        logging.info(f"FAISS index built and saved to {index_path}")
        print(f"FAISS index built and saved to {index_path}")
    except Exception as e:
        logging.error(f"Error building FAISS index: {e}")
        print(f"Error building FAISS index: {e}")

def save_metadata_mapping(df, mapping_path):
    try:
        mapping = df.to_dict(orient='records')
        with open(mapping_path, 'w', encoding='utf-8') as f:
            json.dump(mapping, f, ensure_ascii=False, indent=4)
        logging.info(f"Metadata mapping saved to {mapping_path}")
        print(f"Metadata mapping saved to {mapping_path}")
    except Exception as e:
        logging.error(f"Error saving metadata mapping: {e}")
        print(f"Error saving metadata mapping: {e}")

def load_faiss_index(index_path):
    try:
        index = faiss.read_index(index_path)
        print(f"FAISS index loaded from {index_path}")
        logging.info(f"FAISS index loaded from {index_path}")
        return index
    except Exception as e:
        logging.error(f"Error loading FAISS index: {e}")
        print(f"Error loading FAISS index: {e}")
        return None

def load_metadata_mapping(mapping_path):
    try:
        with open(mapping_path, 'r', encoding='utf-8') as f:
            metadata_mapping = json.load(f)
        print(f"Metadata mapping loaded from {mapping_path}")
        logging.info(f"Metadata mapping loaded from {mapping_path}")
        return metadata_mapping
    except Exception as e:
        logging.error(f"Error loading metadata mapping: {e}")
        print(f"Error loading metadata mapping: {e}")
        return []

def load_qa_dataset(dataset_name):
    try:
        dataset = load_dataset(dataset_name)
        print(f"Loaded QA dataset '{dataset_name}' with {len(dataset['train'])} samples.")
        logging.info(f"Loaded QA dataset '{dataset_name}' with {len(dataset['train'])} samples.")
        return dataset
    except Exception as e:
        logging.error(f"Error loading QA dataset '{dataset_name}': {e}")
        print(f"Error loading QA dataset '{dataset_name}': {e}")
        return None

In [4]:
def fine_tune_rag(qa_dataset, faiss_index_path, metadata_mapping_path):
    try:
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-nq",
            index_name="custom",
            passages_path="passages_dataset",  # Updated to point to the passages dataset directory
            use_dummy_dataset=False
        )
        model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
        
        # Load FAISS index and set passages
        retriever.index = faiss.read_index(faiss_index_path)
        with open(metadata_mapping_path, 'r', encoding='utf-8') as f:
            metadata = json.load(f)
        retriever.set_passages(metadata)
        
        training_args = TrainingArguments(
            output_dir=FINE_TUNED_MODEL_PATH,
            per_device_train_batch_size=2,
            per_device_eval_batch_size=2,
            num_train_epochs=3,
            learning_rate=5e-5,
            evaluation_strategy="steps",
            save_steps=500,
            save_total_limit=2,
            logging_steps=100,
            fp16=True,
            dataloader_num_workers=4,
        )
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=qa_dataset['train'],
            eval_dataset=qa_dataset['validation'] if 'validation' in qa_dataset else qa_dataset['train'],
        )
        
        trainer.train()
        trainer.save_model(FINE_TUNED_MODEL_PATH)
        tokenizer.save_pretrained(FINE_TUNED_MODEL_PATH)
        retriever.save_pretrained(FINE_TUNED_MODEL_PATH)
        logging.info(f"Fine-tuned RAG model saved to {FINE_TUNED_MODEL_PATH}")
        print(f"Fine-tuned RAG model saved to {FINE_TUNED_MODEL_PATH}")
    except Exception as e:
        logging.error(f"Error during fine-tuning: {e}")
        print(f"Error during fine-tuning: {e}")


In [5]:
def load_finetuned_rag(model_path):
    try:
        tokenizer = RagTokenizer.from_pretrained(model_path)
        retriever = RagRetriever.from_pretrained(
            model_path,
            index_name="custom",
            passages_path="passages_dataset",  # Updated to point to the passages dataset directory
            use_dummy_dataset=False
        )
        model = RagSequenceForGeneration.from_pretrained(model_path)
        
        # Load FAISS index and set passages
        retriever.index = faiss.read_index(FAISS_INDEX_PATH)
        with open(METADATA_MAPPING_PATH, 'r', encoding='utf-8') as f:
            metadata = json.load(f)
        retriever.set_passages(metadata)
        
        return tokenizer, retriever, model
    except Exception as e:
        logging.error(f"Error loading fine-tuned RAG model: {e}")
        print(f"Error loading fine-tuned RAG model: {e}")
        return None, None, None

In [6]:
def answer_query(tokenizer, retriever, model, query, device, max_new_tokens=200):
    try:
        inputs = tokenizer(query, return_tensors="pt").to(device)
        generated = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
        answer = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
        return answer
    except Exception as e:
        logging.error(f"Error generating answer: {e}")
        print(f"Error generating answer: {e}")
        return "I'm sorry, I couldn't generate an answer for your query."


In [None]:
def interactive_querying(tokenizer, retriever, model, device):
    print("\nRAG System is ready. You can now ask questions related to AI papers.")
    print("Type 'exit' to quit.\n")
    
    while True:
        user_query = input("Your Question: ")
        if user_query.lower() in ['exit', 'quit']:
            print("Exiting RAG System. Goodbye!")
            break
        
        answer = answer_query(tokenizer, retriever, model, user_query, device)
        print(f"\nAnswer:\n{answer}\n")

def main():
    # Step 1: Prepare Knowledge Base
    df = load_and_filter_arxiv(INPUT_JSON, CATEGORIES)
    if df.empty:
        print("No data loaded. Exiting.")
        return
    combined_texts = preprocess_documents(df)
    
    # Check if embeddings file exists
    if Path(EMBEDDINGS_PATH).exists():
        print(f"Embeddings file '{EMBEDDINGS_PATH}' found. Loading embeddings...")
        embeddings = np.load(EMBEDDINGS_PATH)
    else:
        embeddings = generate_embeddings(combined_texts, EMBEDDING_MODEL_NAME)
        np.save(EMBEDDINGS_PATH, embeddings)
        print(f"Embeddings saved to {EMBEDDINGS_PATH}")
    
    build_faiss_index(embeddings, FAISS_INDEX_PATH)
    save_metadata_mapping(df, METADATA_MAPPING_PATH)
    
    # Create and save passages dataset
    print("Creating passages dataset...")
    passages_dataset = Dataset.from_dict({
        'title': df['title'].tolist(),
        'abstract': df['abstract'].tolist(),
        'categories': df['categories'].tolist(),
        'text': combined_texts
    })
    passages_dataset.save_to_disk('passages_dataset')
    print("Passages dataset saved to 'passages_dataset' directory.")
    
    # Step 2: Load QA Dataset
    qa_dataset = load_qa_dataset(QA_DATASET_NAME)
    if qa_dataset is None:
        print("Failed to load QA dataset. Exiting.")
        return
    
    # Step 3: Fine-Tune RAG Model
    fine_tune_rag(qa_dataset, FAISS_INDEX_PATH, METADATA_MAPPING_PATH)
    
    # Step 4: Inference
    tokenizer, retriever, model = load_finetuned_rag(FINE_TUNED_MODEL_PATH)
    if tokenizer is None or retriever is None or model is None:
        print("Failed to load the fine-tuned RAG model. Exiting.")
        return
    
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        model.to(device)
        print("Using macOS GPU (MPS) for inference.")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
        model.to(device)
        print("Using CUDA GPU for inference.")
    else:
        device = torch.device("cpu")
        model.to(device)
        print("Using CPU for inference.")
    
    interactive_querying(tokenizer, retriever, model, device)

if __name__ == "__main__":
    main()

In [None]:
# Install required libraries (uncomment if needed)
# !pip install transformers datasets peft accelerate torch tqdm pandas numpy

import pandas as pd
import json
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from peft import get_peft_model, LoraConfig, TaskType
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset
import torch
import numpy as np

# ==============================
# Configuration Variables
# ==============================

# Path to the filtered arXiv metadata JSON file
INPUT_JSON = 'arxiv-metadata-oai-snapshot.json'  # Replace with your actual file path

# Path to save the fine-tuned model
OUTPUT_DIR = 'fine_tuned_gpt_neo'

# Number of training epochs
NUM_EPOCHS = 3

# Learning rate
LEARNING_RATE = 5e-5

# Number of top similar documents to retrieve initially
INITIAL_TOP_K = 50

# Number of documents after re-ranking to include in context
TOP_N = 5

# Batch size per device during training
BATCH_SIZE = 1  # Adjust based on memory availability

# Embedding model name
EMBEDDING_MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'

# Generation model name (using an instruction-tuned model)
GENERATION_MODEL_NAME = 'google/flan-t5-small'  # Change to 'google/flan-t5-base' if feasible

# Maximum input length for the generation model
MAX_INPUT_LENGTH = 512

# Reserve tokens for the answer and prompt
RESERVED_TOKENS = 100

# Categories to filter arXiv papers
CATEGORIES = ['cs.AI', 'cs.LG', 'cs.CV', 'cs.CL', 'stat.ML']

# ==============================
# Function Definitions
# ==============================

def load_metadata(json_path, categories):
    """
    Loads and filters the arXiv metadata from a JSON file based on specified categories.

    Parameters:
        json_path (str): Path to the arXiv JSON dataset.
        categories (list): List of categories to filter by.

    Returns:
        pd.DataFrame: DataFrame containing filtered metadata.
    """
    papers = []
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc="Loading and Filtering Papers"):
                try:
                    paper = json.loads(line)
                except json.JSONDecodeError:
                    continue  # Skip malformed lines

                paper_categories = paper.get('categories', '').split()
                if not set(paper_categories).intersection(set(categories)):
                    continue  # Skip papers not in the desired categories

                arxiv_id = paper.get('id', '')
                title = paper.get('title', '').replace('\n', ' ').strip()
                authors = ', '.join(paper.get('authors', []))
                abstract = paper.get('abstract', '').replace('\n', ' ').strip()

                papers.append({
                    'arxiv_id': arxiv_id,
                    'title': title,
                    'authors': authors,
                    'abstract': abstract
                })
        df = pd.DataFrame(papers)
        print(f"Loaded and filtered metadata for {len(df)} papers.")
        return df
    except Exception as e:
        print(f"Error loading metadata: {e}")
        return pd.DataFrame()

def preprocess_text(row):
    """
    Preprocesses the text by combining relevant fields.

    Parameters:
        row (pd.Series): A row from the dataframe.

    Returns:
        str: Combined text.
    """
    title = row['title'] if pd.notna(row['title']) else ''
    abstract = row['abstract'] if pd.notna(row['abstract']) else ''
    return f"Title: {title}\nAbstract: {abstract}"

def generate_embeddings(texts, model_name='sentence-transformers/all-MiniLM-L6-v2'):
    """
    Generates embeddings for a list of texts using SentenceTransformer.

    Parameters:
        texts (list of str): List of text strings.
        model_name (str): SentenceTransformer model name.

    Returns:
        np.ndarray: Array of embeddings.
    """
    model = SentenceTransformer(model_name)
    embeddings = model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
    return embeddings

def load_model_with_lora(model_name, device):
    """
    Loads the pretrained model with LoRA configuration for fine-tuning.

    Parameters:
        model_name (str): Hugging Face model name.
        device (torch.device): Device to load the model onto.

    Returns:
        tuple: (tokenizer, model)
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16 if device.type == 'mps' else torch.float32
    )

    # Define LoRA configuration
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
    )

    # Apply LoRA to the model
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    return tokenizer, model

def get_training_arguments(output_dir, per_device_train_batch_size=1, num_train_epochs=3):
    """
    Defines the training arguments for the Trainer.

    Parameters:
        output_dir (str): Directory to save the fine-tuned model.
        per_device_train_batch_size (int): Batch size per device.
        num_train_epochs (int): Number of training epochs.

    Returns:
        TrainingArguments: Hugging Face TrainingArguments instance.
    """
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=8,  # Adjust based on GPU memory
        evaluation_strategy="no",
        save_strategy="epoch",
        logging_steps=100,
        fp16=True if torch.cuda.is_available() or torch.backends.mps.is_available() else False,
        save_total_limit=2,
        dataloader_num_workers=4,
        optim="adamw_torch",
        learning_rate=LEARNING_RATE,
        weight_decay=0.01,
        report_to="none",  # Disable logging to external systems
    )
    return training_args

def prepare_dataset(df, tokenizer, max_length=512):
    """
    Prepares the dataset for training by tokenizing the text.

    Parameters:
        df (pd.DataFrame): DataFrame containing the dataset.
        tokenizer: Tokenizer instance.
        max_length (int): Maximum sequence length.

    Returns:
        Dataset: Hugging Face Dataset object.
    """
    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            truncation=True,
            max_length=max_length,
            padding="max_length",
        )

    # Combine title and abstract
    df['text'] = df.apply(preprocess_text, axis=1)
    dataset = Dataset.from_pandas(df[['text']])
    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    return tokenized_dataset

def get_trainer(model, tokenizer, training_args, train_dataset):
    """
    Initializes the Hugging Face Trainer.

    Parameters:
        model: The model to train.
        tokenizer: The tokenizer.
        training_args: TrainingArguments instance.
        train_dataset: The training dataset.

    Returns:
        Trainer: Hugging Face Trainer instance.
    """
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
    )
    return trainer

def save_fine_tuned_model(trainer, output_dir):
    """
    Saves the fine-tuned model and tokenizer.

    Parameters:
        trainer: Hugging Face Trainer instance.
        output_dir (str): Directory to save the model.
    """
    trainer.save_model(output_dir)
    tokenizer = trainer.tokenizer
    tokenizer.save_pretrained(output_dir)
    print(f"Fine-tuned model saved to {output_dir}")

# ==============================
# Main Execution
# ==============================

def main():
    # Detect device
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using MacOS GPU (MPS).")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using CUDA GPU.")
    else:
        device = torch.device("cpu")
        print("Using CPU.")

    # Step 1: Load and filter the metadata
    df = load_metadata(INPUT_JSON, CATEGORIES)
    if df.empty:
        print("No data to process. Exiting.")
        return

    # Step 2: Preprocess text for embeddings
    print("Preprocessing text for embeddings...")
    df['combined_text'] = df.apply(preprocess_text, axis=1)
    combined_texts = df['combined_text'].tolist()

    # Step 3: Generate embeddings
    print("Generating embeddings...")
    embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
    # Check if embeddings file exists
    if Path('embeddings.npy').exists():
        print(f"Embeddings file 'embeddings.npy' found. Loading embeddings...")
        embeddings = np.load('embeddings.npy')
    else:
        embeddings = generate_embeddings(combined_texts, model_name=EMBEDDING_MODEL_NAME)
        np.save('embeddings.npy', embeddings)
        print(f"Embeddings saved to 'embeddings.npy'")

    # Step 4: Build FAISS index (optional for re-ranking)
    # If you plan to use FAISS for initial retrieval, implement it here.
    # For fine-tuning, this might not be necessary.

    # Step 5: Load the generation model with LoRA
    print("Loading generation model with LoRA...")
    tokenizer, model = load_model_with_lora(GENERATION_MODEL_NAME, device)

    # Step 6: Prepare the dataset for training
    print("Preparing the dataset for training...")
    tokenized_dataset = prepare_dataset(df, tokenizer, max_length=MAX_INPUT_LENGTH)

    # Step 7: Define training arguments
    print("Setting up training arguments...")
    training_args = get_training_arguments(OUTPUT_DIR, per_device_train_batch_size=BATCH_SIZE, num_train_epochs=NUM_EPOCHS)

    # Step 8: Initialize the Trainer
    print("Initializing the Trainer...")
    trainer = get_trainer(model, tokenizer, training_args, tokenized_dataset)

    # Step 9: Fine-Tune the Model
    print("Starting fine-tuning...")
    trainer.train()

    # Step 10: Save the Fine-Tuned Model
    print("Saving the fine-tuned model...")
    save_fine_tuned_model(trainer, OUTPUT_DIR)

    print("Fine-tuning complete.")

if __name__ == "__main__":
    main()