In [None]:
# Cell 1: GitHub Setup and Enhanced Import Configuration

# Install required packages first
!pip install edgartools transformers torch requests beautifulsoup4 'lxml[html_clean]' uuid numpy newspaper3k --quiet

import os
import sys
import importlib
import importlib.util
import psycopg2
from edgar import set_identity

# GitHub credentials - use Kaggle secrets for security
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
github_token = user_secrets.get_secret("GITHUB_TOKEN")
repo_url = f"https://{github_token}@github.com/amiralpert/SmartReach.git"
local_path = "/kaggle/working/SmartReach"

print("📦 Setting up GitHub repository...")

# Clone or update repo with force pull
if os.path.exists(local_path):
    print(f"📂 Repository exists at {local_path}")
    print("🔄 Force updating from GitHub...")
    !cd {local_path} && git fetch origin
    !cd {local_path} && git reset --hard origin/main
    !cd {local_path} && git pull origin main
    print("✅ Repository updated")
    
    # Show current commit
    !cd {local_path} && echo "Current commit:" && git log --oneline -1
else:
    print(f"📥 Cloning repository to {local_path}")
    !git clone {repo_url} {local_path}
    print("✅ Repository cloned")

# Clear any cached modules from previous runs
modules_to_clear = [key for key in sys.modules.keys() if 'auto_logger' in key.lower() or 'clean' in key.lower()]
for mod in modules_to_clear:
    del sys.modules[mod]
    print(f"  🧹 Cleared cached module: {mod}")

# Add to Python path for regular imports
if f'{local_path}/BizIntel' in sys.path:
    sys.path.remove(f'{local_path}/BizIntel')
sys.path.insert(0, f'{local_path}/BizIntel')

print("✓ Python path configured for SEC entity extraction!")

# Configure EdgarTools authentication - REQUIRED by SEC
set_identity("SmartReach BizIntel amir@leanbio.consulting")
print("🆔 EdgarTools identity configured: SmartReach BizIntel amir@leanbio.consulting")

# Set up database configuration
NEON_CONFIG = {
    'host': 'ep-royal-star-ad1gn0d4-pooler.c-2.us-east-1.aws.neon.tech',
    'database': 'BizIntelSmartReach',
    'user': 'neondb_owner',
    'password': 'npg_aTFt6Pug3Kpy',
    'sslmode': 'require'
}

# Set up the new clean auto-logger
try:
    # Create connection for logger
    logger_conn = psycopg2.connect(**NEON_CONFIG)
    print("✓ Database connected for clean logger")

    # Import the redesigned clean auto-logger
    logger_module_path = f"{local_path}/BizIntel/Scripts/KaggleLogger/auto_logger.py"
    if os.path.exists(logger_module_path):
        spec = importlib.util.spec_from_file_location("auto_logger", logger_module_path)
        auto_logger_module = importlib.util.module_from_spec(spec)
        sys.modules["auto_logger"] = auto_logger_module
        spec.loader.exec_module(auto_logger_module)

        # Use the new clean logging setup
        setup_clean_logging = auto_logger_module.setup_clean_logging
        logger = setup_clean_logging(logger_conn, "SEC_EntityExtraction")
        
        print("✨ Clean auto-logging enabled!")
        print("📋 Features:")
        print("   • One row per cell execution")
        print("   • Complete output capture")
        print("   • Proper cell numbers from # Cell N: comments")
        print("   • Full error tracebacks")
        print("   • Execution timing")
        
    else:
        print(f"✗ Clean auto-logger not found at {logger_module_path}")
        logger = None
        
except Exception as e:
    print(f"⚠️ Clean logger setup failed: {e}")
    print("  Continuing without auto-logging...")
    logger = None

print("\n🚀 Setup complete! SEC Entity Extraction Engine with Enhanced EdgarTools Integration ready.")
print("💡 Run cells with proper # Cell N: comments for best logging.")
print("🔧 EdgarTools configured with section extraction capabilities.")

In [None]:
# Cell 2: Database Connection Test

import psycopg2

# Test database connection
def test_database_connection():
    try:
        conn = psycopg2.connect(**NEON_CONFIG)
        cursor = conn.cursor()
        
        # Check SEC-related tables
        cursor.execute('''
            SELECT 
                (SELECT COUNT(*) FROM raw_data.sec_filings) as sec_filings,
                (SELECT COUNT(*) FROM core.companies) as companies,
                (SELECT COUNT(*) FROM system_uno.sec_entities_raw) as sec_entities_extracted,
                (SELECT COUNT(DISTINCT company_domain) FROM raw_data.sec_filings) as companies_with_filings,
                (SELECT COUNT(*) FROM raw_data.sec_filings WHERE url IS NOT NULL) as filings_with_urls
        ''')
        
        counts = cursor.fetchone()
        print("✓ Database connected successfully!")
        print(f"  SEC Filings: {counts[0]}")
        print(f"  Companies: {counts[1]}")
        print(f"  Extracted SEC Entities: {counts[2]}")
        print(f"  Companies with SEC Filings: {counts[3]}")
        print(f"  SEC Filings with URLs: {counts[4]}")
        
        # Show sample SEC filing data
        cursor.execute('''
            SELECT company_domain, filing_type, COUNT(*) as count
            FROM raw_data.sec_filings 
            GROUP BY company_domain, filing_type 
            ORDER BY company_domain, count DESC
            LIMIT 10
        ''')
        
        filing_stats = cursor.fetchall()
        print("\n📊 SEC Filing Distribution:")
        for stat in filing_stats:
            print(f"  {stat[0]}: {stat[1]} ({stat[2]} filings)")
        
        cursor.close()
        conn.close()
        return True
        
    except Exception as e:
        print(f"✗ Database connection failed: {e}")
        return False

# Test connection
test_database_connection()

In [None]:
# Cell 3: Accession-Based SEC Content Extraction with EdgarTools

import edgar
from edgar import Filing, find, set_identity, Company
from edgar.documents import parse_html
from edgar.documents.extractors.section_extractor import SectionExtractor
import requests
import re
from typing import Dict, List, Optional, Tuple
from bs4 import BeautifulSoup

# Ensure identity is set
set_identity("SmartReach BizIntel amir@leanbio.consulting")

def extract_accession_from_url(url: str) -> Optional[str]:
    """Extract accession number from SEC filing URL"""
    # SEC filing URLs typically have format: 
    # https://www.sec.gov/Archives/edgar/data/CIK/ACCESSION_NUMBER/filename
    
    # Match standard pattern: 0001234567-23-000001
    accession_match = re.search(r'(\d{10}-\d{2}-\d{6})', url)
    if accession_match:
        return accession_match.group(1)
    
    # Match compressed pattern in URL path: 000123456723000001
    compressed_match = re.search(r'/(\d{18})/', url)
    if compressed_match:
        # Convert compressed format to standard format
        accession = compressed_match.group(1)
        return f"{accession[:10]}-{accession[10:12]}-{accession[12:]}"
    
    return None

def get_filing_sections(accession_number: str, filing_type: str = None) -> Dict[str, str]:
    """Get structured sections from SEC filing using accession number"""
    try:
        # Find filing using accession number
        filing = find(accession_number)
        
        if not filing:
            raise ValueError(f"Filing not found for accession: {accession_number}")
        
        # Auto-detect filing type if not provided
        if not filing_type:
            filing_type = getattr(filing, 'form', '10-K')
        
        print(f"   📄 Found {filing_type} for {getattr(filing, 'company', 'Unknown Company')}")
        
        # Get structured HTML content
        html_content = filing.html()
        if not html_content:
            raise ValueError("No HTML content available")
        
        # Parse HTML to Document object
        document = parse_html(html_content)
        
        # Extract sections using SectionExtractor
        extractor = SectionExtractor(filing_type=filing_type)
        sections = extractor.extract(document)
        
        print(f"   📑 SectionExtractor found {len(sections)} sections: {list(sections.keys())}")
        
        # Convert sections to text dictionary
        section_texts = {}
        for section_name, section in sections.items():
            try:
                if hasattr(section, 'text'):
                    text = section.text() if callable(section.text) else section.text
                    if isinstance(text, str) and text.strip():
                        section_texts[section_name] = text.strip()
                        print(f"      • {section_name}: {len(text):,} chars")
                elif hasattr(section, '__str__'):
                    text = str(section).strip()
                    if text:
                        section_texts[section_name] = text
                        print(f"      • {section_name}: {len(text):,} chars (via str)")
            except Exception as section_e:
                print(f"      ⚠️ Could not extract {section_name}: {section_e}")
                continue
        
        # If SectionExtractor returns no sections, fall back to full document text
        if not section_texts:
            print(f"   🔄 No structured sections found, using full document as 'full_document' section")
            full_text = document.text() if hasattr(document, 'text') and callable(document.text) else str(document)
            if full_text and len(full_text.strip()) > 100:  # Only use if substantial content
                section_texts['full_document'] = full_text.strip()
                print(f"      • full_document: {len(full_text):,} chars")
        
        return section_texts
        
    except Exception as e:
        print(f"   ❌ Section extraction failed: {e}")
        return {}

