# Agent Colab: Research Paper Entity Extraction Benchmark

This notebook sets up **Gemini 3 Pro** as an autonomous agent to solve the Research Paper Entity Extraction and Citation Analysis benchmark.

**Requirements:**
- Google Colab Pro (for native Gemini access via `google.colab.ai`)
- Dataset files uploaded or accessible

**Model Used:**
- `google/gemini-3-pro` - The most advanced reasoning model available in Colab Pro

**Implementation:**
- Uses `google.colab.ai` module for native Colab Pro AI integration
- No external API keys required - uses Colab Pro's built-in AI capabilities

**Note:** This notebook should run end-to-end without manual intervention.

## Setup and Dependencies

In [7]:
# Install required packages
%pip install -q pandas networkx

In [8]:
from google.colab import ai
import json
import pandas as pd
import numpy as np
from collections import defaultdict, Counter
from datetime import datetime
from typing import Dict, List, Any, Tuple
import re
import networkx as nx
import warnings
import unittest
warnings.filterwarnings('ignore')

# List available models in Colab Pro
print("Available AI models in Colab Pro:")
available_models = ai.list_models()
for model in available_models:
    print(f"  - {model}")

Available AI models in Colab Pro:


TimeoutException: Requesting secret MODEL_PROXY_API_KEY timed out. Secrets can only be fetched when running from the Colab UI.

## Agent Configuration

Select and configure Gemini-3-Pro from available Colab Pro models.

In [None]:
# Select the most capable model for agentic tasks
# Using Gemini 2.5 Pro - the most advanced reasoning model available
MODEL_NAME = "google/gemini-2.5-pro"

# Verify the model is available
if MODEL_NAME in available_models:
    print(f"Model '{MODEL_NAME}' is available - SELECTED")
else:
    print(f"Warning: '{MODEL_NAME}' not found. Available models: {available_models}")
    # Fallback to other Pro/capable models
    fallback_order = ["google/gemini-2.5-pro", "google/gemini-2.0-flash", "google/gemini-2.5-flash"]
    for fallback in fallback_order:
        if fallback in available_models:
            MODEL_NAME = fallback
            print(f"Using fallback model: {MODEL_NAME}")
            break

print(f"\nAgent model selected: {MODEL_NAME}")

Agent model initialized: gemini-3-pro


## Load Dataset

Load the benchmark dataset files.

In [4]:
# Define file paths
PAPERS_FILE = "papers_metadata.json"
CITATIONS_FILE = "citations.csv"
AFFILIATIONS_FILE = "author_affiliations.json"

# Load the data files
with open(PAPERS_FILE, 'r') as f:
    papers_raw = json.load(f)

citations_raw = pd.read_csv(CITATIONS_FILE)

with open(AFFILIATIONS_FILE, 'r') as f:
    affiliations_raw = json.load(f)

print(f"Dataset loaded:")
print(f"  - Papers: {len(papers_raw)} records")
print(f"  - Citations: {len(citations_raw)} relationships")
print(f"  - Affiliations: {len(affiliations_raw.get('authors', {}))} authors, {len(affiliations_raw.get('institutions', {}))} institutions")

FileNotFoundError: [Errno 2] No such file or directory: 'papers_metadata.json'

## Benchmark Prompt

The task specification for the agent.

