In [1]:
# Cell 1: GitHub Setup and Enhanced Import Configuration with PHASE 3 Improvements

# Install required packages first
!pip install edgartools transformers torch requests beautifulsoup4 'lxml[html_clean]' uuid numpy newspaper3k groq --quiet

import os
import sys
import importlib
import importlib.util
import psycopg2
from psycopg2.extras import execute_values
from psycopg2 import pool
import time
import json
import pickle
import traceback
from pathlib import Path
from functools import wraps
from contextlib import contextmanager
from collections import OrderedDict
from typing import Dict, List, Optional, Any
from datetime import datetime
from edgar import set_identity

# ============================================================================
# CENTRALIZED CONFIGURATION - All settings in one place
# ============================================================================

# Use Kaggle secrets for all sensitive credentials
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

# Database Configuration (using Kaggle secrets for security)
NEON_CONFIG = {
    'host': user_secrets.get_secret("NEON_HOST"),
    'database': user_secrets.get_secret("NEON_DATABASE"),
    'user': user_secrets.get_secret("NEON_USER"), 
    'password': user_secrets.get_secret("NEON_PASSWORD"),
    'sslmode': 'require'
}

# Master Configuration Dictionary - PHASE 3 ENHANCED
CONFIG = {
    # GitHub Settings
    'github': {
        'token': user_secrets.get_secret("GITHUB_TOKEN"),
        'repo_url': f"https://{user_secrets.get_secret('GITHUB_TOKEN')}@github.com/amiralpert/SmartReach.git",
        'local_path': "/kaggle/working/SmartReach"
    },
    
    # Database Settings
    'database': NEON_CONFIG,
    
    # Connection Pool Settings - PHASE 3
    'pool': {
        'min_connections': 2,
        'max_connections': 10,
        'keepalives': 1,
        'keepalives_idle': 30,
        'keepalives_interval': 10,
        'keepalives_count': 5
    },
    
    # Connection Retry Settings
    'retry': {
        'max_attempts': 3,
        'initial_delay': 1,  # seconds
        'exponential_base': 2,
        'max_delay': 30  # seconds
    },
    
    # Model Configuration - PHASE 3 ENHANCED
    'models': {
        'confidence_threshold': 0.8, #confidence NER model needs to store an entity 
        'batch_size': 16, # number of text chunks to process at one time through NER
        'max_length': 512, #max chunk size to process through NER 512 token limit for BERT models 
        'chunk_overlap': 0.1,  # 10% overlap between chunks for complete entity extraction
        'warm_up_enabled': True,  # PHASE 3: Model warm-up
        'warm_up_text': 'Pfizer announced FDA approval for new cancer drug targeting BRCA mutations.'
    },
    
    # Cache Settings - PHASE 2 ENHANCED
    'cache': {
        'enabled': True,
        'max_size_mb': 100,  # Maximum cache size in MB
        'ttl': 3600,  # seconds
        'eviction_policy': 'LRU'  # Least Recently Used
    },
    
    # Processing Settings - PHASE 3 ENHANCED  
    'processing': {
        'filing_batch_size': 3,
        'entity_batch_size': 100,  # Max entities per database insert
        'section_validation': True,  # Enforce section name validation
        'debug_mode': False,
        'max_insert_batch': 500,  # Maximum batch for database inserts
        'deprecation_warnings': True,  # Show warnings for deprecated functions
        'checkpoint_enabled': True,  # PHASE 3: Enable checkpointing
        'checkpoint_dir': '/kaggle/working/checkpoints',  # PHASE 3: Checkpoint directory
        'deduplication_threshold': 0.85  # PHASE 3: Similarity threshold for dedup
    },
    
    # EdgarTools Settings
    'edgar': {
        'identity': "SmartReach BizIntel amir@leanbio.consulting"
    }
}

if not CONFIG['github']['token']:
    raise ValueError("❌ GITHUB_TOKEN is required in Kaggle secrets")

if not CONFIG['database']['password']:
    raise ValueError("❌ NEON_PASSWORD is required in Kaggle secrets")

print("✅ Configuration loaded from Kaggle secrets")
print(f"   Database: {CONFIG['database']['host']}")
print(f"   Processing: Batch size={CONFIG['processing']['filing_batch_size']}, Section validation={CONFIG['processing']['section_validation']}")
print(f"   Cache: Max size={CONFIG['cache']['max_size_mb']}MB, TTL={CONFIG['cache']['ttl']}s")
print(f"   Database batching: Max insert batch={CONFIG['processing']['max_insert_batch']}")
print(f"   Checkpointing: {'Enabled' if CONFIG['processing']['checkpoint_enabled'] else 'Disabled'}")

# ============================================================================
# CONNECTION RETRY DECORATOR
# ============================================================================

def retry_on_connection_error(func):
    """Decorator to retry database operations on connection errors"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        max_attempts = CONFIG['retry']['max_attempts']
        delay = CONFIG['retry']['initial_delay']
        
        for attempt in range(max_attempts):
            try:
                return func(*args, **kwargs)
            except (psycopg2.OperationalError, psycopg2.InterfaceError) as e:
                if attempt == max_attempts - 1:
                    error_msg = f"ERROR [ConnectionRetry]: Failed after {max_attempts} attempts - {type(e).__name__}: {str(e)}"
                    print(error_msg)
                    if logger:
                        logger.log(error_msg)
                    raise
                
                wait_time = min(delay * (CONFIG['retry']['exponential_base'] ** attempt), CONFIG['retry']['max_delay'])
                print(f"WARNING [ConnectionRetry]: Attempt {attempt + 1}/{max_attempts} failed. Retrying in {wait_time}s...")
                time.sleep(wait_time)
        
        return None
    return wrapper

# ============================================================================
# PHASE 3: ENHANCED ERROR LOGGING WITH STACK TRACES
# ============================================================================

def log_error(component: str, message: str, exception: Exception = None, context: dict = None):
    """Enhanced error logging with stack traces"""
    if exception:
        error_msg = f"ERROR [{component}]: {message} - {type(exception).__name__}: {str(exception)}"
        # Add stack trace for debugging
        if CONFIG['processing'].get('debug_mode', False):
            error_msg += f"\nStack trace:\n{traceback.format_exc()}"
    else:
        error_msg = f"ERROR [{component}]: {message}"
    
    if context:
        error_msg += f" | Context: {context}"
    
    print(error_msg)  # Auto-logger captures this
    return error_msg

def log_warning(component: str, message: str, context: dict = None):
    """Standardized warning logging"""
    warning_msg = f"WARNING [{component}]: {message}"
    if context:
        warning_msg += f" | Context: {context}"
    print(warning_msg)
    return warning_msg

def log_info(component: str, message: str):
    """Standardized info logging"""
    info_msg = f"INFO [{component}]: {message}"
    print(info_msg)
    return info_msg

# ============================================================================
# PHASE 3: ENHANCED DATABASE MANAGER WITH CONNECTION POOLING
# ============================================================================

class DatabaseManager:
    """Singleton database manager with connection pooling"""
    
    _instance = None
    _pool = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(DatabaseManager, cls).__new__(cls)
        return cls._instance
    
    def __init__(self):
        if DatabaseManager._pool is None:
            self._initialize_pool()
    
    def _initialize_pool(self):
        """Initialize connection pool"""
        try:
            DatabaseManager._pool = psycopg2.pool.ThreadedConnectionPool(
                CONFIG['pool']['min_connections'],
                CONFIG['pool']['max_connections'],
                **CONFIG['database'],
                keepalives=CONFIG['pool']['keepalives'],
                keepalives_idle=CONFIG['pool']['keepalives_idle'],
                keepalives_interval=CONFIG['pool']['keepalives_interval'],
                keepalives_count=CONFIG['pool']['keepalives_count']
            )
            log_info("DatabaseManager", f"Connection pool initialized with {CONFIG['pool']['max_connections']} max connections")
        except Exception as e:
            log_error("DatabaseManager", "Failed to initialize connection pool", e)
            raise
    
    @contextmanager
    def get_connection(self):
        """Get connection from pool with automatic return"""
        conn = None
        try:
            conn = DatabaseManager._pool.getconn()
            yield conn
            conn.commit()
        except Exception as e:
            if conn:
                conn.rollback()
            raise e
        finally:
            if conn:
                DatabaseManager._pool.putconn(conn)
    
    def close_all(self):
        """Close all connections in pool"""
        if DatabaseManager._pool:
            DatabaseManager._pool.closeall()
            log_info("DatabaseManager", "All connections closed")
    
    def get_pool_status(self) -> Dict:
        """Get current pool status"""
        if DatabaseManager._pool:
            return {
                'minconn': DatabaseManager._pool.minconn,
                'maxconn': DatabaseManager._pool.maxconn,
                'closed': DatabaseManager._pool.closed
            }
        return {'status': 'not initialized'}

# Initialize global database manager
DB_MANAGER = DatabaseManager()

# PHASE 2: Keep backward compatibility
@contextmanager
def get_db_connection():
    """Legacy context manager - now uses DatabaseManager"""
    with DB_MANAGER.get_connection() as conn:
        yield conn

# ============================================================================
# PHASE 3: CHECKPOINT MANAGER FOR FAILURE RECOVERY
# ============================================================================

class CheckpointManager:
    """Manage pipeline checkpoints for failure recovery"""
    
    def __init__(self, checkpoint_dir: str = None):
        self.checkpoint_dir = Path(checkpoint_dir or CONFIG['processing']['checkpoint_dir'])
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.current_checkpoint = None
        log_info("CheckpointManager", f"Initialized with directory: {self.checkpoint_dir}")
    
    def save_checkpoint(self, state: Dict, checkpoint_name: str = None) -> str:
        """Save pipeline state to checkpoint"""
        try:
            if not checkpoint_name:
                checkpoint_name = f"checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
            
            checkpoint_path = self.checkpoint_dir / f"{checkpoint_name}.pkl"
            
            # Add metadata
            state['checkpoint_metadata'] = {
                'created_at': datetime.now().isoformat(),
                'pipeline_version': '3.0',
                'config_hash': hash(str(CONFIG))
            }
            
            with open(checkpoint_path, 'wb') as f:
                pickle.dump(state, f)
            
            self.current_checkpoint = checkpoint_path
            log_info("CheckpointManager", f"Saved checkpoint: {checkpoint_name}")
            return str(checkpoint_path)
            
        except Exception as e:
            log_error("CheckpointManager", "Failed to save checkpoint", e)
            return None
    
    def load_checkpoint(self, checkpoint_path: str = None) -> Optional[Dict]:
        """Load pipeline state from checkpoint"""
        try:
            if not checkpoint_path:
                checkpoint_path = self.current_checkpoint
            
            if not checkpoint_path:
                # Find latest checkpoint
                checkpoints = list(self.checkpoint_dir.glob("checkpoint_*.pkl"))
                if not checkpoints:
                    log_warning("CheckpointManager", "No checkpoints found")
                    return None
                checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime)
            
            with open(checkpoint_path, 'rb') as f:
                state = pickle.load(f)
            
            log_info("CheckpointManager", f"Loaded checkpoint: {Path(checkpoint_path).name}")
            return state
            
        except Exception as e:
            log_error("CheckpointManager", "Failed to load checkpoint", e)
            return None
    
    def list_checkpoints(self) -> List[Dict]:
        """List available checkpoints"""
        checkpoints = []
        for cp_file in self.checkpoint_dir.glob("checkpoint_*.pkl"):
            try:
                stats = cp_file.stat()
                checkpoints.append({
                    'name': cp_file.stem,
                    'path': str(cp_file),
                    'size_mb': stats.st_size / (1024 * 1024),
                    'modified': datetime.fromtimestamp(stats.st_mtime).isoformat()
                })
            except:
                continue
        return sorted(checkpoints, key=lambda x: x['modified'], reverse=True)
    
    def cleanup_old_checkpoints(self, keep_last: int = 5):
        """Clean up old checkpoints"""
        checkpoints = list(self.checkpoint_dir.glob("checkpoint_*.pkl"))
        checkpoints.sort(key=lambda p: p.stat().st_mtime, reverse=True)
        
        for cp_file in checkpoints[keep_last:]:
            try:
                cp_file.unlink()
                log_info("CheckpointManager", f"Deleted old checkpoint: {cp_file.name}")
            except:
                continue

# Initialize global checkpoint manager
CHECKPOINT_MANAGER = CheckpointManager()

# ============================================================================
# PHASE 2: SIZE-LIMITED LRU CACHE
# ============================================================================

class SizeLimitedLRUCache:
    """LRU cache with size limit in MB"""
    
    def __init__(self, max_size_mb: int):
        self.max_size_bytes = max_size_mb * 1024 * 1024
        self.cache = OrderedDict()
        self.current_size = 0
        self.hits = 0
        self.misses = 0
    
    def _estimate_size(self, value: str) -> int:
        """Estimate size of cached value in bytes"""
        return len(value.encode('utf-8')) if isinstance(value, str) else sys.getsizeof(value)
    
    def get(self, key: str):
        """Get item from cache"""
        if key in self.cache:
            self.hits += 1
            # Move to end (most recently used)
            self.cache.move_to_end(key)
            return self.cache[key]
        self.misses += 1
        return None
    
    def put(self, key: str, value, size: int = None):
        """Put item in cache with LRU eviction"""
        if size is None:
            size = self._estimate_size(value)
        
        # Remove old entries if needed
        while self.current_size + size > self.max_size_bytes and self.cache:
            evicted_key, evicted_value = self.cache.popitem(last=False)
            self.current_size -= self._estimate_size(evicted_value)
            log_info("Cache", f"Evicted {evicted_key} to maintain size limit")
        
        # Add new entry
        if key in self.cache:
            self.current_size -= self._estimate_size(self.cache[key])
        
        self.cache[key] = value
        self.current_size += size
        self.cache.move_to_end(key)
    
    def get_stats(self) -> dict:
        """Get cache statistics"""
        hit_rate = (self.hits / (self.hits + self.misses) * 100) if (self.hits + self.misses) > 0 else 0
        return {
            'entries': len(self.cache),
            'size_mb': self.current_size / (1024 * 1024),
            'hits': self.hits,
            'misses': self.misses,
            'hit_rate': hit_rate
        }

# Initialize global cache for EdgarTools sections
SECTION_CACHE = SizeLimitedLRUCache(CONFIG['cache']['max_size_mb'])

# ============================================================================
# PHASE 2: DEPRECATION WARNING SYSTEM
# ============================================================================

def deprecated(replacement_func: str = None):
    """Decorator to mark functions as deprecated"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            if CONFIG['processing']['deprecation_warnings']:
                msg = f"Function '{func.__name__}' is deprecated and will be removed."
                if replacement_func:
                    msg += f" Use '{replacement_func}' instead."
                log_warning("Deprecation", msg)
            return func(*args, **kwargs)
        return wrapper
    return decorator