def route_sections_to_models(sections: Dict[str, str], filing_type: str) -> Dict[str, List[str]]:
    """Route sections to appropriate NER models based on filing type"""
    routing = {
        'biobert': [],
        'bert_base': [],
        'roberta': [],
        'finbert': []
    }
    
    if filing_type.upper() in ['10-K', '10-Q']:
        for section_name, section_text in sections.items():
            # FinBERT gets financial statements exclusively
            if 'financial' in section_name.lower() or 'statement' in section_name.lower():
                routing['finbert'].append(section_name)
            else:
                # All other sections go to BERT/RoBERTa/BioBERT
                routing['bert_base'].append(section_name)
                routing['roberta'].append(section_name)
                routing['biobert'].append(section_name)
    
    elif filing_type.upper() == '8-K':
        # 8-K: all item sections go to all four models
        for section_name in sections.keys():
            routing['biobert'].append(section_name)
            routing['bert_base'].append(section_name)
            routing['roberta'].append(section_name)
            routing['finbert'].append(section_name)
    
    else:
        # Default routing for other filing types
        for section_name in sections.keys():
            routing['bert_base'].append(section_name)
            routing['roberta'].append(section_name)
            routing['biobert'].append(section_name)
    
    # Remove empty routing
    routing = {model: sections_list for model, sections_list in routing.items() if sections_list}
    
    return routing

def process_sec_filing_with_sections(filing_data: Dict) -> Dict:
    """Process SEC filing with section-based extraction"""
    try:
        filing_id = filing_data.get('id')
        filing_url = filing_data.get('url')
        filing_type = filing_data.get('filing_type', '10-K')
        company_domain = filing_data.get('company_domain', 'Unknown')
        
        print(f"🏢 Processing {filing_type} for {company_domain}")
        print(f"   📄 Filing ID: {filing_id}")
        print(f"   🔗 URL: {filing_url}")
        
        # Extract accession number
        accession_number = extract_accession_from_url(filing_url)
        if not accession_number:
            raise ValueError(f"Could not extract accession from URL: {filing_url}")
        
        print(f"   📋 Accession: {accession_number}")
        
        # Get structured sections
        sections = get_filing_sections(accession_number, filing_type)
        if not sections:
            raise ValueError("No sections extracted")
        
        print(f"   📑 Extracted {len(sections)} sections")
        
        # Route sections to models
        model_routing = route_sections_to_models(sections, filing_type)
        print(f"   🎯 Model routing: {[f'{model}: {len(secs)} sections' for model, secs in model_routing.items()]}")
        
        return {
            'filing_id': filing_id,
            'company_domain': company_domain,
            'filing_type': filing_type,
            'accession_number': accession_number,
            'url': filing_url,
            'sections': sections,
            'model_routing': model_routing,
            'total_sections': len(sections),
            'processing_status': 'success'
        }
        
    except Exception as e:
        print(f"   ❌ Processing failed: {e}")
        return {
            'filing_id': filing_data.get('id'),
            'company_domain': filing_data.get('company_domain', 'Unknown'),
            'filing_type': filing_data.get('filing_type', 'Unknown'),
            'url': filing_data.get('url'),
            'error': str(e),
            'processing_status': 'failed'
        }

def get_unprocessed_filings(limit: int = 5) -> List[Dict]:
    """Get SEC filings that haven't been processed yet"""
    try:
        conn = psycopg2.connect(**NEON_CONFIG)
        cursor = conn.cursor()
        
        cursor.execute("""
            SELECT sf.id, sf.company_domain, sf.filing_type, sf.url, sf.filing_date, sf.title
            FROM raw_data.sec_filings sf
            LEFT JOIN system_uno.sec_entities_raw ser ON ser.sec_filing_ref = CONCAT('SEC_', sf.id)
            WHERE sf.url IS NOT NULL 
            AND ser.sec_filing_ref IS NULL
            ORDER BY sf.filing_date DESC
            LIMIT %s
        """, (limit,))
        
        filings = cursor.fetchall()
        cursor.close()
        conn.close()
        
        return [{
            'id': filing[0],
            'company_domain': filing[1],
            'filing_type': filing[2],
            'url': filing[3],
            'filing_date': filing[4],
            'title': filing[5]
        } for filing in filings]
        
    except Exception as e:
        print(f"❌ Database query failed: {e}")
        return []

# Test the accession-based extraction
print("🧪 Testing accession-based section extraction...")

test_filings = get_unprocessed_filings(limit=1)

if test_filings:
    print(f"📄 Found {len(test_filings)} test filing(s)")
    
    for filing in test_filings:
        result = process_sec_filing_with_sections(filing)
        
        if result['processing_status'] == 'success':
            print(f"   ✅ Success: {result['total_sections']} sections extracted")
            print(f"   📊 Available sections: {list(result['sections'].keys())}")
            
            # Show model routing summary
            for model, section_list in result['model_routing'].items():
                print(f"   🎯 {model}: {len(section_list)} sections")
                
            # Show sample section content
            first_section = list(result['sections'].keys())[0]
            sample_text = result['sections'][first_section][:200] + "..."
            print(f"   📝 Sample from '{first_section}': {sample_text}")
            
        else:
            print(f"   ❌ Failed: {result.get('error', 'Unknown error')}")
else:
    print("📭 No unprocessed filings found for testing")

print("\n✅ Accession-based SEC section extraction ready!")
print("🔧 Features:")
print("   • Accession number extraction from URLs")
print("   • EdgarTools section extraction with SectionExtractor")
print("   • Fallback to full document if no sections found")
print("   • Automatic model routing based on filing type")
print("   • 10-K/10-Q: Financial sections → FinBERT, Others → BERT/RoBERTa/BioBERT")
print("   • 8-K: All sections → All models")
print("   • Section-level content targeting for optimal entity extraction")

In [None]:
# Cell 4: EntityExtractionPipeline Class and NER Model Loading

import torch
import time
import uuid
import json
from datetime import datetime
from typing import List, Dict, Any, Optional
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed

print("🚀 Loading EntityExtractionPipeline and NER Models...")