In [None]:
BENCHMARK_PROMPT = """
# Research Paper Entity Extraction and Citation Analysis Benchmark

## Scenario

You are a data scientist tasked with building an automated pipeline for analyzing research paper metadata. 
Your goal is to extract structured information from a collection of research papers, resolve entity ambiguities, 
construct a citation network, and produce a comprehensive analytical report.

You must decide for yourself how to decompose the task, which intermediate computations to perform, and in what order.
Do not simply follow a fixed step-by-step structure.

## Context

You have access to three data files (already loaded):
1. papers_raw - List of paper metadata dictionaries
2. citations_raw - DataFrame of citation relationships  
3. affiliations_raw - Reference data about known authors and institutions

The data contains intentional challenges:
- Author names appear in different formats (e.g., "John Smith" vs "J. Smith")
- Institution names have variations (e.g., "MIT" vs "Massachusetts Institute of Technology")
- Some papers have missing fields
- The citation data may contain anomalies (orphan references, self-citations)

## Required Output Variables

You must produce these variables:

### Core Data Variables
- papers_df: pd.DataFrame with columns: paper_id, title, authors, institution, abstract, keywords, venue, year, publication_date
- citations_df: pd.DataFrame with columns: citing_paper, cited_paper
- affiliations_data: dict with 'authors' and 'institutions' keys

### Entity Extraction Variables
- extracted_authors: list[dict] with keys: name, paper_ids, name_variations
- extracted_institutions: list[dict] with keys: name, paper_ids, name_variations
- extracted_topics: dict[str, int] mapping topics to frequency counts
- methods_from_abstracts: list[str] of research methods found

### Entity Resolution Variables
- author_resolution_map: dict[str, str] mapping variations to canonical names
- institution_resolution_map: dict[str, str] mapping variations to canonical names
- resolved_author_count: int
- resolved_institution_count: int

### Citation Network Variables
- citation_graph: dict[str, list[str]] adjacency list
- in_degree: dict[str, int] incoming citations per paper
- out_degree: dict[str, int] outgoing citations per paper
- pagerank_scores: dict[str, float] PageRank scores
- top_cited_papers: list[str] top 10 most cited paper IDs
- orphan_citations: list[dict] citations to non-existent papers
- self_citations: list[str] papers that cite themselves

### Validation Dictionary
- validation_results: dict[str, bool] with keys:
  - papers_loaded_ok, citations_loaded_ok, affiliations_loaded_ok
  - no_duplicate_paper_ids, authors_extracted, institutions_extracted
  - resolution_maps_valid, citation_graph_built, pagerank_computed
  - orphans_identified, self_citations_identified, all_pagerank_finite

### Summary Statistics
- summary_stats: dict with keys:
  - total_papers, total_citations, unique_authors_raw, unique_authors_resolved
  - unique_institutions_raw, unique_institutions_resolved
  - papers_with_missing_abstract, papers_with_missing_keywords
  - orphan_citation_count, self_citation_count, avg_citations_per_paper
  - most_common_venue, year_range

### Final Report
- final_report: dict with structure:
  - metadata: {task, papers_analyzed, execution_timestamp}
  - entity_extraction: {authors, institutions, topics}
  - citation_analysis: {total_citations, top_10_cited_papers, orphan_citations, self_citations, network_statistics}
  - data_quality: {missing_abstracts, missing_keywords, missing_institutions, duplicate_author_entries}
  - validation_summary: {all_checks_passed, failed_checks}

## Constraints
1. Do not hardcode specific paper IDs, author names, or institution names
2. Entity resolution must use fuzzy matching or reference data
3. PageRank must use damping factor 0.85
4. Handle edge cases gracefully

## Success Criteria
1. All validation checks pass
2. Entity resolution reduces author count
3. Orphan citations are identified (at least one exists)
4. Self-citations are identified (at least one exists)
5. PageRank scores sum to approximately 1.0
6. Final report follows exact schema
7. All numeric values are finite

Write complete Python code to solve this task. Store all results in the specified variable names.
"""

print("Benchmark prompt loaded")

## Agent Execution

In [None]:
def run_agent_task(prompt, data_context):
    """Run the agent using google.colab.ai to generate code for the task."""
    
    # Prepare context with data samples
    context = f"""
You have access to the following data (already loaded in Python):

papers_raw: A list of {len(data_context['papers'])} paper dictionaries
Sample: {json.dumps(data_context['papers'][0], indent=2)}

citations_raw: A pandas DataFrame with {len(data_context['citations'])} rows
Columns: {data_context['citations'].columns.tolist()}
Sample:
{data_context['citations'].head(3).to_string()}

affiliations_raw: A dictionary with author and institution reference data
Keys: {list(data_context['affiliations'].keys())}
Sample author: {json.dumps(list(data_context['affiliations']['authors'].values())[0], indent=2)}
Sample institution: {json.dumps(list(data_context['affiliations']['institutions'].values())[0], indent=2)}

{prompt}
"""
    
    print("Sending task to agent...")
    print("="*50)
    
    # Use google.colab.ai to generate response
    # The ai.generate_text function uses the Colab Pro's native AI capabilities
    response = ai.generate_text(
        prompt=context,
        model=MODEL_NAME,
        temperature=0.1,  # Lower temperature for deterministic outputs
    )
    
    return response


