### LLM-assisted enrichment (Including Add/Complete fields/keywords, scoring, Contribution Summary(Optional))

#### Complete fields of study

In [2]:
!pip install backoff

Collecting backoff
  Downloading backoff-2.2.1-py3-none-any.whl.metadata (14 kB)
Downloading backoff-2.2.1-py3-none-any.whl (15 kB)
Installing collected packages: backoff
Successfully installed backoff-2.2.1




This enhance_fields_of_study.py script enhances academic paper datasets by extracting relevant fields of study using DeepSeek's LLM API. It processes paper metadata (title and abstract) to generate 3-6 academic discipline labels per paper through structured prompts. The implementation features concurrent processing with rate limiting, exponential backoff for error handling, and checkpointing for resumable operation on large datasets. The enhanced data includes extracted fields and timestamp metadata for traceability.

In [None]:
# enhance_fields_of_study.py
"""
Enhanced paper fields extraction using DeepSeek API with concurrent processing.

This script processes academic papers from a JSONL file, calls the DeepSeek API
to extract relevant academic fields of study for each paper, and saves the 
enhanced data to a new JSONL file with progress tracking and error handling.
"""

import json
import requests
import time
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from tqdm import tqdm
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import random
from datetime import datetime
import backoff

# Configure logging for better monitoring
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('api_processing.log', encoding='utf-8'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# API Configuration
API_KEY = "sk-79990d599cd74bc0a56f6ca2f200a621"
API_URL = "https://api.deepseek.com/v1/chat/completions"
REQUEST_TIMEOUT = 30
MAX_WORKERS = 10  # Increased but with better rate limiting
BASE_DELAY = 1.0  # Base delay between requests
MAX_RETRIES = 5
BATCH_SIZE = 100  # Save progress every N papers

# Rate limiting and thread safety
rate_limit_lock = Lock()
last_request_time = 0
total_requests = 0
failed_requests = 0

def rate_limited_request():
    """Ensure minimum delay between API requests across all threads with jitter."""
    global last_request_time
    
    with rate_limit_lock:
        current_time = time.time()
        elapsed = current_time - last_request_time
        if elapsed < BASE_DELAY:
            sleep_time = BASE_DELAY - elapsed + random.uniform(0, 0.1)  # Add jitter
            time.sleep(sleep_time)
        last_request_time = time.time()

@backoff.on_exception(
    backoff.expo,
    (requests.exceptions.RequestException, KeyError, json.JSONDecodeError),
    max_tries=MAX_RETRIES,
    max_time=300
)
def call_deepseek_api_with_backoff(prompt: str) -> Optional[str]:
    """
    Call DeepSeek API with exponential backoff and better error handling.
    
    Args:
        prompt: The prompt to send to the API
    
    Returns:
        API response content as string, or None if all retries fail
    """
    global total_requests, failed_requests
    
    # Apply rate limiting
    rate_limited_request()
    
    total_requests += 1
    
    try:
        response = requests.post(
            API_URL,
            headers={
                "Authorization": f"Bearer {API_KEY}",
                "Content-Type": "application/json"
            },
            json={
                "model": "deepseek-chat",
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0.2,
                "max_tokens": 200
            },
            timeout=REQUEST_TIMEOUT
        )
        
        # Check for rate limiting
        if response.status_code == 429:
            retry_after = response.headers.get('Retry-After', 60)
            logger.warning(f"Rate limited. Waiting {retry_after} seconds...")
            time.sleep(float(retry_after))
            response.raise_for_status()
        
        response.raise_for_status()
        
        data = response.json()
        
        # Validate response structure
        if not isinstance(data, dict):
            raise ValueError(f"Response is not a dictionary: {type(data)}")
        
        if "choices" not in data or not data["choices"]:
            raise ValueError(f"No choices in response: {data}")
        
        message = data["choices"][0].get("message", {})
        if not message or "content" not in message:
            raise ValueError(f"No content in message: {message}")
        
        content = message["content"].strip()
        
        # Validate that content is not empty
        if not content:
            raise ValueError("Empty response content")
        
        # Basic validation - should start with [ for JSON array
        if not content.startswith('[') or not content.endswith(']'):
            logger.warning(f"Response doesn't look like JSON array: {content[:100]}...")
            # Try to extract JSON array if it's wrapped in other text
            import re
            json_match = re.search(r'\[.*\]', content, re.DOTALL)
            if json_match:
                content = json_match.group(0)
            else:
                raise ValueError(f"Response not a valid JSON array: {content[:100]}...")
        
        return content
        
    except requests.exceptions.Timeout:
        logger.warning("API request timeout")
        raise
    except requests.exceptions.HTTPError as e:
        failed_requests += 1
        if e.response.status_code >= 500:
            logger.error(f"Server error {e.response.status_code}: {e.response.text[:200]}")
        else:
            logger.error(f"HTTP error {e.response.status_code}: {e.response.text[:200]}")
        raise
    except Exception as e:
        failed_requests += 1
        logger.error(f"Unexpected API error: {str(e)[:200]}")
        raise

def call_deepseek_api(prompt: str, max_retries: int = MAX_RETRIES) -> Optional[str]:
    """
    Wrapper for API call with fallback to simpler retry logic.
    
    Args:
        prompt: The prompt to send to the API
        max_retries: Maximum number of retry attempts
    
    Returns:
        API response content as string, or None if all retries fail
    """
    try:
        return call_deepseek_api_with_backoff(prompt)
    except Exception as e:
        logger.error(f"All retries failed for API call: {str(e)[:200]}")
        return None

def extract_academic_fields(paper: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract academic fields of study from paper metadata using DeepSeek API.
    
    Args:
        paper: Dictionary containing paper metadata with 'title' and 'abstract'
    
    Returns:
        Dictionary with original paper data plus extracted fields
    """
    paper_id = paper.get("paper_id", paper.get("id", "unknown"))
    title = paper.get("title", "")[:200]  # Truncate for logging
    
    # Clean and prepare abstract
    abstract = paper.get("abstract", "")
    if abstract and len(abstract) > 1000:
        abstract = abstract[:1000] + "..."
    
    prompt = f"""
You are an academic classification expert.

Given the following paper information, infer the most appropriate academic fields of study.

Rules:
- Output ONLY a JSON array of strings.
- Each item must be a broad academic field (e.g., "Computer Vision", "Machine Learning").
- Include 3 to 6 fields only.
- Do NOT include model names or datasets.

Title:
{title}

Abstract:
{abstract}
"""
    fields = []
    
    try:
        api_response = call_deepseek_api(prompt)
        
        if api_response is None:
            logger.debug(f"No API response for paper {paper_id}")
        else:
            try:
                # Try to parse as JSON
                parsed = json.loads(api_response)
                
                if isinstance(parsed, list):
                    # Validate and clean each field
                    for field in parsed:
                        if isinstance(field, str):
                            cleaned = field.strip()
                            if cleaned and len(cleaned) < 100:  # Reasonable length check
                                fields.append(cleaned)
                
                # Limit to 6 fields
                fields = fields[:6]
                
            except json.JSONDecodeError as e:
                logger.warning(f"Failed to parse JSON for paper {paper_id}: {str(e)[:100]}")
                logger.debug(f"Raw response: {api_response[:500]}")
                
                # Fallback: try to extract fields from text response
                if api_response:
                    import re
                    # Look for field-like patterns
                    potential_fields = re.findall(r'"([^"]+)"', api_response)
                    if not potential_fields:
                        potential_fields = re.findall(r'[\w\s]+(?=,|\.|$)', api_response)
                    
                    fields = [f.strip() for f in potential_fields if len(f.strip()) > 3 and len(f.strip()) < 50]
                    fields = list(set(fields))[:6]  # Deduplicate and limit
    
    except Exception as e:
        logger.error(f"Unexpected error extracting fields for paper {paper_id}: {str(e)[:200]}")
    
    # Create a copy of the paper with extracted fields
    enhanced_paper = paper.copy()
    enhanced_paper["fields_of_study"] = fields
    enhanced_paper["fields_extraction_time"] = datetime.now().isoformat()
    
    return enhanced_paper

def load_papers(input_path: Path) -> List[Dict[str, Any]]:
    """
    Load papers from JSONL file.
    
    Args:
        input_path: Path to input JSONL file
    
    Returns:
        List of paper dictionaries
    
    Raises:
        FileNotFoundError: If input file doesn't exist
    """
    if not input_path.exists():
        raise FileNotFoundError(f"Input file not found: {input_path}")
    
    papers = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            try:
                papers.append(json.loads(line.strip()))
            except json.JSONDecodeError as e:
                logger.warning(f"Invalid JSON on line {line_num}: {e}")
    
    logger.info(f"Loaded {len(papers)} papers from {input_path}")
    return papers

def save_checkpoint(papers: List[Dict[str, Any]], checkpoint_path: Path):
    """
    Save intermediate results to checkpoint file.
    
    Args:
        papers: List of enhanced paper dictionaries
        checkpoint_path: Path to checkpoint file
    """
    with open(checkpoint_path, 'w', encoding='utf-8') as f:
        for paper in papers:
            f.write(json.dumps(paper, ensure_ascii=False) + '\n')
    
    logger.info(f"Checkpoint saved: {len(papers)} papers to {checkpoint_path}")

def load_checkpoint(checkpoint_path: Path) -> Tuple[List[Dict[str, Any]], set]:
    """
    Load papers from checkpoint file and get processed paper IDs.
    
    Args:
        checkpoint_path: Path to checkpoint file
    
    Returns:
        Tuple of (papers list, set of processed paper IDs)
    """
    if not checkpoint_path.exists():
        return [], set()
    
    papers = []
    processed_ids = set()
    
    with open(checkpoint_path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                paper = json.loads(line.strip())
                papers.append(paper)
                paper_id = paper.get("paper_id", paper.get("id"))
                if paper_id:
                    processed_ids.add(paper_id)
            except json.JSONDecodeError:
                continue
    
    logger.info(f"Loaded {len(papers)} papers from checkpoint")
    return papers, processed_ids

def process_paper_batch(papers: List[Dict[str, Any]], 
                       max_workers: int = MAX_WORKERS,
                       checkpoint_path: Optional[Path] = None) -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
    """
    Process multiple papers concurrently using ThreadPoolExecutor.
    
    Args:
        papers: List of paper dictionaries to process
        max_workers: Maximum number of concurrent threads
        checkpoint_path: Optional path for checkpointing
    
    Returns:
        Tuple of (enhanced papers list, statistics dictionary)
    """
    enhanced_papers = []
    stats = {
        "successful": 0,
        "failed": 0,
        "empty_fields": 0,
        "total": len(papers)
    }
    
    # Adjust workers based on rate limits
    effective_workers = min(max_workers, 5)  # Conservative start
    
    # Use ThreadPoolExecutor for concurrent processing
    with ThreadPoolExecutor(max_workers=effective_workers) as executor:
        # Submit all tasks
        future_to_paper = {
            executor.submit(extract_academic_fields, paper): paper 
            for paper in papers
        }
        
        # Process results as they complete
        logger.info(f"Processing {len(papers)} papers with {effective_workers} concurrent workers...")
        
        completed = 0
        for future in tqdm(as_completed(future_to_paper), total=len(papers), 
                          desc="Extracting academic fields", unit="paper"):
            try:
                enhanced_paper = future.result(timeout=REQUEST_TIMEOUT + 10)
                enhanced_papers.append(enhanced_paper)
                
                fields = enhanced_paper.get("fields_of_study", [])
                if fields:
                    stats["successful"] += 1
                else:
                    stats["empty_fields"] += 1
                    stats["failed"] += 1
                    
            except Exception as e:
                logger.error(f"Error processing paper: {str(e)[:200]}")
                stats["failed"] += 1
                # Add original paper with empty fields as fallback
                original_paper = future_to_paper[future].copy()
                original_paper["fields_of_study"] = []
                enhanced_papers.append(original_paper)
            
            completed += 1
            
            # Save checkpoint periodically
            if checkpoint_path and completed % BATCH_SIZE == 0:
                save_checkpoint(enhanced_papers, checkpoint_path)
    
    return enhanced_papers, stats

def main():
    """Main execution function with concurrent processing and checkpointing."""
    
    # Define file paths
    input_file = Path("../Data_Cleaning/papers_final_aligned.jsonl")
    output_file = Path("papers_enhanced_fields.jsonl")
    checkpoint_file = Path("papers_enhanced_checkpoint.jsonl")
    
    start_time = time.time()
    
    try:
        # Load papers
        papers = load_papers(input_file)
        
        if not papers:
            logger.error("No papers loaded from input file")
            return
        
        # Load checkpoint if exists
        checkpoint_papers, processed_ids = load_checkpoint(checkpoint_file)
        
        # Filter out already processed papers
        if processed_ids:
            papers_to_process = []
            for paper in papers:
                paper_id = paper.get("paper_id", paper.get("id"))
                if paper_id not in processed_ids:
                    papers_to_process.append(paper)
            
            logger.info(f"Already processed {len(checkpoint_papers)} papers. "
                       f"Processing {len(papers_to_process)} remaining papers.")
            
            # Process remaining papers
            enhanced_new_papers, stats = process_paper_batch(
                papers_to_process, 
                checkpoint_path=checkpoint_file
            )
            
            # Combine checkpoint and new results
            enhanced_papers = checkpoint_papers + enhanced_new_papers
        else:
            # Process all papers
            enhanced_papers, stats = process_paper_batch(
                papers, 
                checkpoint_path=checkpoint_file
            )
        
        # Final save
        with open(output_file, 'w', encoding='utf-8') as f:
            for paper in enhanced_papers:
                f.write(json.dumps(paper, ensure_ascii=False) + '\n')
        
        # Print summary
        elapsed_time = time.time() - start_time
        logger.info("=" * 60)
        logger.info(f"PROCESSING COMPLETE!")
        logger.info(f"Total papers processed: {stats['total']}")
        logger.info(f"Successfully extracted fields: {stats['successful']}")
        logger.info(f"Papers with empty fields: {stats['empty_fields']}")
        logger.info(f"Failed extractions: {stats['failed']}")
        logger.info(f"Total API requests: {total_requests}")
        logger.info(f"Failed API requests: {failed_requests}")
        logger.info(f"Success rate: {(total_requests - failed_requests) / total_requests * 100:.1f}%")
        logger.info(f"Total time: {elapsed_time:.2f} seconds")
        logger.info(f"Processing rate: {stats['total'] / elapsed_time:.2f} papers/second")
        logger.info(f"Results saved to: {output_file}")
        logger.info("=" * 60)
        
        # Clean up checkpoint file
        if checkpoint_file.exists():
            checkpoint_file.unlink()
            logger.info(f"Checkpoint file removed: {checkpoint_file}")
        
    except FileNotFoundError as e:
        logger.error(f"File error: {e}")
    except KeyboardInterrupt:
        logger.info("Processing interrupted by user.")
        logger.info(f"Checkpoint saved to {checkpoint_file}. Resume by running again.")
    except Exception as e:
        logger.error(f"Unexpected error: {e}", exc_info=True)

if __name__ == "__main__":
    main()

#### Generate key words for data augmentation

In [None]:
# enhance_keywords.py
"""
Enhanced academic keywords extraction using DeepSeek API with concurrent processing.

This script processes academic papers from a JSONL file, calls the DeepSeek API
to extract relevant academic keywords for each paper, and saves the enhanced data
to a new JSONL file with progress tracking, error handling, and checkpointing.
"""

import json
import requests
import time
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import random
from datetime import datetime
import backoff
from threading import Lock
import re

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('keywords_extraction.log', encoding='utf-8'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# API Configuration
API_KEY = "sk-79990d599cd74bc0a56f6ca2f200a621"
API_URL = "https://api.deepseek.com/v1/chat/completions"
REQUEST_TIMEOUT = 30
MAX_WORKERS = 8  # Adjust based on API rate limits
BASE_DELAY = 0.5  # Base delay between requests
MAX_RETRIES = 5
BATCH_SIZE = 50  # Save progress every N papers

# Rate limiting
rate_limit_lock = Lock()
last_request_time = 0
request_counter = 0
error_counter = 0

def rate_limited_request():
    """Ensure minimum delay between API requests with jitter."""
    global last_request_time
    
    with rate_limit_lock:
        current_time = time.time()
        elapsed = current_time - last_request_time
        if elapsed < BASE_DELAY:
            sleep_time = BASE_DELAY - elapsed + random.uniform(0, 0.1)
            time.sleep(sleep_time)
        last_request_time = time.time()

@backoff.on_exception(
    backoff.expo,
    (requests.exceptions.RequestException, KeyError, json.JSONDecodeError),
    max_tries=MAX_RETRIES,
    max_time=300,
    jitter=backoff.full_jitter
)
def call_deepseek_with_backoff(prompt: str, paper_id: str = "unknown") -> Optional[str]:
    """
    Call DeepSeek API with exponential backoff and comprehensive error handling.
    
    Args:
        prompt: The prompt to send to the API
        paper_id: Paper identifier for logging
    
    Returns:
        API response content as string, or None if all retries fail
    """
    global request_counter, error_counter
    
    rate_limited_request()
    request_counter += 1
    
    try:
        response = requests.post(
            API_URL,
            headers={
                "Authorization": f"Bearer {API_KEY}",
                "Content-Type": "application/json"
            },
            json={
                "model": "deepseek-chat",
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0.3,
                "max_tokens": 300
            },
            timeout=REQUEST_TIMEOUT
        )
        
        # Handle rate limiting
        if response.status_code == 429:
            retry_after = int(response.headers.get('Retry-After', 60))
            logger.warning(f"Rate limited for paper {paper_id}. Waiting {retry_after} seconds...")
            time.sleep(retry_after)
            response.raise_for_status()
        
        response.raise_for_status()
        
        data = response.json()
        
        # Validate response structure
        if not isinstance(data, dict):
            raise ValueError(f"Response is not a dictionary: {type(data)}")
        
        choices = data.get("choices", [])
        if not choices:
            raise ValueError(f"No choices in response: {data}")
        
        message = choices[0].get("message", {})
        content = message.get("content", "").strip()
        
        if not content:
            raise ValueError("Empty response content")
        
        # Log success periodically
        if request_counter % 100 == 0:
            logger.info(f"Completed {request_counter} API requests")
        
        return content
        
    except requests.exceptions.Timeout:
        logger.warning(f"Timeout for paper {paper_id}")
        raise
    except requests.exceptions.HTTPError as e:
        error_counter += 1
        status_code = e.response.status_code if e.response else "unknown"
        logger.error(f"HTTP error {status_code} for paper {paper_id}")
        if status_code >= 500:
            # Server error, use longer backoff
            time.sleep(10)
        raise
    except Exception as e:
        error_counter += 1
        logger.error(f"API error for paper {paper_id}: {str(e)[:100]}")
        raise

def call_deepseek(prompt: str, paper_id: str = "unknown") -> Optional[str]:
    """
    Wrapper for API call with fallback handling.
    
    Args:
        prompt: The prompt to send to the API
        paper_id: Paper identifier for logging
    
    Returns:
        API response content as string, or None if all retries fail
    """
    try:
        return call_deepseek_with_backoff(prompt, paper_id)
    except Exception as e:
        logger.error(f"All retries failed for paper {paper_id}: {str(e)[:100]}")
        return None

def extract_keywords(paper: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract academic keywords from paper metadata using DeepSeek API.
    
    Args:
        paper: Dictionary containing paper metadata
    
    Returns:
        Dictionary with original paper data plus extracted keywords
    """
    paper_id = paper.get("paper_id", paper.get("id", "unknown"))
    
    # Prepare title and abstract
    title = paper.get("title", "").strip()
    abstract = paper.get("abstract", "").strip()
    
    # Truncate if too long (API has token limits)
    if len(abstract) > 3000:
        abstract = abstract[:3000] + "..."
    
    prompt = f"""
Extract high-quality academic keywords from the following paper.

Rules:
- Output ONLY a JSON array of strings.
- 5 to 8 keywords.
- Keywords should describe tasks, methods, or research problems.
- Avoid generic words like "model", "method", "framework".
- Each keyword should be specific and meaningful.

Title:
{title}

Abstract:
{abstract}
"""
    keywords = []
    
    try:
        api_response = call_deepseek(prompt, paper_id)
        
        if api_response is None:
            logger.debug(f"No API response for paper {paper_id}")
        else:
            # Try to parse as JSON
            try:
                parsed = json.loads(api_response)
                
                if isinstance(parsed, list):
                    # Clean and validate keywords
                    for keyword in parsed:
                        if isinstance(keyword, str):
                            cleaned = keyword.strip()
                            # Filter out generic keywords and validate length
                            if (cleaned and 
                                len(cleaned) >= 3 and 
                                len(cleaned) <= 50 and
                                cleaned.lower() not in ['model', 'method', 'framework', 
                                                       'approach', 'system', 'algorithm']):
                                keywords.append(cleaned)
                    
                    # Deduplicate and limit
                    seen = set()
                    unique_keywords = []
                    for k in keywords:
                        if k not in seen:
                            seen.add(k)
                            unique_keywords.append(k)
                    keywords = unique_keywords[:8]  # Max 8 as per requirements
                    
                    # Ensure minimum of 5 if we have some
                    if 0 < len(keywords) < 5:
                        logger.debug(f"Paper {paper_id} has only {len(keywords)} keywords")
                
            except json.JSONDecodeError:
                # Fallback: extract keywords from text response
                logger.debug(f"Failed to parse JSON for paper {paper_id}, extracting from text")
                
                # Try to find JSON array pattern
                json_pattern = r'\[.*?\]'
                match = re.search(json_pattern, api_response, re.DOTALL)
                if match:
                    try:
                        parsed = json.loads(match.group(0))
                        if isinstance(parsed, list):
                            keywords = [str(k).strip() for k in parsed if str(k).strip()]
                            keywords = list(set(keywords))[:8]
                    except:
                        pass
                
                # If still no keywords, use text-based extraction
                if not keywords:
                    # Split by common delimiters and filter
                    lines = api_response.strip().split('\n')
                    for line in lines:
                        line = line.strip()
                        # Remove quotes, brackets, etc.
                        line = re.sub(r'^[\s\[\]\'"]+|[\s\[\]\'\"]+$', '', line)
                        if (line and 
                            len(line) >= 3 and 
                            len(line) <= 50 and
                            line.lower() not in ['model', 'method', 'framework'] and
                            ',' not in line and ';' not in line):
                            keywords.append(line)
                    
                    keywords = list(set(keywords))[:8]
    
    except Exception as e:
        logger.error(f"Unexpected error extracting keywords for paper {paper_id}: {str(e)[:100]}")
    
    # Create enhanced paper
    enhanced_paper = paper.copy()
    enhanced_paper["keywords"] = keywords
    enhanced_paper["keywords_extraction_time"] = datetime.now().isoformat()
    
    return enhanced_paper

def load_papers(input_path: Path) -> List[Dict[str, Any]]:
    """Load papers from JSONL file."""
    papers = []
    try:
        with open(input_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    papers.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    logger.warning(f"Invalid JSON on line {line_num}: {e}")
        
        logger.info(f"Loaded {len(papers)} papers from {input_path}")
    except FileNotFoundError:
        logger.error(f"Input file not found: {input_path}")
        raise
    
    return papers

def save_checkpoint(papers: List[Dict[str, Any]], checkpoint_path: Path):
    """Save intermediate results to checkpoint file."""
    try:
        with open(checkpoint_path, 'w', encoding='utf-8') as f:
            for paper in papers:
                f.write(json.dumps(paper, ensure_ascii=False) + '\n')
        
        logger.debug(f"Checkpoint saved: {len(papers)} papers")
    except Exception as e:
        logger.error(f"Failed to save checkpoint: {e}")

def load_checkpoint(checkpoint_path: Path) -> Tuple[List[Dict[str, Any]], set]:
    """Load papers from checkpoint file and get processed paper IDs."""
    if not checkpoint_path.exists():
        return [], set()
    
    papers = []
    processed_ids = set()
    
    try:
        with open(checkpoint_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    paper = json.loads(line.strip())
                    papers.append(paper)
                    paper_id = paper.get("paper_id", paper.get("id"))
                    if paper_id:
                        processed_ids.add(paper_id)
                except json.JSONDecodeError:
                    continue
        
        logger.info(f"Loaded {len(papers)} papers from checkpoint")
    except Exception as e:
        logger.error(f"Failed to load checkpoint: {e}")
    
    return papers, processed_ids

def process_concurrently(papers: List[Dict[str, Any]], 
                        max_workers: int = MAX_WORKERS,
                        checkpoint_path: Optional[Path] = None) -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
    """
    Process papers concurrently using ThreadPoolExecutor.
    
    Args:
        papers: List of papers to process
        max_workers: Maximum concurrent threads
        checkpoint_path: Optional checkpoint file path
    
    Returns:
        Tuple of (enhanced papers, statistics)
    """
    enhanced_papers = []
    stats = {
        "total": len(papers),
        "successful": 0,
        "failed": 0,
        "with_keywords": 0,
        "empty_keywords": 0
    }
    
    # Adjust workers based on rate limits
    effective_workers = min(max_workers, 10)
    logger.info(f"Processing {len(papers)} papers with {effective_workers} workers")
    
    with ThreadPoolExecutor(max_workers=effective_workers) as executor:
        # Submit tasks
        future_to_index = {
            executor.submit(extract_keywords, paper): i 
            for i, paper in enumerate(papers)
        }
        
        # Track completion order for proper ordering
        results = [None] * len(papers)
        
        # Process with progress bar
        with tqdm(total=len(papers), desc="Extracting keywords", unit="paper") as pbar:
            for future in as_completed(future_to_index):
                idx = future_to_index[future]
                
                try:
                    enhanced_paper = future.result(timeout=REQUEST_TIMEOUT + 30)
                    results[idx] = enhanced_paper
                    
                    keywords = enhanced_paper.get("keywords", [])
                    if keywords:
                        stats["with_keywords"] += 1
                        stats["successful"] += 1
                    else:
                        stats["empty_keywords"] += 1
                        stats["failed"] += 1
                        
                except Exception as e:
                    logger.error(f"Processing failed for paper at index {idx}: {str(e)[:100]}")
                    stats["failed"] += 1
                    # Add fallback
                    fallback_paper = papers[idx].copy()
                    fallback_paper["keywords"] = []
                    results[idx] = fallback_paper
                
                pbar.update(1)
                
                # Periodic checkpoint
                if checkpoint_path and pbar.n % BATCH_SIZE == 0:
                    # Collect completed results
                    completed = [r for r in results if r is not None]
                    save_checkpoint(completed, checkpoint_path)
    
    # Filter out None results (shouldn't happen with proper error handling)
    enhanced_papers = [r for r in results if r is not None]
    
    return enhanced_papers, stats

def main():
    """Main execution function."""
    
    # File paths
    input_file = Path("papers_final_aligned.jsonl")
    output_file = Path("papers_enhanced_keywords.jsonl")
    checkpoint_file = Path("keywords_extraction_checkpoint.jsonl")
    
    start_time = time.time()
    
    try:
        # Load papers
        papers = load_papers(input_file)
        
        if not papers:
            logger.error("No papers loaded. Exiting.")
            return
        
        # Load checkpoint
        checkpoint_papers, processed_ids = load_checkpoint(checkpoint_file)
        
        # Determine which papers need processing
        if processed_ids:
            papers_to_process = []
            for paper in papers:
                paper_id = paper.get("paper_id", paper.get("id"))
                if paper_id not in processed_ids:
                    papers_to_process.append(paper)
            
            logger.info(f"Resuming: {len(checkpoint_papers)} already processed, "
                       f"{len(papers_to_process)} remaining")
            
            # Process remaining papers
            enhanced_new_papers, stats = process_concurrently(
                papers_to_process,
                checkpoint_path=checkpoint_file
            )
            
            # Combine results
            enhanced_papers = checkpoint_papers + enhanced_new_papers
        else:
            # Process all papers
            enhanced_papers, stats = process_concurrently(
                papers,
                checkpoint_path=checkpoint_file
            )
        
        # Save final results
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                for paper in enhanced_papers:
                    f.write(json.dumps(paper, ensure_ascii=False) + '\n')
            logger.info(f"Saved {len(enhanced_papers)} papers to {output_file}")
        except Exception as e:
            logger.error(f"Failed to save output: {e}")
            # Save to backup file
            backup_file = Path(f"{output_file.stem}_backup_{int(time.time())}.jsonl")
            with open(backup_file, 'w', encoding='utf-8') as f:
                for paper in enhanced_papers:
                    f.write(json.dumps(paper, ensure_ascii=False) + '\n')
            logger.info(f"Saved backup to {backup_file}")
        
        # Print summary
        elapsed_time = time.time() - start_time
        logger.info("=" * 60)
        logger.info("KEYWORDS EXTRACTION COMPLETE")
        logger.info("=" * 60)
        logger.info(f"Total papers: {stats['total']}")
        logger.info(f"Successfully processed: {stats['successful']}")
        logger.info(f"Papers with keywords: {stats['with_keywords']}")
        logger.info(f"Papers with empty keywords: {stats['empty_keywords']}")
        logger.info(f"Failed: {stats['failed']}")
        logger.info(f"Total API requests: {request_counter}")
        logger.info(f"API errors: {error_counter}")
        if request_counter > 0:
            logger.info(f"API success rate: {(request_counter - error_counter) / request_counter * 100:.1f}%")
        logger.info(f"Total time: {elapsed_time:.1f} seconds")
        logger.info(f"Processing rate: {stats['total'] / elapsed_time:.2f} papers/sec")
        logger.info("=" * 60)
        
        # Clean up checkpoint
        if checkpoint_file.exists():
            try:
                checkpoint_file.unlink()
                logger.info(f"Removed checkpoint file: {checkpoint_file}")
            except Exception as e:
                logger.warning(f"Failed to remove checkpoint: {e}")
        
    except KeyboardInterrupt:
        logger.info("\nProcessing interrupted by user.")
        logger.info(f"Checkpoint saved. Resume by running the script again.")
    except Exception as e:
        logger.error(f"Unexpected error: {e}", exc_info=True)

if __name__ == "__main__":
    main()

#### Score the article based on the title and abstract of the paper

This script (enhance_scoring.py) enhances academic paper datasets by adding automated quality assessments using the DeepSeek LLM API. It processes papers from a JSONL file, constructs structured prompts for AI evaluation, and appends quality scores to each paper before saving the enhanced data to a new file. The implementation includes basic API integration with rate limiting to prevent service overload.

Scoring System

The scoring model evaluates papers across four primary dimensions (0-10 integer scale) plus two comprehensive metrics:
• Novelty: Originality and innovative contribution

• Technical Depth: Methodological rigor and complexity

• Clarity: Presentation quality and logical structure

• Impact Potential: Field contribution and application value

• Overall Score: Weighted composite (float)

• Confidence: Model's assessment certainty (0-1 float)

The AI model generates scores by analyzing paper titles and abstracts, simulating peer-review assessment standards with structured JSON output for consistency.

In [None]:
# enhance_scoring.py
"""
Enhanced academic paper scoring using DeepSeek API with concurrent processing.

This script processes academic papers from a JSONL file, calls the DeepSeek API
to evaluate paper quality across multiple dimensions, and saves the enhanced data
to a new JSONL file with progress tracking, error handling, and checkpointing.
"""

import json
import requests
import time
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import random
from datetime import datetime
import backoff
from threading import Lock
import re

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('paper_scoring.log', encoding='utf-8'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# API Configuration
API_KEY = "sk-79990d599cd74bc0a56f6ca2f200a621"
API_URL = "https://api.deepseek.com/v1/chat/completions"
REQUEST_TIMEOUT = 40  # Increased for scoring task
MAX_WORKERS = 6  # Conservative for scoring task
BASE_DELAY = 0.8  # Increased delay for scoring
MAX_RETRIES = 5
BATCH_SIZE = 30  # Save progress every N papers
MAX_TOKENS = 150

# Rate limiting and tracking
rate_limit_lock = Lock()
last_request_time = 0
request_counter = 0
error_counter = 0

# Score validation ranges
SCORE_RANGES = {
    "novelty": (0, 10),
    "technical_depth": (0, 10),
    "clarity": (0, 10),
    "impact_potential": (0, 10),
    "overall_score": (0.0, 10.0),
    "confidence": (0.0, 1.0)
}

def rate_limited_request():
    """Ensure minimum delay between API requests with jitter."""
    global last_request_time
    
    with rate_limit_lock:
        current_time = time.time()
        elapsed = current_time - last_request_time
        if elapsed < BASE_DELAY:
            sleep_time = BASE_DELAY - elapsed + random.uniform(0, 0.2)
            time.sleep(sleep_time)
        last_request_time = time.time()

def validate_scores(scores: Dict[str, Any]) -> Dict[str, Any]:
    """Validate and normalize score values."""
    if not isinstance(scores, dict):
        return {}
    
    validated = {}
    
    for key, (min_val, max_val) in SCORE_RANGES.items():
        if key in scores:
            try:
                value = scores[key]
                # Convert to appropriate type
                if key in ["overall_score", "confidence"]:
                    value = float(value)
                else:
                    value = int(round(float(value)))
                
                # Clamp to valid range
                if key in ["overall_score", "confidence"]:
                    value = max(min_val, min(max_val, value))
                else:
                    value = max(min_val, min(max_val, value))
                
                validated[key] = value
            except (ValueError, TypeError):
                logger.debug(f"Invalid value for {key}: {scores[key]}")
                # Set default values
                if key in ["overall_score", "confidence"]:
                    validated[key] = 0.0
                else:
                    validated[key] = 0
    
    # Calculate overall score if not provided but other scores exist
    if "overall_score" not in validated and all(k in validated for k in ["novelty", "technical_depth", "clarity", "impact_potential"]):
        validated["overall_score"] = round(
            (validated["novelty"] + validated["technical_depth"] + 
             validated["clarity"] + validated["impact_potential"]) / 4.0, 1
        )
    
    # Add confidence if missing
    if "confidence" not in validated:
        validated["confidence"] = 0.5
    
    return validated

@backoff.on_exception(
    backoff.expo,
    (requests.exceptions.RequestException, KeyError, json.JSONDecodeError),
    max_tries=MAX_RETRIES,
    max_time=300,
    jitter=backoff.full_jitter
)
def call_deepseek_with_backoff(prompt: str, paper_id: str = "unknown") -> Optional[str]:
    """
    Call DeepSeek API with exponential backoff for scoring task.
    
    Args:
        prompt: The prompt to send to the API
        paper_id: Paper identifier for logging
    
    Returns:
        API response content as string, or None if all retries fail
    """
    global request_counter, error_counter
    
    rate_limited_request()
    request_counter += 1
    
    try:
        response = requests.post(
            API_URL,
            headers={
                "Authorization": f"Bearer {API_KEY}",
                "Content-Type": "application/json"
            },
            json={
                "model": "deepseek-chat",
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0.1,
                "max_tokens": MAX_TOKENS
            },
            timeout=REQUEST_TIMEOUT
        )
        
        # Handle rate limiting
        if response.status_code == 429:
            retry_after = int(response.headers.get('Retry-After', 90))
            logger.warning(f"Rate limited for paper {paper_id}. Waiting {retry_after} seconds...")
            time.sleep(retry_after)
            response.raise_for_status()
        
        response.raise_for_status()
        
        data = response.json()
        
        # Validate response structure
        if not isinstance(data, dict):
            raise ValueError(f"Response is not a dictionary: {type(data)}")
        
        choices = data.get("choices", [])
        if not choices:
            raise ValueError(f"No choices in response: {data}")
        
        message = choices[0].get("message", {})
        content = message.get("content", "").strip()
        
        if not content:
            raise ValueError("Empty response content")
        
        # Log progress periodically
        if request_counter % 50 == 0:
            logger.info(f"Completed {request_counter} scoring requests")
        
        return content
        
    except requests.exceptions.Timeout:
        logger.warning(f"Timeout for paper {paper_id}")
        raise
    except requests.exceptions.HTTPError as e:
        error_counter += 1
        status_code = e.response.status_code if e.response else "unknown"
        logger.error(f"HTTP error {status_code} for paper {paper_id}")
        if status_code >= 500:
            time.sleep(15)  # Longer sleep for server errors
        raise
    except Exception as e:
        error_counter += 1
        logger.error(f"API error for paper {paper_id}: {str(e)[:100]}")
        raise

def call_deepseek(prompt: str, paper_id: str = "unknown") -> Optional[str]:
    """
    Wrapper for API call with fallback handling.
    
    Args:
        prompt: The prompt to send to the API
        paper_id: Paper identifier for logging
    
    Returns:
        API response content as string, or None if all retries fail
    """
    try:
        return call_deepseek_with_backoff(prompt, paper_id)
    except Exception as e:
        logger.error(f"All retries failed for paper {paper_id}: {str(e)[:100]}")
        return None

def extract_paper_scores(paper: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract quality scores from paper metadata using DeepSeek API.
    
    Args:
        paper: Dictionary containing paper metadata
    
    Returns:
        Dictionary with original paper data plus quality scores
    """
    paper_id = paper.get("paper_id", paper.get("id", "unknown"))
    
    # Prepare content
    title = paper.get("title", "").strip()
    abstract = paper.get("abstract", "").strip()
    
    # Truncate if too long
    if len(abstract) > 2500:
        abstract = abstract[:2500] + "..."
    
    prompt = f"""
You are a senior conference reviewer with expertise in academic paper evaluation.

Evaluate the following paper strictly and objectively.

Score each dimension from 0 to 10 (integer only for novelty, technical_depth, clarity, impact_potential).
Overall score should be a float between 0.0 and 10.0 with one decimal place.
Confidence should be a float between 0.0 and 1.0 with two decimal places.

IMPORTANT: Return ONLY valid JSON in the following format, no additional text:
{{
  "novelty": int,
  "technical_depth": int,
  "clarity": int,
  "impact_potential": int,
  "overall_score": float,
  "confidence": float
}}

Evaluation Guidelines:
1. Novelty: Originality and new contributions (0=no novelty, 10=highly novel)
2. Technical Depth: Sophistication of methods and analysis (0=superficial, 10=very deep)
3. Clarity: Presentation quality and writing (0=unclear, 10=very clear)
4. Impact Potential: Potential influence on field (0=no impact, 10=high impact)
5. Overall Score: Weighted average considering all factors
6. Confidence: Your confidence in the evaluation (0.0=low, 1.0=high)

Title:
{title}

Abstract:
{abstract}
"""
    scores = {}
    
    try:
        api_response = call_deepseek(prompt, paper_id)
        
        if api_response is None:
            logger.debug(f"No API response for paper {paper_id}")
        else:
            # Clean response - remove markdown code blocks if present
            clean_response = api_response.strip()
            clean_response = re.sub(r'^```json\s*', '', clean_response)
            clean_response = re.sub(r'\s*```$', '', clean_response)
            
            # Try to parse as JSON
            try:
                parsed = json.loads(clean_response)
                
                if isinstance(parsed, dict):
                    scores = validate_scores(parsed)
                else:
                    logger.warning(f"Response is not a dictionary for paper {paper_id}")
                    
            except json.JSONDecodeError as e:
                logger.debug(f"Failed to parse JSON for paper {paper_id}: {str(e)[:100]}")
                
                # Advanced JSON extraction
                json_patterns = [
                    r'\{[^{}]*\}',  # Simple object
                    r'\{.*\}',      # Nested object (greedy)
                ]
                
                for pattern in json_patterns:
                    matches = re.findall(pattern, clean_response, re.DOTALL)
                    for match in matches:
                        try:
                            parsed = json.loads(match)
                            if isinstance(parsed, dict) and any(k in parsed for k in SCORE_RANGES.keys()):
                                scores = validate_scores(parsed)
                                if scores:
                                    break
                        except:
                            continue
                    if scores:
                        break
                
                # Fallback: extract scores from text
                if not scores:
                    logger.debug(f"Using text extraction fallback for paper {paper_id}")
                    
                    # Look for key-value patterns
                    for key in SCORE_RANGES.keys():
                        pattern = rf'"{key}"\s*:\s*([0-9]+(?:\.[0-9]+)?)'
                        match = re.search(pattern, clean_response, re.IGNORECASE)
                        if not match:
                            # Try without quotes
                            pattern = rf'{key}\s*:\s*([0-9]+(?:\.[0-9]+)?)'
                            match = re.search(pattern, clean_response, re.IGNORECASE)
                        
                        if match:
                            try:
                                value = match.group(1)
                                if key in ["overall_score", "confidence"]:
                                    scores[key] = float(value)
                                else:
                                    scores[key] = int(round(float(value)))
                            except:
                                pass
    
    except Exception as e:
        logger.error(f"Unexpected error scoring paper {paper_id}: {str(e)[:100]}")
    
    # Validate and add default values for missing scores
    scores = validate_scores(scores)
    
    # Create enhanced paper
    enhanced_paper = paper.copy()
    enhanced_paper["quality_scores"] = scores
    enhanced_paper["scoring_time"] = datetime.now().isoformat()
    
    return enhanced_paper

def load_papers(input_path: Path) -> List[Dict[str, Any]]:
    """Load papers from JSONL file."""
    papers = []
    try:
        with open(input_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    papers.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    logger.warning(f"Invalid JSON on line {line_num}: {e}")
        
        logger.info(f"Loaded {len(papers)} papers from {input_path}")
    except FileNotFoundError:
        logger.error(f"Input file not found: {input_path}")
        raise
    
    return papers

def save_checkpoint(papers: List[Dict[str, Any]], checkpoint_path: Path):
    """Save intermediate results to checkpoint file."""
    try:
        with open(checkpoint_path, 'w', encoding='utf-8') as f:
            for paper in papers:
                f.write(json.dumps(paper, ensure_ascii=False) + '\n')
        
        logger.debug(f"Checkpoint saved: {len(papers)} papers")
    except Exception as e:
        logger.error(f"Failed to save checkpoint: {e}")

def load_checkpoint(checkpoint_path: Path) -> Tuple[List[Dict[str, Any]], set]:
    """Load papers from checkpoint file and get processed paper IDs."""
    if not checkpoint_path.exists():
        return [], set()
    
    papers = []
    processed_ids = set()
    
    try:
        with open(checkpoint_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    paper = json.loads(line.strip())
                    papers.append(paper)
                    paper_id = paper.get("paper_id", paper.get("id"))
                    if paper_id:
                        processed_ids.add(paper_id)
                except json.JSONDecodeError:
                    continue
        
        logger.info(f"Loaded {len(papers)} papers from checkpoint")
    except Exception as e:
        logger.error(f"Failed to load checkpoint: {e}")
    
    return papers, processed_ids

def process_concurrently(papers: List[Dict[str, Any]], 
                        max_workers: int = MAX_WORKERS,
                        checkpoint_path: Optional[Path] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
    """
    Process papers concurrently using ThreadPoolExecutor.
    
    Args:
        papers: List of papers to process
        max_workers: Maximum concurrent threads
        checkpoint_path: Optional checkpoint file path
    
    Returns:
        Tuple of (enhanced papers, statistics)
    """
    enhanced_papers = []
    stats = {
        "total": len(papers),
        "successful": 0,
        "failed": 0,
        "complete_scores": 0,
        "partial_scores": 0,
        "no_scores": 0
    }
    
    # Adjust workers for scoring task
    effective_workers = min(max_workers, 8)
    logger.info(f"Scoring {len(papers)} papers with {effective_workers} workers")
    
    with ThreadPoolExecutor(max_workers=effective_workers) as executor:
        # Submit tasks
        future_to_index = {
            executor.submit(extract_paper_scores, paper): i 
            for i, paper in enumerate(papers)
        }
        
        # Track completion order for proper ordering
        results = [None] * len(papers)
        
        # Process with progress bar
        with tqdm(total=len(papers), desc="Evaluating papers", unit="paper") as pbar:
            for future in as_completed(future_to_index):
                idx = future_to_index[future]
                
                try:
                    enhanced_paper = future.result(timeout=REQUEST_TIMEOUT + 45)
                    results[idx] = enhanced_paper
                    
                    scores = enhanced_paper.get("quality_scores", {})
                    if scores:
                        required_keys = ["novelty", "technical_depth", "clarity", 
                                        "impact_potential", "overall_score", "confidence"]
                        present_keys = [k for k in required_keys if k in scores]
                        
                        if len(present_keys) >= 4:  # At least 4 of 6 scores
                            stats["complete_scores"] += 1
                            stats["successful"] += 1
                        elif len(present_keys) >= 1:
                            stats["partial_scores"] += 1
                            stats["successful"] += 1
                        else:
                            stats["no_scores"] += 1
                            stats["failed"] += 1
                    else:
                        stats["no_scores"] += 1
                        stats["failed"] += 1
                        
                except Exception as e:
                    logger.error(f"Processing failed for paper at index {idx}: {str(e)[:100]}")
                    stats["failed"] += 1
                    # Add fallback
                    fallback_paper = papers[idx].copy()
                    fallback_paper["quality_scores"] = {}
                    results[idx] = fallback_paper
                
                pbar.update(1)
                
                # Periodic checkpoint
                if checkpoint_path and pbar.n % BATCH_SIZE == 0:
                    # Collect completed results
                    completed = [r for r in results if r is not None]
                    save_checkpoint(completed, checkpoint_path)
                    
                    # Update progress in log
                    logger.info(f"Progress: {pbar.n}/{len(papers)} papers scored "
                               f"({pbar.n/len(papers)*100:.1f}%)")
    
    # Filter out None results
    enhanced_papers = [r for r in results if r is not None]
    
    return enhanced_papers, stats

def calculate_statistics(papers: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Calculate statistics on the collected scores."""
    scores_data = []
    for paper in papers:
        scores = paper.get("quality_scores", {})
        if scores:
            scores_data.append(scores)
    
    if not scores_data:
        return {"total_papers_with_scores": 0}
    
    stats = {
        "total_papers_with_scores": len(scores_data),
        "average_scores": {},
        "score_distributions": {}
    }
    
    # Calculate averages
    for key in SCORE_RANGES.keys():
        if any(key in s for s in scores_data):
            values = [s[key] for s in scores_data if key in s]
            if values:
                stats["average_scores"][key] = {
                    "mean": sum(values) / len(values),
                    "min": min(values),
                    "max": max(values),
                    "count": len(values)
                }
    
    # Calculate distributions for integer scores
    for key in ["novelty", "technical_depth", "clarity", "impact_potential"]:
        if key in [k for s in scores_data for k in s]:
            values = [s[key] for s in scores_data if key in s]
            distribution = {i: values.count(i) for i in range(11)}
            stats["score_distributions"][key] = distribution
    
    return stats

def main():
    """Main execution function."""
    
    # File paths
    input_file = Path("papers_final_aligned.jsonl")
    output_file = Path("papers_enhanced_scores.jsonl")
    checkpoint_file = Path("paper_scoring_checkpoint.jsonl")
    stats_file = Path("scoring_statistics.json")
    
    start_time = time.time()
    
    try:
        # Load papers
        papers = load_papers(input_file)
        
        if not papers:
            logger.error("No papers loaded. Exiting.")
            return
        
        # Load checkpoint
        checkpoint_papers, processed_ids = load_checkpoint(checkpoint_file)
        
        # Determine which papers need processing
        if processed_ids:
            papers_to_process = []
            for paper in papers:
                paper_id = paper.get("paper_id", paper.get("id"))
                if paper_id not in processed_ids:
                    papers_to_process.append(paper)
            
            logger.info(f"Resuming: {len(checkpoint_papers)} already processed, "
                       f"{len(papers_to_process)} remaining")
            
            # Process remaining papers
            enhanced_new_papers, stats = process_concurrently(
                papers_to_process,
                checkpoint_path=checkpoint_file
            )
            
            # Combine results
            enhanced_papers = checkpoint_papers + enhanced_new_papers
        else:
            # Process all papers
            enhanced_papers, stats = process_concurrently(
                papers,
                checkpoint_path=checkpoint_file
            )
        
        # Save final results
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                for paper in enhanced_papers:
                    f.write(json.dumps(paper, ensure_ascii=False) + '\n')
            logger.info(f"Saved {len(enhanced_papers)} papers to {output_file}")
        except Exception as e:
            logger.error(f"Failed to save output: {e}")
            # Save to backup file
            backup_file = Path(f"{output_file.stem}_backup_{int(time.time())}.jsonl")
            with open(backup_file, 'w', encoding='utf-8') as f:
                for paper in enhanced_papers:
                    f.write(json.dumps(paper, ensure_ascii=False) + '\n')
            logger.info(f"Saved backup to {backup_file}")
        
        # Calculate and save statistics
        scoring_stats = calculate_statistics(enhanced_papers)
        try:
            with open(stats_file, 'w', encoding='utf-8') as f:
                json.dump(scoring_stats, f, indent=2, ensure_ascii=False)
            logger.info(f"Saved scoring statistics to {stats_file}")
        except Exception as e:
            logger.error(f"Failed to save statistics: {e}")
        
        # Print summary
        elapsed_time = time.time() - start_time
        hours, remainder = divmod(elapsed_time, 3600)
        minutes, seconds = divmod(remainder, 60)
        
        logger.info("=" * 70)
        logger.info("PAPER SCORING COMPLETE")
        logger.info("=" * 70)
        logger.info(f"Total papers: {stats['total']}")
        logger.info(f"Successfully processed: {stats['successful']}")
        logger.info(f"Complete scores (4+ dimensions): {stats['complete_scores']}")
        logger.info(f"Partial scores (1-3 dimensions): {stats['partial_scores']}")
        logger.info(f"No scores: {stats['no_scores']}")
        logger.info(f"Failed: {stats['failed']}")
        logger.info(f"Total API requests: {request_counter}")
        logger.info(f"API errors: {error_counter}")
        if request_counter > 0:
            success_rate = (request_counter - error_counter) / request_counter * 100
            logger.info(f"API success rate: {success_rate:.1f}%")
        logger.info(f"Total time: {int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}")
        logger.info(f"Processing rate: {stats['total'] / elapsed_time:.2f} papers/sec")
        
        # Show score statistics if available
        if scoring_stats["total_papers_with_scores"] > 0:
            logger.info("-" * 70)
            logger.info("SCORING STATISTICS:")
            for key, val in scoring_stats["average_scores"].items():
                logger.info(f"  {key:20s}: {val['mean']:.2f} (min: {val['min']}, max: {val['max']})")
        
        logger.info("=" * 70)
        
        # Clean up checkpoint
        if checkpoint_file.exists():
            try:
                checkpoint_file.unlink()
                logger.info(f"Removed checkpoint file: {checkpoint_file}")
            except Exception as e:
                logger.warning(f"Failed to remove checkpoint: {e}")
        
    except KeyboardInterrupt:
        logger.info("\nProcessing interrupted by user.")
        logger.info(f"Checkpoint saved. Resume by running the script again.")
    except Exception as e:
        logger.error(f"Unexpected error: {e}", exc_info=True)

if __name__ == "__main__":
    main()

#### Generates Structured Research Contribution Summaries
This script (`enhance_optional_summary.py`) processes papers from a JSONL file, extracts the title and abstract, and creates a formatted summary containing four key components: problem statement, methodology, key contributions, and potential application scenarios. The enhanced data is saved to a new JSONL file with structured summaries appended to each paper record.

The script employs a single-threaded approach with 1-second delays between API calls to prevent rate limiting. It uses a structured prompt template to ensure consistent JSON output format from the LLM, with temperature set to 0.25 for balanced creativity and consistency. The output includes four structured fields: problem description, methodological approach, a list of 3-5 key contributions, and potential application scenarios, all extracted automatically from paper metadata.

In [None]:
# enhance_summary.py
"""
Enhanced paper contribution summarization using DeepSeek API with concurrent processing.

This script processes academic papers from a JSONL file, calls the DeepSeek API
to extract structured research contributions for each paper, and saves the enhanced data
to a new JSONL file with progress tracking, error handling, and checkpointing.
"""

import json
import requests
import time
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import random
from datetime import datetime
import backoff
from threading import Lock
import re

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('contribution_extraction.log', encoding='utf-8'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# API Configuration
API_KEY = "sk-79990d599cd74bc0a56f6ca2f200a621"
API_URL = "https://api.deepseek.com/v1/chat/completions"
REQUEST_TIMEOUT = 45  # Increased for summarization task
MAX_WORKERS = 5  # Conservative for summarization
BASE_DELAY = 0.7  # Base delay between requests
MAX_RETRIES = 5
BATCH_SIZE = 25  # Save progress every N papers
MAX_TOKENS = 400  # Increased for longer summaries

# Rate limiting and tracking
rate_limit_lock = Lock()
last_request_time = 0
request_counter = 0
error_counter = 0

# Expected contribution summary structure
CONTRIBUTION_STRUCTURE = {
    "problem": str,
    "method": str,
    "key_contributions": list,
    "application_scenarios": list
}

def rate_limited_request():
    """Ensure minimum delay between API requests with jitter."""
    global last_request_time
    
    with rate_limit_lock:
        current_time = time.time()
        elapsed = current_time - last_request_time
        if elapsed < BASE_DELAY:
            sleep_time = BASE_DELAY - elapsed + random.uniform(0, 0.15)
            time.sleep(sleep_time)
        last_request_time = time.time()

@backoff.on_exception(
    backoff.expo,
    (requests.exceptions.RequestException, KeyError, json.JSONDecodeError),
    max_tries=MAX_RETRIES,
    max_time=300,
    jitter=backoff.full_jitter
)
def call_deepseek_with_backoff(prompt: str, paper_id: str = "unknown") -> Optional[str]:
    """
    Call DeepSeek API with exponential backoff for contribution extraction.
    
    Args:
        prompt: The prompt to send to the API
        paper_id: Paper identifier for logging
    
    Returns:
        API response content as string, or None if all retries fail
    """
    global request_counter, error_counter
    
    rate_limited_request()
    request_counter += 1
    
    try:
        response = requests.post(
            API_URL,
            headers={
                "Authorization": f"Bearer {API_KEY}",
                "Content-Type": "application/json"
            },
            json={
                "model": "deepseek-chat",
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0.25,
                "max_tokens": MAX_TOKENS
            },
            timeout=REQUEST_TIMEOUT
        )
        
        # Handle rate limiting
        if response.status_code == 429:
            retry_after = int(response.headers.get('Retry-After', 75))
            logger.warning(f"Rate limited for paper {paper_id}. Waiting {retry_after} seconds...")
            time.sleep(retry_after)
            response.raise_for_status()
        
        response.raise_for_status()
        
        data = response.json()
        
        # Validate response structure
        if not isinstance(data, dict):
            raise ValueError(f"Response is not a dictionary: {type(data)}")
        
        choices = data.get("choices", [])
        if not choices:
            raise ValueError(f"No choices in response: {data}")
        
        message = choices[0].get("message", {})
        content = message.get("content", "").strip()
        
        if not content:
            raise ValueError("Empty response content")
        
        # Log progress periodically
        if request_counter % 40 == 0:
            logger.info(f"Completed {request_counter} summarization requests")
        
        return content
        
    except requests.exceptions.Timeout:
        logger.warning(f"Timeout for paper {paper_id}")
        raise
    except requests.exceptions.HTTPError as e:
        error_counter += 1
        status_code = e.response.status_code if e.response else "unknown"
        logger.error(f"HTTP error {status_code} for paper {paper_id}")
        if status_code >= 500:
            time.sleep(12)  # Longer sleep for server errors
        raise
    except Exception as e:
        error_counter += 1
        logger.error(f"API error for paper {paper_id}: {str(e)[:100]}")
        raise

def call_deepseek(prompt: str, paper_id: str = "unknown") -> Optional[str]:
    """
    Wrapper for API call with fallback handling.
    
    Args:
        prompt: The prompt to send to the API
        paper_id: Paper identifier for logging
    
    Returns:
        API response content as string, or None if all retries fail
    """
    try:
        return call_deepseek_with_backoff(prompt, paper_id)
    except Exception as e:
        logger.error(f"All retries failed for paper {paper_id}: {str(e)[:100]}")
        return None

def validate_contribution_summary(summary: Dict[str, Any]) -> Dict[str, Any]:
    """
    Validate and clean the contribution summary structure.
    
    Args:
        summary: Raw summary dictionary
    
    Returns:
        Validated and cleaned summary dictionary
    """
    if not isinstance(summary, dict):
        return {}
    
    validated = {}
    
    # Validate each expected field
    for field, expected_type in CONTRIBUTION_STRUCTURE.items():
        if field in summary:
            value = summary[field]
            
            # Convert to appropriate type
            if expected_type == str:
                if isinstance(value, str):
                    validated[field] = value.strip()
                elif value is not None:
                    validated[field] = str(value).strip()
                else:
                    validated[field] = ""
            
            elif expected_type == list:
                if isinstance(value, list):
                    # Clean list items
                    cleaned_list = []
                    for item in value:
                        if isinstance(item, str):
                            cleaned = item.strip()
                            if cleaned:
                                cleaned_list.append(cleaned)
                        elif item is not None:
                            cleaned = str(item).strip()
                            if cleaned:
                                cleaned_list.append(cleaned)
                    validated[field] = cleaned_list
                else:
                    validated[field] = []
        
        else:
            # Set default values for missing fields
            if expected_type == str:
                validated[field] = ""
            else:
                validated[field] = []
    
    # Additional validation
    if validated.get("problem") and len(validated["problem"]) > 500:
        validated["problem"] = validated["problem"][:500] + "..."
    
    if validated.get("method") and len(validated["method"]) > 500:
        validated["method"] = validated["method"][:500] + "..."
    
    # Limit list lengths
    for list_field in ["key_contributions", "application_scenarios"]:
        if list_field in validated:
            validated[list_field] = validated[list_field][:10]  # Max 10 items
    
    return validated

def extract_contribution_summary(paper: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract structured contribution summary from paper metadata using DeepSeek API.
    
    Args:
        paper: Dictionary containing paper metadata
    
    Returns:
        Dictionary with original paper data plus contribution summary
    """
    paper_id = paper.get("paper_id", paper.get("id", "unknown"))
    
    # Prepare content
    title = paper.get("title", "").strip()
    abstract = paper.get("abstract", "").strip()
    
    # Truncate if too long
    if len(abstract) > 3500:
        abstract = abstract[:3500] + "..."
    
    prompt = f"""
You are an expert research analyst specializing in academic paper analysis.

Summarize the paper into structured research contributions based on the title and abstract.

IMPORTANT: Return ONLY a valid JSON object with the following keys:
- "problem": A concise description of the research problem addressed (1-2 sentences)
- "method": A brief overview of the methodology or approach used (1-2 sentences)
- "key_contributions": A list of 3-5 key contributions or innovations
- "application_scenarios": A list of 2-4 potential application scenarios

Guidelines:
- Be specific and technical, not generic
- Focus on what is novel or significant
- For contributions: use bullet-point style but return as JSON list
- For application scenarios: be concrete about where this research could be applied
- Keep each contribution and scenario item concise (1 sentence each)

Title:
{title}

Abstract:
{abstract}
"""
    summary = {}
    
    try:
        api_response = call_deepseek(prompt, paper_id)
        
        if api_response is None:
            logger.debug(f"No API response for paper {paper_id}")
        else:
            # Clean response - remove markdown code blocks and whitespace
            clean_response = api_response.strip()
            clean_response = re.sub(r'^```(?:json)?\s*', '', clean_response, flags=re.IGNORECASE)
            clean_response = re.sub(r'\s*```$', '', clean_response)
            
            # Try to parse as JSON
            try:
                parsed = json.loads(clean_response)
                summary = validate_contribution_summary(parsed)
                
            except json.JSONDecodeError as e:
                logger.debug(f"Failed to parse JSON for paper {paper_id}: {str(e)[:100]}")
                
                # Advanced JSON extraction
                json_patterns = [
                    r'\{[^{}]*\}',  # Simple object
                    r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}',  # Nested object
                ]
                
                for pattern in json_patterns:
                    matches = re.finditer(pattern, clean_response, re.DOTALL)
                    for match in matches:
                        try:
                            parsed = json.loads(match.group(0))
                            if isinstance(parsed, dict):
                                temp_summary = validate_contribution_summary(parsed)
                                # Check if we have at least some valid structure
                                if (temp_summary.get("problem") or 
                                    temp_summary.get("method") or 
                                    temp_summary.get("key_contributions")):
                                    summary = temp_summary
                                    break
                        except:
                            continue
                    if summary:
                        break
                
                # Fallback: extract structured information from text
                if not summary:
                    logger.debug(f"Using advanced text extraction for paper {paper_id}")
                    summary = extract_summary_from_text(clean_response)
    
    except Exception as e:
        logger.error(f"Unexpected error summarizing paper {paper_id}: {str(e)[:100]}")
    
    # Final validation
    summary = validate_contribution_summary(summary)
    
    # Create enhanced paper
    enhanced_paper = paper.copy()
    enhanced_paper["contribution_summary"] = summary
    enhanced_paper["summary_extraction_time"] = datetime.now().isoformat()
    
    return enhanced_paper

def extract_summary_from_text(text: str) -> Dict[str, Any]:
    """
    Extract structured summary from text response when JSON parsing fails.
    
    Args:
        text: Raw text response from API
    
    Returns:
        Structured summary dictionary
    """
    summary = {
        "problem": "",
        "method": "",
        "key_contributions": [],
        "application_scenarios": []
    }
    
    # Extract problem and method using keyword patterns
    lines = text.split('\n')
    
    # Look for problem description
    problem_patterns = [
        r'[Pp]roblem[:：]?\s*(.+)',
        r'[Rr]esearch\s+[Pp]roblem[:：]?\s*(.+)',
        r'[Tt]he\s+problem\s+is\s+(.+)',
    ]
    
    for i, line in enumerate(lines):
        line = line.strip()
        if not line:
            continue
        
        # Check for problem
        for pattern in problem_patterns:
            match = re.search(pattern, line)
            if match:
                summary["problem"] = match.group(1).strip()
                break
        
        # Check for method
        method_patterns = [
            r'[Mm]ethod(?:ology)?[:：]?\s*(.+)',
            r'[Aa]pproach[:：]?\s*(.+)',
            r'[Tt]echnique[:：]?\s*(.+)',
        ]
        
        for pattern in method_patterns:
            match = re.search(pattern, line)
            if match:
                summary["method"] = match.group(1).strip()
                break
        
        # Look for contributions (bullet points, numbered lists, or after headings)
        if re.match(r'^[Cc]ontributions?[:：]?', line) or re.match(r'^[Kk]ey\s+[Cc]ontributions?[:：]?', line):
            # Collect next lines as contributions
            for j in range(i + 1, min(i + 10, len(lines))):
                next_line = lines[j].strip()
                if next_line and not re.match(r'^[A-Za-z\s]+[:：]?$', next_line):
                    # Clean bullet points or numbers
                    clean_line = re.sub(r'^[•\-\*\d\.\s]+', '', next_line)
                    if clean_line and len(clean_line) > 5:
                        summary["key_contributions"].append(clean_line)
            
            # Limit to 5 contributions
            summary["key_contributions"] = summary["key_contributions"][:5]
        
        # Look for applications
        if re.match(r'^[Aa]pplications?[:：]?', line) or re.match(r'^[Uu]ses?[:：]?', line):
            # Collect next lines as applications
            for j in range(i + 1, min(i + 8, len(lines))):
                next_line = lines[j].strip()
                if next_line and not re.match(r'^[A-Za-z\s]+[:：]?$', next_line):
                    clean_line = re.sub(r'^[•\-\*\d\.\s]+', '', next_line)
                    if clean_line and len(clean_line) > 5:
                        summary["application_scenarios"].append(clean_line)
            
            # Limit to 4 applications
            summary["application_scenarios"] = summary["application_scenarios"][:4]
    
    return summary

def load_papers(input_path: Path) -> List[Dict[str, Any]]:
    """Load papers from JSONL file."""
    papers = []
    try:
        with open(input_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    papers.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    logger.warning(f"Invalid JSON on line {line_num}: {e}")
        
        logger.info(f"Loaded {len(papers)} papers from {input_path}")
    except FileNotFoundError:
        logger.error(f"Input file not found: {input_path}")
        raise
    
    return papers

def save_checkpoint(papers: List[Dict[str, Any]], checkpoint_path: Path):
    """Save intermediate results to checkpoint file."""
    try:
        with open(checkpoint_path, 'w', encoding='utf-8') as f:
            for paper in papers:
                f.write(json.dumps(paper, ensure_ascii=False) + '\n')
        
        logger.debug(f"Checkpoint saved: {len(papers)} papers")
    except Exception as e:
        logger.error(f"Failed to save checkpoint: {e}")

def load_checkpoint(checkpoint_path: Path) -> Tuple[List[Dict[str, Any]], set]:
    """Load papers from checkpoint file and get processed paper IDs."""
    if not checkpoint_path.exists():
        return [], set()
    
    papers = []
    processed_ids = set()
    
    try:
        with open(checkpoint_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    paper = json.loads(line.strip())
                    papers.append(paper)
                    paper_id = paper.get("paper_id", paper.get("id"))
                    if paper_id:
                        processed_ids.add(paper_id)
                except json.JSONDecodeError:
                    continue
        
        logger.info(f"Loaded {len(papers)} papers from checkpoint")
    except Exception as e:
        logger.error(f"Failed to load checkpoint: {e}")
    
    return papers, processed_ids

def process_concurrently(papers: List[Dict[str, Any]], 
                        max_workers: int = MAX_WORKERS,
                        checkpoint_path: Optional[Path] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
    """
    Process papers concurrently using ThreadPoolExecutor.
    
    Args:
        papers: List of papers to process
        max_workers: Maximum concurrent threads
        checkpoint_path: Optional checkpoint file path
    
    Returns:
        Tuple of (enhanced papers, statistics)
    """
    enhanced_papers = []
    stats = {
        "total": len(papers),
        "successful": 0,
        "failed": 0,
        "complete_summaries": 0,
        "partial_summaries": 0,
        "empty_summaries": 0
    }
    
    # Adjust workers for summarization task
    effective_workers = min(max_workers, 6)
    logger.info(f"Summarizing {len(papers)} papers with {effective_workers} workers")
    
    with ThreadPoolExecutor(max_workers=effective_workers) as executor:
        # Submit tasks
        future_to_index = {
            executor.submit(extract_contribution_summary, paper): i 
            for i, paper in enumerate(papers)
        }
        
        # Track completion order for proper ordering
        results = [None] * len(papers)
        
        # Process with progress bar
        with tqdm(total=len(papers), desc="Extracting contributions", unit="paper") as pbar:
            for future in as_completed(future_to_index):
                idx = future_to_index[future]
                
                try:
                    enhanced_paper = future.result(timeout=REQUEST_TIMEOUT + 50)
                    results[idx] = enhanced_paper
                    
                    summary = enhanced_paper.get("contribution_summary", {})
                    
                    if summary:
                        # Check completeness
                        has_problem = bool(summary.get("problem"))
                        has_method = bool(summary.get("method"))
                        has_contributions = bool(summary.get("key_contributions"))
                        has_applications = bool(summary.get("application_scenarios"))
                        
                        complete_fields = sum([has_problem, has_method, has_contributions, has_applications])
                        
                        if complete_fields >= 3:
                            stats["complete_summaries"] += 1
                            stats["successful"] += 1
                        elif complete_fields >= 1:
                            stats["partial_summaries"] += 1
                            stats["successful"] += 1
                        else:
                            stats["empty_summaries"] += 1
                            stats["failed"] += 1
                    else:
                        stats["empty_summaries"] += 1
                        stats["failed"] += 1
                        
                except Exception as e:
                    logger.error(f"Processing failed for paper at index {idx}: {str(e)[:100]}")
                    stats["failed"] += 1
                    # Add fallback
                    fallback_paper = papers[idx].copy()
                    fallback_paper["contribution_summary"] = {}
                    results[idx] = fallback_paper
                
                pbar.update(1)
                
                # Periodic checkpoint
                if checkpoint_path and pbar.n % BATCH_SIZE == 0:
                    # Collect completed results
                    completed = [r for r in results if r is not None]
                    save_checkpoint(completed, checkpoint_path)
                    
                    # Update progress in log
                    progress_pct = pbar.n / len(papers) * 100
                    logger.info(f"Progress: {pbar.n}/{len(papers)} papers summarized "
                               f"({progress_pct:.1f}%)")
    
    # Filter out None results
    enhanced_papers = [r for r in results if r is not None]
    
    return enhanced_papers, stats

def calculate_summary_statistics(papers: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Calculate statistics on the collected summaries."""
    summaries = [p.get("contribution_summary", {}) for p in papers]
    
    stats = {
        "total_papers": len(papers),
        "papers_with_summary": 0,
        "field_completeness": {},
        "average_lengths": {},
        "summary_quality": {}
    }
    
    valid_summaries = [s for s in summaries if s]
    stats["papers_with_summary"] = len(valid_summaries)
    
    if not valid_summaries:
        return stats
    
    # Field completeness
    for field in CONTRIBUTION_STRUCTURE.keys():
        has_field = sum(1 for s in valid_summaries if s.get(field))
        stats["field_completeness"][field] = {
            "count": has_field,
            "percentage": has_field / len(valid_summaries) * 100
        }
    
    # Average lengths
    for field in ["problem", "method"]:
        lengths = [len(s.get(field, "")) for s in valid_summaries if s.get(field)]
        if lengths:
            stats["average_lengths"][field] = {
                "mean": sum(lengths) / len(lengths),
                "min": min(lengths),
                "max": max(lengths)
            }
    
    for list_field in ["key_contributions", "application_scenarios"]:
        counts = [len(s.get(list_field, [])) for s in valid_summaries]
        if counts:
            stats["average_lengths"][list_field] = {
                "mean": sum(counts) / len(counts),
                "min": min(counts),
                "max": max(counts)
            }
    
    # Summary quality scoring
    quality_scores = []
    for summary in valid_summaries:
        score = 0
        if summary.get("problem"):
            score += 1
        if summary.get("method"):
            score += 1
        if summary.get("key_contributions"):
            score += 1
        if summary.get("application_scenarios"):
            score += 1
        quality_scores.append(score)
    
    if quality_scores:
        stats["summary_quality"] = {
            "average_score": sum(quality_scores) / len(quality_scores),
            "distribution": {
                "0_fields": quality_scores.count(0),
                "1_field": quality_scores.count(1),
                "2_fields": quality_scores.count(2),
                "3_fields": quality_scores.count(3),
                "4_fields": quality_scores.count(4)
            }
        }
    
    return stats

def main():
    """Main execution function."""
    
    # File paths
    input_file = Path("papers_final_aligned.jsonl")
    output_file = Path("papers_enhanced_contributions.jsonl")
    checkpoint_file = Path("contribution_summary_checkpoint.jsonl")
    stats_file = Path("contribution_statistics.json")
    
    start_time = time.time()
    
    try:
        # Load papers
        papers = load_papers(input_file)
        
        if not papers:
            logger.error("No papers loaded. Exiting.")
            return
        
        # Load checkpoint
        checkpoint_papers, processed_ids = load_checkpoint(checkpoint_file)
        
        # Determine which papers need processing
        if processed_ids:
            papers_to_process = []
            for paper in papers:
                paper_id = paper.get("paper_id", paper.get("id"))
                if paper_id not in processed_ids:
                    papers_to_process.append(paper)
            
            logger.info(f"Resuming: {len(checkpoint_papers)} already processed, "
                       f"{len(papers_to_process)} remaining")
            
            # Process remaining papers
            enhanced_new_papers, stats = process_concurrently(
                papers_to_process,
                checkpoint_path=checkpoint_file
            )
            
            # Combine results
            enhanced_papers = checkpoint_papers + enhanced_new_papers
        else:
            # Process all papers
            enhanced_papers, stats = process_concurrently(
                papers,
                checkpoint_path=checkpoint_file
            )
        
        # Save final results
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                for paper in enhanced_papers:
                    f.write(json.dumps(paper, ensure_ascii=False) + '\n')
            logger.info(f"Saved {len(enhanced_papers)} papers to {output_file}")
        except Exception as e:
            logger.error(f"Failed to save output: {e}")
            # Save to backup file
            backup_file = Path(f"{output_file.stem}_backup_{int(time.time())}.jsonl")
            with open(backup_file, 'w', encoding='utf-8') as f:
                for paper in enhanced_papers:
                    f.write(json.dumps(paper, ensure_ascii=False) + '\n')
            logger.info(f"Saved backup to {backup_file}")
        
        # Calculate and save statistics
        summary_stats = calculate_summary_statistics(enhanced_papers)
        try:
            with open(stats_file, 'w', encoding='utf-8') as f:
                json.dump(summary_stats, f, indent=2, ensure_ascii=False)
            logger.info(f"Saved contribution statistics to {stats_file}")
        except Exception as e:
            logger.error(f"Failed to save statistics: {e}")
        
        # Print summary
        elapsed_time = time.time() - start_time
        hours, remainder = divmod(elapsed_time, 3600)
        minutes, seconds = divmod(remainder, 60)
        
        logger.info("=" * 70)
        logger.info("CONTRIBUTION SUMMARIZATION COMPLETE")
        logger.info("=" * 70)
        logger.info(f"Total papers: {stats['total']}")
        logger.info(f"Successfully processed: {stats['successful']}")
        logger.info(f"Complete summaries (3+ fields): {stats['complete_summaries']}")
        logger.info(f"Partial summaries (1-2 fields): {stats['partial_summaries']}")
        logger.info(f"Empty summaries: {stats['empty_summaries']}")
        logger.info(f"Failed: {stats['failed']}")
        logger.info(f"Total API requests: {request_counter}")
        logger.info(f"API errors: {error_counter}")
        if request_counter > 0:
            success_rate = (request_counter - error_counter) / request_counter * 100
            logger.info(f"API success rate: {success_rate:.1f}%")
        logger.info(f"Total time: {int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}")
        logger.info(f"Processing rate: {stats['total'] / elapsed_time:.2f} papers/sec")
        
        # Show summary statistics if available
        if summary_stats["papers_with_summary"] > 0:
            logger.info("-" * 70)
            logger.info("SUMMARY STATISTICS:")
            logger.info(f"Papers with summary: {summary_stats['papers_with_summary']} "
                       f"({summary_stats['papers_with_summary']/summary_stats['total_papers']*100:.1f}%)")
            
            for field, completeness in summary_stats["field_completeness"].items():
                logger.info(f"  {field:20s}: {completeness['count']:4d} papers "
                           f"({completeness['percentage']:.1f}%)")
            
            if summary_stats["summary_quality"]:
                quality = summary_stats["summary_quality"]
                logger.info(f"  Average quality score: {quality['average_score']:.1f}/4.0")
        
        logger.info("=" * 70)
        
        # Clean up checkpoint
        if checkpoint_file.exists():
            try:
                checkpoint_file.unlink()
                logger.info(f"Removed checkpoint file: {checkpoint_file}")
            except Exception as e:
                logger.warning(f"Failed to remove checkpoint: {e}")
        
    except KeyboardInterrupt:
        logger.info("\nProcessing interrupted by user.")
        logger.info(f"Checkpoint saved. Resume by running the script again.")
    except Exception as e:
        logger.error(f"Unexpected error: {e}", exc_info=True)

if __name__ == "__main__":
    main()

In [None]:
# enhance_optional_summary.py
"""
Enhanced paper contribution summarization with optimized concurrent processing.

This script processes academic papers from a JSONL file, calls the DeepSeek API
to extract structured research contributions using concurrent processing with
adaptive rate limiting and request batching for maximum throughput.
"""

import json
import requests
import time
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple, Union
from concurrent.futures import ThreadPoolExecutor, as_completed, Future
from tqdm import tqdm
import random
from datetime import datetime, timedelta
from threading import Lock, Semaphore
import re
import queue
from dataclasses import dataclass, field
from collections import deque

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('contribution_extraction_optimized.log', encoding='utf-8'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# API Configuration
API_KEY = "sk-79990d599cd74bc0a56f6ca2f200a621"
API_URL = "https://api.deepseek.com/v1/chat/completions"
REQUEST_TIMEOUT = 35
MAX_WORKERS = 12  # Increased for higher concurrency
MIN_DELAY = 0.15  # Minimum delay between requests (optimized)
MAX_DELAY = 2.0   # Maximum delay for backoff
MAX_RETRIES = 4
BATCH_SIZE = 50
MAX_TOKENS = 350  # Optimized for faster responses

# Adaptive rate limiting
@dataclass
class RateLimiter:
    """Adaptive rate limiter that adjusts based on success/failure rates."""
    min_delay: float = MIN_DELAY
    max_delay: float = MAX_DELAY
    current_delay: float = MIN_DELAY
    success_window: deque = field(default_factory=lambda: deque(maxlen=50))
    lock: Lock = field(default_factory=Lock)
    
    def success(self):
        """Record a successful request."""
        with self.lock:
            self.success_window.append(True)
            # If last 10 requests had 90% success, decrease delay slightly
            if len(self.success_window) >= 10:
                success_rate = sum(1 for s in list(self.success_window)[-10:] if s) / 10
                if success_rate > 0.9 and self.current_delay > self.min_delay:
                    self.current_delay = max(self.min_delay, self.current_delay * 0.9)
    
    def failure(self):
        """Record a failed request."""
        with self.lock:
            self.success_window.append(False)
            # Increase delay on failure
            self.current_delay = min(self.max_delay, self.current_delay * 1.5)
    
    def get_delay(self) -> float:
        """Get current delay with jitter."""
        with self.lock:
            jitter = random.uniform(-0.05, 0.05) * self.current_delay
            return max(self.min_delay, self.current_delay + jitter)

# Global rate limiter and statistics
rate_limiter = RateLimiter()
request_counter = 0
error_counter = 0
last_request_time = 0
rate_lock = Lock()

# Expected contribution summary structure
CONTRIBUTION_STRUCTURE = {
    "problem": str,
    "method": str,
    "key_contributions": list,
    "application_scenarios": list
}

def adaptive_rate_limit():
    """Adaptive rate limiting with dynamic adjustment."""
    global last_request_time
    
    with rate_lock:
        current_time = time.time()
        elapsed = current_time - last_request_time
        delay_needed = rate_limiter.get_delay()
        
        if elapsed < delay_needed:
            sleep_time = delay_needed - elapsed
            time.sleep(sleep_time)
        
        last_request_time = time.time()
        return delay_needed

def call_deepseek_fast(prompt: str, paper_id: str = "unknown", timeout: int = REQUEST_TIMEOUT) -> Optional[str]:
    """
    Optimized API call with adaptive rate limiting and minimal overhead.
    
    Args:
        prompt: The prompt to send to the API
        paper_id: Paper identifier for logging
        timeout: Request timeout in seconds
    
    Returns:
        API response content as string, or None if failed
    """
    global request_counter, error_counter
    
    # Apply adaptive rate limiting
    adaptive_rate_limit()
    request_counter += 1
    
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json"
    }
    
    payload = {
        "model": "deepseek-chat",
        "messages": [{"role": "user", "content": prompt}],
        "temperature": 0.25,
        "max_tokens": MAX_TOKENS,
        "stream": False
    }
    
    for attempt in range(MAX_RETRIES):
        try:
            response = requests.post(
                API_URL,
                headers=headers,
                json=payload,
                timeout=timeout
            )
            
            if response.status_code == 429:
                # Rate limited - use exponential backoff
                wait_time = min(30, 2 ** attempt + random.uniform(0, 1))
                logger.warning(f"Rate limited for paper {paper_id}, waiting {wait_time:.1f}s")
                time.sleep(wait_time)
                rate_limiter.failure()
                continue
            
            response.raise_for_status()
            
            # Fast JSON parsing
            data = response.json()
            content = data.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
            
            if not content:
                raise ValueError("Empty response content")
            
            # Record success
            rate_limiter.success()
            
            # Return cleaned response
            return content.strip()
            
        except requests.exceptions.Timeout:
            logger.debug(f"Timeout for paper {paper_id}, attempt {attempt + 1}")
            if attempt < MAX_RETRIES - 1:
                time.sleep(1 << attempt)  # Exponential backoff
            else:
                error_counter += 1
                rate_limiter.failure()
                return None
                
        except requests.exceptions.HTTPError as e:
            if e.response.status_code >= 500:
                logger.debug(f"Server error for paper {paper_id}: {e.response.status_code}")
                time.sleep(3 * (attempt + 1))
            else:
                logger.debug(f"HTTP error for paper {paper_id}: {e.response.status_code}")
                break  # Client error, don't retry immediately
            if attempt == MAX_RETRIES - 1:
                error_counter += 1
                rate_limiter.failure()
                return None
                
        except Exception as e:
            logger.debug(f"Request error for paper {paper_id}: {str(e)[:100]}")
            if attempt == MAX_RETRIES - 1:
                error_counter += 1
                rate_limiter.failure()
                return None
            time.sleep(0.5 * (attempt + 1))
    
    error_counter += 1
    rate_limiter.failure()
    return None

def validate_contribution_summary(summary: Dict[str, Any]) -> Dict[str, Any]:
    """Fast validation and cleaning of contribution summary."""
    if not isinstance(summary, dict):
        return {}
    
    validated = {}
    
    for field, expected_type in CONTRIBUTION_STRUCTURE.items():
        value = summary.get(field)
        
        if expected_type == str:
            if isinstance(value, str):
                validated[field] = value.strip()[:300]  # Limit length
            elif value is not None:
                validated[field] = str(value).strip()[:300]
            else:
                validated[field] = ""
        
        elif expected_type == list:
            if isinstance(value, list):
                # Fast cleaning of list items
                cleaned = []
                for item in value:
                    if isinstance(item, str) and item.strip():
                        cleaned.append(item.strip()[:200])
                    elif item is not None:
                        cleaned.append(str(item).strip()[:200])
                    if len(cleaned) >= 8:  # Limit to 8 items
                        break
                validated[field] = cleaned
            else:
                validated[field] = []
    
    return validated

def extract_contribution_summary_fast(paper: Dict[str, Any]) -> Dict[str, Any]:
    """
    Fast extraction of contribution summary with optimized prompt.
    
    Args:
        paper: Dictionary containing paper metadata
    
    Returns:
        Dictionary with original paper data plus contribution summary
    """
    paper_id = paper.get("paper_id", paper.get("id", "unknown"))
    
    # Prepare content with length limits
    title = paper.get("title", "").strip()[:200]
    abstract = paper.get("abstract", "").strip()[:1500]
    
    # Optimized prompt for speed
    prompt = f"""Summarize this paper into JSON:

Title: {title}
Abstract: {abstract}

Return JSON with:
- "problem": research problem (1 sentence)
- "method": approach used (1 sentence)
- "key_contributions": list of 2-4 contributions
- "application_scenarios": list of 1-3 applications

JSON ONLY:"""
    
    summary = {}
    
    try:
        api_response = call_deepseek_fast(prompt, paper_id, timeout=25)
        
        if api_response:
            # Fast JSON extraction with fallback
            try:
                parsed = json.loads(api_response)
                summary = validate_contribution_summary(parsed)
            except json.JSONDecodeError:
                # Quick fallback: extract JSON-like structure
                json_match = re.search(r'\{[^{}]*\}', api_response, re.DOTALL)
                if json_match:
                    try:
                        parsed = json.loads(json_match.group(0))
                        summary = validate_contribution_summary(parsed)
                    except:
                        # Minimal text extraction
                        lines = [l.strip() for l in api_response.split('\n') if l.strip()]
                        if lines:
                            summary = {
                                "problem": lines[0][:200] if len(lines) > 0 else "",
                                "method": lines[1][:200] if len(lines) > 1 else "",
                                "key_contributions": lines[2:4] if len(lines) > 2 else [],
                                "application_scenarios": lines[4:6] if len(lines) > 4 else []
                            }
    except Exception as e:
        logger.debug(f"Error for paper {paper_id}: {str(e)[:80]}")
    
    # Create enhanced paper
    enhanced_paper = paper.copy()
    enhanced_paper["contribution_summary"] = summary
    
    return enhanced_paper

class BatchProcessor:
    """Batch processor for high-throughput concurrent processing."""
    
    def __init__(self, max_workers: int = MAX_WORKERS, batch_size: int = BATCH_SIZE):
        self.max_workers = max_workers
        self.batch_size = batch_size
        self.semaphore = Semaphore(max_workers * 2)  # Limit concurrent futures
        self.results_lock = Lock()
        
    def process_batch(self, papers: List[Dict[str, Any]], 
                     desc: str = "Processing") -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
        """
        Process a batch of papers with optimized concurrency.
        
        Args:
            papers: List of papers to process
            desc: Description for progress bar
        
        Returns:
            Tuple of (processed papers, statistics)
        """
        stats = {
            "total": len(papers),
            "successful": 0,
            "failed": 0,
            "with_summary": 0
        }
        
        # Prepare results array
        results = [None] * len(papers)
        
        # Use ThreadPoolExecutor with optimized settings
        with ThreadPoolExecutor(max_workers=self.max_workers, 
                              thread_name_prefix="api_worker") as executor:
            
            # Submit all tasks
            futures = {}
            for idx, paper in enumerate(papers):
                with self.semaphore:
                    future = executor.submit(extract_contribution_summary_fast, paper)
                    futures[future] = idx
            
            # Process completed futures with progress bar
            completed = 0
            with tqdm(total=len(papers), desc=desc, unit="paper", 
                     mininterval=0.5, maxinterval=1.0) as pbar:
                
                for future in as_completed(futures):
                    idx = futures[future]
                    
                    try:
                        result = future.result(timeout=30)
                        results[idx] = result
                        
                        # Update statistics
                        summary = result.get("contribution_summary", {})
                        if summary:
                            stats["with_summary"] += 1
                            stats["successful"] += 1
                        else:
                            stats["failed"] += 1
                            
                    except Exception as e:
                        logger.debug(f"Batch processing error at index {idx}: {str(e)[:80]}")
                        stats["failed"] += 1
                        # Add fallback
                        fallback = papers[idx].copy()
                        fallback["contribution_summary"] = {}
                        results[idx] = fallback
                    
                    completed += 1
                    pbar.update(1)
                    
                    # Update progress periodically
                    if completed % 20 == 0:
                        pbar.set_postfix({
                            "rate": f"{completed/(time.time()-pbar.start_t):.1f}/s",
                            "delay": f"{rate_limiter.current_delay:.2f}s"
                        })
        
        # Filter out None results
        processed_papers = [r for r in results if r is not None]
        
        return processed_papers, stats

def load_papers_fast(input_path: Path) -> List[Dict[str, Any]]:
    """Fast loading of papers from JSONL file."""
    papers = []
    try:
        with open(input_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    papers.append(json.loads(line.strip()))
                except json.JSONDecodeError:
                    continue
                    
        logger.info(f"Loaded {len(papers)} papers from {input_path}")
        return papers
        
    except FileNotFoundError:
        logger.error(f"Input file not found: {input_path}")
        raise

def save_checkpoint_fast(papers: List[Dict[str, Any]], checkpoint_path: Path):
    """Fast checkpoint saving."""
    try:
        with open(checkpoint_path, 'w', encoding='utf-8') as f:
            for paper in papers:
                f.write(json.dumps(paper, ensure_ascii=False) + '\n')
    except Exception as e:
        logger.error(f"Checkpoint save error: {e}")

def load_checkpoint_fast(checkpoint_path: Path) -> Tuple[List[Dict[str, Any]], set]:
    """Fast checkpoint loading."""
    if not checkpoint_path.exists():
        return [], set()
    
    papers = []
    processed_ids = set()
    
    try:
        with open(checkpoint_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    paper = json.loads(line.strip())
                    papers.append(paper)
                    paper_id = paper.get("paper_id", paper.get("id"))
                    if paper_id:
                        processed_ids.add(paper_id)
                except json.JSONDecodeError:
                    continue
                    
        logger.info(f"Loaded {len(papers)} papers from checkpoint")
    except Exception as e:
        logger.error(f"Checkpoint load error: {e}")
    
    return papers, processed_ids

def main():
    """Main execution function with optimized processing."""
    
    # File paths
    input_file = Path("papers_final_aligned.jsonl")
    output_file = Path("papers_enhanced_contributions_fast.jsonl")
    checkpoint_file = Path("fast_contribution_checkpoint.jsonl")
    
    start_time = time.time()
    processor = BatchProcessor(max_workers=MAX_WORKERS, batch_size=BATCH_SIZE)
    
    try:
        # Load papers
        logger.info("Loading papers...")
        papers = load_papers_fast(input_file)
        
        if not papers:
            logger.error("No papers loaded. Exiting.")
            return
        
        # Load checkpoint
        checkpoint_papers, processed_ids = load_checkpoint_fast(checkpoint_file)
        
        # Determine remaining papers
        if processed_ids:
            papers_to_process = []
            for paper in papers:
                paper_id = paper.get("paper_id", paper.get("id"))
                if paper_id not in processed_ids:
                    papers_to_process.append(paper)
            
            logger.info(f"Resuming: {len(checkpoint_papers)} processed, {len(papers_to_process)} remaining")
            remaining_papers = papers_to_process
        else:
            logger.info(f"Starting fresh: {len(papers)} papers to process")
            remaining_papers = papers
        
        # Process in chunks for better memory management
        chunk_size = 500
        all_processed = []
        total_stats = {
            "total": 0,
            "successful": 0,
            "failed": 0,
            "with_summary": 0
        }
        
        for chunk_start in range(0, len(remaining_papers), chunk_size):
            chunk_end = min(chunk_start + chunk_size, len(remaining_papers))
            chunk = remaining_papers[chunk_start:chunk_end]
            
            logger.info(f"Processing chunk {chunk_start//chunk_size + 1}/{(len(remaining_papers)-1)//chunk_size + 1}")
            
            # Process chunk
            processed_chunk, stats = processor.process_batch(
                chunk, 
                desc=f"Chunk {chunk_start//chunk_size + 1}"
            )
            
            # Update statistics
            total_stats["total"] += stats["total"]
            total_stats["successful"] += stats["successful"]
            total_stats["failed"] += stats["failed"]
            total_stats["with_summary"] += stats["with_summary"]
            
            # Save checkpoint
            current_results = checkpoint_papers + all_processed + processed_chunk
            save_checkpoint_fast(current_results, checkpoint_file)
            
            # Add to results
            all_processed.extend(processed_chunk)
            
            # Show intermediate stats
            logger.info(f"Chunk complete: {stats['successful']}/{stats['total']} successful "
                       f"({stats['successful']/stats['total']*100:.1f}%)")
        
        # Combine all results
        enhanced_papers = checkpoint_papers + all_processed
        
        # Save final results
        logger.info("Saving final results...")
        save_checkpoint_fast(enhanced_papers, output_file)
        
        # Print comprehensive summary
        elapsed_time = time.time() - start_time
        papers_per_second = total_stats["total"] / elapsed_time
        
        logger.info("=" * 70)
        logger.info("OPTIMIZED PROCESSING COMPLETE")
        logger.info("=" * 70)
        logger.info(f"Total papers processed: {total_stats['total']}")
        logger.info(f"Papers with summaries: {total_stats['with_summary']} "
                   f"({total_stats['with_summary']/total_stats['total']*100:.1f}%)")
        logger.info(f"Successful: {total_stats['successful']}")
        logger.info(f"Failed: {total_stats['failed']}")
        logger.info(f"Total API requests: {request_counter}")
        logger.info(f"API errors: {error_counter}")
        if request_counter > 0:
            success_rate = (request_counter - error_counter) / request_counter * 100
            logger.info(f"API success rate: {success_rate:.1f}%")
        logger.info(f"Final rate limit delay: {rate_limiter.current_delay:.2f}s")
        logger.info(f"Total time: {elapsed_time:.0f}s ({elapsed_time/60:.1f} min)")
        logger.info(f"Processing rate: {papers_per_second:.2f} papers/sec")
        logger.info(f"Estimated completion time for 10k papers: {10000/papers_per_second/60:.1f} min")
        logger.info(f"Output saved to: {output_file}")
        logger.info("=" * 70)
        
        # Remove checkpoint file
        if checkpoint_file.exists():
            try:
                checkpoint_file.unlink()
                logger.info(f"Checkpoint removed: {checkpoint_file}")
            except Exception as e:
                logger.warning(f"Could not remove checkpoint: {e}")
        
    except KeyboardInterrupt:
        logger.info("\nProcessing interrupted. Checkpoint saved for resume.")
    except Exception as e:
        logger.error(f"Unexpected error: {e}", exc_info=True)

if __name__ == "__main__":
    # Performance optimization: disable debug logging for production
    logging.getLogger("urllib3").setLevel(logging.WARNING)
    logging.getLogger("requests").setLevel(logging.WARNING)
    
    main()

#### Simple merge all the generated jsonl without filtering

In [None]:
# build_simple_dataset.py
import json
from collections import defaultdict
from typing import Dict, Any, List
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def safe_string(value: Any) -> str:
    """Safely convert value to string and clean."""
    if value is None:
        return ""
    if isinstance(value, str):
        return value.strip()
    return str(value).strip()

def load_jsonl(path: str) -> Dict[str, Any]:
    """Load JSONL file into a dictionary with paper_id as key."""
    data = {}
    try:
        with open(path, "r", encoding="utf-8") as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line:
                    continue
                
                try:
                    obj = json.loads(line)
                    # Try multiple possible ID fields
                    pid = obj.get("paper_id") or obj.get("id") or f"paper_{line_num}"
                    data[pid] = obj
                except json.JSONDecodeError as e:
                    logger.warning(f"Line {line_num} in {path}: JSON decode error - {e}")
                except Exception as e:
                    logger.warning(f"Line {line_num} in {path}: Error - {e}")
    except FileNotFoundError:
        logger.error(f"File not found: {path}")
    except Exception as e:
        logger.error(f"Error loading {path}: {e}")
    
    logger.info(f"Loaded {len(data)} items from {path}")
    return data

def clean_list(lst: Any, max_len: int = None) -> List[str]:
    """Clean and deduplicate a list."""
    if not lst or not isinstance(lst, list):
        return []
    
    cleaned = []
    seen = set()
    
    for item in lst:
        if item is None:
            continue
        
        # Convert to string and clean
        if isinstance(item, str):
            clean_item = item.strip()
        else:
            clean_item = str(item).strip()
        
        if clean_item and clean_item not in seen:
            cleaned.append(clean_item)
            seen.add(clean_item)
            
        if max_len and len(cleaned) >= max_len:
            break
    
    return cleaned

def safe_int(value: Any) -> int:
    """Safely convert to integer."""
    if value is None:
        return 0
    try:
        return int(value)
    except (ValueError, TypeError):
        return 0

def safe_float(value: Any) -> float:
    """Safely convert to float."""
    if value is None:
        return 0.0
    try:
        return float(value)
    except (ValueError, TypeError):
        return 0.0

def normalize_scores(scores: Dict[str, Any]) -> Dict[str, Any]:
    """Normalize quality scores to ensure consistent structure."""
    if not scores or not isinstance(scores, dict):
        return {
            "novelty": 0,
            "technical_depth": 0,
            "clarity": 0,
            "impact_potential": 0,
            "overall_score": 0.0,
            "confidence": 0.0
        }
    
    normalized = {}
    
    # Integer scores (0-10)
    for key in ["novelty", "technical_depth", "clarity", "impact_potential"]:
        value = scores.get(key)
        if value is not None:
            int_val = safe_int(value)
            normalized[key] = max(0, min(10, int_val))  # Clamp to 0-10
        else:
            normalized[key] = 0
    
    # Float scores
    overall = safe_float(scores.get("overall_score"))
    normalized["overall_score"] = max(0.0, min(10.0, overall))  # Clamp to 0.0-10.0
    
    confidence = safe_float(scores.get("confidence"))
    normalized["confidence"] = max(0.0, min(1.0, confidence))  # Clamp to 0.0-1.0
    
    return normalized

def normalize_contributions(contrib: Dict[str, Any]) -> Dict[str, Any]:
    """Normalize contribution summary to ensure consistent structure."""
    if not contrib or not isinstance(contrib, dict):
        return {
            "problem": "",
            "method": "",
            "key_contributions": [],
            "application_scenarios": []
        }
    
    normalized = {}
    
    # String fields
    normalized["problem"] = safe_string(contrib.get("problem"))
    normalized["method"] = safe_string(contrib.get("method"))
    
    # List fields
    normalized["key_contributions"] = clean_list(contrib.get("key_contributions"), max_len=10)
    normalized["application_scenarios"] = clean_list(contrib.get("application_scenarios"), max_len=10)
    
    return normalized

def main():
    """Simple merge without filtering."""
    logger.info("Starting simple merge of all papers...")
    
    # Load all data files
    base = load_jsonl("papers_final_aligned.jsonl")
    fields = load_jsonl("papers_enhanced_fields.jsonl")
    keywords = load_jsonl("papers_enhanced_keywords.jsonl")
    scores = load_jsonl("papers_enhanced_scores.jsonl")
    contribs = load_jsonl("papers_enhanced_contributions.jsonl")
    
    if not base:
        logger.error("No base papers found. Exiting.")
        return
    
    logger.info(f"Base papers: {len(base)}")
    logger.info(f"Fields data: {len(fields)}")
    logger.info(f"Keywords data: {len(keywords)}")
    logger.info(f"Scores data: {len(scores)}")
    logger.info(f"Contributions data: {len(contribs)}")
    
    final_papers = []
    stats = defaultdict(int)
    
    for pid, paper in base.items():
        merged = {}
        
        # Basic paper info (always include)
        merged["paper_id"] = pid
        merged["title"] = safe_string(paper.get("title"))
        merged["abstract"] = safe_string(paper.get("abstract"))
        merged["source"] = safe_string(paper.get("source"))
        merged["abstract_source"] = safe_string(paper.get("abstract_source"))
        
        # Optional fields (include if available)
        authors = paper.get("authors")
        if authors is not None:
            merged["authors"] = clean_list(authors, max_len=20)
            stats["has_authors"] += 1
        
        publish_year = paper.get("publish_year")
        if publish_year is not None:
            merged["publish_year"] = safe_int(publish_year)
            stats["has_publish_year"] += 1
        
        venue = paper.get("venue")
        if venue is not None:
            merged["venue"] = safe_string(venue)
            stats["has_venue"] += 1
        
        citation_count = paper.get("citation_count")
        if citation_count is not None:
            merged["citation_count"] = safe_int(citation_count)
            stats["has_citation_count"] += 1
        
        url = paper.get("url")
        if url is not None:
            merged["url"] = safe_string(url)
            stats["has_url"] += 1
        
        # Enhanced fields
        if pid in fields:
            merged["fields_of_study"] = clean_list(
                fields[pid].get("fields_of_study"), 
                max_len=8
            )
            stats["has_fields"] += 1
        else:
            merged["fields_of_study"] = []
            stats["missing_fields"] += 1
        
        # Keywords
        if pid in keywords:
            merged["keywords"] = clean_list(
                keywords[pid].get("keywords"), 
                max_len=8
            )
            stats["has_keywords"] += 1
        else:
            merged["keywords"] = []
            stats["missing_keywords"] += 1
        
        # Quality scores
        if pid in scores:
            scores_data = scores[pid].get("quality_scores", {})
            merged["quality_scores"] = normalize_scores(scores_data)
            stats["has_scores"] += 1
        else:
            merged["quality_scores"] = normalize_scores({})
            stats["missing_scores"] += 1
        
        # Contribution summary
        if pid in contribs:
            contrib_data = contribs[pid].get("contribution_summary", {})
            merged["contribution_summary"] = normalize_contributions(contrib_data)
            stats["has_contributions"] += 1
        else:
            merged["contribution_summary"] = normalize_contributions({})
            stats["missing_contributions"] += 1
        
        # Add to final papers (NO FILTERING)
        final_papers.append(merged)
    
    # Save all papers
    output_file = "papers_master_simple.jsonl"
    logger.info(f"Saving {len(final_papers)} papers to {output_file}")
    
    try:
        with open(output_file, "w", encoding="utf-8") as f:
            for paper in final_papers:
                f.write(json.dumps(paper, ensure_ascii=False) + "\n")
        
        logger.info("Successfully saved all papers")
    except Exception as e:
        logger.error(f"Failed to save output file: {e}")
        # Try backup
        backup_file = f"{output_file}.backup"
        try:
            with open(backup_file, "w", encoding="utf-8") as f:
                for paper in final_papers:
                    f.write(json.dumps(paper, ensure_ascii=False) + "\n")
            logger.info(f"Saved backup to {backup_file}")
        except Exception as e2:
            logger.error(f"Failed to save backup: {e2}")
    
    # Print statistics
    logger.info("=" * 60)
    logger.info("SIMPLE MERGE COMPLETE")
    logger.info("=" * 60)
    logger.info(f"Total papers processed: {len(final_papers)}")
    logger.info(f"Total papers in base: {len(base)}")
    
    logger.info("-" * 60)
    logger.info("DATA AVAILABILITY STATISTICS:")
    logger.info(f"  Papers with authors: {stats['has_authors']} ({stats['has_authors']/len(base)*100:.1f}%)")
    logger.info(f"  Papers with publish_year: {stats['has_publish_year']} ({stats['has_publish_year']/len(base)*100:.1f}%)")
    logger.info(f"  Papers with venue: {stats['has_venue']} ({stats['has_venue']/len(base)*100:.1f}%)")
    logger.info(f"  Papers with citation_count: {stats['has_citation_count']} ({stats['has_citation_count']/len(base)*100:.1f}%)")
    logger.info(f"  Papers with url: {stats['has_url']} ({stats['has_url']/len(base)*100:.1f}%)")
    logger.info(f"  Papers with fields_of_study: {stats['has_fields']} ({stats['has_fields']/len(base)*100:.1f}%)")
    logger.info(f"  Papers with keywords: {stats['has_keywords']} ({stats['has_keywords']/len(base)*100:.1f}%)")
    logger.info(f"  Papers with quality_scores: {stats['has_scores']} ({stats['has_scores']/len(base)*100:.1f}%)")
    logger.info(f"  Papers with contribution_summary: {stats['has_contributions']} ({stats['has_contributions']/len(base)*100:.1f}%)")
    
    # Calculate average scores for all papers
    if final_papers:
        avg_scores = defaultdict(float)
        for paper in final_papers:
            qs = paper.get("quality_scores", {})
            for key in ["novelty", "technical_depth", "clarity", "impact_potential", "overall_score", "confidence"]:
                if key in qs:
                    avg_scores[key] += qs[key]
        
        logger.info("-" * 60)
        logger.info("AVERAGE QUALITY SCORES:")
        for key in ["novelty", "technical_depth", "clarity", "impact_potential"]:
            if key in avg_scores:
                avg = avg_scores[key] / len(final_papers)
                logger.info(f"  {key:20s}: {avg:.2f}/10")
        
        if "overall_score" in avg_scores:
            avg = avg_scores["overall_score"] / len(final_papers)
            logger.info(f"  overall_score:        {avg:.2f}/10")
        
        if "confidence" in avg_scores:
            avg = avg_scores["confidence"] / len(final_papers)
            logger.info(f"  confidence:           {avg:.3f}")
    
    logger.info("=" * 60)

if __name__ == "__main__":
    main()

#### Merge all the generated jsonl with filtering

In [28]:
import json
from collections import defaultdict
from pathlib import Path
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def load_jsonl(path):
    """Load JSONL file and return dictionary with paper_id as key."""
    data = {}
    skipped = 0
    
    if not Path(path).exists():
        logger.error(f"File not found: {path}")
        return data
    
    with open(path, "r", encoding="utf-8") as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
                
            try:
                obj = json.loads(line)
                pid = obj.get("paper_id") or obj.get("id")
                
                if pid:
                    data[pid] = obj
                else:
                    logger.warning(f"Line {line_num}: No paper_id found")
                    skipped += 1
            except json.JSONDecodeError as e:
                logger.warning(f"Line {line_num}: JSON decode error - {e}")
                skipped += 1
            except Exception as e:
                logger.warning(f"Line {line_num}: Unexpected error - {e}")
                skipped += 1
    
    logger.info(f"Loaded {len(data)} papers from {path}, skipped {skipped} lines")
    return data

def clean_string(value):
    """Safely clean string, handling None values."""
    if value is None:
        return ""
    if isinstance(value, str):
        return value.strip()
    return str(value).strip()

def clean_list(lst, max_len=None):
    """Clean a list, handling None values and duplicates."""
    if not lst or not isinstance(lst, list):
        return []
    
    cleaned = []
    seen = set()
    
    for x in lst:
        if x is None:
            continue
            
        if isinstance(x, str):
            item = x.strip()
        else:
            item = str(x).strip()
            
        if item and item not in seen:
            cleaned.append(item)
            seen.add(item)
            
        if max_len and len(cleaned) >= max_len:
            break
    
    return cleaned

def safe_int(value):
    """Safely convert to integer."""
    if value is None:
        return None
    
    try:
        # Handle string, int, float
        if isinstance(value, str):
            # Remove non-numeric characters from end
            cleaned = value.strip()
            # Try to extract first number
            import re
            match = re.search(r'[-+]?\d+', cleaned)
            if match:
                return int(match.group())
            return None
        elif isinstance(value, (int, float)):
            return int(value)
        else:
            return None
    except (ValueError, TypeError):
        return None

def safe_float(value):
    """Safely convert to float."""
    if value is None:
        return None
    
    try:
        if isinstance(value, str):
            cleaned = value.strip()
            # Try to parse as float
            try:
                return float(cleaned)
            except ValueError:
                # Try to extract number
                import re
                match = re.search(r'[-+]?\d*\.?\d+', cleaned)
                if match:
                    return float(match.group())
                return None
        elif isinstance(value, (int, float)):
            return float(value)
        else:
            return None
    except (ValueError, TypeError):
        return None

def get_nested_value(obj, keys, default=None):
    """Safely get nested value from dictionary."""
    if not obj or not isinstance(obj, dict):
        return default
    
    current = obj
    for key in keys:
        if not isinstance(current, dict) or key not in current:
            return default
        current = current[key]
    
    return current

def main():
    # Define file paths
    base_path = "papers_final_aligned.jsonl"
    fields_path = "papers_enhanced_fields.jsonl"
    keywords_path = "papers_enhanced_keywords.jsonl"
    scores_path = "papers_enhanced_scores.jsonl"
    contribs_path = "papers_enhanced_contributions.jsonl"
    output_path = "papers_master_final.jsonl"
    
    logger.info("Loading data files...")
    
    # Load all data files
    base = load_jsonl(base_path)
    fields = load_jsonl(fields_path)
    keywords = load_jsonl(keywords_path)
    scores = load_jsonl(scores_path)
    contribs = load_jsonl(contribs_path)
    
    if not base:
        logger.error("No base papers loaded. Exiting.")
        return
    
    logger.info(f"Base papers: {len(base)}")
    logger.info(f"Fields data: {len(fields)}")
    logger.info(f"Keywords data: {len(keywords)}")
    logger.info(f"Scores data: {len(scores)}")
    logger.info(f"Contributions data: {len(contribs)}")
    
    final_papers = []
    dropped = 0
    merge_stats = defaultdict(int)
    
    logger.info("Merging and filtering papers...")
    
    for pid, paper in base.items():
        merged = {}
        
        # Basic paper info
        merged["source"] = paper.get("source")
        merged["paper_id"] = pid
        merged["title"] = clean_string(paper.get("title"))
        merged["abstract"] = clean_string(paper.get("abstract"))
        merged["abstract_source"] = paper.get("abstract_source")
        merged["authors"] = clean_list(paper.get("authors"), max_len=20)
        merged["publish_year"] = safe_int(paper.get("publish_year"))
        merged["venue"] = clean_string(paper.get("venue"))
        merged["citation_count"] = safe_int(paper.get("citation_count")) or 0
        merged["url"] = paper.get("url")
        
        # Fields of study
        if pid in fields:
            merged["fields_of_study"] = clean_list(
                get_nested_value(fields[pid], ["fields_of_study"]), 
                max_len=8
            )
            merge_stats["has_fields"] += 1
        else:
            merged["fields_of_study"] = []
            merge_stats["missing_fields"] += 1
        
        # Keywords
        if pid in keywords:
            merged["keywords"] = clean_list(
                get_nested_value(keywords[pid], ["keywords"]), 
                max_len=8
            )
            merge_stats["has_keywords"] += 1
        else:
            merged["keywords"] = []
            merge_stats["missing_keywords"] += 1
        
        # Quality scores
        if pid in scores:
            scores_data = get_nested_value(scores[pid], ["quality_scores"], {})
            # Ensure all score fields exist and are valid
            quality_scores = {}
            
            # Integer scores (0-10)
            for int_key in ["novelty", "technical_depth", "clarity", "impact_potential"]:
                value = scores_data.get(int_key)
                int_val = safe_int(value)
                if int_val is not None and 0 <= int_val <= 10:
                    quality_scores[int_key] = int_val
                else:
                    quality_scores[int_key] = 0
            
            # Float scores
            overall = safe_float(scores_data.get("overall_score"))
            if overall is not None and 0 <= overall <= 10:
                quality_scores["overall_score"] = round(overall, 1)
            else:
                quality_scores["overall_score"] = 0.0
            
            confidence = safe_float(scores_data.get("confidence"))
            if confidence is not None and 0 <= confidence <= 1:
                quality_scores["confidence"] = round(confidence, 2)
            else:
                quality_scores["confidence"] = 0.0
            
            merged["quality_scores"] = quality_scores
            merge_stats["has_scores"] += 1
        else:
            merged["quality_scores"] = {
                "novelty": 0,
                "technical_depth": 0,
                "clarity": 0,
                "impact_potential": 0,
                "overall_score": 0.0,
                "confidence": 0.0
            }
            merge_stats["missing_scores"] += 1
        
        # Contribution summary
        if pid in contribs:
            contrib_data = get_nested_value(contribs[pid], ["contribution_summary"], {})
            # Clean and validate contribution summary
            cleaned_contrib = {}
            
            # Problem and method (strings)
            for str_key in ["problem", "method"]:
                value = contrib_data.get(str_key)
                if isinstance(value, str) and value.strip():
                    cleaned_contrib[str_key] = value.strip()[:500]  # Limit length
                else:
                    cleaned_contrib[str_key] = ""
            
            # Lists
            for list_key in ["key_contributions", "application_scenarios"]:
                value = contrib_data.get(list_key)
                if isinstance(value, list):
                    cleaned_contrib[list_key] = clean_list(value, max_len=10)
                else:
                    cleaned_contrib[list_key] = []
            
            merged["contribution_summary"] = cleaned_contrib
            merge_stats["has_contribs"] += 1
        else:
            merged["contribution_summary"] = {
                "problem": "",
                "method": "",
                "key_contributions": [],
                "application_scenarios": []
            }
            merge_stats["missing_contribs"] += 1
        
        # Apply quality filters - 移除对标题和摘要长度的检查
        title = merged["title"]
        abstract = merged["abstract"]
        qs = merged["quality_scores"]
        
        try:
            # Check if paper meets minimum quality criteria
            # 不再检查标题和摘要长度
            score_too_low = qs.get("overall_score", 0) < 2
            depth_too_low = qs.get("technical_depth", 0) < 2
            confidence_too_low = qs.get("confidence", 0) < 0.6
            
            # 只根据分数阈值过滤
            if (score_too_low or depth_too_low or confidence_too_low):
                dropped += 1
                
                # Log reason for dropping (for debugging)
                if score_too_low:
                    merge_stats["dropped_score"] += 1
                if depth_too_low:
                    merge_stats["dropped_depth"] += 1
                if confidence_too_low:
                    merge_stats["dropped_confidence"] += 1
                    
                continue
            
            # Additional validation - 仍然检查标题和摘要是否为空
            if not title or not abstract:
                dropped += 1
                merge_stats["dropped_empty"] += 1
                continue
            
            final_papers.append(merged)
            
        except Exception as e:
            logger.warning(f"Error filtering paper {pid}: {e}")
            dropped += 1
            merge_stats["dropped_error"] += 1
            continue
    
    # Save final papers
    logger.info(f"Saving {len(final_papers)} papers to {output_path}")
    
    try:
        with open(output_path, "w", encoding="utf-8") as f:
            for p in final_papers:
                f.write(json.dumps(p, ensure_ascii=False) + "\n")
        
        logger.info("Successfully saved final papers")
    except Exception as e:
        logger.error(f"Failed to save output file: {e}")
        # Try to save backup
        backup_path = f"{output_path}.backup"
        try:
            with open(backup_path, "w", encoding="utf-8") as f:
                for p in final_papers:
                    f.write(json.dumps(p, ensure_ascii=False) + "\n")
            logger.info(f"Saved backup to {backup_path}")
        except:
            logger.error("Failed to save backup")
    
    # Print summary
    logger.info("=" * 60)
    logger.info("MASTER MERGE COMPLETE")
    logger.info("=" * 60)
    logger.info(f"Base papers: {len(base)}")
    logger.info(f"Final papers: {len(final_papers)}")
    logger.info(f"Dropped papers: {dropped}")
    logger.info(f"Retention rate: {len(final_papers)/len(base)*100:.1f}%")
    
    logger.info("-" * 60)
    logger.info("DATA AVAILABILITY:")
    logger.info(f"Papers with fields: {merge_stats['has_fields']} ({merge_stats['has_fields']/len(base)*100:.1f}%)")
    logger.info(f"Papers with keywords: {merge_stats['has_keywords']} ({merge_stats['has_keywords']/len(base)*100:.1f}%)")
    logger.info(f"Papers with scores: {merge_stats['has_scores']} ({merge_stats['has_scores']/len(base)*100:.1f}%)")
    logger.info(f"Papers with contributions: {merge_stats['has_contribs']} ({merge_stats['has_contribs']/len(base)*100:.1f}%)")
    
    if dropped > 0:
        logger.info("-" * 60)
        logger.info("DROP REASONS:")
        # 移除标题和摘要长度相关的丢弃原因
        if merge_stats["dropped_score"]:
            logger.info(f"  Score too low: {merge_stats['dropped_score']}")
        if merge_stats["dropped_depth"]:
            logger.info(f"  Technical depth too low: {merge_stats['dropped_depth']}")
        if merge_stats["dropped_confidence"]:
            logger.info(f"  Confidence too low: {merge_stats['dropped_confidence']}")
        if merge_stats["dropped_empty"]:
            logger.info(f"  Empty title/abstract: {merge_stats['dropped_empty']}")
        if merge_stats["dropped_error"]:
            logger.info(f"  Processing error: {merge_stats['dropped_error']}")
    
    # Calculate average scores for final papers
    if final_papers:
        avg_scores = defaultdict(float)
        for paper in final_papers:
            qs = paper.get("quality_scores", {})
            for key in ["novelty", "technical_depth", "clarity", "impact_potential", "overall_score", "confidence"]:
                if key in qs:
                    avg_scores[key] += qs[key]
        
        logger.info("-" * 60)
        logger.info("AVERAGE SCORES IN FINAL DATASET:")
        for key in ["novelty", "technical_depth", "clarity", "impact_potential"]:
            if key in avg_scores:
                avg = avg_scores[key] / len(final_papers)
                logger.info(f"  {key:20s}: {avg:.2f}/10")
        
        if "overall_score" in avg_scores:
            avg = avg_scores["overall_score"] / len(final_papers)
            logger.info(f"  overall_score:        {avg:.2f}/10")
        
        if "confidence" in avg_scores:
            avg = avg_scores["confidence"] / len(final_papers)
            logger.info(f"  confidence:           {avg:.3f}")
    
    logger.info("=" * 60)

if __name__ == "__main__":
    main()

2025-12-23 12:28:16,998 - INFO - Loading data files...
2025-12-23 12:28:17,036 - INFO - Loaded 6242 papers from papers_final_aligned.jsonl, skipped 0 lines
2025-12-23 12:28:17,077 - INFO - Loaded 6242 papers from papers_enhanced_fields.jsonl, skipped 0 lines
2025-12-23 12:28:17,266 - INFO - Loaded 6242 papers from papers_enhanced_keywords.jsonl, skipped 0 lines
2025-12-23 12:28:17,316 - INFO - Loaded 6242 papers from papers_enhanced_scores.jsonl, skipped 0 lines
2025-12-23 12:28:17,376 - INFO - Loaded 6242 papers from papers_enhanced_contributions.jsonl, skipped 0 lines
2025-12-23 12:28:17,376 - INFO - Base papers: 6242
2025-12-23 12:28:17,377 - INFO - Fields data: 6242
2025-12-23 12:28:17,378 - INFO - Keywords data: 6242
2025-12-23 12:28:17,378 - INFO - Scores data: 6242
2025-12-23 12:28:17,378 - INFO - Contributions data: 6242
2025-12-23 12:28:17,379 - INFO - Merging and filtering papers...
2025-12-23 12:28:17,473 - INFO - Saving 3236 papers to papers_master_final.jsonl
2025-12-23 12