# ============================================================================
# GITHUB SETUP
# ============================================================================

print("\n📦 Setting up GitHub repository...")
local_path = CONFIG['github']['local_path']
repo_url = CONFIG['github']['repo_url']

# Clone or update repo with force pull
if os.path.exists(local_path):
    log_info("GitHub", f"Repository exists at {local_path}")
    log_info("GitHub", "Force updating from main branch")
    !cd {local_path} && git fetch origin
    !cd {local_path} && git reset --hard origin/main
    !cd {local_path} && git pull origin main
    log_info("GitHub", "Repository updated successfully")
    
    # Show current commit
    !cd {local_path} && echo "Current commit:" && git log --oneline -1
else:
    log_info("GitHub", f"Cloning repository to {local_path}")
    !git clone {repo_url} {local_path}
    log_info("GitHub", "Repository cloned successfully")

# Clear any cached modules from previous runs
modules_to_clear = [key for key in sys.modules.keys() if 'auto_logger' in key.lower() or 'clean' in key.lower()]
for mod in modules_to_clear:
    del sys.modules[mod]
    log_info("ModuleCache", f"Cleared cached module: {mod}")

# Add to Python path for regular imports
if f'{local_path}/BizIntel' in sys.path:
    sys.path.remove(f'{local_path}/BizIntel')
sys.path.insert(0, f'{local_path}/BizIntel')

log_info("Setup", "Python path configured for SEC entity extraction")

# Configure EdgarTools authentication - REQUIRED by SEC
set_identity(CONFIG['edgar']['identity'])
log_info("EdgarTools", f"Identity configured: {CONFIG['edgar']['identity']}")

# ============================================================================
# AUTO-LOGGER SETUP
# ============================================================================

@retry_on_connection_error
def setup_logger():
    """Set up the auto-logger with retry logic"""
    log_info("AutoLogger", "Initializing clean logger with connection pooling")
    
    # Import the redesigned clean auto-logger
    logger_module_path = f"{local_path}/BizIntel/Scripts/KaggleLogger/auto_logger.py"
    if os.path.exists(logger_module_path):
        spec = importlib.util.spec_from_file_location("auto_logger", logger_module_path)
        auto_logger_module = importlib.util.module_from_spec(spec)
        sys.modules["auto_logger"] = auto_logger_module
        spec.loader.exec_module(auto_logger_module)

        # Use the new clean logging setup with DB_MANAGER for connection pooling
        setup_clean_logging = auto_logger_module.setup_clean_logging
        logger = setup_clean_logging(DB_MANAGER, "SEC_EntityExtraction")
        
        log_info("AutoLogger", "Clean auto-logging enabled successfully")
        print("📋 Features:")
        print("   • One row per cell execution")
        print("   • Complete output capture")
        print("   • Proper cell numbers from # Cell N: comments")
        print("   • Full error tracebacks")
        print("   • Execution timing")
        print("   • Connection pooling for reliability")
        return logger
    else:
        log_error("AutoLogger", f"Logger module not found at {logger_module_path}")
        return None

try:
    logger = setup_logger()
except Exception as e:
    log_error("AutoLogger", "Logger setup failed", e)
    log_warning("AutoLogger", "Continuing without auto-logging")
    logger = None

