<a href="https://colab.research.google.com/github/apoorvapu/data_science/blob/main/RAG-literatureReview.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Generative AI in Research: using Large Language Models (LLMs) to enhance and streamline the academic literature review process.

leverage RAG Techniques for summarizing papers, identifying connections across papers (authors, references, methods), uncovering key themes in them.

Download 2 papers (related to diffusion model) and convert them to .txt files in a directory named "data". Use these .txt files as input papers and evaluate if the RAG technique is giving good results.

In [2]:
!pip install pymupdf requests

Collecting pymupdf
  Downloading pymupdf-1.25.5-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (3.4 kB)
Downloading pymupdf-1.25.5-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (20.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.0/20.0 MB[0m [31m74.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pymupdf
Successfully installed pymupdf-1.25.5


### download any 2 papers of diffusion model and convert their pdf to .txt files

In [3]:
!rm -r data

rm: cannot remove 'data': No such file or directory


In [4]:
import os
import requests
import fitz  # PyMuPDF

# Sample arXiv paper IDs related to diffusion models
arxiv_ids = [
    "2006.11239",  # Denoising Diffusion Probabilistic Models
    "2105.05233",  # Improved Denoising Diffusion Probabilistic Models
]

def download_pdf(arxiv_id, output_folder):
    url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
    pdf_path = os.path.join(output_folder, f"{arxiv_id}.pdf")
    response = requests.get(url)
    with open(pdf_path, "wb") as f:
        f.write(response.content)
    print(f"Downloaded {arxiv_id}")
    return pdf_path

def pdf_to_text(pdf_path, txt_path):
    doc = fitz.open(pdf_path)
    text = ""
    for page in doc:
        text += page.get_text()
    with open(txt_path, "w", encoding="utf-8") as f:
        f.write(text)
    print(f"Converted to text: {txt_path}")

def main():
    data_dir = "data"
    os.makedirs(data_dir, exist_ok=True)

    for arxiv_id in arxiv_ids:
        pdf_path = download_pdf(arxiv_id, data_dir)
        txt_path = os.path.join(data_dir, f"{arxiv_id}.txt")
        pdf_to_text(pdf_path, txt_path)

if __name__ == "__main__":
    main()


Downloaded 2006.11239
Converted to text: data/2006.11239.txt
Downloaded 2105.05233
Converted to text: data/2105.05233.txt


In [5]:
!rm data/*.pdf

In [6]:
# Install dependencies
!pip install langchain langchain_community faiss-cpu sentence-transformers transformers networkx matplotlib spacy
!python -m spacy download en_core_web_sm

Collecting langchain_community
  Downloading langchain_community-0.3.22-py3-none-any.whl.metadata (2.4 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Collecting langchain-core<1.0.0,>=0.3.51 (from langchain)
  Downloading langchain_core-0.3.55-py3-none-any.whl.metadata (5.9 kB)
Collecting langchain
  Downloading langchain-0.3.24-py3-none-any.whl.metadata (7.8 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain_community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain_community)
  Downloading pydantic_settings-2.9.1-py3-none-any.whl.metadata (3.8 kB)
Collecting httpx-sse<1.0.0,>=0.4.0 (from langchain_community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain_community)
  Downloading marshmallow-3.26.1-py3-none-any.whl.metadata (7.3 kB

In [42]:
import os, glob
import gc
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import networkx as nx
import matplotlib.pyplot as plt
import spacy
import re
gc.collect()
import torch
torch.cuda.empty_cache()

def clean_text(text):
    import re
    # Remove inline citations like [14], [14, 27]
    text = re.sub(r"\[[0-9,\s]+\]", "", text)
    # Remove URLs
    text = re.sub(r"http\S+|www\.\S+", "", text)
    # Remove LaTeX math expressions
    text = re.sub(r"\$.*?\$", "", text)
    # Remove repeated words
    text = re.sub(r"\b(\w+)( \1\b)+", r"\1", text)
    # Remove special characters
    text = re.sub(r"[^a-zA-Z0-9.,;:?!\s]", "", text)
    # Remove excessive whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text


# Load a document and return its content
def load_document(file_path):
    loader = TextLoader(file_path)
    docs = loader.load()
    for doc in docs:
        doc.page_content = clean_text(doc.page_content)  # Apply cleaning here
    return docs

# Split the document into chunks ensuring each chunk is under the token limit
def split_document(docs, chunk_size=1500, chunk_overlap=100):
    # Make sure docs is always a list
    if not isinstance(docs, list):
        docs = [docs]

    splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    return splitter.split_documents(docs)

# Vector store (use sentence embeddings)
def create_faiss_index(docs):
    embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
    return FAISS.from_documents(docs, embeddings)

# Load HuggingFace LLM
def load_llm():
    model_id = "google/flan-t5-xl"  # Better GPU utilization, faster than flan-t5-large
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to("cuda")  # Move model to GPU
    #pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)
    pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0, max_new_tokens=1024)

    return HuggingFacePipeline(pipeline=pipe)


# Build RAG chain
def build_qa_chain(llm, vectorstore):
    return RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever(), chain_type="stuff")

# Analyze a single document in chunks and store results with a chunk limit
def analyze_document(llm, docs, batch_size=1, max_chunks=10):
    results = {"Summary": [], "Connections": [], "Themes": []}
    chunks = split_document(docs)
    chunks = chunks[:max_chunks]

    for i in range(0, len(chunks), batch_size):
        batch = chunks[i:i + batch_size]

        for label, q in [
            ("Summary", "Summarize the following scientific paper text in concise bullet points. Include: Main contribution, Dataset used, method, Evaluation metrics and Key results."),
            ("Connections", "You are reading several research papers. Based on the passage below, what connections or similarities can you identify with other papers on diffusion models? Mention common techniques, models, datasets, or evaluation strategies."),
            ("Themes", "What are the central *research themes* in the following paper? List them as concise topics.")
        ]:
            prompts = [f"{q}\n\n---\n{chunk.page_content.strip()}" for chunk in batch]
            try:
                responses = llm.pipeline(prompts)
                for response in responses:
                    text = response['generated_text'].strip()
                    results[label].append(text)
                    print(f"\n🔍 {label}:\n{text}")
            except Exception as e:
                print(f"Error during {label} batch: {e}")
                results[label].extend(["Error"] * len(batch))

        del batch
        torch.cuda.empty_cache()
        gc.collect()

    return results


# Main pipeline
def process_all_documents(data_dir="data", max_chunks=10):
    files = glob.glob(os.path.join(data_dir, "*.txt"))
    results = {"Summary": [], "Connections": [], "Themes": []}

    llm = load_llm()

    for file_path in files:
        print(f"\n📄 Processing: {file_path}")
        # Load and clean the document
        doc = load_document(file_path)

        # Process the document in chunks
        doc_results = analyze_document(llm, docs=doc, max_chunks=max_chunks)

        # Collect results
        for label in results:
            results[label].extend(doc_results.get(label, []))

        # Free up GPU memory after processing each document
        del doc
        torch.cuda.empty_cache()
        gc.collect()  # Run garbage collection to free memory

    # Clean-up results (e.g., remove empty strings or redundant entries)
    for label in results:
        flat = [str(item).strip() for sublist in results[label] for item in (sublist if isinstance(sublist, list) else [sublist])]
        cleaned = [s for s in flat if s and s.lower() != "error"]
        combined_text = "\n".join(cleaned)

        if label == "Themes":
          final_theme = summarize_combined_output(llm, combined_text, label) # Remove duplicates
          themes = list(dict.fromkeys([line.strip() for line in final_theme.split("\n") if line.strip()]))
          results[label] = "\n".join(themes)
        else:
          results[label] = summarize_combined_output(llm, combined_text, label)


    return results

# Summarize combined chunk outputs into a single final output
def summarize_combined_output(llm, text, label):
    prompts = {
        "Summary": "You are a helpful scientific assistant. Based on the following combined summaries of a scientific paper, provide a single concise overall summary. Mention the main contribution, dataset used, method, evaluation metrics, and key results.",
        "Connections": "You are reading several research papers. Based on the following notes, summarize the common connections or similarities across papers, focusing on shared techniques, datasets, or models.",
        "Themes": "Summarize the central research themes mentioned in the combined text below. List them as concise, broad topics."
    }
    prompt = f"{prompts[label]}\n\n{text}"
    try:
        response = llm.pipeline(prompt)
        return response[0]["generated_text"].strip()
    except Exception as e:
        print(f"Error generating final {label}: {e}")
        return "Error"


# Main execution
results = process_all_documents(data_dir="data", max_chunks=10)  # Limit the number of chunks for faster processing
# Final outputs
final_connections = results.get("Connections", "")
final_themes = results.get("Themes", "")
final_summary = results.get("Summary", "")

# Display summaries
print("\nFinal Summary:")
print(final_summary)
print("\nFinal Connections Across Papers:")
print(final_connections)
print("\nFinal Themes:")
print(final_themes)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0



📄 Processing: data/2105.05233.txt

🔍 Summary:
We show that diffusion models can achieve image sample quality superior to the current stateoftheart generative models. We achieve this on unconditional im age synthesis by nding a better architecture through a series of ablations. For conditional image synthesis, we further improve sample quality with classier guidance: a simple, computeefcient method for trading off diversity for delity using gradients from a classier. We achieve an FID of 2.97 on ImageNet 128128, 4.59 on ImageNet 256256, and 7.72 on ImageNet 512512, and we match BigGANdeep even with as few as 25 forward passes per sample. Finally, we nd that classier guidance combines well with upsampling diffusion models, further improving FID to 3.94 on ImageNet 256256 and 3.85 on ImageNet 512512. We release our code at 1 Introduction

🔍 Connections:
ImageNet 512512 images, they are not yet capable of producing highquality images.

🔍 Themes:
Image Synthesis

🔍 Summary:
Main contributi

Token indices sequence length is longer than the specified maximum sequence length for this model (570 > 512). Running this sequence through the model will result in indexing errors



🔍 Summary:
--- Diffusion models are a generalization of the forward process model. They are characterized by the following properties: --- The forward process is xed to a Markov chain that gradually adds Gaussian noise to the data according to a variance schedule 1, . . . , T: qx1:T x0 : T Y t1 qxtxt1, qxtxt1 : Nxt; p 1 txt1, tI 2 Training is performed by optimizing the usual variational bound on negative log likelihood: E log px0 Eq log px0:T qx1:T x0 Eq log pxT X t1 log pxt1xt qxtxt1 : L 3 The forward process variances t can be learned by reparameterization or held constant as hyperparameters, and expressiveness of the reverse process is ensured in part by the choice of Gaussian conditionals in pxtxt; tx0, 1 tI 4 2 Efficient training is therefore possible by optimizing random terms of L with stochastic gradient descent. Further improvements come from variance reduction by rewriting L 3 as: Eq DKLqxT x0 pxT z LT X t1 DKLqxt1xt, x0 pxt1xt z Lt1 log px0x1 z L0

🔍 Connections:
KL diverg

3. Metadata Extraction & Knowledge Graph
The system uses an LLM to extract structured metadata from each paper, including:

Title, authors, publication year
Keywords and methodologies
Abstract
Cited papers

This metadata forms the basis of a knowledge graph built with NetworkX, where:

Papers, authors, journals, methodologies, and keywords are nodes
Relationships (authored, cites, uses) are edges

This graph representation allows visualization of relationships between papers and helps identify key authors, methodologies, and research themes across multiple papers.

4. Analysis Capabilities

Paper Summarization
For each paper, the system generates a comprehensive summary covering:

Main research questions
Methodology
Key findings
Limitations and future work

This helps quickly understand individual papers without reading the full text.


Paper Comparison
The system can compare multiple papers to identify:

Common themes
Differences in methodology
Complementary or contradictory findings
Research gaps


Theme Extraction
The system analyzes all papers to identify common themes, showing which papers address each theme and how they contribute to it.


Question Answering
When you ask a question:

The system retrieves the most relevant chunks from across all papers
It provides the chunks as context to the LLM
The LLM synthesizes an answer based on this context, citing the relevant papers



In [None]:

class AcademicRAG:
    def __init__(self, api_key: str, model_name: str = "gpt-4o"):
        """
        Initialize the Academic RAG system.

        Args:
            api_key: OpenAI API key
            model_name: LLM model to use
        """
        self.api_key = api_key
        openai.api_key = api_key

        # Initialize embedding model
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

        # Initialize LLM
        self.llm = ChatOpenAI(
            model_name=model_name,
            openai_api_key=api_key,
            temperature=0.1
        )

        # Initialize vector store
        self.vector_store = None

        # Initialize document storage
        self.documents = {}
        self.chunks = {}
        self.metadata = {}
        self.knowledge_graph = nx.DiGraph()

    def load_pdf(self, file_path: str) -> str:
        """
        Extract text from PDF.

        Args:
            file_path: Path to PDF file

        Returns:
            Full text content of the PDF
        """
        doc_id = Path(file_path).stem
        pdf_document = fitz.open(file_path)
        text = ""

        for page_num in range(len(pdf_document)):
            page = pdf_document[page_num]
            text += page.get_text()

        # Store the full document
        self.documents[doc_id] = text

        print(f"Loaded document: {doc_id} ({len(text)} characters)")
        return text

    def load_multiple_pdfs(self, directory: str) -> Dict[str, str]:
        """
        Load multiple PDFs from a directory.

        Args:
            directory: Directory containing PDF files

        Returns:
            Dictionary mapping document IDs to full text content
        """
        pdf_files = [f for f in os.listdir(directory) if f.lower().endswith('.pdf')]

        for pdf_file in pdf_files:
            file_path = os.path.join(directory, pdf_file)
            self.load_pdf(file_path)

        return self.documents

    def chunk_document(self, doc_id: str, chunk_size: int = 1000, chunk_overlap: int = 200) -> List[Document]:
        """
        Split document into chunks.

        Args:
            doc_id: Document identifier
            chunk_size: Size of each chunk
            chunk_overlap: Overlap between chunks

        Returns:
            List of document chunks with metadata
        """
        if doc_id not in self.documents:
            raise ValueError(f"Document {doc_id} not found. Load it first.")

        text = self.documents[doc_id]

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len,
        )

        chunks = text_splitter.create_documents([text])

        # Add metadata to chunks
        for i, chunk in enumerate(chunks):
            chunk.metadata = {
                "doc_id": doc_id,
                "chunk_id": i,
                "source": doc_id
            }

        self.chunks[doc_id] = chunks
        print(f"Split {doc_id} into {len(chunks)} chunks")

        return chunks

    def chunk_all_documents(self, chunk_size: int = 1000, chunk_overlap: int = 200) -> Dict[str, List[Document]]:
        """
        Split all loaded documents into chunks.

        Args:
            chunk_size: Size of each chunk
            chunk_overlap: Overlap between chunks

        Returns:
            Dictionary mapping document IDs to lists of chunks
        """
        for doc_id in self.documents:
            self.chunk_document(doc_id, chunk_size, chunk_overlap)

        return self.chunks

    def build_vector_store(self):
        """
        Build vector store from all document chunks.
        """
        all_chunks = []
        for doc_chunks in self.chunks.values():
            all_chunks.extend(doc_chunks)

        # Create embedding class that adapts SentenceTransformer to LangChain
        class STEmbeddings(Embeddings):
            def __init__(self, model):
                self.model = model

            def embed_documents(self, texts):
                return self.model.encode(texts).tolist()

            def embed_query(self, text):
                return self.model.encode(text).tolist()

        embeddings = STEmbeddings(self.embedding_model)

        # Create FAISS vector store
        self.vector_store = FAISS.from_documents(all_chunks, embeddings)
        print(f"Built vector store with {len(all_chunks)} chunks")

    def extract_metadata(self):
        """
        Extract metadata from papers using LLM.
        """
        schema = {
            "properties": {
                "title": {"type": "string"},
                "authors": {"type": "array", "items": {"type": "string"}},
                "publication_year": {"type": "integer"},
                "journal": {"type": "string"},
                "abstract": {"type": "string"},
                "keywords": {"type": "array", "items": {"type": "string"}},
                "methodology": {"type": "array", "items": {"type": "string"}},
                "cited_papers": {"type": "array", "items": {"type": "string"}},
            },
            "required": ["title", "authors"],
        }

        extraction_chain = create_extraction_chain(schema, self.llm)

        for doc_id, text in self.documents.items():
            # Use the first chunk plus any detected abstract section for metadata extraction
            first_chunk = self.chunks[doc_id][0].page_content
            abstract_match = re.search(r"(?i)abstract(.*?)(?:introduction|keywords|background|related work)", text, re.DOTALL)
            abstract_text = abstract_match.group(1) if abstract_match else ""

            input_text = first_chunk + "\n\n" + abstract_text

            # Truncate if too long
            if len(input_text) > 5000:
                input_text = input_text[:5000]

            try:
                result = extraction_chain.invoke({"input": input_text})
                metadata = result["extracted"][0] if result["extracted"] else {}
                self.metadata[doc_id] = metadata
                print(f"Extracted metadata for {doc_id}: {metadata['title'] if 'title' in metadata else 'Unknown'}")
            except Exception as e:
                print(f"Error extracting metadata for {doc_id}: {e}")
                self.metadata[doc_id] = {"title": doc_id, "authors": ["Unknown"]}

    def build_knowledge_graph(self):
        """
        Build knowledge graph from extracted paper metadata.
        """
        G = nx.DiGraph()

        # Add papers as nodes
        for doc_id, metadata in self.metadata.items():
            title = metadata.get("title", doc_id)
            authors = metadata.get("authors", ["Unknown"])
            year = metadata.get("publication_year", "Unknown")
            journal = metadata.get("journal", "Unknown")

            # Add paper node
            G.add_node(title,
                       type="paper",
                       authors=authors,
                       year=year,
                       journal=journal,
                       doc_id=doc_id)

            # Add author nodes and connections
            for author in authors:
                if not G.has_node(author):
                    G.add_node(author, type="author")
                G.add_edge(author, title, type="authored")

            # Add journal node and connection
            if journal != "Unknown":
                if not G.has_node(journal):
                    G.add_node(journal, type="journal")
                G.add_edge(title, journal, type="published_in")

            # Add methodology nodes
            methods = metadata.get("methodology", [])
            for method in methods:
                if not G.has_node(method):
                    G.add_node(method, type="methodology")
                G.add_edge(title, method, type="uses")

            # Add keyword nodes
            keywords = metadata.get("keywords", [])
            for keyword in keywords:
                if not G.has_node(keyword):
                    G.add_node(keyword, type="keyword")
                G.add_edge(title, keyword, type="contains")

        # Add citation links
        for doc_id, metadata in self.metadata.items():
            source_title = metadata.get("title", doc_id)
            cited_papers = metadata.get("cited_papers", [])

            for cited_paper in cited_papers:
                # Try to match with existing paper nodes
                matching_papers = [node for node in G.nodes() if G.nodes[node].get("type") == "paper" and cited_paper.lower() in node.lower()]

                if matching_papers:
                    G.add_edge(source_title, matching_papers[0], type="cites")
                else:
                    # Add as a new node if not found
                    G.add_node(cited_paper, type="external_paper")
                    G.add_edge(source_title, cited_paper, type="cites")

        self.knowledge_graph = G
        print(f"Built knowledge graph with {len(G.nodes())} nodes and {len(G.edges())} edges")

        return G

    def visualize_knowledge_graph(self, output_file: str = "knowledge_graph.html"):
        """
        Visualize the knowledge graph.

        Args:
            output_file: Path to save the HTML visualization
        """
        G = self.knowledge_graph

        # Create pyvis network
        net = Network(height="750px", width="100%", notebook=False, directed=True)

        # Node colors by type
        color_map = {
            "paper": "#4285F4",  # blue
            "author": "#EA4335",  # red
            "journal": "#FBBC05",  # yellow
            "methodology": "#34A853",  # green
            "keyword": "#800080",  # purple
            "external_paper": "#A0A0A0"  # gray
        }

        # Add nodes
        for node in G.nodes():
            node_type = G.nodes[node].get("type", "unknown")
            title = f"Type: {node_type}"

            if node_type == "paper":
                authors = ", ".join(G.nodes[node].get("authors", ["Unknown"]))
                year = G.nodes[node].get("year", "Unknown")
                journal = G.nodes[node].get("journal", "Unknown")
                title = f"Paper: {node}\nAuthors: {authors}\nYear: {year}\nJournal: {journal}"

            net.add_node(node,
                         title=title,
                         color=color_map.get(node_type, "#000000"),
                         size=20 if node_type == "paper" else 10)

        # Add edges
        edge_colors = {
            "authored": "#EA4335",
            "published_in": "#FBBC05",
            "uses": "#34A853",
            "contains": "#800080",
            "cites": "#4285F4"
        }

        for source, target, data in G.edges(data=True):
            edge_type = data.get("type", "unknown")
            net.add_edge(source, target,
                         title=edge_type,
                         color=edge_colors.get(edge_type, "#000000"),
                         arrows="to")

        # Set physics layout
        net.force_atlas_2based(gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.08)
        net.toggle_physics(True)

        # Save visualization
        net.show(output_file)
        print(f"Knowledge graph visualization saved to {output_file}")

    def retrieve_similar_chunks(self, query: str, k: int = 5) -> List[Tuple[Document, float]]:
        """
        Retrieve similar chunks for a query.

        Args:
            query: Query text
            k: Number of chunks to retrieve

        Returns:
            List of (document, similarity) tuples
        """
        if not self.vector_store:
            raise ValueError("Vector store not built. Call build_vector_store() first.")

        results = self.vector_store.similarity_search_with_score(query, k=k)

        return results

    def generate_paper_summary(self, doc_id: str) -> str:
        """
        Generate a summary for a specific paper.

        Args:
            doc_id: Document identifier

        Returns:
            Summary text
        """
        if doc_id not in self.documents:
            raise ValueError(f"Document {doc_id} not found.")

        prompt = PromptTemplate(
            input_variables=["text"],
            template="""
            You are an academic research assistant. Summarize the following academic paper excerpt:

            {text}

            Provide a comprehensive summary that includes:
            1. The main research question/problem
            2. The methodology used
            3. Key findings and contributions
            4. Limitations and future work

            Summary:
            """
        )

        # Use the first few chunks (up to 10,000 characters)
        chunks = self.chunks.get(doc_id, [])
        combined_text = ""
        for chunk in chunks:
            combined_text += chunk.page_content + "\n\n"
            if len(combined_text) > 10000:
                break
        combined_text = combined_text[:10000]

        chain = LLMChain(llm=self.llm, prompt=prompt)
        summary = chain.invoke({"text": combined_text})

        return summary["text"]

    def compare_papers(self, doc_ids: List[str]) -> str:
        """
        Compare multiple papers.

        Args:
            doc_ids: List of document identifiers

        Returns:
            Comparison text
        """
        # Validate all doc_ids
        for doc_id in doc_ids:
            if doc_id not in self.documents:
                raise ValueError(f"Document {doc_id} not found.")

        # Get metadata for each paper
        paper_info = []
        for doc_id in doc_ids:
            metadata = self.metadata.get(doc_id, {})
            title = metadata.get("title", doc_id)
            authors = ", ".join(metadata.get("authors", ["Unknown"]))
            year = metadata.get("publication_year", "Unknown")

            # Get a brief summary by using first chunk
            first_chunk = self.chunks[doc_id][0].page_content
            paper_info.append(f"Title: {title}\nAuthors: {authors}\nYear: {year}\nExcerpt: {first_chunk[:500]}...")

        # Generate comparison using LLM
        prompt = PromptTemplate(
            input_variables=["papers"],
            template="""
            You are an academic research assistant. Compare and contrast the following academic papers:

            {papers}

            Provide a detailed comparison including:
            1. Common themes and research questions
            2. Differences in methodology
            3. Complementary or contradictory findings
            4. How they build upon each other's work
            5. Potential gaps or areas for future research

            Comparison:
            """
        )

        chain = LLMChain(llm=self.llm, prompt=prompt)
        comparison = chain.invoke({"papers": "\n\n".join(paper_info)})

        return comparison["text"]

    def answer_question(self, question: str, k: int = 5) -> str:
        """
        Answer a question based on the papers.

        Args:
            question: Question text
            k: Number of chunks to retrieve

        Returns:
            Answer text
        """
        # Retrieve relevant chunks
        relevant_chunks = self.retrieve_similar_chunks(question, k=k)

        # Prepare context from chunks
        context = []
        for chunk, score in relevant_chunks:
            doc_id = chunk.metadata["doc_id"]
            metadata = self.metadata.get(doc_id, {})
            title = metadata.get("title", doc_id)
            authors = ", ".join(metadata.get("authors", ["Unknown"]))
            context.append(f"From '{title}' by {authors}:\n{chunk.page_content}")

        # Generate answer using LLM
        prompt = PromptTemplate(
            input_variables=["question", "context"],
            template="""
            You are an academic research assistant with expertise in synthesizing information from academic papers.

            Question: {question}

            Here are relevant excerpts from academic papers:

            {context}

            Please provide a comprehensive answer to the question based on the given context.
            Cite the papers you're referencing in your response.
            If the information provided is insufficient to answer the question, clearly state what's missing.

            Answer:
            """
        )

        chain = LLMChain(llm=self.llm, prompt=prompt)
        answer = chain.invoke({"question": question, "context": "\n\n".join(context)})

        return answer["text"]

    def extract_themes(self) -> Dict:
        """
        Extract common themes across papers.

        Returns:
            Dictionary of themes with supporting evidence
        """
        # Prepare paper information
        papers_info = []
        for doc_id, metadata in self.metadata.items():
            title = metadata.get("title", doc_id)
            abstract = metadata.get("abstract", "")
            keywords = metadata.get("keywords", [])

            if not abstract:
                # Try to find abstract in the document
                doc_text = self.documents[doc_id]
                abstract_match = re.search(r"(?i)abstract(.*?)(?:introduction|keywords|background|related work)", doc_text, re.DOTALL)
                abstract = abstract_match.group(1) if abstract_match else "No abstract found"

            papers_info.append(f"Title: {title}\nAbstract: {abstract}\nKeywords: {', '.join(keywords)}")

        # Generate themes using LLM
        prompt = PromptTemplate(
            input_variables=["papers"],
            template="""
            You are an academic research assistant. Analyze the following academic papers and extract common themes:

            {papers}

            Identify 3-5 major themes across these papers. For each theme:
            1. Provide a clear name and description
            2. List which papers address this theme
            3. Describe how each paper contributes to or approaches this theme
            4. Note any contradictions or complementary findings within the theme

            Format your response as a JSON structure where each theme is a key, and the value contains the description and paper relationships.

            Themes:
            """
        )

        chain = LLMChain(llm=self.llm, prompt=prompt)
        themes_result = chain.invoke({"papers": "\n\n".join(papers_info)})

        # Extract themes from the text response
        themes_text = themes_result["text"]

        # Use LLM to convert to structured format
        extraction_prompt = PromptTemplate(
            input_variables=["themes_text"],
            template="""
            Convert the following themes analysis into a proper JSON structure:

            {themes_text}

            JSON format:
            ```
            {{
                "Theme 1 Name": {{
                    "description": "Description of theme 1",
                    "papers": [
                        {{
                            "title": "Paper Title 1",
                            "contribution": "How Paper 1 contributes to this theme"
                        }},
                        ...
                    ]
                }},
                ...
            }}
            ```

            Only respond with the valid JSON, nothing else.
            """
        )

        chain = LLMChain(llm=self.llm, prompt=extraction_prompt)
        json_result = chain.invoke({"themes_text": themes_text})

        try:
            # Extract JSON from the response (it might be wrapped in backticks)
            json_text = re.search(r'```json\n(.*?)\n```', json_result["text"], re.DOTALL)
            if json_text:
                json_str = json_text.group(1)
            else:
                json_str = json_result["text"]

            # Clean up any extra backticks
            json_str = json_str.replace('```', '').strip()

            import json
            themes = json.loads(json_str)
            return themes
        except Exception as e:
            print(f"Error parsing themes JSON: {e}")
            return {"error": "Could not parse themes", "raw_text": themes_text}

    def run_full_analysis(self, pdf_directory: str, output_dir: str = "./output"):
        """
        Run full analysis pipeline on a directory of PDFs.

        Args:
            pdf_directory: Directory containing PDF files
            output_dir: Directory to save output files
        """
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)

        # 1. Load PDFs
        print("Loading PDFs...")
        self.load_multiple_pdfs(pdf_directory)

        # 2. Chunk documents
        print("Chunking documents...")
        self.chunk_all_documents()

        # 3. Build vector store
        print("Building vector store...")
        self.build_vector_store()

        # 4. Extract metadata
        print("Extracting metadata...")
        self.extract_metadata()

        # 5. Build knowledge graph
        print("Building knowledge graph...")
        self.build_knowledge_graph()

        # 6. Visualize knowledge graph
        print("Visualizing knowledge graph...")
        self.visualize_knowledge_graph(os.path.join(output_dir, "knowledge_graph.html"))

        # 7. Generate summaries for each paper
        print("Generating summaries...")
        summaries = {}
        for doc_id in self.documents:
            summaries[doc_id] = self.generate_paper_summary(doc_id)

            # Save individual summaries
            with open(os.path.join(output_dir, f"{doc_id}_summary.txt"), "w") as f:
                f.write(summaries[doc_id])

        # 8. Compare all papers
        print("Comparing papers...")
        comparison = self.compare_papers(list(self.documents.keys()))
        with open(os.path.join(output_dir, "paper_comparison.txt"), "w") as f:
            f.write(comparison)

        # 9. Extract themes
        print("Extracting themes...")
        themes = self.extract_themes()

        import json
        with open(os.path.join(output_dir, "themes.json"), "w") as f:
            json.dump(themes, f, indent=2)

        print(f"Analysis complete! Results saved in {output_dir}")

        return {
            "summaries": summaries,
            "comparison": comparison,
            "themes": themes
        }



In [20]:
# Example usage
if __name__ == "__main__":
    # Replace with your OpenAI API key
    OPENAI_API_KEY = "your-api-key-here"

    # Initialize the RAG system
    rag = AcademicRAG(api_key=OPENAI_API_KEY)

    # Run full analysis
    results = rag.run_full_analysis("./papers")

    # Ask questions
    question = "What are the main methodologies used across these papers and how do they compare?"
    answer = rag.answer_question(question)
    print(f"Q: {question}\nA: {answer}")

    # Get more paper relationships
    print("Paper relationships:")
    for source, target, data in rag.knowledge_graph.edges(data=True):
        if data.get("type") == "cites":
            print(f"  {source} cites {target}")

NameError: name 'AcademicRAG' is not defined