class EntityExtractionPipeline:
    """Extensible pipeline for entity extraction from multiple data sources"""
    
    def __init__(self, db_config: Dict, model_config: Dict = None):
        self.db_config = db_config
        self.model_config = model_config or self._get_default_model_config()
        self.loaded_models = {}
        self.pipeline_stats = {
            'documents_processed': 0,
            'total_entities_extracted': 0,
            'processing_time_total': 0,
            'data_sources_supported': ['sec_filings', 'press_releases', 'patents']
        }
        
        print(f"🔧 EntityExtractionPipeline initialized")
        print(f"   📊 Supported data sources: {self.pipeline_stats['data_sources_supported']}")
    
    def _get_default_model_config(self) -> Dict:
        """Default model configuration for biotech entity extraction"""
        return {
            'models': {
                'biobert': {
                    'model_name': 'alvaroalon2/biobert_diseases_ner',
                    'confidence_threshold': 0.5,
                    'description': 'Biomedical entities (diseases, medications, treatments)'
                },
                'bert_base': {
                    'model_name': 'dslim/bert-base-NER',
                    'confidence_threshold': 0.5,
                    'description': 'General entities (persons, organizations, locations)'
                },
                'finbert': {
                    'model_name': 'ProsusAI/finbert',
                    'confidence_threshold': 0.5,
                    'description': 'Financial entities and metrics'
                },
                'roberta': {
                    'model_name': 'Jean-Baptiste/roberta-large-ner-english',
                    'confidence_threshold': 0.6,
                    'description': 'High-precision general entities'
                }
            },
            'routing_rules': {
                'sec_filings': {
                    '10-K': {
                        'financial_sections': ['finbert'],
                        'other_sections': ['biobert', 'bert_base', 'roberta']
                    },
                    '10-Q': {
                        'financial_sections': ['finbert'],
                        'other_sections': ['biobert', 'bert_base', 'roberta']
                    },
                    '8-K': {
                        'all_sections': ['biobert', 'bert_base', 'roberta', 'finbert']
                    }
                },
                'press_releases': {
                    'all_content': ['biobert', 'bert_base', 'roberta']
                },
                'patents': {
                    'all_content': ['biobert', 'bert_base', 'roberta']
                }
            },
            'entity_type_mapping': {
                # BioBERT mappings
                'Disease': 'MEDICAL_CONDITION',
                'Chemical': 'MEDICATION',
                'CHEMICAL': 'MEDICATION',
                'DISEASE': 'MEDICAL_CONDITION',
                'DRUG': 'MEDICATION',
                'Drug': 'MEDICATION',
                'Compound': 'MEDICATION',
                'Treatment': 'THERAPY',
                'Therapy': 'THERAPY',
                
                # BERT-base mappings
                'PER': 'PERSON',
                'ORG': 'ORGANIZATION',
                'LOC': 'LOCATION',
                'MISC': 'MISCELLANEOUS',
                
                # Financial mappings
                'MONEY': 'FINANCIAL',
                'PERCENT': 'FINANCIAL',
                'NUMBER': 'METRIC'
            }
        }
    
    def load_models(self) -> Dict:
        """Load all NER models specified in configuration"""
        print(f"📦 Loading {len(self.model_config['models'])} NER models...")
        
        for model_key, model_info in self.model_config['models'].items():
            try:
                model_name = model_info['model_name']
                print(f"   🧠 Loading {model_key}: {model_name}")
                
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                model = AutoModelForTokenClassification.from_pretrained(model_name)
                ner_pipeline = pipeline(
                    "ner", 
                    model=model, 
                    tokenizer=tokenizer,
                    aggregation_strategy="average", 
                    device=0 if torch.cuda.is_available() else -1
                )
                
                self.loaded_models[model_key] = {
                    'pipeline': ner_pipeline,
                    'threshold': model_info['confidence_threshold'],
                    'description': model_info['description'],
                    'stats': {
                        'texts_processed': 0,
                        'entities_found': 0,
                        'processing_time': 0
                    }
                }
                
                print(f"      ✓ {model_key} loaded (threshold: {model_info['confidence_threshold']})")
                
            except Exception as e:
                print(f"      ❌ {model_key} failed: {e}")
                # Use BERT-base as fallback for failed models
                if model_key != 'bert_base' and 'bert_base' in self.loaded_models:
                    self.loaded_models[model_key] = self.loaded_models['bert_base']
                    print(f"      🔄 Using BERT-base fallback for {model_key}")
        
        loaded_count = len(self.loaded_models)
        device = "GPU" if torch.cuda.is_available() else "CPU"
        
        print(f"   ✅ Successfully loaded: {loaded_count}/4 models on {device}")
        
        return self.loaded_models
    
    def get_routing_for_content(self, data_source: str, filing_type: str = None, section_name: str = None) -> List[str]:
        """Get model routing for specific content"""
        routing_rules = self.model_config['routing_rules']
        
        if data_source not in routing_rules:
            # Default to all text models for unknown data sources
            return ['biobert', 'bert_base', 'roberta']
        
        source_rules = routing_rules[data_source]
        
        if data_source == 'sec_filings' and filing_type:
            filing_rules = source_rules.get(filing_type, source_rules.get('10-K'))
            
            # Check if it's a financial section
            if section_name and any(fin_word in section_name.lower() 
                                  for fin_word in ['financial', 'statement', 'balance', 'income', 'cash']):
                return filing_rules.get('financial_sections', ['finbert'])
            else:
                return filing_rules.get('other_sections', ['biobert', 'bert_base', 'roberta'])
        
        elif data_source in ['press_releases', 'patents']:
            return source_rules['all_content']
        
        # Default routing
        return ['biobert', 'bert_base', 'roberta']
    
    def extract_entities_from_text(self, text: str, model_name: str, metadata: Dict = None) -> List[Dict]:
        """Extract entities from text using specific model"""
        if model_name not in self.loaded_models or not text.strip():
            return []
        
        try:
            start_time = time.time()
            model_info = self.loaded_models[model_name]
            
            # Run NER pipeline
            raw_entities = model_info['pipeline'](text)
            
            # Process and filter results
            processed_entities = []
            for entity in raw_entities:
                if entity['score'] >= model_info['threshold']:
                    # Normalize entity type
                    entity_type = self.model_config['entity_type_mapping'].get(
                        entity['entity_group'], entity['entity_group']
                    )
                    
                    processed_entity = {
                        'entity_text': entity['word'].strip(),
                        'entity_type': entity_type,
                        'confidence_score': float(entity['score']),
                        'char_start': entity['start'],
                        'char_end': entity['end'],
                        'model_source': model_name,
                        'original_label': entity['entity_group'],
                        'extraction_id': str(uuid.uuid4()),
                        'extraction_timestamp': datetime.now().isoformat()
                    }
                    
                    # Add metadata if provided
                    if metadata:
                        processed_entity.update(metadata)
                    
                    processed_entities.append(processed_entity)
            
            # Update statistics
            processing_time = time.time() - start_time
            model_info['stats']['texts_processed'] += 1
            model_info['stats']['entities_found'] += len(processed_entities)
            model_info['stats']['processing_time'] += processing_time
            
            return processed_entities
            
        except Exception as e:
            print(f"   ❌ {model_name} extraction failed: {e}")
            return []
    
    def process_sec_filing_sections(self, filing_result: Dict) -> List[Dict]:
        """Process SEC filing sections with appropriate model routing"""
        if filing_result.get('processing_status') != 'success':
            return []
        
        filing_id = filing_result['filing_id']
        filing_type = filing_result['filing_type']
        company_domain = filing_result['company_domain']
        sections = filing_result['sections']
        
        print(f"🔍 Processing {len(sections)} sections for {company_domain} {filing_type}")
        
        all_entities = []
        
        for section_name, section_text in sections.items():
            # Get routing for this section
            applicable_models = self.get_routing_for_content('sec_filings', filing_type, section_name)
            applicable_models = [m for m in applicable_models if m in self.loaded_models]
            
            if not applicable_models:
                continue
            
            print(f"   📑 Processing '{section_name}' with {len(applicable_models)} models")
            
            # Metadata for this section
            section_metadata = {
                'filing_id': filing_id,
                'company_domain': company_domain,
                'filing_type': filing_type,
                'section_name': section_name,
                'sec_filing_ref': f'SEC_{filing_id}',
                'data_source': 'sec_filings'
            }
            
            # Process with each applicable model
            section_entities = []
            for model_name in applicable_models:
                model_entities = self.extract_entities_from_text(section_text, model_name, section_metadata)
                section_entities.extend(model_entities)
                print(f"      • {model_name}: {len(model_entities)} entities")
            
            all_entities.extend(section_entities)
        
        # Merge entities at same positions
        merged_entities = self.merge_position_overlaps(all_entities)
        print(f"   🔗 Merged: {len(all_entities)} → {len(merged_entities)} entities")
        
        return merged_entities
    
    def merge_position_overlaps(self, entities: List[Dict]) -> List[Dict]:
        """Merge entities detected at same position by different models"""
        if not entities:
            return []
        
        # Group entities by position
        position_groups = {}
        
        for entity in entities:
            # Create position key
            pos_key = f"{entity.get('filing_id', 'unknown')}_{entity.get('section_name', 'unknown')}_{entity.get('char_start', 0)}_{entity.get('char_end', 0)}"
            
            if pos_key not in position_groups:
                position_groups[pos_key] = []
            position_groups[pos_key].append(entity)
        
        merged_entities = []
        
        for pos_key, group in position_groups.items():
            if len(group) == 1:
                merged_entities.append(group[0])
            else:
                # Merge multiple detections
                merged = self._merge_entity_group(group)
                merged_entities.append(merged)
        
        return merged_entities
    
    def _merge_entity_group(self, entities: List[Dict]) -> Dict:
        """Merge entities from different models at same position"""
        # Priority for biotech domain: BioBERT > FinBERT > RoBERTa > BERT-base
        priority = {'biobert': 4, 'finbert': 3, 'roberta': 2, 'bert_base': 1}
        
        # Sort by priority, then by confidence
        entities.sort(key=lambda x: (priority.get(x['model_source'], 0), x['confidence_score']), reverse=True)
        
        # Use best entity as base
        best = entities[0].copy()
        
        # Add multi-model metadata
        best['models_detected'] = [e['model_source'] for e in entities]
        best['all_confidences'] = {e['model_source']: e['confidence_score'] for e in entities}
        best['primary_model'] = best['model_source']
        best['entity_variations'] = {e['model_source']: e['entity_text'] for e in entities}
        best['is_merged'] = True
        best['confidence_score'] = max(e['confidence_score'] for e in entities)
        
        return best
    
    def add_data_source_support(self, source_name: str, routing_config: Dict, processor_func: callable = None):
        """Dynamically add support for new data sources"""
        # Add to routing rules
        self.model_config['routing_rules'][source_name] = routing_config
        
        # Add to supported sources
        if source_name not in self.pipeline_stats['data_sources_supported']:
            self.pipeline_stats['data_sources_supported'].append(source_name)
        
        print(f"✅ Added support for data source: {source_name}")
        print(f"   📊 Routing config: {routing_config}")
    
    def get_pipeline_statistics(self) -> Dict:
        """Get comprehensive pipeline statistics"""
        return {
            'pipeline_stats': self.pipeline_stats.copy(),
            'model_stats': {
                name: info['stats'].copy() 
                for name, info in self.loaded_models.items()
            },
            'loaded_models': list(self.loaded_models.keys()),
            'supported_data_sources': self.pipeline_stats['data_sources_supported'],
            'device': "GPU" if torch.cuda.is_available() else "CPU"
        }

# Initialize the EntityExtractionPipeline
entity_pipeline = EntityExtractionPipeline(NEON_CONFIG)

# Load all NER models
loaded_models = entity_pipeline.load_models()