# Prepare data context
data_context = {
    'papers': papers_raw,
    'citations': citations_raw,
    'affiliations': affiliations_raw
}

# Run the agent
agent_response = run_agent_task(BENCHMARK_PROMPT, data_context)
print("Agent response received")
print("="*50)
print(agent_response[:2000] + "..." if len(agent_response) > 2000 else agent_response)

In [None]:
# Extract Python code from agent response and execute it
def extract_and_execute_code(response_text):
    """Extract Python code blocks from the response and execute them."""
    
    # Find all code blocks
    code_blocks = re.findall(r'```python\n(.*?)```', response_text, re.DOTALL)
    
    if not code_blocks:
        # Try without language specifier
        code_blocks = re.findall(r'```\n(.*?)```', response_text, re.DOTALL)
    
    if not code_blocks:
        print("No code blocks found in response")
        return None
    
    # Combine all code blocks
    full_code = "\n\n".join(code_blocks)
    
    print(f"Extracted {len(code_blocks)} code block(s)")
    print("Executing agent code...")
    print("="*50)
    
    # Execute the code
    exec_globals = {
        'papers_raw': papers_raw,
        'citations_raw': citations_raw,
        'affiliations_raw': affiliations_raw,
        'pd': pd,
        'np': np,
        'json': json,
        're': re,
        'nx': nx,
        'defaultdict': defaultdict,
        'Counter': Counter,
        'datetime': datetime,
        'Dict': Dict,
        'List': List,
        'Any': Any,
        'Tuple': Tuple,
    }
    
    try:
        exec(full_code, exec_globals)
        print("Code executed successfully!")
        return exec_globals
    except Exception as e:
        print(f"Error executing code: {e}")
        return None

# Execute the agent's code
exec_result = extract_and_execute_code(agent_response)

# If successful, extract variables to global scope
if exec_result:
    # Extract all required variables
    required_vars = [
        'papers_df', 'citations_df', 'affiliations_data',
        'extracted_authors', 'extracted_institutions', 'extracted_topics', 'methods_from_abstracts',
        'author_resolution_map', 'institution_resolution_map', 'resolved_author_count', 'resolved_institution_count',
        'citation_graph', 'in_degree', 'out_degree', 'pagerank_scores', 'top_cited_papers',
        'orphan_citations', 'self_citations',
        'validation_results', 'summary_stats', 'final_report'
    ]
    
    for var in required_vars:
        if var in exec_result:
            globals()[var] = exec_result[var]
            print(f"  Loaded: {var}")
        else:
            print(f"  Missing: {var}")

## Agent Output

In [None]:
# Display the agent's outputs
try:
    print("=== VALIDATION RESULTS ===")
    print(json.dumps(validation_results, indent=2))
    print("\n=== FINAL REPORT ===")
    print(json.dumps(final_report, indent=2))
except NameError as e:
    print(f"Variable not defined: {e}")
    print("Agent may not have completed the task successfully.")

---

# Unit Tests

Comprehensive tests to validate the agent's solution.

In [None]:
class TestDataLoading(unittest.TestCase):
    """Tests for data loading functionality."""
    
    def test_papers_df_exists_and_not_empty(self):
        self.assertIsInstance(papers_df, pd.DataFrame)
        self.assertGreater(len(papers_df), 0)
    
    def test_papers_df_has_required_columns(self):
        required = {'paper_id', 'title', 'authors', 'institution', 
                   'abstract', 'keywords', 'venue', 'year', 'publication_date'}
        self.assertTrue(required.issubset(set(papers_df.columns)))
    
    def test_citations_df_exists_and_not_empty(self):
        self.assertIsInstance(citations_df, pd.DataFrame)
        self.assertGreater(len(citations_df), 0)
    
    def test_citations_df_has_required_columns(self):
        required = {'citing_paper', 'cited_paper'}
        self.assertTrue(required.issubset(set(citations_df.columns)))
    
    def test_affiliations_data_structure(self):
        self.assertIsInstance(affiliations_data, dict)
        self.assertIn('authors', affiliations_data)
        self.assertIn('institutions', affiliations_data)
    
    def test_no_duplicate_paper_ids(self):
        self.assertEqual(papers_df['paper_id'].nunique(), len(papers_df))