print("\n" + "="*80)
print("🚀 SEC ENTITY EXTRACTION ENGINE INITIALIZED - PHASE 3 PRODUCTION-READY")
print("="*80)
print(f"✅ GitHub: Repository at {CONFIG['github']['local_path']}")
print(f"✅ Database: Connected to {CONFIG['database']['host']}")
print(f"✅ EdgarTools: Configured as '{CONFIG['edgar']['identity']}'")
print(f"✅ Logging: {'Enabled' if logger else 'Disabled'}")
print(f"✅ Section Validation: {'ENFORCED' if CONFIG['processing']['section_validation'] else 'Disabled'}")
print(f"✅ PHASE 2 Enhancements:")
print(f"   • Database context manager: get_db_connection()")
print(f"   • Size-limited LRU cache: {CONFIG['cache']['max_size_mb']}MB")
print(f"   • Batch size limits: Max {CONFIG['processing']['max_insert_batch']} per insert")
print(f"   • Deprecation warnings: {'Enabled' if CONFIG['processing']['deprecation_warnings'] else 'Disabled'}")
print(f"✅ PHASE 3 PRODUCTION FEATURES:")
print(f"   • Connection Pool: {CONFIG['pool']['max_connections']} max connections")
print(f"   • Checkpoint System: Recovery from failures enabled")
print(f"   • Enhanced Error Logging: Stack traces included")
print(f"   • Model Warm-up: Enabled for faster first inference")
print(f"   • Entity Deduplication: {CONFIG['processing']['deduplication_threshold']} threshold")
print("="*80)

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m54.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m55.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m81.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m51.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [3

In [2]:
# Cell 2: Accession-Based SEC Content Extraction with EdgarTools - SIMPLIFIED

import edgar
from edgar import Filing, find, set_identity, Company
from edgar.documents import parse_html
from edgar.documents.extractors.section_extractor import SectionExtractor
import requests
import re
from typing import Dict, List, Optional, Tuple
from bs4 import BeautifulSoup

# Ensure identity is set
set_identity(CONFIG['edgar']['identity'])

# REMOVED extract_accession_from_url - No longer needed since database has accession_number column

def get_filing_sections(accession_number: str, filing_type: str = None) -> Dict[str, str]:
    """Get structured sections from SEC filing using accession number
    
    PHASE 2 ENHANCED: Using cache for sections to avoid repeated API calls
    """
    # Check cache first
    cache_key = f"{accession_number}#{filing_type or 'UNKNOWN'}"
    cached_sections = SECTION_CACHE.get(cache_key)
    if cached_sections:
        log_info("Cache", f"Cache hit for {accession_number}")
        return cached_sections
    
    try:
        # Find filing using accession number
        filing = find(accession_number)
        
        if not filing:
            raise ValueError(f"Filing not found for accession: {accession_number}")
            
        # Auto-detect filing type if not provided
        if not filing_type:
            filing_type = getattr(filing, 'form', '10-K')
        
        log_info("EdgarTools", f"Found {filing_type} for {getattr(filing, 'company', 'Unknown Company')}")
        
        # Get structured HTML content
        html_content = filing.html()
        if not html_content:
            raise ValueError("No HTML content available")
        
        # Parse HTML to Document object
        document = parse_html(html_content)
        
        # Extract sections using SectionExtractor
        extractor = SectionExtractor(filing_type=filing_type)
        sections = extractor.extract(document)
        
        log_info("EdgarTools", f"SectionExtractor found {len(sections)} sections")
        
        # Convert sections to text dictionary
        section_texts = {}
        for section_name, section in sections.items():
            try:
                if hasattr(section, 'text'):
                    text = section.text() if callable(section.text) else section.text
                    if isinstance(text, str) and text.strip():
                        section_texts[section_name] = text.strip()
                        print(f"      • {section_name}: {len(text):,} chars")
                elif hasattr(section, '__str__'):
                    text = str(section).strip()
                    if text:
                        section_texts[section_name] = text
                        print(f"      • {section_name}: {len(text):,} chars (via str)")
            except Exception as section_e:
                log_warning("EdgarTools", f"Could not extract section {section_name}", {"error": str(section_e)})
                continue
        
        # If SectionExtractor returns no sections, fall back to full document text
        if not section_texts:
            log_warning("EdgarTools", "No structured sections found, using full document fallback")
            full_text = document.text() if hasattr(document, 'text') and callable(document.text) else str(document)
            if full_text and len(full_text.strip()) > 100:  # Only use if substantial content
                section_texts['full_document'] = full_text.strip()
                log_warning('EdgarTools', f'Processing entire document ({len(full_text):,} chars) - this may take several minutes')
                log_info("EdgarTools", f"Using full document: {len(full_text):,} chars")
        
        # Cache the result
        if section_texts and CONFIG['cache']['enabled']:
            SECTION_CACHE.put(cache_key, section_texts)
            log_info("Cache", f"Cached sections for {accession_number} ({len(section_texts)} sections)")
        
        return section_texts
        
    except Exception as e:
        log_error("EdgarTools", f"Failed to fetch filing {accession_number}", e)
        return {}  # Return empty dict on network/API failure

def route_sections_to_models(sections: Dict[str, str], filing_type: str) -> Dict[str, List[str]]:
    """Route sections to appropriate NER models based on filing type"""
    routing = {
        'biobert': [],
        'bert_base': [],
        'roberta': [],
        'finbert': []
    }
    
    if filing_type.upper() in ['10-K', '10-Q']:
        for section_name, section_text in sections.items():
            # FinBERT gets financial statements exclusively
            if 'financial' in section_name.lower() or 'statement' in section_name.lower():
                routing['finbert'].append(section_name)
            else:
                # All other sections go to BERT/RoBERTa/BioBERT
                routing['bert_base'].append(section_name)
                routing['roberta'].append(section_name)
                routing['biobert'].append(section_name)
    
    elif filing_type.upper() == '8-K':
        # 8-K: all item sections go to all four models
        for section_name in sections.keys():
            routing['biobert'].append(section_name)
            routing['bert_base'].append(section_name)
            routing['roberta'].append(section_name)
            routing['finbert'].append(section_name)
    
    else:
        # Default routing for other filing types
        for section_name in sections.keys():
            routing['bert_base'].append(section_name)
            routing['roberta'].append(section_name)
            routing['biobert'].append(section_name)
    
    # Remove empty routing
    routing = {model: sections_list for model, sections_list in routing.items() if sections_list}
    
    return routing

def process_sec_filing_with_sections(filing_data: Dict) -> Dict:
    """Process SEC filing with section-based extraction
    
    SIMPLIFIED: Now uses accession_number directly from database
    """
    try:
        filing_id = filing_data.get('id')
        accession_number = filing_data.get('accession_number')  # DIRECT FROM DATABASE
        filing_type = filing_data.get('filing_type', '10-K')
        company_domain = filing_data.get('company_domain', 'Unknown')
        filing_url = filing_data.get('url')  # Still keep for reference
        
        log_info("FilingProcessor", f"Processing {filing_type} for {company_domain}")
        print(f"   📄 Filing ID: {filing_id}")
        print(f"   📑 Accession: {accession_number}")
        
        # Validate accession number
        if not accession_number:
            raise ValueError(f"Missing accession number for filing {filing_id}")
        
        # Get structured sections using accession directly
        sections = get_filing_sections(accession_number, filing_type)
        if not sections:
            raise ValueError("No sections extracted")
        
        log_info("FilingProcessor", f"Extracted {len(sections)} sections")
        
        # Route sections to models
        model_routing = route_sections_to_models(sections, filing_type)
        print(f"   🎯 Model routing: {[f'{model}: {len(secs)} sections' for model, secs in model_routing.items()]}")
        
        # Validate section names if configured
        if CONFIG['processing']['section_validation']:
            missing_sections = [name for name in sections.keys() if not name]
            if missing_sections:
                log_warning("FilingProcessor", f"Found {len(missing_sections)} sections without names")
        
        # Show cache statistics
        cache_stats = SECTION_CACHE.get_stats()
        if cache_stats['hits'] > 0:
            print(f"   📊 Cache: {cache_stats['hit_rate']:.1f}% hit rate, {cache_stats['size_mb']:.1f}MB used")
        
        return {
            'filing_id': filing_id,
            'company_domain': company_domain,
            'filing_type': filing_type,
            'accession_number': accession_number,
            'url': filing_url,
            'sections': sections,
            'model_routing': model_routing,
            'total_sections': len(sections),
            'processing_status': 'success'
        }
        
    except Exception as e:
        log_error("FilingProcessor", "Filing processing failed", e, 
                 {"filing_id": filing_data.get('id'), "accession": filing_data.get('accession_number')})
        return {
            'filing_id': filing_data.get('id'),
            'company_domain': filing_data.get('company_domain', 'Unknown'),
            'filing_type': filing_data.get('filing_type', 'Unknown'),
            'accession_number': filing_data.get('accession_number'),
            'error': str(e),
            'processing_status': 'failed'
        }

@retry_on_connection_error
def get_unprocessed_filings(limit: int = 5) -> List[Dict]:
    """Get SEC filings that haven't been processed yet
    
    ENHANCED: Now retrieves accession_number directly from database
    """
    with get_db_connection() as conn:  # PHASE 2: Using context manager
        cursor = conn.cursor()
        
        cursor.execute("""
            SELECT 
                sf.id, 
                sf.company_domain, 
                sf.filing_type, 
                sf.accession_number,  -- DIRECT FROM DATABASE
                sf.url, 
                sf.filing_date, 
                sf.title
            FROM raw_data.sec_filings sf
            LEFT JOIN system_uno.sec_entities_raw ser 
                ON ser.sec_filing_ref = CONCAT('SEC_', sf.id)
            WHERE sf.accession_number IS NOT NULL  -- Must have accession
                AND ser.sec_filing_ref IS NULL     -- Not yet processed
            ORDER BY sf.filing_date DESC
            LIMIT %s
        """, (limit,))
        
        filings = cursor.fetchall()
        cursor.close()
        
        log_info("DatabaseQuery", f"Retrieved {len(filings)} unprocessed filings with accession numbers")
        
        return [{
            'id': filing[0],
            'company_domain': filing[1],
            'filing_type': filing[2],
            'accession_number': filing[3],  # ACCESSION NUMBER INCLUDED
            'url': filing[4],
            'filing_date': filing[5],
            'title': filing[6]
        } for filing in filings]

# Test the simplified extraction
log_info("Test", "Starting simplified section extraction test (using direct accession)")

test_filings = get_unprocessed_filings(limit=1)

if test_filings:
    test_result = process_sec_filing_with_sections(test_filings[0])
    if test_result['processing_status'] == 'success':
        log_info("Test", f"✅ Successfully extracted {test_result['total_sections']} sections")
    else:
        log_error("Test", f"❌ Section extraction failed: {test_result.get('error')}")
else:
    log_info("Test", "No test filings available")

print("✅ Cell 2 complete - EdgarTools section extraction ready")

INFO [Test]: Starting simplified section extraction test (using direct accession)
INFO [DatabaseQuery]: Retrieved 1 unprocessed filings with accession numbers
INFO [FilingProcessor]: Processing 10-Q for grail.com
   📄 Filing ID: 425
   📑 Accession: 0001699031-25-000166
INFO [EdgarTools]: Found 10-Q for GRAIL, Inc.
INFO [EdgarTools]: SectionExtractor found 6 sections
      • financial_statements: 46,388 chars
      • mda: 77,570 chars
      • market_risk: 2,134 chars
      • controls_procedures: 1,607 chars
      • legal_proceedings: 236 chars
      • risk_factors: 33,402 chars
INFO [Cache]: Cached sections for 0001699031-25-000166 (6 sections)
INFO [FilingProcessor]: Extracted 6 sections
   🎯 Model routing: ['biobert: 5 sections', 'bert_base: 5 sections', 'roberta: 5 sections', 'finbert: 1 sections']
INFO [Test]: ✅ Successfully extracted 6 sections
✅ Cell 2 complete - EdgarTools section extraction ready


In [3]:
# Cell 3: EntityExtractionPipeline Class and NER Model Loading

import torch
import time
import uuid
import json
from datetime import datetime
from typing import List, Dict, Any, Optional
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed

print("🚀 Loading EntityExtractionPipeline and NER Models...")

class EntityExtractionPipeline:
    """Extensible pipeline for entity extraction from multiple data sources"""
    
    def __init__(self, db_config: Dict, model_config: Dict = None):
        self.db_config = db_config
        self.model_config = model_config or self._get_default_model_config()
        self.loaded_models = {}
        self.pipeline_stats = {
            'documents_processed': 0,
            'total_entities_extracted': 0,
            'processing_time_total': 0,
            'data_sources_supported': ['sec_filings', 'press_releases', 'patents']
        }
        
        print(f"🔧 EntityExtractionPipeline initialized")
        print(f"   📊 Supported data sources: {self.pipeline_stats['data_sources_supported']}")
    
    def _get_default_model_config(self) -> Dict:
        """Default model configuration for biotech entity extraction"""
        return {
            'models': {
                'biobert': {
                    'model_name': 'alvaroalon2/biobert_diseases_ner',
                    'confidence_threshold': 0.5,
                    'description': 'Biomedical entities (diseases, medications, treatments)'
                },
                'bert_base': {
                    'model_name': 'dslim/bert-base-NER',
                    'confidence_threshold': 0.5,
                    'description': 'General entities (persons, organizations, locations)'
                },
                'finbert': {
                    'model_name': 'ProsusAI/finbert',
                    'confidence_threshold': 0.5,
                    'description': 'Financial entities and metrics'
                },
                'roberta': {
                    'model_name': 'Jean-Baptiste/roberta-large-ner-english',
                    'confidence_threshold': 0.6,
                    'description': 'High-precision general entities'
                }
            },
            'routing_rules': {
                'sec_filings': {
                    '10-K': {
                        'financial_sections': ['finbert'],
                        'other_sections': ['biobert', 'bert_base', 'roberta']
                    },
                    '10-Q': {
                        'financial_sections': ['finbert'],
                        'other_sections': ['biobert', 'bert_base', 'roberta']
                    },
                    '8-K': {
                        'all_sections': ['biobert', 'bert_base', 'roberta', 'finbert']
                    }
                },
                'press_releases': {
                    'all_content': ['biobert', 'bert_base', 'roberta']
                },
                'patents': {
                    'all_content': ['biobert', 'bert_base', 'roberta']
                }
            },
            'entity_type_mapping': {
                # BioBERT mappings
                'Disease': 'MEDICAL_CONDITION',
                'Chemical': 'MEDICATION',
                'CHEMICAL': 'MEDICATION',
                'DISEASE': 'MEDICAL_CONDITION',
                'DRUG': 'MEDICATION',
                'Drug': 'MEDICATION',
                'Compound': 'MEDICATION',
                'Treatment': 'THERAPY',
                'Therapy': 'THERAPY',
                
                # BERT-base mappings
                'PER': 'PERSON',
                'ORG': 'ORGANIZATION',
                'LOC': 'LOCATION',
                'MISC': 'MISCELLANEOUS',
                
                # Financial mappings
                'MONEY': 'FINANCIAL',
                'PERCENT': 'FINANCIAL',
                'NUMBER': 'METRIC'
            }
        }
    
    def load_models(self) -> Dict:
        """Load all NER models specified in configuration"""
        print(f"📦 Loading {len(self.model_config['models'])} NER models...")
        
        for model_key, model_info in self.model_config['models'].items():
            try:
                model_name = model_info['model_name']
                print(f"   🧠 Loading {model_key}: {model_name}")
                
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                model = AutoModelForTokenClassification.from_pretrained(model_name)
                ner_pipeline = pipeline(
                    "ner", 
                    model=model, 
                    tokenizer=tokenizer,
                    aggregation_strategy="average", 
                    device=0 if torch.cuda.is_available() else -1
                )
                
                self.loaded_models[model_key] = {
                    'pipeline': ner_pipeline,
                    'threshold': model_info['confidence_threshold'],
                    'description': model_info['description'],
                    'stats': {
                        'texts_processed': 0,
                        'entities_found': 0,
                        'processing_time': 0
                    }
                }
                
                print(f"      ✓ {model_key} loaded (threshold: {model_info['confidence_threshold']})")
                
            except Exception as e:
                print(f"      ❌ {model_key} failed: {e}")
                # Use BERT-base as fallback for failed models
                if model_key != 'bert_base' and 'bert_base' in self.loaded_models:
                    self.loaded_models[model_key] = self.loaded_models['bert_base']
                    print(f"      🔄 Using BERT-base fallback for {model_key}")
        
        loaded_count = len(self.loaded_models)
        device = "GPU" if torch.cuda.is_available() else "CPU"
        
        print(f"   ✅ Successfully loaded: {loaded_count}/4 models on {device}")
        
        return self.loaded_models
    
    def get_routing_for_content(self, data_source: str, filing_type: str = None, section_name: str = None) -> List[str]:
        """Get model routing for specific content"""
        routing_rules = self.model_config['routing_rules']
        
        if data_source not in routing_rules:
            # Default to all text models for unknown data sources
            return ['biobert', 'bert_base', 'roberta']
        
        source_rules = routing_rules[data_source]
        
        if data_source == 'sec_filings' and filing_type:
            filing_rules = source_rules.get(filing_type, source_rules.get('10-K'))
            
            # Check if it's a financial section
            if section_name and any(fin_word in section_name.lower() 
                                  for fin_word in ['financial', 'statement', 'balance', 'income', 'cash']):
                return filing_rules.get('financial_sections', ['finbert'])
            else:
                return filing_rules.get('other_sections', ['biobert', 'bert_base', 'roberta'])
        
        elif data_source in ['press_releases', 'patents']:
            return source_rules['all_content']
        
        # Default routing
        return ['biobert', 'bert_base', 'roberta']
    
    def extract_entities_from_text(self, text: str, model_name: str, metadata: Dict = None) -> List[Dict]:
        """Extract entities from text using specific model"""
        if model_name not in self.loaded_models or not text.strip():
            return []
        
        try:
            start_time = time.time()
            model_info = self.loaded_models[model_name]
            
            # Process in overlapping chunks to handle long text
            max_length = CONFIG["models"]["max_length"]  # 512 tokens
            overlap = int(max_length * CONFIG["models"]["chunk_overlap"])  # Use configured overlap percentage
            stride = max(1, max_length - overlap)  # 461 token stride
            
            # Create overlapping chunks
            all_raw_entities = []
            for chunk_start in range(0, len(text), stride):
                chunk_end = min(chunk_start + max_length, len(text))
                chunk = text[chunk_start:chunk_end]
                
                # Run NER on chunk
                chunk_entities = model_info["pipeline"](chunk)
                
                # Adjust positions to full text coordinates
                for entity in chunk_entities:
                    entity["start"] = entity.get("start", 0) + chunk_start
                    entity["end"] = entity.get("end", 0) + chunk_start
                    all_raw_entities.append(entity)
                
                # Break if we've processed the entire text
                if chunk_end >= len(text):
                    break
            
            # Deduplicate entities found in overlapping regions
            raw_entities = []
            seen_entities = set()
            for entity in all_raw_entities:
                # Create unique key for entity (position + text + type)
                entity_key = (entity.get("start"), entity.get("end"), entity.get("word"), entity.get("entity_group"))
                if entity_key not in seen_entities:
                    seen_entities.add(entity_key)
                    raw_entities.append(entity)
            
            # Process and filter results
            processed_entities = []
            for entity in raw_entities:
                if entity['score'] >= model_info['threshold']:
                    # Normalize entity type
                    entity_type = self.model_config['entity_type_mapping'].get(
                        entity['entity_group'], entity['entity_group']
                    )
                    
                    processed_entity = {
                        'entity_text': entity['word'].strip(),
                        'entity_type': entity_type,
                        'confidence_score': float(entity['score']),
                        'char_start': entity['start'],
                        'char_end': entity['end'],
                        'model_source': model_name,
                        'original_label': entity['entity_group'],
                        'extraction_id': str(uuid.uuid4()),
                        'extraction_timestamp': datetime.now().isoformat()
                    }
                    
                    # Add metadata if provided
                    if metadata:
                        processed_entity.update(metadata)
                    
                    processed_entities.append(processed_entity)
            
            # Update statistics
            processing_time = time.time() - start_time
            model_info['stats']['texts_processed'] += 1
            model_info['stats']['entities_found'] += len(processed_entities)
            model_info['stats']['processing_time'] += processing_time
            
            return processed_entities
            
        except Exception as e:
            print(f"   ❌ {model_name} extraction failed: {e}")
            return []
    
    def process_sec_filing_sections(self, filing_result: Dict) -> List[Dict]:
        """Process SEC filing sections with appropriate model routing"""
        if filing_result.get('processing_status') != 'success':
            return []
        
        filing_id = filing_result['filing_id']
        filing_type = filing_result['filing_type']
        company_domain = filing_result['company_domain']
        sections = filing_result['sections']
        
        print(f"🔍 Processing {len(sections)} sections for {company_domain} {filing_type}")
        
        all_entities = []
        
        for section_name, section_text in sections.items():
            # Get routing for this section
            applicable_models = self.get_routing_for_content('sec_filings', filing_type, section_name)
            applicable_models = [m for m in applicable_models if m in self.loaded_models]
            
            if not applicable_models:
                continue
            
            print(f"   📑 Processing '{section_name}' with {len(applicable_models)} models")
            
            # Metadata for this section
            section_metadata = {
                'filing_id': filing_id,
                'company_domain': company_domain,
                'filing_type': filing_type,
                'section_name': section_name,
                'sec_filing_ref': f'SEC_{filing_id}',
                'data_source': 'sec_filings'
            }
            
            # Process with each applicable model
            section_entities = []
            for model_name in applicable_models:
                model_entities = self.extract_entities_from_text(section_text, model_name, section_metadata)
                section_entities.extend(model_entities)
                print(f"      • {model_name}: {len(model_entities)} entities")
            
            all_entities.extend(section_entities)
        
        # Merge entities at same positions
        merged_entities = self.merge_position_overlaps(all_entities)
        print(f"   🔗 Merged: {len(all_entities)} → {len(merged_entities)} entities")
        
        return merged_entities
    
    def merge_position_overlaps(self, entities: List[Dict]) -> List[Dict]:
        """Merge entities detected at same position by different models"""
        if not entities:
            return []
        
        # Group entities by position
        position_groups = {}
        
        for entity in entities:
            # Create position key
            pos_key = f"{entity.get('filing_id', 'unknown')}_{entity.get('section_name', 'unknown')}_{entity.get('char_start', 0)}_{entity.get('char_end', 0)}"
            
            if pos_key not in position_groups:
                position_groups[pos_key] = []
            position_groups[pos_key].append(entity)
        
        merged_entities = []
        
        for pos_key, group in position_groups.items():
            if len(group) == 1:
                merged_entities.append(group[0])
            else:
                # Merge multiple detections
                merged = self._merge_entity_group(group)
                merged_entities.append(merged)
        
        return merged_entities
    
    def _merge_entity_group(self, entities: List[Dict]) -> Dict:
        """Merge entities from different models at same position"""
        # Priority for biotech domain: BioBERT > FinBERT > RoBERTa > BERT-base
        priority = {'biobert': 4, 'finbert': 3, 'roberta': 2, 'bert_base': 1}
        
        # Sort by priority, then by confidence
        entities.sort(key=lambda x: (priority.get(x['model_source'], 0), x['confidence_score']), reverse=True)
        
        # Use best entity as base
        best = entities[0].copy()
        
        # Add multi-model metadata
        best['models_detected'] = [e['model_source'] for e in entities]
        best['all_confidences'] = {e['model_source']: e['confidence_score'] for e in entities}
        best['primary_model'] = best['model_source']
        best['entity_variations'] = {e['model_source']: e['entity_text'] for e in entities}
        best['is_merged'] = True
        best['confidence_score'] = max(e['confidence_score'] for e in entities)
        
        return best
    
    def add_data_source_support(self, source_name: str, routing_config: Dict, processor_func: callable = None):
        """Dynamically add support for new data sources"""
        # Add to routing rules
        self.model_config['routing_rules'][source_name] = routing_config
        
        # Add to supported sources
        if source_name not in self.pipeline_stats['data_sources_supported']:
            self.pipeline_stats['data_sources_supported'].append(source_name)
        
        print(f"✅ Added support for data source: {source_name}")
        print(f"   📊 Routing config: {routing_config}")
    
    def get_pipeline_statistics(self) -> Dict:
        """Get comprehensive pipeline statistics"""
        return {
            'pipeline_stats': self.pipeline_stats.copy(),
            'model_stats': {
                name: info['stats'].copy() 
                for name, info in self.loaded_models.items()
            },
            'loaded_models': list(self.loaded_models.keys()),
            'supported_data_sources': self.pipeline_stats['data_sources_supported'],
            'device': "GPU" if torch.cuda.is_available() else "CPU"
        }


# Initialize the EntityExtractionPipeline
entity_pipeline = EntityExtractionPipeline(NEON_CONFIG)

# Load all NER models
loaded_models = entity_pipeline.load_models()

print(f"\n✅ EntityExtractionPipeline ready!")
print(f"   🎯 Loaded models: {list(loaded_models.keys())}")
print(f"   🔧 Supported sources: {entity_pipeline.pipeline_stats['data_sources_supported']}")
print(f"   💻 Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

2025-09-08 02:11:06.632270: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757297466.904105      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757297466.979162      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


🚀 Loading EntityExtractionPipeline and NER Models...
🔧 EntityExtractionPipeline initialized
   📊 Supported data sources: ['sec_filings', 'press_releases', 'patents']
📦 Loading 4 NER models...
   🧠 Loading biobert: alvaroalon2/biobert_diseases_ner


tokenizer_config.json:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/431M [00:00<?, ?B/s]

Device set to use cuda:0


model.safetensors:   0%|          | 0.00/431M [00:00<?, ?B/s]

      ✓ biobert loaded (threshold: 0.5)
   🧠 Loading bert_base: dslim/bert-base-NER


tokenizer_config.json:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/829 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use cuda:0


      ✓ bert_base loaded (threshold: 0.5)
   🧠 Loading finbert: ProsusAI/finbert


tokenizer_config.json:   0%|          | 0.00/252 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/758 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Device set to use cuda:0


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

      ✓ finbert loaded (threshold: 0.5)
   🧠 Loading roberta: Jean-Baptiste/roberta-large-ner-english


tokenizer_config.json:   0%|          | 0.00/255 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/849 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Device set to use cuda:0


      ✓ roberta loaded (threshold: 0.6)
   ✅ Successfully loaded: 4/4 models on GPU

✅ EntityExtractionPipeline ready!
   🎯 Loaded models: ['biobert', 'bert_base', 'finbert', 'roberta']
   🔧 Supported sources: ['sec_filings', 'press_releases', 'patents']
   💻 Device: GPU


In [4]:
# Cell 4: In-Memory Pipeline with Entity Storage and Relationship Extraction

import psycopg2
from psycopg2.extras import execute_values
import numpy as np
from groq import Groq
import re
from typing import List, Dict, Tuple, Optional
import json
import uuid
import time
import os
from datetime import datetime

print("🚀 Loading In-Memory Pipeline with Storage and Relationship Extraction...")

# ================================================================================
# RELATIONSHIP EXTRACTOR WITH LLAMA 3.1
# ================================================================================

class RelationshipExtractor:
    """Extract company-centric relationships using Llama 3.1"""
    
    def __init__(self, llama_config: Dict = None):
        """Initialize with Llama configuration"""
        self.config = llama_config or CONFIG.get('llama', {})
        self.client = None
        self.stats = {
            'entities_processed': 0,
            'relationships_found': 0,
            'llama_calls': 0,
            'processing_time': 0
        }
        
        # Initialize Groq client if API key available
        api_key = self.config.get('api_key') or os.getenv('GROQ_API_KEY') or user_secrets.get_secret("GROQ_API_KEY")
        if api_key:
            self.client = Groq(api_key=api_key)
            print(f"   ✅ Llama 3.1 client initialized via Groq")
        else:
            print(f"   ⚠️ No Groq API key found - relationship extraction disabled")
    

    def _filter_entities(self, entities: List[Dict]) -> List[Dict]:
        """Filter out non-business entities like SEC section headers and document structure"""
        filtered = []
        
        # Patterns to exclude (SEC filing structure elements)
        exclude_patterns = [
            r'^Item \d+[A-Za-z]?\.?\s*',  # Item 1, Item 2A, etc.
            r'^Part [IVX]+\.?\s*',  # Part I, Part II, etc.
            r'^Section \d+\.?\s*',  # Section 1, Section 2, etc.
            r'^Exhibit \d+\.?\s*',  # Exhibit 1, etc.
            r'^Schedule [A-Z]\.?\s*',  # Schedule A, etc.
            r'^Note \d+\.?\s*',  # Note 1, Note 2, etc.
            r'^Table of Contents',
            r'^Index to',
            r'^Forward.?Looking Statements',
            r'^Risk Factors$',
            r'^Legal Proceedings$',
            r'^Management.s Discussion',
            r'^Quantitative and Qualitative Disclosures',
            r'^Controls and Procedures',
            r'^Financial Statements',
            r'^Signatures$',
            r'^SIGNATURES$'
        ]
        
        for entity in entities:
            entity_text = entity.get('entity_text', '').strip()
            entity_type = entity.get('entity_type', '')
            
            # Skip empty or very short entities
            if not entity_text or len(entity_text) < 3:
                continue
            
            # Skip entities that are just numbers
            if entity_text.replace('.', '').replace(',', '').isdigit():
                continue
            
            # Skip if entity matches SEC structure patterns
            is_structure = False
            for pattern in exclude_patterns:
                if re.match(pattern, entity_text, re.IGNORECASE):
                    is_structure = True
                    break
            
            if is_structure:
                continue
            
            # Skip entities with category '0' that look like SEC headers
            if entity.get('entity_category') == '0' or entity_type == '0':
                # Check if it starts with common SEC filing patterns
                if any(entity_text.startswith(prefix) for prefix in ['Item ', 'Part ', 'Section ', 'Note ']):
                    continue
            
            # Additional filtering for low-quality entities
            if entity_type in ['MISCELLANEOUS', '0']:
                # For miscellaneous entities, apply stricter filtering
                # Keep only if confidence is high and text looks like a real entity
                confidence = entity.get('confidence_score', 0)
                if confidence < 0.8:
                    continue
                
                # Check if it looks like a real company/organization name
                # (contains capitalized words, not all caps, not all lowercase)
                words = entity_text.split()
                if len(words) > 0:
                    has_capital = any(word[0].isupper() for word in words if word)
                    all_caps = entity_text.isupper()
                    all_lower = entity_text.islower()
                    
                    if not has_capital or all_caps or all_lower:
                        continue
            
            # Entity passed all filters
            filtered.append(entity)
        
        print(f"      🎯 Filtered: {len(entities)} → {len(filtered)} entities (removed {len(entities) - len(filtered)} non-business entities)")
        return filtered

    def extract_company_relationships(self, 
                                     entities: List[Dict], 
                                     sections: Dict[str, str],
                                     company_domain: str) -> List[Dict]:
        """Extract relationships between company and all found entities"""
        if not self.client or not entities:
            return []
        # Filter out non-business entities (SEC headers, etc.)
        entities = self._filter_entities(entities)
        if not entities:
            print("      ⚠️ No valid business entities after filtering")
            return []

        
        print(f"   🔍 Analyzing relationships for {company_domain}")
        relationships = []
        
        # Group entities by section for context efficiency
        entities_by_section = {}
        for entity in entities:
            section = entity.get('section_name', 'unknown')
            if section not in entities_by_section:
                entities_by_section[section] = []
            entities_by_section[section].append(entity)
        
        # Process each section's entities
        for section_name, section_entities in entities_by_section.items():
            if section_name not in sections:
                continue
                
            section_text = sections[section_name]
            print(f"      📑 Processing {len(section_entities)} entities in '{section_name}'")
            
            # Analyze each entity's relationship to the company
            for entity in section_entities:
                # Skip self-references
                company_name = company_domain.replace('.com', '').replace('tx', '')
                if entity['entity_text'].lower() == company_name.lower():
                    continue
                
                # Get context around entity
                context = self._get_entity_context(entity, section_text)
                
                # Build and send prompt to Llama
                relationship = self._analyze_relationship(
                    company_domain, entity, context, section_name
                )
                
                if relationship:
                    relationships.append(relationship)
                    self.stats['relationships_found'] += 1
                
                self.stats['entities_processed'] += 1
        
        print(f"   ✅ Found {len(relationships)} relationships from {len(entities)} entities")
        return relationships
    
    def _get_entity_context(self, entity: Dict, section_text: str, window: int = 500) -> str:
        """Get context around an entity"""
        start = max(0, entity['char_start'] - window)
        end = min(len(section_text), entity['char_end'] + window)
        return section_text[start:end]
    
    def _analyze_relationship(self, 
                            company_domain: str,
                            entity: Dict,
                            context: str,
                            section_name: str) -> Optional[Dict]:
        """Analyze a single entity's relationship to the company using Llama"""
        if not self.client:
            return None
        
        try:
            # Build prompt
            prompt = f"""Analyze the business relationship between the filing company and the mentioned entity.

Filing Company: {company_domain}
Entity Found: {entity['entity_text']} (Type: {entity['entity_type']})
Section: {section_name}

Context:
{context}

Based on the context, determine:
1. Relationship type (choose ONE from: PARTNERSHIP, COMPETITOR, REGULATORY, CLINICAL_TRIAL, SUPPLIER, CUSTOMER, INVESTOR, ACQUISITION, LICENSING, RESEARCH, NONE)
2. Relationship direction (company_to_entity or entity_to_company)
3. Business impact (positive, negative, or neutral)
4. Confidence level (high, medium, or low)
5. Brief summary (one sentence)

Format your response as:
TYPE: <relationship_type>
DIRECTION: <direction>
IMPACT: <impact>
CONFIDENCE: <confidence>
SUMMARY: <one_sentence_summary>
"""
            
            # Call Llama 3.1
            response = self.client.chat.completions.create(
                model="llama-3.1-70b-versatile",
                messages=[{"role": "user", "content": prompt}],
                temperature=0.3,
                max_tokens=200
            )
            
            self.stats['llama_calls'] += 1
            llama_response = response.choices[0].message.content
            
            # Parse response
            parsed = self._parse_llama_response(llama_response)
            
            if parsed and parsed['type'] != 'NONE':
                return {
                    'company_domain': company_domain,
                    'entity_text': entity['entity_text'],
                    'entity_type': entity['entity_type'],
                    'entity_id': entity['extraction_id'],
                    'relationship_type': parsed['type'],
                    'relationship_direction': parsed['direction'],
                    'business_impact': parsed['impact'],
                    'confidence_level': parsed['confidence'],
                    'summary': parsed['summary'],
                    'section_name': section_name,
                    'context_used': context[:1000],  # Store first 1000 chars
                    'llama_response': llama_response,
                    'extraction_timestamp': datetime.now().isoformat()
                }
            
            return None
            
        except Exception as e:
            print(f"         ⚠️ Llama analysis failed for {entity['entity_text']}: {e}")
            return None
    
    def _parse_llama_response(self, response: str) -> Optional[Dict]:
        """Parse structured response from Llama"""
        try:
            lines = response.strip().split('\n')
            parsed = {}
            
            for line in lines:
                if ':' in line:
                    key, value = line.split(':', 1)
                    key = key.strip().upper()
                    value = value.strip()
                    
                    if key == 'TYPE':
                        parsed['type'] = value
                    elif key == 'DIRECTION':
                        parsed['direction'] = value
                    elif key == 'IMPACT':
                        parsed['impact'] = value.lower()
                    elif key == 'CONFIDENCE':
                        parsed['confidence'] = value.lower()
                    elif key == 'SUMMARY':
                        parsed['summary'] = value
            
            # Validate required fields
            required = ['type', 'direction', 'impact', 'confidence', 'summary']
            if all(field in parsed for field in required):
                return parsed
            
            return None
            
        except Exception:
            return None

# ================================================================================
# ENHANCED STORAGE WITH ATOMIC TRANSACTIONS
# ================================================================================

class PipelineEntityStorage:
    """Enhanced storage system with atomic entity+relationship storage"""
    
    def __init__(self, db_config: Dict):
        self.db_config = db_config
        self.storage_stats = {
            'total_entities_stored': 0,
            'total_relationships_stored': 0,
            'filings_processed': 0,
            'transactions_completed': 0,
            'transactions_failed': 0,
            'merged_entities': 0,
            'single_model_entities': 0,
            'failed_inserts': 0
        }
        
        # Ensure table structures
        self._ensure_table_structures()
    
    def _ensure_table_structures(self):
        """Ensure both entity and relationship tables exist with proper schema"""
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            # Ensure entity table columns (same as before)
            self._ensure_entity_table(cursor)
            
            # Create relationship table if not exists
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS system_uno.entity_relationships (
                    relationship_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
                    company_domain TEXT NOT NULL,
                    entity_id UUID REFERENCES system_uno.sec_entities_raw(extraction_id),
                    entity_text TEXT NOT NULL,
                    entity_type TEXT,
                    relationship_type TEXT NOT NULL,
                    relationship_direction TEXT,
                    business_impact TEXT,
                    confidence_level TEXT,
                    summary TEXT,
                    section_name TEXT,
                    context_used TEXT,
                    llama_response TEXT,
                    filing_ref TEXT,
                    extraction_timestamp TIMESTAMP,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            """)
            
            # Create indexes for relationships
            cursor.execute("""
                CREATE INDEX IF NOT EXISTS idx_relationships_company 
                ON system_uno.entity_relationships(company_domain)
            """)
            cursor.execute("""
                CREATE INDEX IF NOT EXISTS idx_relationships_type 
                ON system_uno.entity_relationships(relationship_type)
            """)
            cursor.execute("""
                CREATE INDEX IF NOT EXISTS idx_relationships_impact 
                ON system_uno.entity_relationships(business_impact)
            """)
            
            conn.commit()
            cursor.close()
            conn.close()
            print("   ✅ Database tables verified/created")
            
        except Exception as e:
            print(f"   ⚠️ Table structure setup failed: {e}")
    
    def _ensure_entity_table(self, cursor):
        """Ensure entity table has all required columns"""
        # Check existing columns
        cursor.execute("""
            SELECT column_name 
            FROM information_schema.columns 
            WHERE table_schema = 'system_uno' 
            AND table_name = 'sec_entities_raw'
        """)
        
        existing_columns = {row[0] for row in cursor.fetchall()}
        
        # Required columns for pipeline
        required_columns = {
            'models_detected': 'TEXT[]',
            'all_confidences': 'JSONB',
            'primary_model': 'TEXT',
            'entity_variations': 'JSONB',
            'is_merged': 'BOOLEAN DEFAULT FALSE',
            'section_name': 'TEXT',
            'data_source': 'TEXT DEFAULT \'sec_filings\'',
            'extraction_timestamp': 'TIMESTAMP DEFAULT CURRENT_TIMESTAMP',
            'original_label': 'TEXT'
        }
        
        # Add missing columns
        for col_name, col_type in required_columns.items():
            if col_name not in existing_columns:
                cursor.execute(f'ALTER TABLE system_uno.sec_entities_raw ADD COLUMN {col_name} {col_type}')
                print(f"      ✓ Added column: {col_name}")
    
    def store_entities_and_relationships(self, 
                                        entities: List[Dict], 
                                        relationships: List[Dict],
                                        filing_ref: str) -> bool:
        """Store entities and relationships in a single atomic transaction"""
        if not entities:
            return True
        
        conn = None
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            # Start transaction
            conn.autocommit = False
            
            print(f"   💾 Storing {len(entities)} entities and {len(relationships)} relationships...")
            
            # Count merged vs single-model entities BEFORE storage
            merged_count = sum(1 for e in entities if e.get('is_merged', False))
            single_model_count = len(entities) - merged_count
            
            # Step 1: Store entities
            entity_data = []
            for entity in entities:
                # Prepare entity data (same as original)
                models_detected = entity.get('models_detected', [entity.get('model_source', 'unknown')])
                if not isinstance(models_detected, list):
                    models_detected = [str(models_detected)]
                
                all_confidences = entity.get('all_confidences', 
                    {entity.get('model_source', 'unknown'): entity.get('confidence_score', 0.0)})
                entity_variations = entity.get('entity_variations', 
                    {entity.get('model_source', 'unknown'): entity.get('entity_text', '')})
                
                entity_tuple = (
                    entity.get('extraction_id', str(uuid.uuid4())),
                    entity.get('company_domain', ''),
                    entity.get('entity_text', '').strip()[:1000],
                    entity.get('entity_type', 'UNKNOWN'),
                    float(entity.get('confidence_score', 0.0)),
                    int(entity.get('char_start', 0)),
                    int(entity.get('char_end', 0)),
                    entity.get('surrounding_text', '')[:2000] if entity.get('surrounding_text') else '',
                    entity.get('sec_filing_ref', filing_ref),
                    models_detected,
                    json.dumps(all_confidences),
                    entity.get('primary_model', entity.get('model_source', 'unknown')),
                    json.dumps(entity_variations),
                    entity.get('is_merged', False),
                    entity.get('section_name', ''),
                    entity.get('data_source', 'sec_filings'),
                    entity.get('extraction_timestamp'),
                    entity.get('original_label', '')
                )
                entity_data.append(entity_tuple)
            
            # Batch insert entities
            entity_query = """
                INSERT INTO system_uno.sec_entities_raw 
                (extraction_id, company_domain, entity_text, entity_category, 
                 confidence_score, character_start, character_end, surrounding_text, 
                 sec_filing_ref, models_detected, all_confidences, primary_model,
                 entity_variations, is_merged, section_name, data_source,
                 extraction_timestamp, original_label)
                VALUES %s
                ON CONFLICT (extraction_id) DO UPDATE SET
                    models_detected = EXCLUDED.models_detected,
                    all_confidences = EXCLUDED.all_confidences,
                    is_merged = EXCLUDED.is_merged
            """
            
            execute_values(cursor, entity_query, entity_data, page_size=100)
            entities_stored = cursor.rowcount
            
            # Step 2: Store relationships
            if relationships:
                relationship_data = []
                for rel in relationships:
                    rel_tuple = (
                        rel.get('company_domain', ''),
                        rel.get('entity_id'),
                        rel.get('entity_text', ''),
                        rel.get('entity_type', ''),
                        rel.get('relationship_type', ''),
                        rel.get('relationship_direction', ''),
                        rel.get('business_impact', ''),
                        rel.get('confidence_level', ''),
                        rel.get('summary', ''),
                        rel.get('section_name', ''),
                        rel.get('context_used', ''),
                        rel.get('llama_response', ''),
                        filing_ref,
                        rel.get('extraction_timestamp')
                    )
                    relationship_data.append(rel_tuple)
                
                relationship_query = """
                    INSERT INTO system_uno.entity_relationships
                    (company_domain, entity_id, entity_text, entity_type,
                     relationship_type, relationship_direction, business_impact,
                     confidence_level, summary, section_name, context_used,
                     llama_response, filing_ref, extraction_timestamp)
                    VALUES %s
                """
                
                execute_values(cursor, relationship_query, relationship_data, page_size=100)
                relationships_stored = cursor.rowcount
            else:
                relationships_stored = 0
            
            # Commit transaction
            conn.commit()
            
            # Update statistics
            self.storage_stats['total_entities_stored'] += entities_stored
            self.storage_stats['total_relationships_stored'] += relationships_stored
            self.storage_stats['transactions_completed'] += 1
            self.storage_stats['merged_entities'] += merged_count
            self.storage_stats['single_model_entities'] += single_model_count
            
            print(f"      ✅ Transaction complete: {entities_stored} entities, {relationships_stored} relationships")
            print(f"         • Merged entities: {merged_count}, Single-model: {single_model_count}")
            
            cursor.close()
            conn.close()
            return True
            
        except Exception as e:
            print(f"      ❌ Transaction failed: {e}")
            if conn:
                conn.rollback()
                conn.close()
            self.storage_stats['transactions_failed'] += 1
            self.storage_stats['failed_inserts'] += len(entities)
            return False
    
    def get_storage_verification(self, filing_ref: str) -> Dict:
        """Verify stored entities and relationships for a filing"""
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            # Get entity statistics
            cursor.execute("""
                SELECT 
                    COUNT(*) as total_entities,
                    COUNT(DISTINCT entity_category) as unique_categories,
                    COUNT(DISTINCT section_name) as sections_processed,
                    AVG(confidence_score) as avg_confidence
                FROM system_uno.sec_entities_raw
                WHERE sec_filing_ref = %s
            """, (filing_ref,))
            
            entity_stats = cursor.fetchone()
            
            # Get relationship statistics
            cursor.execute("""
                SELECT 
                    COUNT(*) as total_relationships,
                    COUNT(DISTINCT relationship_type) as relationship_types,
                    COUNT(CASE WHEN business_impact = 'positive' THEN 1 END) as positive_impact,
                    COUNT(CASE WHEN business_impact = 'negative' THEN 1 END) as negative_impact,
                    COUNT(CASE WHEN business_impact = 'neutral' THEN 1 END) as neutral_impact
                FROM system_uno.entity_relationships
                WHERE filing_ref = %s
            """, (filing_ref,))
            
            rel_stats = cursor.fetchone()
            
            cursor.close()
            conn.close()
            
            return {
                'filing_ref': filing_ref,
                'entities': {
                    'total': entity_stats[0] if entity_stats else 0,
                    'categories': entity_stats[1] if entity_stats else 0,
                    'sections': entity_stats[2] if entity_stats else 0,
                    'avg_confidence': float(entity_stats[3]) if entity_stats and entity_stats[3] else 0
                },
                'relationships': {
                    'total': rel_stats[0] if rel_stats else 0,
                    'types': rel_stats[1] if rel_stats else 0,
                    'positive': rel_stats[2] if rel_stats else 0,
                    'negative': rel_stats[3] if rel_stats else 0,
                    'neutral': rel_stats[4] if rel_stats else 0
                }
            }
            
        except Exception as e:
            print(f"   ❌ Verification failed: {e}")
            return {}

# ================================================================================
# REFACTORED MAIN PIPELINE WITH IN-MEMORY PROCESSING
# ================================================================================

# Initialize components
pipeline_storage = PipelineEntityStorage(NEON_CONFIG)
relationship_extractor = RelationshipExtractor()

def process_filing_with_pipeline(filing_data: Dict) -> Dict:
    """Process filing with in-memory entity and relationship extraction"""
    try:
        start_time = time.time()
        
        # Step 1: Extract sections (Cell 2 function)
        print(f"\n📄 Processing {filing_data['filing_type']} for {filing_data['company_domain']}")
        section_result = process_sec_filing_with_sections(filing_data)
        
        if section_result['processing_status'] != 'success':
            return {
                'success': False,
                'filing_id': filing_data.get('id'),
                'error': section_result.get('error', 'Section extraction failed'),
                'processing_time': time.time() - start_time
            }
        
        # Keep sections in memory for context retrieval
        sections_dict = section_result['sections']
        
        # Step 2: Extract entities (Cell 3 function) - keep in memory
        entities = entity_pipeline.process_sec_filing_sections(section_result)
        
        if not entities:
            return {
                'success': False,
                'filing_id': filing_data.get('id'),
                'error': 'No entities extracted',
                'processing_time': time.time() - start_time
            }
        
        print(f"   🔍 Extracted {len(entities)} entities")
        
        # Step 3: Extract relationships using in-memory entities and sections
        relationships = relationship_extractor.extract_company_relationships(
            entities, 
            sections_dict,
            filing_data['company_domain']
        )
        
        # Step 4: Store everything in single transaction
        filing_ref = f"SEC_{filing_data.get('id')}"
        storage_success = pipeline_storage.store_entities_and_relationships(
            entities, 
            relationships,
            filing_ref
        )
        
        # Step 5: Verify storage
        verification = pipeline_storage.get_storage_verification(filing_ref)
        
        processing_time = time.time() - start_time
        
        return {
            'success': storage_success,
            'filing_id': filing_data.get('id'),
            'company_domain': filing_data.get('company_domain'),
            'filing_type': filing_data.get('filing_type'),
            'sections_processed': len(sections_dict),
            'entities_extracted': len(entities),
            'relationships_found': len(relationships),
            'entities_stored': verification.get('entities', {}).get('total', 0),
            'relationships_stored': verification.get('relationships', {}).get('total', 0),
            'processing_time': round(processing_time, 2),
            'verification': verification,
            'sample_entities': entities[:3],
            'sample_relationships': relationships[:3]
        }
        
    except Exception as e:
        return {
            'success': False,
            'filing_id': filing_data.get('id'),
            'error': str(e),
            'processing_time': time.time() - start_time
        }

def process_filings_batch(limit: int = 3) -> Dict:
    """Process multiple filings with complete in-memory pipeline"""
    print(f"\n🚀 Processing batch of {limit} SEC filings with in-memory pipeline...")
    
    batch_start = time.time()
    
    # Get unprocessed filings
    filings = get_unprocessed_filings(limit)
    
    if not filings:
        return {'success': False, 'message': 'No filings to process'}
    
    print(f"📊 Found {len(filings)} filings to process")
    
    # Process each filing
    results = []
    successful = 0
    total_entities = 0
    total_relationships = 0
    
    for i, filing in enumerate(filings, 1):
        print(f"\n[{i}/{len(filings)}] Processing {filing['filing_type']} for {filing['company_domain']}")
        
        result = process_filing_with_pipeline(filing)
        results.append(result)
        
        if result['success']:
            successful += 1
            total_entities += result.get('entities_extracted', 0)
            total_relationships += result.get('relationships_found', 0)
            
            print(f"   ✅ Success: {result['entities_extracted']} entities, {result['relationships_found']} relationships")
            print(f"   ⏱️ Processing time: {result['processing_time']}s")
            
            # Show sample relationships
            for rel in result.get('sample_relationships', [])[:2]:
                print(f"      • {rel['entity_text']} → {rel['relationship_type']} ({rel['business_impact']})")
        else:
            print(f"   ❌ Failed: {result.get('error', 'Unknown error')}")
        
        # Brief pause between filings
        if i < len(filings):
            time.sleep(1)
    
    batch_time = time.time() - batch_start
    
    # Update pipeline statistics
    entity_pipeline.pipeline_stats['documents_processed'] += successful
    entity_pipeline.pipeline_stats['total_entities_extracted'] += total_entities
    pipeline_storage.storage_stats['filings_processed'] += successful
    
    return {
        'success': successful > 0,
        'filings_processed': len(filings),
        'successful_filings': successful,
        'failed_filings': len(filings) - successful,
        'total_entities_extracted': total_entities,
        'total_relationships_found': total_relationships,
        'batch_processing_time': round(batch_time, 2),
        'avg_time_per_filing': round(batch_time / len(filings), 2) if filings else 0,
        'results': results
    }

# ================================================================================
# QUICK ACCESS FUNCTIONS
# ================================================================================

def test_pipeline(company_domain: str = None):
    """Test the pipeline with a single filing"""
    print("\n🧪 Testing in-memory pipeline...")
    
    # Get one filing
    if company_domain:
        # Modify get_unprocessed_filings to accept company filter
        # For now, just get any filing
        filings = get_unprocessed_filings(limit=1)
    else:
        filings = get_unprocessed_filings(limit=1)
    
    if not filings:
        print("❌ No test filings available")
        return None
    
    result = process_filing_with_pipeline(filings[0])
    
    if result['success']:
        print(f"\n✅ Pipeline test successful!")
        print(f"   📊 Sections: {result['sections_processed']}")
        print(f"   🔍 Entities: {result['entities_extracted']}")
        print(f"   🔗 Relationships: {result['relationships_found']}")
        print(f"   💾 Stored: {result['entities_stored']} entities, {result['relationships_stored']} relationships")
        print(f"   ⏱️ Time: {result['processing_time']}s")
    else:
        print(f"\n❌ Pipeline test failed: {result.get('error')}")
    
    return result

print("\n✅ In-Memory Pipeline Components Ready!")
print("   🎯 RelationshipExtractor with Llama 3.1")
print("   💾 Atomic storage for entities + relationships")
print("   🚀 In-memory processing (no DB round-trips)")
print("   📊 Usage: batch_results = process_filings_batch(limit=5)")
print("   🧪 Test: test_result = test_pipeline()")

🚀 Loading In-Memory Pipeline with Storage and Relationship Extraction...
   ✅ Database tables verified/created
   ✅ Llama 3.1 client initialized via Groq

✅ In-Memory Pipeline Components Ready!
   🎯 RelationshipExtractor with Llama 3.1
   💾 Atomic storage for entities + relationships
   🚀 In-memory processing (no DB round-trips)
   📊 Usage: batch_results = process_filings_batch(limit=5)
   🧪 Test: test_result = test_pipeline()


In [None]:
# Cell 5: Execute Pipeline - Process SEC Filings

print("="*80)
print("🚀 STARTING SEC FILING PROCESSING PIPELINE")
print("="*80)

# Configure processing parameters
BATCH_SIZE = 3  # Number of filings to process in this run
PROCESS_RELATIONSHIPS = True  # Set to False to skip relationship extraction

# Check if we have a Groq API key for relationship extraction
has_groq = bool(user_secrets.get_secret("GROQ_API_KEY") or os.getenv('GROQ_API_KEY'))
if not has_groq and PROCESS_RELATIONSHIPS:
    print("⚠️ No Groq API key found - relationship extraction will be skipped")
    print("   To enable relationships, add GROQ_API_KEY to Kaggle secrets")

# Check for available unprocessed filings
print("\n📊 Checking for unprocessed filings...")
available_filings = get_unprocessed_filings(limit=10)
print(f"   Found {len(available_filings)} unprocessed filings")

if available_filings:
    print("\n📋 Available filings to process:")
    for i, filing in enumerate(available_filings[:5], 1):
        print(f"   {i}. {filing['company_domain']} - {filing['filing_type']} ({filing['filing_date']})")
    
    # Process the batch
    print(f"\n🔄 Processing {min(BATCH_SIZE, len(available_filings))} filings...")
    print("-"*60)
    
    # Run the pipeline
    batch_results = process_filings_batch(limit=BATCH_SIZE)
    
    # Display results summary
    print("\n" + "="*80)
    print("📊 PROCESSING SUMMARY")
    print("="*80)
    
    if batch_results['success']:
        print(f"✅ Successfully processed {batch_results['successful_filings']}/{batch_results['filings_processed']} filings")
        print(f"   • Total entities extracted: {batch_results['total_entities_extracted']:,}")
        print(f"   • Total relationships found: {batch_results['total_relationships_found']:,}")
        print(f"   • Total processing time: {batch_results['batch_processing_time']:.1f}s")
        print(f"   • Average time per filing: {batch_results['avg_time_per_filing']:.1f}s")
        
        # Show detailed results for each filing
        print(f"\n📈 Detailed Results:")
        for i, result in enumerate(batch_results['results'], 1):
            if result['success']:
                print(f"\n   Filing {i}: {result['company_domain']} - {result['filing_type']}")
                print(f"      ✓ Sections: {result['sections_processed']}")
                print(f"      ✓ Entities: {result['entities_extracted']}")
                print(f"      ✓ Relationships: {result['relationships_found']}")
                print(f"      ✓ Time: {result['processing_time']:.1f}s")
            else:
                print(f"\n   Filing {i}: FAILED - {result.get('error', 'Unknown error')}")
        
        # Show pipeline statistics
        print(f"\n📊 Pipeline Statistics:")
        print(f"   • Documents processed (total): {entity_pipeline.pipeline_stats['documents_processed']}")
        print(f"   • Entities extracted (total): {entity_pipeline.pipeline_stats['total_entities_extracted']}")
        print(f"   • Storage transactions: {pipeline_storage.storage_stats['transactions_completed']} successful, {pipeline_storage.storage_stats['transactions_failed']} failed")
        print(f"   • Merged entities: {pipeline_storage.storage_stats['merged_entities']}")
        print(f"   • Single-model entities: {pipeline_storage.storage_stats['single_model_entities']}")
        
    else:
        print(f"❌ Processing failed: {batch_results.get('message', 'Unknown error')}")
    
    # Generate analytics report
    print("\n" + "="*80)
    generate_pipeline_analytics_report()
    
else:
    print("\n⚠️ No unprocessed filings found in raw_data.sec_filings")
    print("   All available filings have already been processed")
    print("\n💡 To add new filings:")
    print("   1. Insert new records into raw_data.sec_filings with accession_number")
    print("   2. Make sure the accession_number is valid (20 characters)")
    print("   3. Run this cell again to process them")

print("\n✅ Pipeline execution complete!")

In [None]:
# Cell 6: Production Analytics and Monitoring Interface

def generate_pipeline_analytics_report() -> None:
    """Generate comprehensive analytics report for the pipeline"""
    print("\n" + "="*80)
    print("📊 ENTITYEXTRACTIONPIPELINE ANALYTICS DASHBOARD")
    print("="*80)
    
    # Database overview with enhanced metrics
    try:
        conn = psycopg2.connect(**NEON_CONFIG)
        cursor = conn.cursor()
        
        # Enhanced database statistics
        cursor.execute("""
            SELECT 
                COUNT(*) as total_entities,
                COUNT(DISTINCT company_domain) as companies,
                COUNT(DISTINCT sec_filing_ref) as filings,
                COUNT(DISTINCT entity_category) as entity_types,
                AVG(confidence_score) as avg_confidence,
                COUNT(*) FILTER (WHERE is_merged = true) as merged_entities,
                COUNT(DISTINCT primary_model) as active_models,
                COUNT(DISTINCT section_name) as sections_processed,
                COUNT(*) FILTER (WHERE section_name IS NOT NULL AND section_name != '') as entities_with_sections,
                MAX(extraction_timestamp) as last_extraction
            FROM system_uno.sec_entities_raw
            WHERE data_source = 'sec_filings'
        """)
        
        db_overview = cursor.fetchone()
        
        if db_overview and db_overview[0] > 0:
            total_entities = db_overview[0]
            entities_with_sections = db_overview[8]
            section_success_rate = (entities_with_sections / total_entities * 100) if total_entities > 0 else 0
            
            print(f"\n📈 DATABASE OVERVIEW:")
            print(f"   Total Entities Extracted: {db_overview[0]:,}")
            print(f"   Companies Processed: {db_overview[1]:,}")
            print(f"   SEC Filings Analyzed: {db_overview[2]:,}")
            print(f"   Entity Categories Found: {db_overview[3]:,}")
            print(f"   Average Confidence Score: {db_overview[4]:.3f}")
            print(f"   Multi-Model Entities: {db_overview[5]:,} ({db_overview[5]/db_overview[0]*100:.1f}%)")
            print(f"   Active Models: {db_overview[6]:,}")
            print(f"   Unique Sections Found: {db_overview[7]:,}")
            print(f"   🎯 SECTION SUCCESS RATE: {entities_with_sections:,}/{total_entities:,} ({section_success_rate:.1f}%)")
            print(f"   Last Extraction: {db_overview[9] or 'Never'}")
            
            # Alert if section success rate is low
            if section_success_rate < 90 and total_entities > 10:
                print(f"   🚨 WARNING: Section success rate is {section_success_rate:.1f}% - Pipeline routing issue!")
            elif section_success_rate >= 90:
                print(f"   ✅ EXCELLENT: Section success rate is {section_success_rate:.1f}% - Pipeline working correctly!")
        else:
            print(f"\n📈 DATABASE OVERVIEW: No entities found - database is clean for testing")
        
        cursor.close()
        conn.close()
                
    except Exception as e:
        print(f"   ❌ Could not retrieve analytics: {e}")
    
    # Pipeline statistics
    try:
        pipeline_stats = entity_pipeline.get_pipeline_statistics()
        
        print(f"\n🔧 PIPELINE STATISTICS:")
        print(f"   Documents Processed: {pipeline_stats['pipeline_stats']['documents_processed']:,}")
        print(f"   Total Entities Found: {pipeline_stats['pipeline_stats']['total_entities_extracted']:,}")
        print(f"   Processing Time: {pipeline_stats['pipeline_stats']['processing_time_total']:.2f}s")
        print(f"   Device: {pipeline_stats['device']}")
        print(f"   Loaded Models: {', '.join(pipeline_stats['loaded_models'])}")
        print(f"   Supported Sources: {', '.join(pipeline_stats['supported_data_sources'])}")
        
        # Individual model statistics
        print(f"\n📊 INDIVIDUAL MODEL PERFORMANCE:")
        for model_name, stats in pipeline_stats['model_stats'].items():
            if stats['texts_processed'] > 0:
                entities_per_text = stats['entities_found'] / stats['texts_processed']
                avg_time = stats['processing_time'] / stats['texts_processed']
                print(f"   {model_name:>12}: {stats['texts_processed']:>4} texts | {stats['entities_found']:>5} entities | {entities_per_text:>4.1f} avg/text | {avg_time:>4.2f}s avg")
        
        # Storage statistics
        storage_stats = pipeline_storage.storage_stats
        if storage_stats['total_entities_stored'] > 0:
            print(f"\n💾 STORAGE STATISTICS:")
            print(f"   Entities Stored: {storage_stats['total_entities_stored']:,}")
            print(f"   Filings Processed: {storage_stats['filings_processed']:,}")
            print(f"   Merged Entities: {storage_stats['merged_entities']:,}")
            print(f"   Single-Model Entities: {storage_stats['single_model_entities']:,}")
            print(f"   Failed Inserts: {storage_stats['failed_inserts']:,}")
            
            merge_rate = (storage_stats['merged_entities'] / storage_stats['total_entities_stored'] * 100) if storage_stats['total_entities_stored'] > 0 else 0
            print(f"   Multi-Model Detection Rate: {merge_rate:.1f}%")
    
    except Exception as e:
        print(f"\n🔧 Pipeline statistics unavailable: {e}")
    
    print("\n" + "="*80)
    print("✅ EntityExtractionPipeline Analytics Complete!")
    print("="*80)

print(f"\n🎯 PRODUCTION COMMANDS:")
print(f"   • Process new filings: batch_results = process_filings_batch(limit=5)")
print(f"   • Check results:       generate_pipeline_analytics_report()")
print(f"   • View statistics:     context_retriever.get_retrieval_statistics()")

print(f"\n✅ EntityExtractionPipeline Production Interface Ready!")
print(f"🔧 SINGLE ENTRY POINT: process_filings_batch() - ensures all extractions use section-based pipeline")
print(f"📊 Database cleared - ready for fresh testing with guaranteed section extraction!")