# Literature Review Agent

This notebook implements the literature review component using BERT and SPECTER models for semantic analysis of research papers.

In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import torch
from transformers import AutoTokenizer, AutoModel
from utils.config import setup_logging, MODEL_CONFIGS, RESEARCH_PAPERS_DIR, OUTPUTS_DIR
from utils.helpers import load_research_papers, save_json
import numpy as np
from typing import List, Dict
from tqdm import tqdm

In [None]:
# Setup logging
logger = setup_logging('literature_review')

# Load models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# BERT for text analysis
bert_name = MODEL_CONFIGS['literature_review']['bert_model']
bert_tokenizer = AutoTokenizer.from_pretrained(bert_name)
bert_model = AutoModel.from_pretrained(bert_name).to(device)

# SPECTER for citation analysis
specter_name = MODEL_CONFIGS['literature_review']['specter_model']
specter_tokenizer = AutoTokenizer.from_pretrained(specter_name)
specter_model = AutoModel.from_pretrained(specter_name).to(device)

In [None]:
def extract_key_insights(text: str, max_length: int = 512) -> Dict:
    """Extract key insights from text using BERT."""
    # Tokenize and encode text
    inputs = bert_tokenizer(text, return_tensors='pt', max_length=max_length,
                           truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Get BERT embeddings
    with torch.no_grad():
        outputs = bert_model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)
    
    # Extract key sentences (simplified for demonstration)
    sentences = text.split('.')
    sentence_scores = []
    
    for sentence in sentences:
        if len(sentence.strip()) > 0:
            inputs = bert_tokenizer(sentence, return_tensors='pt', max_length=128,
                                   truncation=True, padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = bert_model(**inputs)
                sentence_emb = outputs.last_hidden_state.mean(dim=1)
                score = torch.cosine_similarity(embeddings, sentence_emb)
                sentence_scores.append((sentence.strip(), score.item()))
    
    # Sort sentences by importance score
    sentence_scores.sort(key=lambda x: x[1], reverse=True)
    
    return {
        'key_points': [s[0] for s in sentence_scores[:3]],
        'importance_scores': [s[1] for s in sentence_scores[:3]]
    }

In [None]:
def analyze_papers() -> Dict:
    """Analyze all papers in the research directory."""
    logger.info('Starting paper analysis')
    
    # Load papers
    papers = load_research_papers(RESEARCH_PAPERS_DIR)
    if not papers:
        logger.warning('No papers found in research directory')
        return {}
    
    analysis_results = {}
    for paper in tqdm(papers, desc='Analyzing papers'):
        try:
            # Extract insights
            insights = extract_key_insights(paper['content'])
            
            # Store results
            analysis_results[paper['title']] = {
                'key_insights': insights['key_points'],
                'importance_scores': insights['importance_scores']
            }
            
            logger.info(f'Successfully analyzed paper: {paper["title"]}')
            
        except Exception as e:
            logger.error(f'Error analyzing paper {paper["title"]}: {str(e)}')
    
    # Save results
    output_path = OUTPUTS_DIR / 'literature_analysis.json'
    save_json(analysis_results, output_path)
    logger.info(f'Saved analysis results to {output_path}')
    
    return analysis_results

In [None]:
if __name__ == "__main__":
    # Run analysis
    results = analyze_papers()
    
    # Print summary
    print(f"Analyzed {len(results)} papers")
    for title, analysis in results.items():
        print(f"\n{title}:")
        for i, (insight, score) in enumerate(zip(analysis['key_insights'], 
                                                analysis['importance_scores']), 1):
            print(f"{i}. {insight} (score: {score:.3f})")