class TestEntityExtraction(unittest.TestCase):
    """Tests for entity extraction functionality."""
    
    def test_extracted_authors_not_empty(self):
        self.assertGreater(len(extracted_authors), 0)
    
    def test_extracted_authors_structure(self):
        for author in extracted_authors:
            self.assertIn('name', author)
            self.assertIn('paper_ids', author)
            self.assertIn('name_variations', author)
    
    def test_extracted_institutions_not_empty(self):
        self.assertGreater(len(extracted_institutions), 0)
    
    def test_extracted_topics_is_dict(self):
        self.assertIsInstance(extracted_topics, dict)
    
    def test_methods_from_abstracts_is_list(self):
        self.assertIsInstance(methods_from_abstracts, list)


class TestEntityResolution(unittest.TestCase):
    """Tests for entity resolution functionality."""
    
    def test_author_resolution_map_not_empty(self):
        self.assertGreater(len(author_resolution_map), 0)
    
    def test_institution_resolution_map_not_empty(self):
        self.assertGreater(len(institution_resolution_map), 0)
    
    def test_resolution_reduces_author_count(self):
        raw_count = len(extracted_authors)
        self.assertLessEqual(resolved_author_count, raw_count)
    
    def test_resolved_counts_are_positive(self):
        self.assertGreater(resolved_author_count, 0)
        self.assertGreater(resolved_institution_count, 0)


class TestCitationNetwork(unittest.TestCase):
    """Tests for citation network functionality."""
    
    def test_citation_graph_not_empty(self):
        self.assertGreater(len(citation_graph), 0)
    
    def test_in_degree_covers_all_papers(self):
        self.assertEqual(len(in_degree), len(papers_df))
    
    def test_out_degree_covers_all_papers(self):
        self.assertEqual(len(out_degree), len(papers_df))
    
    def test_pagerank_scores_not_empty(self):
        self.assertGreater(len(pagerank_scores), 0)
    
    def test_pagerank_scores_sum_to_one(self):
        total = sum(pagerank_scores.values())
        self.assertAlmostEqual(total, 1.0, delta=0.01)
    
    def test_pagerank_scores_are_finite(self):
        for score in pagerank_scores.values():
            self.assertTrue(np.isfinite(score))
    
    def test_orphan_citations_identified(self):
        self.assertIsInstance(orphan_citations, list)
        self.assertGreater(len(orphan_citations), 0)
    
    def test_self_citations_identified(self):
        self.assertIsInstance(self_citations, list)
        self.assertGreater(len(self_citations), 0)


class TestValidationResults(unittest.TestCase):
    """Tests for validation results."""
    
    def test_validation_results_is_dict(self):
        self.assertIsInstance(validation_results, dict)
    
    def test_validation_results_has_required_keys(self):
        required_keys = {
            "papers_loaded_ok", "citations_loaded_ok", "affiliations_loaded_ok",
            "no_duplicate_paper_ids", "authors_extracted", "institutions_extracted",
            "resolution_maps_valid", "citation_graph_built", "pagerank_computed",
            "orphans_identified", "self_citations_identified", "all_pagerank_finite"
        }
        self.assertTrue(required_keys.issubset(set(validation_results.keys())))
    
    def test_all_validations_pass(self):
        failed = [k for k, v in validation_results.items() if not v]
        self.assertEqual(len(failed), 0, f"Failed validations: {failed}")