# Test the pipeline with sample text
if loaded_models:
    test_text = "Pfizer's COVID-19 vaccine generated $37 billion in revenue. The FDA approved treatment for Alzheimer's disease in Boston."
    print(f"\n🧪 Testing models with sample text...")
    
    for model_name, model_info in loaded_models.items():
        try:
            test_entities = entity_pipeline.extract_entities_from_text(test_text, model_name)
            print(f"   {model_name}: Found {len(test_entities)} entities")
            for entity in test_entities[:2]:
                print(f"      • {entity['entity_type']}: '{entity['entity_text']}' ({entity['confidence_score']:.3f})")
        except Exception as e:
            print(f"   {model_name}: Test failed - {e}")

print(f"\n✅ EntityExtractionPipeline ready!")
print(f"   🎯 Loaded models: {list(loaded_models.keys())}")
print(f"   🔧 Supported sources: {entity_pipeline.pipeline_stats['data_sources_supported']}")
print(f"   💻 Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
print(f"   📊 Ready for biotech-optimized entity extraction!")

In [None]:
# Cell 5: GPU-Optimized Entity Storage and Database Integration

import psycopg2
from psycopg2.extras import execute_values
import numpy as np

class PipelineEntityStorage:
    """Enhanced storage system for EntityExtractionPipeline results"""
    
    def __init__(self, db_config: Dict):
        self.db_config = db_config
        self.storage_stats = {
            'total_entities_stored': 0,
            'filings_processed': 0,
            'merged_entities': 0,
            'single_model_entities': 0,
            'failed_inserts': 0
        }
        
        # Ensure enhanced table structure
        self._ensure_enhanced_table()
    
    def _ensure_enhanced_table(self):
        """Ensure table has all required columns for pipeline storage"""
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            # Check existing columns
            cursor.execute("""
                SELECT column_name 
                FROM information_schema.columns 
                WHERE table_schema = 'system_uno' 
                AND table_name = 'sec_entities_raw'
            """)
            
            existing_columns = {row[0] for row in cursor.fetchall()}
            
            # Required columns for pipeline
            required_columns = {
                'models_detected': 'TEXT[]',
                'all_confidences': 'JSONB',
                'primary_model': 'TEXT',
                'entity_variations': 'JSONB',
                'is_merged': 'BOOLEAN DEFAULT FALSE',
                'section_name': 'TEXT',
                'data_source': 'TEXT DEFAULT \'sec_filings\'',
                'extraction_timestamp': 'TIMESTAMP DEFAULT CURRENT_TIMESTAMP',
                'original_label': 'TEXT'
            }
            
            # Add missing columns
            for col_name, col_type in required_columns.items():
                if col_name not in existing_columns:
                    cursor.execute(f'ALTER TABLE system_uno.sec_entities_raw ADD COLUMN {col_name} {col_type}')
                    print(f"   ✓ Added column: {col_name}")
            
            # Create indexes for performance
            index_queries = [
                'CREATE INDEX IF NOT EXISTS idx_sec_entities_pipeline_position ON system_uno.sec_entities_raw (sec_filing_ref, section_name, character_start, character_end)',
                'CREATE INDEX IF NOT EXISTS idx_sec_entities_models ON system_uno.sec_entities_raw USING GIN (models_detected)',
                'CREATE INDEX IF NOT EXISTS idx_sec_entities_source ON system_uno.sec_entities_raw (data_source, primary_model)',
                'CREATE INDEX IF NOT EXISTS idx_sec_entities_merged ON system_uno.sec_entities_raw (is_merged, entity_category)'
            ]
            
            for idx_query in index_queries:
                cursor.execute(idx_query)
            
            conn.commit()
            cursor.close()
            conn.close()
            print("✓ Enhanced table structure verified")
            
        except Exception as e:
            print(f"⚠️ Table enhancement failed: {e}")
    
    def store_pipeline_entities(self, entities: List[Dict]) -> bool:
        """Store entities from EntityExtractionPipeline"""
        if not entities:
            return True
        
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            print(f"💾 Storing {len(entities)} pipeline entities...")
            
            # Debug: check section names in entities
            section_names = [e.get('section_name', 'MISSING') for e in entities]
            unique_sections = set(section_names)
            print(f"   🔍 Section names in entities: {unique_sections}")
            
            # Prepare batch insert data
            insert_data = []
            merged_count = 0
            
            for entity in entities:
                # Convert entity data for database
                models_detected = entity.get('models_detected', [entity.get('model_source', 'unknown')])
                if not isinstance(models_detected, list):
                    models_detected = [str(models_detected)]
                
                all_confidences = entity.get('all_confidences', {entity.get('model_source', 'unknown'): entity.get('confidence_score', 0.0)})
                entity_variations = entity.get('entity_variations', {entity.get('model_source', 'unknown'): entity.get('entity_text', '')})
                
                is_merged = entity.get('is_merged', False)
                if is_merged:
                    merged_count += 1
                
                # Ensure section_name is properly captured
                section_name = entity.get('section_name', '')
                if not section_name:
                    # Debug missing section name
                    print(f"   ⚠️ Missing section name for entity: {entity.get('entity_text', 'unknown')[:30]}...")
                
                insert_tuple = (
                    entity.get('extraction_id', str(uuid.uuid4())),
                    entity.get('company_domain', ''),
                    entity.get('entity_text', '').strip()[:1000],
                    entity.get('entity_type', 'UNKNOWN'),
                    float(entity.get('confidence_score', 0.0)),
                    int(entity.get('char_start', 0)),
                    int(entity.get('char_end', 0)),
                    entity.get('surrounding_text', '')[:2000] if entity.get('surrounding_text') else '',
                    entity.get('sec_filing_ref', ''),
                    models_detected,
                    json.dumps(all_confidences),
                    entity.get('primary_model', entity.get('model_source', 'unknown')),
                    json.dumps(entity_variations),
                    is_merged,
                    section_name,  # Store section name properly
                    entity.get('data_source', 'sec_filings'),
                    entity.get('extraction_timestamp'),
                    entity.get('original_label', '')
                )
                
                insert_data.append(insert_tuple)
            
            # Batch insert with enhanced schema
            insert_query = """
                INSERT INTO system_uno.sec_entities_raw 
                (extraction_id, company_domain, entity_text, entity_category, 
                 confidence_score, character_start, character_end, surrounding_text, 
                 sec_filing_ref, models_detected, all_confidences, primary_model,
                 entity_variations, is_merged, section_name, data_source,
                 extraction_timestamp, original_label)
                VALUES %s
                ON CONFLICT (extraction_id) DO UPDATE SET
                    models_detected = EXCLUDED.models_detected,
                    all_confidences = EXCLUDED.all_confidences,
                    primary_model = EXCLUDED.primary_model,
                    entity_variations = EXCLUDED.entity_variations,
                    is_merged = EXCLUDED.is_merged,
                    section_name = EXCLUDED.section_name
            """
            
            execute_values(cursor, insert_query, insert_data, page_size=100)
            rows_affected = cursor.rowcount
            
            conn.commit()
            cursor.close()
            conn.close()
            
            # Update statistics
            self.storage_stats['total_entities_stored'] += len(entities)
            self.storage_stats['merged_entities'] += merged_count
            self.storage_stats['single_model_entities'] += len(entities) - merged_count
            
            print(f"   ✅ Stored {rows_affected} entities")
            print(f"   🔗 Merged: {merged_count}, Single-model: {len(entities) - merged_count}")
            print(f"   📑 Sections represented: {len(unique_sections)} unique section names")
            
            return True
            
        except Exception as e:
            print(f"   ❌ Storage failed: {e}")
            self.storage_stats['failed_inserts'] += len(entities)
            return False
    
    def get_storage_verification(self, sec_filing_ref: str) -> Dict:
        """Verify stored entities for a specific filing"""
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            # Enhanced verification with pipeline-specific fields
            cursor.execute("""
                SELECT 
                    entity_category,
                    COUNT(*) as total_count,
                    AVG(confidence_score) as avg_confidence,
                    COUNT(*) FILTER (WHERE is_merged = true) as merged_count,
                    COUNT(DISTINCT primary_model) as models_used,
                    COUNT(DISTINCT section_name) as sections_processed,
                    MAX(extraction_timestamp) as latest_extraction
                FROM system_uno.sec_entities_raw
                WHERE sec_filing_ref = %s
                GROUP BY entity_category
                ORDER BY total_count DESC
            """, (sec_filing_ref,))
            
            category_stats = cursor.fetchall()
            
            # Model performance breakdown
            cursor.execute("""
                SELECT 
                    primary_model,
                    COUNT(*) as entities,
                    AVG(confidence_score) as avg_conf,
                    COUNT(DISTINCT section_name) as sections
                FROM system_uno.sec_entities_raw
                WHERE sec_filing_ref = %s
                GROUP BY primary_model
                ORDER BY entities DESC
            """, (sec_filing_ref,))
            
            model_stats = cursor.fetchall()
            
            cursor.close()
            conn.close()
            
            return {
                'filing_ref': sec_filing_ref,
                'total_entities': sum(stat[1] for stat in category_stats),
                'categories': [{
                    'category': stat[0],
                    'count': stat[1],
                    'avg_confidence': float(stat[2]),
                    'merged_count': stat[3],
                    'models_used': stat[4],
                    'sections_processed': stat[5]
                } for stat in category_stats],
                'model_performance': [{
                    'model': stat[0],
                    'entities': stat[1],
                    'avg_confidence': float(stat[2]),
                    'sections': stat[3]
                } for stat in model_stats]
            }
            
        except Exception as e:
            print(f"❌ Verification failed: {e}")
            return {}

# Initialize enhanced storage
pipeline_storage = PipelineEntityStorage(NEON_CONFIG)

# Complete pipeline processing function
def process_filing_with_pipeline(filing_data: Dict) -> Dict:
    """Process single filing through complete pipeline"""
    try:
        start_time = time.time()
        
        # Step 1: Extract sections using accession-based method
        section_result = process_sec_filing_with_sections(filing_data)
        
        if section_result['processing_status'] != 'success':
            return {
                'success': False,
                'filing_id': filing_data.get('id'),
                'error': section_result.get('error', 'Section extraction failed'),
                'processing_time': time.time() - start_time
            }
        
        # Step 2: Process sections through EntityExtractionPipeline
        entities = entity_pipeline.process_sec_filing_sections(section_result)
        
        if not entities:
            return {
                'success': False,
                'filing_id': filing_data.get('id'),
                'error': 'No entities extracted',
                'processing_time': time.time() - start_time
            }
        
        # Step 3: Store entities
        storage_success = pipeline_storage.store_pipeline_entities(entities)
        
        # Step 4: Get verification
        verification = pipeline_storage.get_storage_verification(f"SEC_{filing_data.get('id')}")
        
        processing_time = time.time() - start_time
        
        result = {
            'success': storage_success,
            'filing_id': filing_data.get('id'),
            'company_domain': filing_data.get('company_domain'),
            'filing_type': filing_data.get('filing_type'),
            'sections_processed': section_result.get('total_sections', 0),
            'entities_extracted': len(entities),
            'entities_stored': verification.get('total_entities', 0),
            'processing_time': round(processing_time, 2),
            'verification': verification,
            'sample_entities': entities[:3]  # First 3 entities as sample
        }
        
        return result
        
    except Exception as e:
        return {
            'success': False,
            'filing_id': filing_data.get('id'),
            'error': str(e),
            'processing_time': time.time() - start_time
        }

def process_filings_batch(limit: int = 3) -> Dict:
    """Process multiple filings in batch"""
    print(f"🚀 Processing batch of {limit} SEC filings with complete pipeline...")
    
    batch_start = time.time()
    
    # Get unprocessed filings
    filings = get_unprocessed_filings(limit)
    
    if not filings:
        return {'success': False, 'message': 'No filings to process'}
    
    print(f"📄 Found {len(filings)} filings to process")
    
    # Process each filing
    results = []
    successful = 0
    total_entities = 0
    
    for i, filing in enumerate(filings, 1):
        print(f"\n📄 [{i}/{len(filings)}] Processing {filing['filing_type']} for {filing['company_domain']}")
        
        result = process_filing_with_pipeline(filing)
        results.append(result)
        
        if result['success']:
            successful += 1
            total_entities += result.get('entities_extracted', 0)
            print(f"   ✅ Success: {result['entities_extracted']} entities extracted in {result['processing_time']}s")
            
            # Show sample entities with section names
            for entity in result.get('sample_entities', []):
                models = '+'.join(entity.get('models_detected', [entity.get('model_source', 'unknown')]))
                section = entity.get('section_name', 'NO_SECTION')
                print(f"      • {entity['entity_type']}: '{entity['entity_text']}' ({models}, {entity['confidence_score']:.3f}, {section})")
        else:
            print(f"   ❌ Failed: {result.get('error', 'Unknown error')}")
        
        # Brief pause between filings
        if i < len(filings):
            time.sleep(1)
    
    batch_time = time.time() - batch_start
    
    # Update pipeline statistics
    entity_pipeline.pipeline_stats['documents_processed'] += successful
    entity_pipeline.pipeline_stats['total_entities_extracted'] += total_entities
    entity_pipeline.pipeline_stats['processing_time_total'] += batch_time
    
    # Update storage statistics
    pipeline_storage.storage_stats['filings_processed'] += successful
    
    batch_summary = {
        'success': successful > 0,
        'filings_processed': len(filings),
        'successful_filings': successful,
        'failed_filings': len(filings) - successful,
        'total_entities_extracted': total_entities,
        'batch_processing_time': round(batch_time, 2),
        'avg_time_per_filing': round(batch_time / len(filings), 2),
        'results': results
    }
    
    return batch_summary

# Test the complete pipeline
print("🧪 Testing complete pipeline with sample filing...")

test_filings = get_unprocessed_filings(limit=1)

if test_filings:
    test_result = process_filing_with_pipeline(test_filings[0])
    
    if test_result['success']:
        print(f"✅ Pipeline test successful!")
        print(f"   📊 Sections processed: {test_result['sections_processed']}")
        print(f"   🔍 Entities extracted: {test_result['entities_extracted']}")
        print(f"   💾 Entities stored: {test_result['entities_stored']}")
        print(f"   ⏱️ Processing time: {test_result['processing_time']}s")
        
        # Show verification details
        if test_result['verification']:
            print(f"   📈 Verification:")
            for cat in test_result['verification']['categories'][:3]:
                print(f"      • {cat['category']}: {cat['count']} entities")
    else:
        print(f"❌ Pipeline test failed: {test_result.get('error')}")
else:
    print("📭 No test filings available")

print(f"\n✅ Complete EntityExtractionPipeline with enhanced storage ready!")
print(f"🎯 Usage: batch_results = process_filings_batch(limit=5)")
print(f"📊 Features: Section extraction → NER processing → Database storage")
print(f"🔧 Enhanced: Multi-model merging, performance tracking, detailed verification")
print(f"🏷️ Fixed: Section name attribution and storage verification")

# Context Retrieval System for Relationship Extraction
class ContextRetrievalSystem:
    """System to retrieve context around entities for relationship extraction"""
    
    def __init__(self, db_config: Dict):
        self.db_config = db_config
        self.retrieval_stats = {
            'contexts_retrieved': 0,
            'section_fetches': 0,
            'failed_retrievals': 0,
            'cache_hits': 0
        }
        # Simple cache for section content to avoid repeated API calls
        self._section_cache = {}
    
    def get_entity_context(self, entity_id: str, context_window: int = 500) -> Dict:
        """Get context around a specific entity using its position and section"""
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            # Get entity details with section and position
            cursor.execute("""
                SELECT 
                    entity_text, section_name, character_start, character_end,
                    sec_filing_ref, company_domain, entity_category,
                    confidence_score, primary_model
                FROM system_uno.sec_entities_raw
                WHERE extraction_id = %s
                  AND section_name IS NOT NULL 
                  AND section_name != ''
            """, (entity_id,))
            
            entity_data = cursor.fetchone()
            cursor.close()
            conn.close()
            
            if not entity_data:
                return {
                    'success': False, 
                    'error': 'Entity not found or missing section name',
                    'entity_id': entity_id
                }
            
            entity_text, section_name, char_start, char_end, sec_filing_ref, company_domain, category, confidence, model = entity_data
            
            # Extract filing ID from reference
            filing_id = sec_filing_ref.replace('SEC_', '') if sec_filing_ref else None
            if not filing_id:
                return {
                    'success': False,
                    'error': 'Could not extract filing ID',
                    'entity_id': entity_id
                }
            
            # Get filing URL to extract accession
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            cursor.execute("""
                SELECT url, filing_type, title
                FROM raw_data.sec_filings
                WHERE id = %s
            """, (filing_id,))
            
            filing_info = cursor.fetchone()
            cursor.close()
            conn.close()
            
            if not filing_info:
                return {
                    'success': False,
                    'error': f'Filing {filing_id} not found',
                    'entity_id': entity_id
                }
            
            filing_url, filing_type, filing_title = filing_info
            
            # Get section content
            section_content = self._get_section_content(filing_url, filing_type, section_name)
            
            if not section_content:
                return {
                    'success': False,
                    'error': f'Could not retrieve section {section_name}',
                    'entity_id': entity_id
                }
            
            # Extract context around entity position
            context_start = max(0, char_start - context_window)
            context_end = min(len(section_content), char_end + context_window)
            
            context_text = section_content[context_start:context_end]
            
            # Find entity position within context
            entity_pos_in_context = char_start - context_start
            
            self.retrieval_stats['contexts_retrieved'] += 1
            
            return {
                'success': True,
                'entity_id': entity_id,
                'entity_text': entity_text,
                'entity_category': category,
                'context_text': context_text,
                'entity_position_in_context': entity_pos_in_context,
                'entity_length': char_end - char_start,
                'context_window_used': context_window,
                'section_name': section_name,
                'filing_info': {
                    'filing_id': filing_id,
                    'company_domain': company_domain,
                    'filing_type': filing_type,
                    'filing_title': filing_title
                },
                'metadata': {
                    'confidence_score': confidence,
                    'primary_model': model,
                    'original_char_start': char_start,
                    'original_char_end': char_end,
                    'section_length': len(section_content)
                }
            }
            
        except Exception as e:
            self.retrieval_stats['failed_retrievals'] += 1
            return {
                'success': False,
                'error': str(e),
                'entity_id': entity_id
            }
    
    def _get_section_content(self, filing_url: str, filing_type: str, section_name: str) -> str:
        """Get section content using EdgarTools, with caching"""
        cache_key = f"{filing_url}#{section_name}"
        
        # Check cache first
        if cache_key in self._section_cache:
            self.retrieval_stats['cache_hits'] += 1
            return self._section_cache[cache_key]
        
        try:
            from edgar import find
            from edgar.documents import parse_html
            from edgar.documents.extractors.section_extractor import SectionExtractor
            import re
            
            # Extract accession from URL
            accession_match = re.search(r'(\d{10}-\d{2}-\d{6})', filing_url)
            if not accession_match:
                compressed_match = re.search(r'/(\d{18})/', filing_url)
                if compressed_match:
                    accession = compressed_match.group(1)
                    accession_number = f"{accession[:10]}-{accession[10:12]}-{accession[12:]}"
                else:
                    return None
            else:
                accession_number = accession_match.group(1)
            
            # Get filing and extract sections
            filing = find(accession_number)
            if not filing:
                return None
            
            html_content = filing.html()
            if not html_content:
                return None
            
            document = parse_html(html_content)
            extractor = SectionExtractor(filing_type=filing_type)
            sections = extractor.extract(document)
            
            self.retrieval_stats['section_fetches'] += 1
            
            # Find the specific section
            if section_name in sections:
                section_obj = sections[section_name]
                if hasattr(section_obj, 'text'):
                    content = section_obj.text() if callable(section_obj.text) else section_obj.text
                else:
                    content = str(section_obj)
                
                # Cache the result
                self._section_cache[cache_key] = content
                return content
            
            return None
            
        except Exception as e:
            print(f"⚠️ Section retrieval failed for {section_name}: {e}")
            return None
    
    def get_co_occurring_entities(self, filing_id: str, section_name: str, distance_threshold: int = 1000) -> List[Dict]:
        """Get entities that co-occur within a distance threshold in the same section"""
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            # Get all entities in the same section
            cursor.execute("""
                SELECT 
                    extraction_id, entity_text, entity_category, character_start, character_end,
                    confidence_score, primary_model
                FROM system_uno.sec_entities_raw
                WHERE sec_filing_ref = %s
                  AND section_name = %s
                  AND section_name IS NOT NULL
                ORDER BY character_start
            """, (f'SEC_{filing_id}', section_name))
            
            entities = cursor.fetchall()
            cursor.close()
            conn.close()
            
            if len(entities) < 2:
                return []
            
            # Find co-occurring pairs
            co_occurrences = []
            
            for i in range(len(entities)):
                for j in range(i + 1, len(entities)):
                    entity1 = entities[i]
                    entity2 = entities[j]
                    
                    # Calculate distance between entities
                    distance = entity2[3] - entity1[4]  # start of entity2 - end of entity1
                    
                    if 0 <= distance <= distance_threshold:
                        co_occurrences.append({
                            'entity1': {
                                'id': entity1[0],
                                'text': entity1[1],
                                'category': entity1[2],
                                'start': entity1[3],
                                'end': entity1[4],
                                'confidence': entity1[5],
                                'model': entity1[6]
                            },
                            'entity2': {
                                'id': entity2[0],
                                'text': entity2[1],
                                'category': entity2[2],
                                'start': entity2[3],
                                'end': entity2[4],
                                'confidence': entity2[5],
                                'model': entity2[6]
                            },
                            'distance': distance,
                            'section_name': section_name,
                            'filing_id': filing_id
                        })
            
            return co_occurrences
            
        except Exception as e:
            print(f"❌ Co-occurrence search failed: {e}")
            return []
    
    def get_context_for_entity_pair(self, entity1_id: str, entity2_id: str, context_window: int = 300) -> Dict:
        """Get shared context for a pair of entities"""
        try:
            # Get context for both entities
            context1 = self.get_entity_context(entity1_id, context_window)
            context2 = self.get_entity_context(entity2_id, context_window)
            
            if not (context1['success'] and context2['success']):
                return {
                    'success': False,
                    'error': 'Could not retrieve context for one or both entities'
                }
            
            # Verify they're from same section
            if context1['section_name'] != context2['section_name']:
                return {
                    'success': False,
                    'error': 'Entities are not in the same section'
                }
            
            # Find overlapping context or create combined context
            entity1_start = context1['metadata']['original_char_start']
            entity1_end = context1['metadata']['original_char_end']
            entity2_start = context2['metadata']['original_char_start']
            entity2_end = context2['metadata']['original_char_end']
            
            # Calculate combined context boundaries
            combined_start = min(entity1_start, entity2_start) - context_window
            combined_end = max(entity1_end, entity2_end) + context_window
            
            # Get section content and extract combined context
            filing_url_query = f"""
                SELECT url, filing_type, title 
                FROM raw_data.sec_filings 
                WHERE id = %s
            """
            
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            cursor.execute(filing_url_query, (context1['filing_info']['filing_id'],))
            filing_info = cursor.fetchone()
            cursor.close()
            conn.close()
            
            if not filing_info:
                return {'success': False, 'error': 'Filing not found'}
            
            section_content = self._get_section_content(
                filing_info[0], filing_info[1], context1['section_name']
            )
            
            if not section_content:
                return {'success': False, 'error': 'Section content not found'}
            
            # Extract combined context
            combined_start = max(0, combined_start)
            combined_end = min(len(section_content), combined_end)
            combined_context = section_content[combined_start:combined_end]
            
            # Calculate relative positions within combined context
            entity1_rel_start = entity1_start - combined_start
            entity1_rel_end = entity1_end - combined_start
            entity2_rel_start = entity2_start - combined_start
            entity2_rel_end = entity2_end - combined_start
            
            return {
                'success': True,
                'entity1': context1['entity_text'],
                'entity2': context2['entity_text'],
                'combined_context': combined_context,
                'entity1_position': {'start': entity1_rel_start, 'end': entity1_rel_end},
                'entity2_position': {'start': entity2_rel_start, 'end': entity2_rel_end},
                'distance_between_entities': abs(entity1_start - entity2_start),
                'section_name': context1['section_name'],
                'filing_info': context1['filing_info']
            }
            
        except Exception as e:
            return {
                'success': False,
                'error': str(e)
            }
    
    def get_retrieval_statistics(self) -> Dict:
        """Get context retrieval system statistics"""
        cache_size = len(self._section_cache)
        return {
            **self.retrieval_stats,
            'cache_size': cache_size,
            'cache_hit_rate': self.retrieval_stats['cache_hits'] / max(1, self.retrieval_stats['section_fetches']) * 100
        }

# Initialize context retrieval system
context_retriever = ContextRetrievalSystem(NEON_CONFIG)

print(f"\n🔍 ContextRetrievalSystem initialized!")
print(f"   📊 Features: Entity context extraction, co-occurrence detection")
print(f"   🔧 Methods: get_entity_context(), get_co_occurring_entities(), get_context_for_entity_pair()")
print(f"   💾 Caching: Automatic section content caching to reduce API calls")
print(f"   📈 Ready for relationship extraction preparation!")

In [None]:
# Cell 6: Production Analytics and Monitoring Interface

def generate_pipeline_analytics_report() -> None:
    """Generate comprehensive analytics report for the pipeline"""
    print("\n" + "="*80)
    print("📊 ENTITYEXTRACTIONPIPELINE ANALYTICS DASHBOARD")
    print("="*80)
    
    # Database overview with enhanced metrics
    try:
        conn = psycopg2.connect(**NEON_CONFIG)
        cursor = conn.cursor()
        
        # Enhanced database statistics
        cursor.execute("""
            SELECT 
                COUNT(*) as total_entities,
                COUNT(DISTINCT company_domain) as companies,
                COUNT(DISTINCT sec_filing_ref) as filings,
                COUNT(DISTINCT entity_category) as entity_types,
                AVG(confidence_score) as avg_confidence,
                COUNT(*) FILTER (WHERE is_merged = true) as merged_entities,
                COUNT(DISTINCT primary_model) as active_models,
                COUNT(DISTINCT section_name) as sections_processed,
                MAX(extraction_timestamp) as last_extraction
            FROM system_uno.sec_entities_raw
            WHERE data_source = 'sec_filings'
        """)
        
        db_overview = cursor.fetchone()
        
        # Model performance comparison
        cursor.execute("""
            SELECT 
                primary_model,
                COUNT(*) as entity_count,
                AVG(confidence_score) as avg_confidence,
                COUNT(DISTINCT entity_category) as categories_found,
                COUNT(DISTINCT sec_filing_ref) as filings_processed,
                COUNT(*) FILTER (WHERE is_merged = true) as merged_contributions
            FROM system_uno.sec_entities_raw
            WHERE primary_model IS NOT NULL
            GROUP BY primary_model
            ORDER BY entity_count DESC
        """)
        
        model_performance = cursor.fetchall()
        
        # Entity category insights
        cursor.execute("""
            SELECT 
                entity_category,
                COUNT(*) as total_count,
                AVG(confidence_score) as avg_confidence,
                COUNT(DISTINCT primary_model) as models_detecting,
                COUNT(*) FILTER (WHERE is_merged = true) as multi_model_detections,
                COUNT(DISTINCT company_domain) as companies_with_category
            FROM system_uno.sec_entities_raw
            GROUP BY entity_category
            HAVING COUNT(*) >= 5
            ORDER BY total_count DESC
            LIMIT 15
        """)
        
        category_insights = cursor.fetchall()
        
        # Section-level analysis
        cursor.execute("""
            SELECT 
                section_name,
                COUNT(*) as entities_found,
                COUNT(DISTINCT entity_category) as categories,
                AVG(confidence_score) as avg_confidence,
                COUNT(DISTINCT primary_model) as models_used
            FROM system_uno.sec_entities_raw
            WHERE section_name IS NOT NULL AND section_name != ''
            GROUP BY section_name
            HAVING COUNT(*) >= 3
            ORDER BY entities_found DESC
            LIMIT 10
        """)
        
        section_analysis = cursor.fetchall()
        
        # Multi-model collaboration stats
        cursor.execute("""
            SELECT 
                array_length(models_detected, 1) as num_models,
                COUNT(*) as entity_count,
                AVG(confidence_score) as avg_confidence
            FROM system_uno.sec_entities_raw
            WHERE models_detected IS NOT NULL
            GROUP BY array_length(models_detected, 1)
            ORDER BY num_models
        """)
        
        collaboration_stats = cursor.fetchall()
        
        cursor.close()
        conn.close()
        
        # Display results
        if db_overview and db_overview[0]:
            print(f"\n📈 DATABASE OVERVIEW:")
            print(f"   Total Entities Extracted: {db_overview[0]:,}")
            print(f"   Companies Processed: {db_overview[1]:,}")
            print(f"   SEC Filings Analyzed: {db_overview[2]:,}")
            print(f"   Entity Categories Found: {db_overview[3]:,}")
            print(f"   Average Confidence Score: {db_overview[4]:.3f}")
            print(f"   Multi-Model Entities: {db_overview[5]:,} ({db_overview[5]/db_overview[0]*100:.1f}%)")
            print(f"   Active Models: {db_overview[6]:,}")
            print(f"   Sections Processed: {db_overview[7]:,}")
            print(f"   Last Extraction: {db_overview[8] or 'Never'}")
        
        if model_performance:
            print(f"\n🤖 MODEL PERFORMANCE COMPARISON:")
            for model_data in model_performance:
                model, count, conf, categories, filings, merged = model_data
                merged_pct = (merged / count * 100) if count > 0 else 0
                print(f"   {model:>12}: {count:>6,} entities | {conf:.3f} conf | {categories:>2} categories | {filings:>3} filings | {merged_pct:>4.1f}% merged")
        
        if category_insights:
            print(f"\n🏷️ ENTITY CATEGORY INSIGHTS:")
            print(f"   {'Category':<20} {'Count':<8} {'Avg Conf':<10} {'Models':<7} {'Multi-Model':<12} {'Companies':<10}")
            print(f"   {'-'*20} {'-'*8} {'-'*10} {'-'*7} {'-'*12} {'-'*10}")
            for cat_data in category_insights:
                cat, count, conf, models, merged, companies = cat_data
                merged_pct = (merged / count * 100) if count > 0 else 0
                print(f"   {cat:<20} {count:<8,} {conf:<10.3f} {models:<7} {merged_pct:<12.1f}% {companies:<10}")
        
        if section_analysis:
            print(f"\n📑 SECTION-LEVEL PERFORMANCE:")
            for section_data in section_analysis:
                section, entities, categories, conf, models = section_data
                print(f"   {section:<25}: {entities:>4} entities | {categories:>2} categories | {conf:.3f} conf | {models} models")
        
        if collaboration_stats:
            print(f"\n🔗 MULTI-MODEL COLLABORATION:")
            for collab_data in collaboration_stats:
                num_models, count, conf = collab_data
                num_models = num_models or 1
                print(f"   {num_models} model(s): {count:,} entities (avg conf: {conf:.3f})")
                
    except Exception as e:
        print(f"   ❌ Could not retrieve analytics: {e}")
    
    # Pipeline statistics
    pipeline_stats = entity_pipeline.get_pipeline_statistics()
    
    print(f"\n🔧 PIPELINE STATISTICS:")
    print(f"   Documents Processed: {pipeline_stats['pipeline_stats']['documents_processed']:,}")
    print(f"   Total Entities Found: {pipeline_stats['pipeline_stats']['total_entities_extracted']:,}")
    print(f"   Processing Time: {pipeline_stats['pipeline_stats']['processing_time_total']:.2f}s")
    print(f"   Device: {pipeline_stats['device']}")
    print(f"   Loaded Models: {', '.join(pipeline_stats['loaded_models'])}")
    print(f"   Supported Sources: {', '.join(pipeline_stats['supported_data_sources'])}")
    
    # Individual model statistics
    print(f"\n📊 INDIVIDUAL MODEL PERFORMANCE:")
    for model_name, stats in pipeline_stats['model_stats'].items():
        if stats['texts_processed'] > 0:
            entities_per_text = stats['entities_found'] / stats['texts_processed']
            avg_time = stats['processing_time'] / stats['texts_processed']
            print(f"   {model_name:>12}: {stats['texts_processed']:>4} texts | {stats['entities_found']:>5} entities | {entities_per_text:>4.1f} avg/text | {avg_time:>4.2f}s avg")
    
    # Storage statistics
    storage_stats = pipeline_storage.storage_stats
    if storage_stats['total_entities_stored'] > 0:
        print(f"\n💾 STORAGE STATISTICS:")
        print(f"   Entities Stored: {storage_stats['total_entities_stored']:,}")
        print(f"   Filings Processed: {storage_stats['filings_processed']:,}")
        print(f"   Merged Entities: {storage_stats['merged_entities']:,}")
        print(f"   Single-Model Entities: {storage_stats['single_model_entities']:,}")
        print(f"   Failed Inserts: {storage_stats['failed_inserts']:,}")
        
        merge_rate = (storage_stats['merged_entities'] / storage_stats['total_entities_stored'] * 100) if storage_stats['total_entities_stored'] > 0 else 0
        print(f"   Multi-Model Detection Rate: {merge_rate:.1f}%")
    
    print("\n" + "="*80)
    print("✅ EntityExtractionPipeline Analytics Complete!")
    print("="*80)

def run_production_batch(batch_size: int = 5, max_filings: int = 20) -> Dict:
    """Run production-ready batch processing with comprehensive monitoring"""
    
    print(f"🚀 PRODUCTION BATCH PROCESSING")
    print(f"   📦 Batch size: {batch_size}")
    print(f"   📊 Max filings: {max_filings}")
    print(f"   🖥️ Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
    print(f"   🤖 Active models: {len(entity_pipeline.loaded_models)}")
    
    production_start = time.time()
    
    # Get filings to process
    all_filings = get_unprocessed_filings(limit=max_filings)
    
    if not all_filings:
        return {
            'success': False,
            'message': 'No unprocessed filings found',
            'recommendation': 'All SEC filings may already be processed'
        }
    
    print(f"   📄 Found {len(all_filings)} unprocessed filings")
    
    # Initialize production tracking
    production_results = {
        'batch_config': {
            'batch_size': batch_size,
            'max_filings': max_filings,
            'total_filings_found': len(all_filings)
        },
        'processing_results': {
            'successful_filings': 0,
            'failed_filings': 0,
            'total_entities_extracted': 0,
            'total_sections_processed': 0,
            'processing_errors': []
        },
        'model_performance': {},
        'category_breakdown': {},
        'timing': {
            'start_time': datetime.now(),
            'batch_times': [],
            'avg_time_per_filing': 0
        },
        'quality_metrics': {
            'avg_confidence_score': 0,
            'multi_model_detection_rate': 0,
            'entities_per_filing': 0
        }
    }
    
    # Process in batches
    total_batches = (len(all_filings) + batch_size - 1) // batch_size
    
    for batch_num in range(total_batches):
        batch_start_idx = batch_num * batch_size
        batch_end_idx = min(batch_start_idx + batch_size, len(all_filings))
        batch_filings = all_filings[batch_start_idx:batch_end_idx]
        
        print(f"\n📦 Processing batch {batch_num + 1}/{total_batches} ({len(batch_filings)} filings)")
        
        batch_start_time = time.time()
        batch_entities = 0
        batch_sections = 0
        
        for i, filing in enumerate(batch_filings, 1):
            filing_start_time = time.time()
            
            print(f"   📄 [{batch_start_idx + i}/{len(all_filings)}] {filing['filing_type']} - {filing['company_domain']}")
            
            try:
                result = process_filing_with_pipeline(filing)
                
                if result['success']:
                    production_results['processing_results']['successful_filings'] += 1
                    batch_entities += result.get('entities_extracted', 0)
                    batch_sections += result.get('sections_processed', 0)
                    
                    print(f"      ✅ {result['entities_extracted']} entities, {result['sections_processed']} sections ({result['processing_time']}s)")
                    
                    # Track model performance
                    for entity in result.get('sample_entities', []):
                        model = entity.get('primary_model', 'unknown')
                        if model not in production_results['model_performance']:
                            production_results['model_performance'][model] = 0
                        production_results['model_performance'][model] += 1
                        
                        # Track categories
                        category = entity.get('entity_type', 'unknown')
                        if category not in production_results['category_breakdown']:
                            production_results['category_breakdown'][category] = 0
                        production_results['category_breakdown'][category] += 1
                    
                else:
                    production_results['processing_results']['failed_filings'] += 1
                    error_info = {
                        'filing_id': result.get('filing_id'),
                        'company': filing.get('company_domain'),
                        'error': result.get('error', 'Unknown error'),
                        'filing_type': filing.get('filing_type')
                    }
                    production_results['processing_results']['processing_errors'].append(error_info)
                    print(f"      ❌ Failed: {result.get('error', 'Unknown error')}")
                
            except Exception as e:
                production_results['processing_results']['failed_filings'] += 1
                error_info = {
                    'filing_id': filing.get('id'),
                    'company': filing.get('company_domain'),
                    'error': str(e),
                    'filing_type': filing.get('filing_type')
                }
                production_results['processing_results']['processing_errors'].append(error_info)
                print(f"      ❌ Exception: {e}")
        
        batch_time = time.time() - batch_start_time
        production_results['timing']['batch_times'].append(batch_time)
        
        # Update totals
        production_results['processing_results']['total_entities_extracted'] += batch_entities
        production_results['processing_results']['total_sections_processed'] += batch_sections
        
        print(f"   📊 Batch {batch_num + 1} complete: {batch_entities} entities, {batch_sections} sections ({batch_time:.1f}s)")
        
        # Brief pause between batches
        if batch_num < total_batches - 1:
            time.sleep(2)
    
    # Finalize results
    production_time = time.time() - production_start
    production_results['timing']['end_time'] = datetime.now()
    production_results['timing']['total_processing_time'] = production_time
    
    successful = production_results['processing_results']['successful_filings']
    if successful > 0:
        production_results['timing']['avg_time_per_filing'] = production_time / successful
        production_results['quality_metrics']['entities_per_filing'] = production_results['processing_results']['total_entities_extracted'] / successful
    
    # Success determination
    success_rate = successful / len(all_filings) if len(all_filings) > 0 else 0
    production_results['success'] = success_rate >= 0.8  # 80% success rate threshold
    production_results['success_rate'] = success_rate
    
    return production_results

def display_production_summary(results: Dict) -> None:
    """Display comprehensive production batch summary"""
    print(f"\n" + "="*80)
    print(f"🎯 PRODUCTION BATCH SUMMARY")
    print(f"="*80)
    
    if not results.get('success', False):
        print(f"❌ BATCH FAILED (Success rate: {results.get('success_rate', 0)*100:.1f}%)")
    else:
        print(f"✅ BATCH SUCCESSFUL (Success rate: {results.get('success_rate', 0)*100:.1f}%)")
    
    # Configuration and results
    config = results.get('batch_config', {})
    processing = results.get('processing_results', {})
    timing = results.get('timing', {})
    quality = results.get('quality_metrics', {})
    
    print(f"\n📊 PROCESSING RESULTS:")
    print(f"   Total Filings Found: {config.get('total_filings_found', 0)}")
    print(f"   Successful Filings: {processing.get('successful_filings', 0)}")
    print(f"   Failed Filings: {processing.get('failed_filings', 0)}")
    print(f"   Total Entities Extracted: {processing.get('total_entities_extracted', 0):,}")
    print(f"   Total Sections Processed: {processing.get('total_sections_processed', 0):,}")
    
    print(f"\n⏱️ TIMING ANALYSIS:")
    print(f"   Total Processing Time: {timing.get('total_processing_time', 0):.1f}s")
    print(f"   Average Time per Filing: {timing.get('avg_time_per_filing', 0):.2f}s")
    if timing.get('batch_times'):
        batch_times = timing['batch_times']
        print(f"   Fastest Batch: {min(batch_times):.1f}s")
        print(f"   Slowest Batch: {max(batch_times):.1f}s")
    
    print(f"\n📈 QUALITY METRICS:")
    print(f"   Entities per Filing: {quality.get('entities_per_filing', 0):.1f}")
    
    # Model performance
    model_perf = results.get('model_performance', {})
    if model_perf:
        print(f"\n🤖 MODEL CONTRIBUTIONS:")
        total_contributions = sum(model_perf.values())
        for model, count in sorted(model_perf.items(), key=lambda x: x[1], reverse=True):
            pct = (count / total_contributions * 100) if total_contributions > 0 else 0
            print(f"   {model:>12}: {count:>5} entities ({pct:>4.1f}%)")
    
    # Category breakdown
    categories = results.get('category_breakdown', {})
    if categories:
        print(f"\n🏷️ ENTITY CATEGORIES FOUND:")
        total_entities = sum(categories.values())
        for category, count in sorted(categories.items(), key=lambda x: x[1], reverse=True)[:10]:
            pct = (count / total_entities * 100) if total_entities > 0 else 0
            print(f"   {category:<20}: {count:>4} entities ({pct:>4.1f}%)")
    
    # Error summary
    errors = processing.get('processing_errors', [])
    if errors:
        print(f"\n❌ PROCESSING ERRORS ({len(errors)}):")
        error_types = {}
        for error in errors:
            error_type = error.get('error', 'Unknown').split(':')[0]
            error_types[error_type] = error_types.get(error_type, 0) + 1
        
        for error_type, count in sorted(error_types.items(), key=lambda x: x[1], reverse=True):
            print(f"   {error_type}: {count} occurrences")
    
    print(f"\n" + "="*80)

def analyze_model_agreement(filing_ref: str = None, limit: int = 100) -> Dict:
    """Analyze agreement between different models on entity detection"""
    try:
        conn = psycopg2.connect(**NEON_CONFIG)
        cursor = conn.cursor()
        
        base_query = """
            SELECT entity_text, models_detected, all_confidences, is_merged
            FROM system_uno.sec_entities_raw
            WHERE models_detected IS NOT NULL
        """
        
        params = []
        if filing_ref:
            base_query += " AND sec_filing_ref = %s"
            params.append(filing_ref)
        
        base_query += f" ORDER BY confidence_score DESC LIMIT {limit}"
        
        cursor.execute(base_query, params)
        results = cursor.fetchall()
        cursor.close()
        conn.close()
        
        analysis = {
            'total_entities': len(results),
            'agreement_stats': {},
            'high_agreement_entities': [],
            'disagreement_entities': []
        }
        
        for entity_text, models, confidences_json, is_merged in results:
            num_models = len(models) if models else 1
            
            if num_models not in analysis['agreement_stats']:
                analysis['agreement_stats'][num_models] = 0
            analysis['agreement_stats'][num_models] += 1
            
            # Parse confidences
            try:
                confidences = json.loads(confidences_json) if confidences_json else {}
            except:
                confidences = {}
            
            # High agreement: multiple models with similar confidence
            if num_models > 1 and confidences:
                conf_values = list(confidences.values())
                if len(conf_values) > 1:
                    import numpy as np
                    conf_std = np.std(conf_values)
                    if conf_std < 0.1:  # Low standard deviation = high agreement
                        analysis['high_agreement_entities'].append({
                            'entity': entity_text,
                            'models': models,
                            'confidences': confidences,
                            'std_dev': conf_std
                        })
                    else:
                        analysis['disagreement_entities'].append({
                            'entity': entity_text,
                            'models': models,
                            'confidences': confidences,
                            'std_dev': conf_std
                        })
        
        return analysis
        
    except Exception as e:
        print(f"❌ Model agreement analysis failed: {e}")
        return {}

# Generate initial analytics report
generate_pipeline_analytics_report()

print(f"\n🎯 PRODUCTION COMMANDS:")
print(f"   • Small batch:  results = run_production_batch(batch_size=3, max_filings=10)")
print(f"   • Medium batch: results = run_production_batch(batch_size=5, max_filings=25)")
print(f"   • Large batch:  results = run_production_batch(batch_size=8, max_filings=50)")
print(f"   • Show results: display_production_summary(results)")
print(f"   • Analytics:    generate_pipeline_analytics_report()")
print(f"   • Model agreement: agreement = analyze_model_agreement(limit=50)")

print(f"\n✅ EntityExtractionPipeline Production Interface Ready!")
print(f"🚀 Features: Comprehensive monitoring, error tracking, quality metrics")
print(f"📊 Ready for large-scale biotech SEC entity extraction!")