In [1]:
import pandas as pd
papers = pd.read_csv('C:/Users/Keerthana/Documents/machine_learning/ML_Project/Project_Code/data/arxiv_papers_small.csv')

In [None]:
import faiss
import numpy as np
import pandas as pd
import spacy
from transformers import BertTokenizer, BertModel, T5Tokenizer, T5ForConditionalGeneration
import matplotlib.pyplot as plt

# Load spaCy model
nlp = spacy.load("en_core_web_sm")

# Preprocessing function
def preprocess_text(text):
    doc = nlp(text.lower())
    return " ".join([token.text for token in doc if not token.is_stop and not token.is_punct])

# Assuming papers DataFrame is already loaded with columns id, title, and abstract
papers['abstract_processed'] = papers['abstract'].apply(preprocess_text)

# BERT Model Setup
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# T5 Model Setup
t5_tokenizer = T5Tokenizer.from_pretrained('t5-small')
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small')

# Get embeddings for all papers
def get_paper_embeddings(texts):
    embeddings = []
    for text in texts:
        inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        outputs = bert_model(**inputs)
        embeddings.append(outputs.last_hidden_state[:, 0, :].detach().numpy())
    return np.vstack(embeddings)

# Build FAISS Index
def build_faiss_index(embeddings):
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return index

# Generate query-specific summaries using full RAG
def generate_with_rag(query, retrieved_contexts):
    combined_input = f"query: {query} context: {' '.join(retrieved_contexts)}"
    inputs = t5_tokenizer.encode(combined_input, return_tensors="pt", max_length=512, truncation=True)
    outputs = t5_model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    return t5_tokenizer.decode(outputs[0], skip_special_tokens=True)

# Summarize a single paper
def summarize_paper(abstract):
    inputs = t5_tokenizer.encode(f"summarize: {abstract}", return_tensors="pt", truncation=True, max_length=512)
    outputs = t5_model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    return t5_tokenizer.decode(outputs[0], skip_special_tokens=True)

def normalize_distances(distances):
    min_distance = distances.min()
    max_distance = distances.max()
    return (distances - min_distance) / (max_distance - min_distance)

# Retrieve papers using FAISS
def retrieve_papers(query, faiss_index, paper_texts, k=5):
    query_embedding = get_paper_embeddings([query])  # Generate query embedding
    distances, indices = faiss_index.search(query_embedding, k)  # Search FAISS index

    # Normalize distances
    distances_normalized = normalize_distances(distances[0])
    
    retrieved_papers = []
    for i, index in enumerate(indices[0]):
        if index < len(papers):  # Ensure valid index
            distance_normalized = distances_normalized[i]
            relevance_level = (
                "Highly Relevant" if distance_normalized < 0.33 else
                "Moderately Relevant" if distance_normalized < 0.66 else
                "Not Relevant"
            )
            # Summarize each retrieved paper
            paper_summary = summarize_paper(papers.iloc[index]["abstract"])
            retrieved_papers.append({
                "id": papers.iloc[index]["id"],
                "title": papers.iloc[index]["title"],
                "abstract": papers.iloc[index]["abstract"],
                "summary": paper_summary,
                "distance": distance_normalized,
                "relevance": relevance_level
            })
    return retrieved_papers

# Visualization: Relevance Distribution
def plot_relevance_distribution(retrieved_papers):
    relevance_counts = {
        "Highly Relevant": sum(1 for p in retrieved_papers if p['relevance'] == "Highly Relevant"),
        "Moderately Relevant": sum(1 for p in retrieved_papers if p['relevance'] == "Moderately Relevant"),
        "Not Relevant": sum(1 for p in retrieved_papers if p['relevance'] == "Not Relevant"),
    }
    plt.bar(relevance_counts.keys(), relevance_counts.values())
    plt.title('Relevance Distribution')
    plt.xlabel('Relevance Level')
    plt.ylabel('Number of Papers')
    plt.show()

# Main workflow
if __name__ == "__main__":
    # Generate paper embeddings
    paper_embeddings = get_paper_embeddings(papers['abstract_processed'])

    # Build FAISS index
    faiss_index = build_faiss_index(paper_embeddings)

    # Example query
    query = "Sequential plan recognition in AI"
    
    # Retrieve top 5 papers
    retrieved_papers = retrieve_papers(query, faiss_index, papers['abstract_processed'], k=5)

    # Combine retrieved contexts
    retrieved_contexts = [paper['abstract'] for paper in retrieved_papers]

    # Generate summary using RAG
    final_summary = generate_with_rag(query, retrieved_contexts)
    
    # Display retrieved papers with summaries
    print("Retrieved Papers:")
    for paper in retrieved_papers:
        print(f"Paper ID: {paper['id']}")
        print(f"Title: {paper['title']}")
        print(f"Abstract: {paper['abstract']}")
        print(f"Summary: {paper['summary']}")
        print(f"Relevance: {paper['relevance']}")
        print(f"Normalized Distance: {paper['distance']}")
        print("\n")
    
    # Display final query-aware summary
    print("Final RAG-Based Summary:")
    print(final_summary)
    
    # Plot relevance distribution
    plot_relevance_distribution(retrieved_papers)


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