class TestSummaryStats(unittest.TestCase):
    """Tests for summary statistics."""
    
    def test_summary_stats_is_dict(self):
        self.assertIsInstance(summary_stats, dict)
    
    def test_summary_stats_has_required_keys(self):
        required_keys = {
            "total_papers", "total_citations", "unique_authors_raw",
            "unique_authors_resolved", "unique_institutions_raw",
            "unique_institutions_resolved", "papers_with_missing_abstract",
            "papers_with_missing_keywords", "orphan_citation_count",
            "self_citation_count", "avg_citations_per_paper",
            "most_common_venue", "year_range"
        }
        self.assertTrue(required_keys.issubset(set(summary_stats.keys())))


class TestFinalReport(unittest.TestCase):
    """Tests for final report structure."""
    
    def test_final_report_is_dict(self):
        self.assertIsInstance(final_report, dict)
    
    def test_final_report_has_metadata(self):
        self.assertIn('metadata', final_report)
    
    def test_final_report_has_entity_extraction(self):
        self.assertIn('entity_extraction', final_report)
    
    def test_final_report_has_citation_analysis(self):
        self.assertIn('citation_analysis', final_report)
    
    def test_final_report_has_data_quality(self):
        self.assertIn('data_quality', final_report)
    
    def test_final_report_has_validation_summary(self):
        self.assertIn('validation_summary', final_report)
    
    def test_all_checks_passed(self):
        self.assertTrue(final_report['validation_summary']['all_checks_passed'])

In [None]:
# Run all unit tests
def run_tests():
    """Run all unit tests and report results."""
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    
    suite.addTests(loader.loadTestsFromTestCase(TestDataLoading))
    suite.addTests(loader.loadTestsFromTestCase(TestEntityExtraction))
    suite.addTests(loader.loadTestsFromTestCase(TestEntityResolution))
    suite.addTests(loader.loadTestsFromTestCase(TestCitationNetwork))
    suite.addTests(loader.loadTestsFromTestCase(TestValidationResults))
    suite.addTests(loader.loadTestsFromTestCase(TestSummaryStats))
    suite.addTests(loader.loadTestsFromTestCase(TestFinalReport))
    
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    
    print("\n" + "="*50)
    print(f"Tests run: {result.testsRun}")
    print(f"Failures: {len(result.failures)}")
    print(f"Errors: {len(result.errors)}")
    print(f"Success: {result.wasSuccessful()}")
    
    return result

# Execute tests
try:
    test_result = run_tests()
except Exception as e:
    print(f"Error running tests: {e}")
    print("Some required variables may not be defined.")

## Final Summary

In [None]:
# Final summary
print("="*60)
print("BENCHMARK EXECUTION SUMMARY")
print("="*60)

try:
    print(f"\nAgent Model: {MODEL_NAME}")
    print(f"Papers Analyzed: {len(papers_df)}")
    print(f"Citations Processed: {len(citations_df)}")
    print(f"\nEntity Resolution:")
    print(f"  Authors: {len(extracted_authors)} raw -> {resolved_author_count} resolved")
    print(f"  Institutions: {len(extracted_institutions)} raw -> {resolved_institution_count} resolved")
    print(f"\nCitation Network:")
    print(f"  Orphan citations found: {len(orphan_citations)}")
    print(f"  Self-citations found: {len(self_citations)}")
    print(f"  PageRank sum: {sum(pagerank_scores.values()):.4f}")
    print(f"\nValidation Summary:")
    failed = [k for k, v in validation_results.items() if not v]
    if failed:
        print(f"  FAILED checks: {failed}")
    else:
        print("  ALL CHECKS PASSED")
    print(f"\nTest Results:")
    print(f"  Tests run: {test_result.testsRun}")
    print(f"  Failures: {len(test_result.failures)}")
    print(f"  Errors: {len(test_result.errors)}")
    
    if test_result.wasSuccessful() and not failed:
        print("\n" + "="*60)
        print("BENCHMARK COMPLETED SUCCESSFULLY!")
        print("="*60)
    else:
        print("\n" + "="*60)
        print("BENCHMARK COMPLETED WITH ISSUES")
        print("="*60)
except Exception as e:
    print(f"\nError generating summary: {e}")