In [2]:
# Cell 0: Minimal Auto-Logger Bootstrap
# 
# Purpose: Initialize basic logging ONLY - all other setup in Cell 1
# This cell runs first to enable logging for subsequent cells

import os
import sys
import importlib.util

# Get GitHub token for logger access
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
GITHUB_TOKEN = user_secrets.get_secret("GITHUB_TOKEN")

print("🔧 Setting up minimal logger bootstrap...")

# Clone/update repo for logger access
REPO_URL = f"https://{GITHUB_TOKEN}@github.com/amiralpert/SmartReach.git"
LOCAL_PATH = "/kaggle/working/SmartReach"

if os.path.exists(LOCAL_PATH):
    !cd {LOCAL_PATH} && git pull origin main > /dev/null 2>&1
else:
    !git clone {REPO_URL} {LOCAL_PATH} > /dev/null 2>&1

# Add to path
if f'{LOCAL_PATH}/BizIntel' not in sys.path:
    sys.path.insert(0, f'{LOCAL_PATH}/BizIntel')

# Initialize logger with minimal setup
logger_path = f"{LOCAL_PATH}/BizIntel/Scripts/KaggleLogger/auto_logger.py"
if os.path.exists(logger_path):
    spec = importlib.util.spec_from_file_location("auto_logger", logger_path)
    auto_logger = importlib.util.module_from_spec(spec)
    sys.modules["auto_logger"] = auto_logger
    spec.loader.exec_module(auto_logger)
    
    # Simple logger setup - database manager will be provided by Cell 1
    logger = None  # Will be properly initialized after Cell 1 runs
    print("✅ Auto-logger module loaded")
else:
    logger = None
    print("⚠️  Logger module not found - continuing without logging")

print("✅ Cell 0 bootstrap complete")

🔧 Setting up minimal logger bootstrap...
✅ Auto-logger module loaded - database setup in Cell 1
✅ Cell 0 bootstrap complete - run Cell 1 for full setup


In [2]:
# Cell 1: GitHub Setup and Enhanced Import Configuration

# Install required packages first
!pip install edgartools transformers torch accelerate huggingface_hub requests beautifulsoup4 'lxml[html_clean]' uuid numpy newspaper3k --quiet
!pip install -U bitsandbytes --quiet

import os
import sys
import importlib
import importlib.util
import psycopg2
from psycopg2.extras import execute_values
from psycopg2 import pool
import time
import json
import pickle
import traceback
from pathlib import Path
from functools import wraps
from contextlib import contextmanager
from collections import OrderedDict
from typing import Dict, List, Optional, Any
from datetime import datetime
from edgar import set_identity

# ============================================================================
# CENTRALIZED CONFIGURATION - All settings in one place
# ============================================================================

# Use Kaggle secrets for all sensitive credentials
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

# Database Configuration (using Kaggle secrets for security)
NEON_CONFIG = {
    'host': user_secrets.get_secret("NEON_HOST"),
    'database': user_secrets.get_secret("NEON_DATABASE"),
    'user': user_secrets.get_secret("NEON_USER"), 
    'password': user_secrets.get_secret("NEON_PASSWORD"),
    'sslmode': 'require'
}

# Master Configuration Dictionary - PHASE 3 ENHANCED
CONFIG = {
    # GitHub Settings
    'github': {
        'token': user_secrets.get_secret("GITHUB_TOKEN"),
        'repo_url': f"https://{user_secrets.get_secret('GITHUB_TOKEN')}@github.com/amiralpert/SmartReach.git",
        'local_path': "/kaggle/working/SmartReach"
    },
    
    # Database Settings
    'database': NEON_CONFIG,
    
    # Connection Pool Settings - PHASE 3
    'pool': {
        'min_connections': 2,
        'max_connections': 10,
        'keepalives': 1,
        'keepalives_idle': 30,
        'keepalives_interval': 10,
        'keepalives_count': 5
    },
    
    # Connection Retry Settings
    'retry': {
        'max_attempts': 3,
        'initial_delay': 1,  # seconds
        'exponential_base': 2,
        'max_delay': 30  # seconds
    },
    
    # Model Configuration - PHASE 3 ENHANCED
    'models': {
        'confidence_threshold': 0.8, #confidence NER model needs to store an entity 
        'batch_size': 16, # number of text chunks to process at one time through NER
        'max_length': 512, #max chunk size to process through NER 512 token limit for BERT models 
        'chunk_overlap': 0.1,  # 10% overlap between chunks for complete entity extraction
        'warm_up_enabled': True,  # PHASE 3: Model warm-up
        'warm_up_text': 'Pfizer announced FDA approval for new cancer drug targeting BRCA mutations.'
    },
    
    # Cache Settings - PHASE 2 ENHANCED
    'cache': {
        'enabled': True,
        'max_size_mb': 100,  # Maximum cache size in MB
        'ttl': 3600,  # seconds
        'eviction_policy': 'LRU'  # Least Recently Used
    },
    
    # Processing Settings - PHASE 3 ENHANCED  
    'processing': {
        'filing_batch_size': 1,
        'filing_query_limit': 1,       # Explicit limit for get_unprocessed_filings()
        'enable_relationships': True,   # Enable/disable relationship extraction
        'entity_batch_size': 10000,    # Max entities per database insert
        'section_validation': True,    # Enforce section name validation
        'debug_mode': False,
        'max_insert_batch': 50000,     # Maximum batch for database inserts
        'deprecation_warnings': True,  # Show warnings for deprecated functions
        'checkpoint_enabled': True,    # PHASE 3: Enable checkpointing
        'checkpoint_dir': '/kaggle/working/checkpoints',  # PHASE 3: Checkpoint directory
        'deduplication_threshold': 0.85  # PHASE 3: Similarity threshold for dedup
    },
    
    # Llama 3.1 Configuration - CENTRALIZED FOR OPTIMIZATION
    'llama': {
        'enabled': True,
        'model_name': 'meta-llama/Llama-3.1-8B-Instruct',
        'batch_size': 15,              # Entities per Llama call (for future batching)
        'max_new_tokens': 50,          # Reduced from 200 for speed
        'context_window': 400,         # Reduced from 1000 chars for speed  
        'temperature': 0.3,            # Sampling temperature
        'entity_context_window': 400,  # Reduced from 500 chars for entity context
        'test_max_tokens': 50,         # For model testing
        'min_confidence_filter': 0.8,  # Entity filtering threshold
        'timeout_seconds': 30,         # Timeout for model calls
    },
    
        # EdgarTools Settings
    'edgar': {
        'identity': "SmartReach BizIntel amir@leanbio.consulting"
    }
}

if not CONFIG['github']['token']:
    raise ValueError("❌ GITHUB_TOKEN is required in Kaggle secrets")

if not CONFIG['database']['password']:
    raise ValueError("❌ NEON_PASSWORD is required in Kaggle secrets")

print("✅ Configuration loaded from Kaggle secrets")
print(f"   Database: {CONFIG['database']['host']}")
print(f"   Processing: Filing batch={CONFIG['processing']['filing_batch_size']}, Query limit={CONFIG['processing']['filing_query_limit']}")
print(f"   Llama 3.1: Enabled={CONFIG['llama']['enabled']}, Tokens={CONFIG['llama']['max_new_tokens']}, Context={CONFIG['llama']['context_window']}")
print(f"   Cache: Max size={CONFIG['cache']['max_size_mb']}MB, TTL={CONFIG['cache']['ttl']}s")
print(f"   Relationships: {'Enabled' if CONFIG['processing']['enable_relationships'] else 'Disabled'}")

# ============================================================================
# CONNECTION RETRY DECORATOR
# ============================================================================

def retry_on_connection_error(func):
    """Decorator to retry database operations on connection errors"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        max_attempts = CONFIG['retry']['max_attempts']
        delay = CONFIG['retry']['initial_delay']
        
        for attempt in range(max_attempts):
            try:
                return func(*args, **kwargs)
            except (psycopg2.OperationalError, psycopg2.InterfaceError) as e:
                if attempt == max_attempts - 1:
                    error_msg = f"ERROR [ConnectionRetry]: Failed after {max_attempts} attempts - {type(e).__name__}: {str(e)}"
                    print(error_msg)
                if 'logger' in globals() and logger:
                    logger.log(error_msg)
                raise
                
                wait_time = min(delay * (CONFIG['retry']['exponential_base'] ** attempt), CONFIG['retry']['max_delay'])
                print(f"WARNING [ConnectionRetry]: Attempt {attempt + 1}/{max_attempts} failed. Retrying in {wait_time}s...")
                time.sleep(wait_time)
        
        return None
    return wrapper

# ============================================================================
# PHASE 3: ENHANCED ERROR LOGGING WITH STACK TRACES
# ============================================================================

def log_error(component: str, message: str, exception: Exception = None, context: dict = None):
    """Enhanced error logging with stack traces"""
    if exception:
        error_msg = f"ERROR [{component}]: {message} - {type(exception).__name__}: {str(exception)}"
        # Add stack trace for debugging
        if CONFIG['processing'].get('debug_mode', False):
            error_msg += f"\nStack trace:\n{traceback.format_exc()}"
    else:
        error_msg = f"ERROR [{component}]: {message}"
    
    if context:
        error_msg += f" | Context: {context}"
    
    print(error_msg)  # Auto-logger captures this
    return error_msg

def log_warning(component: str, message: str, context: dict = None):
    """Standardized warning logging"""
    warning_msg = f"WARNING [{component}]: {message}"
    if context:
        warning_msg += f" | Context: {context}"
    print(warning_msg)
    return warning_msg

def log_info(component: str, message: str):
    """Standardized info logging"""
    info_msg = f"INFO [{component}]: {message}"
    print(info_msg)
    return info_msg

# ============================================================================
# PHASE 3: ENHANCED DATABASE MANAGER WITH CONNECTION POOLING
# ============================================================================

class DatabaseManager:
    """Singleton database manager with connection pooling"""
    
    _instance = None
    _pool = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(DatabaseManager, cls).__new__(cls)
        return cls._instance
    
    def __init__(self):
        if DatabaseManager._pool is None:
            self._initialize_pool()
    
    def _initialize_pool(self):
        """Initialize connection pool"""
        try:
            DatabaseManager._pool = psycopg2.pool.ThreadedConnectionPool(
                CONFIG['pool']['min_connections'],
                CONFIG['pool']['max_connections'],
                **CONFIG['database'],
                keepalives=CONFIG['pool']['keepalives'],
                keepalives_idle=CONFIG['pool']['keepalives_idle'],
                keepalives_interval=CONFIG['pool']['keepalives_interval'],
                keepalives_count=CONFIG['pool']['keepalives_count']
            )
            log_info("DatabaseManager", f"Connection pool initialized with {CONFIG['pool']['max_connections']} max connections")
        except Exception as e:
            log_error("DatabaseManager", "Failed to initialize connection pool", e)
            raise
    
    @contextmanager
    def get_connection(self):
        """Get connection from pool with automatic return"""
        conn = None
        try:
            conn = DatabaseManager._pool.getconn()
            yield conn
            conn.commit()
        except Exception as e:
            if conn:
                conn.rollback()
            raise e
        finally:
            if conn:
                DatabaseManager._pool.putconn(conn)
    
    def close_all(self):
        """Close all connections in pool"""
        if DatabaseManager._pool:
            DatabaseManager._pool.closeall()
            log_info("DatabaseManager", "All connections closed")
    
    def get_pool_status(self) -> Dict:
        """Get current pool status"""
        if DatabaseManager._pool:
            return {
                'minconn': DatabaseManager._pool.minconn,
                'maxconn': DatabaseManager._pool.maxconn,
                'closed': DatabaseManager._pool.closed
            }
        return {'status': 'not initialized'}

# Initialize global database manager
DB_MANAGER = DatabaseManager()

# PHASE 2: Keep backward compatibility
@contextmanager
def get_db_connection():
    """Legacy context manager - now uses DatabaseManager"""
    with DB_MANAGER.get_connection() as conn:
        yield conn

# ============================================================================
# PHASE 3: CHECKPOINT MANAGER FOR FAILURE RECOVERY
# ============================================================================

class CheckpointManager:
    """Manage pipeline checkpoints for failure recovery"""
    
    def __init__(self, checkpoint_dir: str = None):
        self.checkpoint_dir = Path(checkpoint_dir or CONFIG['processing']['checkpoint_dir'])
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.current_checkpoint = None
        log_info("CheckpointManager", f"Initialized with directory: {self.checkpoint_dir}")
    
    def save_checkpoint(self, state: Dict, checkpoint_name: str = None) -> str:
        """Save pipeline state to checkpoint"""
        try:
            if not checkpoint_name:
                checkpoint_name = f"checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
            
            checkpoint_path = self.checkpoint_dir / f"{checkpoint_name}.pkl"
            
            # Add metadata
            state['checkpoint_metadata'] = {
                'created_at': datetime.now().isoformat(),
                'pipeline_version': '3.0',
                'config_hash': hash(str(CONFIG))
            }
            
            with open(checkpoint_path, 'wb') as f:
                pickle.dump(state, f)
            
            self.current_checkpoint = checkpoint_path
            log_info("CheckpointManager", f"Saved checkpoint: {checkpoint_name}")
            return str(checkpoint_path)
            
        except Exception as e:
            log_error("CheckpointManager", "Failed to save checkpoint", e)
            return None
    
    def load_checkpoint(self, checkpoint_path: str = None) -> Optional[Dict]:
        """Load pipeline state from checkpoint"""
        try:
            if not checkpoint_path:
                checkpoint_path = self.current_checkpoint
            
            if not checkpoint_path:
                # Find latest checkpoint
                checkpoints = list(self.checkpoint_dir.glob("checkpoint_*.pkl"))
                if not checkpoints:
                    log_warning("CheckpointManager", "No checkpoints found")
                    return None
                checkpoint_path = max(checkpoints, key=lambda p: p.stat().st_mtime)
            
            with open(checkpoint_path, 'rb') as f:
                state = pickle.load(f)
            
            log_info("CheckpointManager", f"Loaded checkpoint: {Path(checkpoint_path).name}")
            return state
            
        except Exception as e:
            log_error("CheckpointManager", "Failed to load checkpoint", e)
            return None
    
    def list_checkpoints(self) -> List[Dict]:
        """List available checkpoints"""
        checkpoints = []
        for cp_file in self.checkpoint_dir.glob("checkpoint_*.pkl"):
            try:
                stats = cp_file.stat()
                checkpoints.append({
                    'name': cp_file.stem,
                    'path': str(cp_file),
                    'size_mb': stats.st_size / (1024 * 1024),
                    'modified': datetime.fromtimestamp(stats.st_mtime).isoformat()
                })
            except:
                continue
        return sorted(checkpoints, key=lambda x: x['modified'], reverse=True)
    
    def cleanup_old_checkpoints(self, keep_last: int = 5):
        """Clean up old checkpoints"""
        checkpoints = list(self.checkpoint_dir.glob("checkpoint_*.pkl"))
        checkpoints.sort(key=lambda p: p.stat().st_mtime, reverse=True)
        
        for cp_file in checkpoints[keep_last:]:
            try:
                cp_file.unlink()
                log_info("CheckpointManager", f"Deleted old checkpoint: {cp_file.name}")
            except:
                continue

# Initialize global checkpoint manager
CHECKPOINT_MANAGER = CheckpointManager()

# ============================================================================
# PHASE 2: SIZE-LIMITED LRU CACHE
# ============================================================================

class SizeLimitedLRUCache:
    """LRU cache with size limit in MB"""
    
    def __init__(self, max_size_mb: int):
        self.max_size_bytes = max_size_mb * 1024 * 1024
        self.cache = OrderedDict()
        self.current_size = 0
        self.hits = 0
        self.misses = 0
    
    def _estimate_size(self, value: str) -> int:
        """Estimate size of cached value in bytes"""
        return len(value.encode('utf-8')) if isinstance(value, str) else sys.getsizeof(value)
    
    def get(self, key: str):
        """Get item from cache"""
        if key in self.cache:
            self.hits += 1
            # Move to end (most recently used)
            self.cache.move_to_end(key)
            return self.cache[key]
        self.misses += 1
        return None
    
    def put(self, key: str, value, size: int = None):
        """Put item in cache with LRU eviction"""
        if size is None:
            size = self._estimate_size(value)
        
        # Remove old entries if needed
        while self.current_size + size > self.max_size_bytes and self.cache:
            evicted_key, evicted_value = self.cache.popitem(last=False)
            self.current_size -= self._estimate_size(evicted_value)
            log_info("Cache", f"Evicted {evicted_key} to maintain size limit")
        
        # Add new entry
        if key in self.cache:
            self.current_size -= self._estimate_size(self.cache[key])
        
        self.cache[key] = value
        self.current_size += size
        self.cache.move_to_end(key)
    
    def get_stats(self) -> dict:
        """Get cache statistics"""
        hit_rate = (self.hits / (self.hits + self.misses) * 100) if (self.hits + self.misses) > 0 else 0
        return {
            'entries': len(self.cache),
            'size_mb': self.current_size / (1024 * 1024),
            'hits': self.hits,
            'misses': self.misses,
            'hit_rate': hit_rate
        }

# Initialize global cache for EdgarTools sections
SECTION_CACHE = SizeLimitedLRUCache(CONFIG['cache']['max_size_mb'])

# ============================================================================
# PHASE 2: DEPRECATION WARNING SYSTEM
# ============================================================================

def deprecated(replacement_func: str = None):
    """Decorator to mark functions as deprecated"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            if CONFIG['processing']['deprecation_warnings']:
                msg = f"Function '{func.__name__}' is deprecated and will be removed."
                if replacement_func:
                    msg += f" Use '{replacement_func}' instead."
                log_warning("Deprecation", msg)
            return func(*args, **kwargs)
        return wrapper
    return decorator

# ============================================================================
# GITHUB SETUP
# ============================================================================

print("\n📦 Setting up GitHub repository...")
local_path = CONFIG['github']['local_path']
repo_url = CONFIG['github']['repo_url']

# Clone or update repo with force pull
if os.path.exists(local_path):
    log_info("GitHub", f"Repository exists at {local_path}")
    log_info("GitHub", "Force updating from main branch")
    !cd {local_path} && git fetch origin
    !cd {local_path} && git reset --hard origin/main
    !cd {local_path} && git pull origin main
    log_info("GitHub", "Repository updated successfully")
    
    # Show current commit
    !cd {local_path} && echo "Current commit:" && git log --oneline -1
else:
    log_info("GitHub", f"Cloning repository to {local_path}")
    !git clone {repo_url} {local_path}
    log_info("GitHub", "Repository cloned successfully")

# Clear any cached modules from previous runs
modules_to_clear = [key for key in sys.modules.keys() if 'auto_logger' in key.lower() or 'clean' in key.lower()]
for mod in modules_to_clear:
    del sys.modules[mod]
    log_info("ModuleCache", f"Cleared cached module: {mod}")

# Add to Python path for regular imports
if f'{local_path}/BizIntel' in sys.path:
    sys.path.remove(f'{local_path}/BizIntel')
sys.path.insert(0, f'{local_path}/BizIntel')

log_info("Setup", "Python path configured for SEC entity extraction")

# Configure EdgarTools authentication - REQUIRED by SEC
set_identity(CONFIG['edgar']['identity'])
log_info("EdgarTools", f"Identity configured: {CONFIG['edgar']['identity']}")

print("🚀 SEC ENTITY EXTRACTION ENGINE INITIALIZED - PHASE 3 PRODUCTION-READY")
print("="*80)
print(f"✅ GitHub: Repository at {CONFIG['github']['local_path']}")
print(f"✅ Database: Connected to {CONFIG['database']['host']}")
print(f"✅ EdgarTools: Configured as '{CONFIG['edgar']['identity']}'")
print(f"✅ Logging: {'Enabled' if logger else 'Disabled'}")
print(f"✅ Section Validation: {'ENFORCED' if CONFIG['processing']['section_validation'] else 'Disabled'}")
print(f"✅ PHASE 2 Enhancements:")
print(f"   • Database context manager: get_db_connection()")
print(f"   • Size-limited LRU cache: {CONFIG['cache']['max_size_mb']}MB")
print(f"   • Batch size limits: Max {CONFIG['processing']['max_insert_batch']} per insert")
print(f"   • Deprecation warnings: {'Enabled' if CONFIG['processing']['deprecation_warnings'] else 'Disabled'}")
print(f"✅ PHASE 3 PRODUCTION FEATURES:")
print(f"   • Connection Pool: {CONFIG['pool']['max_connections']} max connections")
print(f"   • Checkpoint System: Recovery from failures enabled")
print(f"   • Enhanced Error Logging: Stack traces included")
print(f"   • Model Warm-up: Enabled for faster first inference")
print(f"   • Entity Deduplication: {CONFIG['processing']['deduplication_threshold']} threshold")
print("="*80)

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m60.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m65.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m90.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m68.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [3

2025-09-13 18:10:26.380699: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757787026.586030      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757787026.646647      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


🚀 Loading Enhanced EntityExtractionPipeline with Quality Control...
🔧 Enhanced EntityExtractionPipeline initialized with quality control
📦 Loading 3 NER models with handlers...
   🧠 Loading biobert: alvaroalon2/biobert_diseases_ner


Device set to use cuda:0


      ✓ biobert loaded with BioBERTHandler
   🧠 Loading bert_base: dslim/bert-base-NER


Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use cuda:0


      ✓ bert_base loaded with BERTHandler
   🧠 Loading roberta: Jean-Baptiste/roberta-large-ner-english


Device set to use cuda:0


      ✓ roberta loaded with RoBERTaHandler
   ✅ Successfully loaded: 3 models

✅ Enhanced EntityExtractionPipeline ready!
   🎯 Loaded models: ['biobert', 'bert_base', 'roberta']
   🔍 Quality control: Enabled
   📊 Consensus requirement: 2 models
   💻 Device: GPU
🚀 Loading In-Memory Pipeline with Storage and Local Llama 3.1-8B...
   ✅ Database tables verified/created
   🔐 Logging in to HuggingFace...
   ✅ Logged in to HuggingFace
   ⚙️ Configuring 4-bit quantization...
   📥 Loading Llama 3.1-8B-Instruct (this may take a minute)...
   ✅ Llama 3.1-8B loaded successfully (4-bit quantized)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


   🧪 Test response: system

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

user

What is a partnership? ...

✅ In-Memory Pipeline Components Ready!
   🎯 RelationshipExtractor with Llama 3.1
   💾 Atomic storage for entities + relationships
   🚀 In-memory processing (no DB round-trips)
   📊 Usage: batch_results = process_filings_batch(limit=5)
   🧪 Test: test_result = test_pipeline()
🚀 STARTING SEC FILING PROCESSING PIPELINE
📝 Relationship extraction enabled with local Llama 3.1-8B
   ℹ️ Using local Llama 3.1-8B for relationship extraction

📊 Checking for unprocessed filings...
INFO [DatabaseQuery]: Retrieved 10 unprocessed filings (excluded 1 problematic)
   Found 10 unprocessed filings

📋 Available filings to process:
   1. grail.com - 8-K (2025-05-16)
   2. grail.com - 10-Q (2025-05-14)
   3. grail.com - 8-K (2025-05-13)
   4. exactsciences.com - 10-Q (2025-05-01)
   5. exactsciences.com - 8-K (2025-05-01)

🔄 Processing 3 filings...
------------------------------------

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


      • biobert: 4 quality entities
      • bert_base: 47 quality entities




      • roberta: 33 quality entities
   ✅ Extracted 48 high-quality entities
   📊 Filtered out 0 low-quality entities
   🔍 Extracted 48 entities
   💾 Storing 48 entities to database...
      ✅ Entities stored: 48 entities
         • Merged entities: 0, Single-model: 48
   🤖 Starting Llama 3.1 relationship extraction (this may take several minutes)...
   🔍 Analyzing relationships for grail.com
      🎯 Filtered: 48 → 44 entities (removed 4 non-business entities)
      🎯 Total entities to analyze: 44
      📑 Processing 44 entities in 'full_document'
         🔬 Analyzing: death
         🔬 Analyzing: D. C
         ⏳ Progress: 10/44 entities (22%)
         🔬 Analyzing: Act
         🔬 Analyzing: Compensation Committee
         ⏳ Progress: 20/44 entities (45%)
         🔬 Analyzing: Company
         ⏱️  Estimated time remaining: 3.0 minutes
         🔬 Analyzing: Company
         ⏳ Progress: 30/44 entities (68%)
         🔬 Analyzing: Company
         🔬 Analyzing: Inc.
         ⏳ Progress: 40/44 

In [3]:
# Cell 2: Database Functions and ORM-like Models with Batching - FIXED HANGING ISSUE

# ============================================================================
# FIX STDOUT/STDERR AFTER CELL CANCELLATION
# ============================================================================
# This fixes the "I/O operation on closed file" error when restarting after cancel

import sys
try:
    # Test if stdout is working
    sys.stdout.write('')
    sys.stdout.flush()
except (ValueError, AttributeError, OSError):
    # stdout is broken, need to restore it
    try:
        from IPython import get_ipython
        from ipykernel.iostream import OutStream
        
        ip = get_ipython()
        if ip and hasattr(ip, 'kernel'):
            # Restore stdout and stderr using kernel's output streams
            sys.stdout = OutStream(ip.kernel.session, ip.kernel.iopub_socket, 'stdout')
            sys.stderr = OutStream(ip.kernel.session, ip.kernel.iopub_socket, 'stderr')
            print("✅ Restored stdout/stderr after cancellation")
    except Exception as e:
        import warnings
        warnings.warn(f"Could not restore output streams: {e}. You may need to restart the kernel.")

# Now safe to continue with Cell 2
import edgar
from edgar import Filing, find, set_identity, Company
from edgar.documents import parse_html
from edgar.documents.extractors.section_extractor import SectionExtractor
import requests
import re
import signal
from functools import wraps
from typing import Dict, List, Optional, Tuple
from bs4 import BeautifulSoup

print("Starting cell 2")


# Ensure identity is set
set_identity(CONFIG['edgar']['identity'])

# ============================================================================
# TIMEOUT HANDLER FOR EDGARTOOLS API CALLS - FIX FOR HANGING
# ============================================================================

class TimeoutError(Exception):
    """Custom timeout exception"""
    pass

def timeout_handler(signum, frame):
    """Signal handler for timeout"""
    raise TimeoutError("EdgarTools API call timed out")

def with_timeout(seconds=30):
    """Decorator to add timeout to functions"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # Set the signal alarm
            signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(seconds)
            try:
                result = func(*args, **kwargs)
            finally:
                # Disable the alarm
                signal.alarm(0)
            return result
        return wrapper
    return decorator

# ============================================================================
# PROBLEMATIC FILINGS TO SKIP
# ============================================================================

# Known problematic filings that cause indefinite hangs
PROBLEMATIC_FILINGS = [
    '0001699031-25-000166',  # Grail 10-Q that caused 11+ hour hang
]

# ============================================================================
# TIMEOUT-WRAPPED EDGARTOOLS CALLS
# ============================================================================

@with_timeout(30)  # 30 second timeout
def find_filing_with_timeout(accession_number: str):
    """Find filing with timeout protection"""
    print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting EdgarTools find() for {accession_number}")
    filing = find(accession_number)
    print(f"[{datetime.now().strftime('%H:%M:%S')}] find() completed successfully")
    return filing

@with_timeout(60)  # 60 second timeout for HTML download
def get_html_with_timeout(filing):
    """Get HTML content with timeout protection"""
    print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting html() fetch...")
    html_content = filing.html()
    if html_content:
        print(f"[{datetime.now().strftime('%H:%M:%S')}] html() completed, size: {len(html_content):,} bytes")
    else:
        print(f"[{datetime.now().strftime('%H:%M:%S')}] html() returned empty content")
    return html_content

@with_timeout(30)  # 30 second timeout for parsing
def parse_html_with_timeout(html_content):
    """Parse HTML with timeout protection"""
    print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting HTML parsing...")
    document = parse_html(html_content)
    print(f"[{datetime.now().strftime('%H:%M:%S')}] HTML parsing completed")
    return document

def get_filing_sections(accession_number: str, filing_type: str = None) -> Dict[str, str]:
    """Get structured sections from SEC filing using accession number
    
    ENHANCED: With timeouts, progress monitoring, and problematic filing skipping
    """
    # Skip known problematic filings
    if accession_number in PROBLEMATIC_FILINGS:
        log_warning("EdgarTools", f"Skipping known problematic filing: {accession_number}")
        return {}
    
    # Check cache first
    cache_key = f"{accession_number}#{filing_type or 'UNKNOWN'}"
    cached_sections = SECTION_CACHE.get(cache_key)
    if cached_sections:
        log_info("Cache", f"Cache hit for {accession_number}")
        return cached_sections
    
    try:
        # Find filing with timeout protection
        try:
            filing = find_filing_with_timeout(accession_number)
        except TimeoutError:
            log_error("EdgarTools", f"Timeout finding filing {accession_number} (30s exceeded)")
            return {}
        
        if not filing:
            raise ValueError(f"Filing not found for accession: {accession_number}")
            
        # Auto-detect filing type if not provided
        if not filing_type:
            filing_type = getattr(filing, 'form', '10-K')
        
        log_info("EdgarTools", f"Found {filing_type} for {getattr(filing, 'company', 'Unknown Company')}")
        
        # Get structured HTML content with timeout
        try:
            html_content = get_html_with_timeout(filing)
        except TimeoutError:
            log_error("EdgarTools", f"Timeout fetching HTML for {accession_number} (60s exceeded)")
            return {}
        
        if not html_content:
            raise ValueError("No HTML content available")
        
        # Limit HTML size to prevent memory issues
        MAX_HTML_SIZE = 10 * 1024 * 1024  # 10MB limit
        if len(html_content) > MAX_HTML_SIZE:
            log_warning("EdgarTools", f"HTML too large ({len(html_content):,} bytes), truncating to {MAX_HTML_SIZE:,}")
            html_content = html_content[:MAX_HTML_SIZE]
        
        # Parse HTML to Document object with timeout
        try:
            document = parse_html_with_timeout(html_content)
        except TimeoutError:
            log_error("EdgarTools", f"Timeout parsing HTML for {accession_number} (30s exceeded)")
            return {}
        
        # Extract sections using SectionExtractor
        extractor = SectionExtractor(filing_type=filing_type)
        sections = extractor.extract(document)
        
        log_info("EdgarTools", f"SectionExtractor found {len(sections)} sections")
        
        # Convert sections to text dictionary
        section_texts = {}
        for section_name, section in sections.items():
            try:
                if hasattr(section, 'text'):
                    text = section.text() if callable(section.text) else section.text
                    if isinstance(text, str) and text.strip():
                        section_texts[section_name] = text.strip()
                        print(f"      • {section_name}: {len(text):,} chars")
                elif hasattr(section, '__str__'):
                    text = str(section).strip()
                    if text:
                        section_texts[section_name] = text
                        print(f"      • {section_name}: {len(text):,} chars (via str)")
            except Exception as section_e:
                log_warning("EdgarTools", f"Could not extract section {section_name}", {"error": str(section_e)})
                continue
        
        # If SectionExtractor returns no sections, fall back to full document text
        if not section_texts:
            log_warning("EdgarTools", "No structured sections found, using full document fallback")
            full_text = document.text() if hasattr(document, 'text') and callable(document.text) else str(document)
            if full_text and len(full_text.strip()) > 100:  # Only use if substantial content
                # Limit full document size
                if len(full_text) > MAX_HTML_SIZE:
                    log_warning("EdgarTools", f"Full document too large ({len(full_text):,} chars), truncating")
                    full_text = full_text[:MAX_HTML_SIZE]
                section_texts['full_document'] = full_text.strip()
                log_info("EdgarTools", f"Using full document: {len(full_text):,} chars")
        
        # Cache the result
        if section_texts and CONFIG['cache']['enabled']:
            SECTION_CACHE.put(cache_key, section_texts)
            log_info("Cache", f"Cached sections for {accession_number} ({len(section_texts)} sections)")
        
        return section_texts
        
    except Exception as e:
        log_error("EdgarTools", f"Failed to fetch filing {accession_number}", e)
        return {}  # Return empty dict on network/API failure

def route_sections_to_models(sections: Dict[str, str], filing_type: str) -> Dict[str, List[str]]:
    """Route sections to appropriate NER models based on filing type"""
    routing = {
        'biobert': [],
        'bert_base': [],
        'roberta': [],
        'finbert': []
    }
    
    if filing_type.upper() in ['10-K', '10-Q']:
        for section_name, section_text in sections.items():
            # FinBERT gets financial statements exclusively
            if 'financial' in section_name.lower() or 'statement' in section_name.lower():
                routing['finbert'].append(section_name)
            else:
                # All other sections go to BERT/RoBERTa/BioBERT
                routing['bert_base'].append(section_name)
                routing['roberta'].append(section_name)
                routing['biobert'].append(section_name)
    
    elif filing_type.upper() == '8-K':
        # 8-K: all item sections go to all four models
        for section_name in sections.keys():
            routing['biobert'].append(section_name)
            routing['bert_base'].append(section_name)
            routing['roberta'].append(section_name)
            routing['finbert'].append(section_name)
    
    else:
        # Default routing for other filing types
        for section_name in sections.keys():
            routing['bert_base'].append(section_name)
            routing['roberta'].append(section_name)
            routing['biobert'].append(section_name)
    
    # Remove empty routing
    routing = {model: sections_list for model, sections_list in routing.items() if sections_list}
    
    return routing

def process_sec_filing_with_sections(filing_data: Dict) -> Dict:
    """Process SEC filing with section-based extraction
    
    ENHANCED: With timeout protection and progress monitoring
    """
    try:
        filing_id = filing_data.get('id')
        accession_number = filing_data.get('accession_number')  # DIRECT FROM DATABASE
        filing_type = filing_data.get('filing_type', '10-K')
        company_domain = filing_data.get('company_domain', 'Unknown')
        filing_url = filing_data.get('url')  # Still keep for reference
        
        log_info("FilingProcessor", f"Processing {filing_type} for {company_domain}")
        print(f"   📄 Filing ID: {filing_id}")
        print(f"   📑 Accession: {accession_number}")
        
        # Validate accession number
        if not accession_number:
            raise ValueError(f"Missing accession number for filing {filing_id}")
        
        # Check if this is a problematic filing
        if accession_number in PROBLEMATIC_FILINGS:
            log_warning("FilingProcessor", f"Skipping problematic filing: {accession_number}")
            return {
                'filing_id': filing_id,
                'company_domain': company_domain,
                'filing_type': filing_type,
                'accession_number': accession_number,
                'error': 'Skipped - known problematic filing',
                'processing_status': 'skipped'
            }
        
        # Get structured sections using accession directly
        sections = get_filing_sections(accession_number, filing_type)
        if not sections:
            raise ValueError("No sections extracted")
        
        log_info("FilingProcessor", f"Extracted {len(sections)} sections")
        
        # Route sections to models
        model_routing = route_sections_to_models(sections, filing_type)
        print(f"   🎯 Model routing: {[f'{model}: {len(secs)} sections' for model, secs in model_routing.items()]}")
        
        # Validate section names if configured
        if CONFIG['processing']['section_validation']:
            missing_sections = [name for name in sections.keys() if not name]
            if missing_sections:
                log_warning("FilingProcessor", f"Found {len(missing_sections)} sections without names")
        
        # Show cache statistics
        cache_stats = SECTION_CACHE.get_stats()
        if cache_stats['hits'] > 0:
            print(f"   📊 Cache: {cache_stats['hit_rate']:.1f}% hit rate, {cache_stats['size_mb']:.1f}MB used")
        
        return {
            'filing_id': filing_id,
            'company_domain': company_domain,
            'filing_type': filing_type,
            'accession_number': accession_number,
            'url': filing_url,
            'sections': sections,
            'model_routing': model_routing,
            'total_sections': len(sections),
            'processing_status': 'success'
        }
        
    except TimeoutError as e:
        log_error("FilingProcessor", "Filing processing timed out", e, 
                 {"filing_id": filing_data.get('id'), "accession": filing_data.get('accession_number')})
        return {
            'filing_id': filing_data.get('id'),
            'company_domain': filing_data.get('company_domain', 'Unknown'),
            'filing_type': filing_data.get('filing_type', 'Unknown'),
            'accession_number': filing_data.get('accession_number'),
            'error': 'Processing timeout',
            'processing_status': 'timeout'
        }
    except Exception as e:
        log_error("FilingProcessor", "Filing processing failed", e, 
                 {"filing_id": filing_data.get('id'), "accession": filing_data.get('accession_number')})
        return {
            'filing_id': filing_data.get('id'),
            'company_domain': filing_data.get('company_domain', 'Unknown'),
            'filing_type': filing_data.get('filing_type', 'Unknown'),
            'accession_number': filing_data.get('accession_number'),
            'error': str(e),
            'processing_status': 'failed'
        }

@retry_on_connection_error
def get_unprocessed_filings(limit: int = 5) -> List[Dict]:
    """Get SEC filings that haven't been processed yet
    
    ENHANCED: Skip known problematic filings
    """
    with get_db_connection() as conn:  # PHASE 2: Using context manager
        cursor = conn.cursor()
        
        # Build exclusion list for SQL
        exclusion_list = "', '".join(PROBLEMATIC_FILINGS)
        exclusion_clause = f"AND sf.accession_number NOT IN ('{exclusion_list}')" if PROBLEMATIC_FILINGS else ""
        
        cursor.execute(f"""
            SELECT 
                sf.id, 
                sf.company_domain, 
                sf.filing_type, 
                sf.accession_number,
                sf.url, 
                sf.filing_date, 
                sf.title
            FROM raw_data.sec_filings sf
            LEFT JOIN system_uno.sec_entities_raw ser 
                ON ser.sec_filing_ref = CONCAT('SEC_', sf.id)
            WHERE sf.accession_number IS NOT NULL  -- Must have accession
                AND ser.sec_filing_ref IS NULL     -- Not yet processed
                {exclusion_clause}                 -- Skip problematic filings
            ORDER BY sf.filing_date DESC
            LIMIT %s
        """, (limit,))
        
        filings = cursor.fetchall()
        cursor.close()
        
        log_info("DatabaseQuery", f"Retrieved {len(filings)} unprocessed filings (excluded {len(PROBLEMATIC_FILINGS)} problematic)")
        
        return [{
            'id': filing[0],
            'company_domain': filing[1],
            'filing_type': filing[2],
            'accession_number': filing[3],
            'url': filing[4],
            'filing_date': filing[5],
            'title': filing[6]
        } for filing in filings]

# Test the simplified extraction with timeout protection
log_info("Test", "Starting section extraction test with timeout protection")

test_filings = get_unprocessed_filings(limit=1)

if test_filings:
    print(f"\n🧪 Testing with filing: {test_filings[0]['company_domain']} - {test_filings[0]['filing_type']}")
    print(f"   Accession: {test_filings[0]['accession_number']}")
    
    test_result = process_sec_filing_with_sections(test_filings[0])
    
    if test_result['processing_status'] == 'success':
        log_info("Test", f"✅ Successfully extracted {test_result['total_sections']} sections")
    elif test_result['processing_status'] == 'timeout':
        log_warning("Test", f"⏱️ Processing timed out - filing may be too large or slow")
    elif test_result['processing_status'] == 'skipped':
        log_info("Test", f"⏭️ Skipped problematic filing")
    else:
        log_error("Test", f"❌ Section extraction failed: {test_result.get('error')}")
else:
    log_info("Test", "No test filings available (all may be processed or problematic)")

print("✅ Cell 2 complete - EdgarTools section extraction with timeout protection ready")

In [4]:
# Cell 3: Enhanced Entity Extraction Pipeline with Quality Control

import torch
import time
import uuid
import json
import re
from datetime import datetime
from typing import List, Dict, Any, Optional, Set, Tuple
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np

print("🚀 Loading Enhanced EntityExtractionPipeline with Quality Control...")

class ModelHandler:
    """Base class for model-specific entity handlers"""
    
    def __init__(self, model_name: str, pipeline_obj):
        self.model_name = model_name
        self.pipeline = pipeline_obj
        self.stats = {
            'texts_processed': 0,
            'entities_found': 0,
            'entities_filtered': 0,
            'processing_time': 0
        }
    
    def process_chunk(self, text: str, chunk_start: int = 0) -> List[Dict]:
        """Process a text chunk and return raw entities"""
        raw_entities = self.pipeline(text)
        
        # Adjust positions to full text coordinates
        for entity in raw_entities:
            entity["start"] = entity.get("start", 0) + chunk_start
            entity["end"] = entity.get("end", 0) + chunk_start
        
        return raw_entities
    
    def filter_entities(self, entities: List[Dict]) -> List[Dict]:
        """Apply model-specific filtering"""
        return entities
    
    def score_entity(self, entity: Dict) -> float:
        """Calculate quality score for an entity"""
        score = entity.get('score', 0.5)
        
        # Penalize very short entities
        text_len = len(entity.get('word', ''))
        if text_len < 3:
            score *= 0.5
        elif text_len > 100:
            score *= 0.7  # Likely a phrase/sentence, not an entity
        
        return score


class BioBERTHandler(ModelHandler):
    """Handler for BioBERT with medical entity focus"""
    
    def filter_entities(self, entities: List[Dict]) -> List[Dict]:
        filtered = []
        for entity in entities:
            # Skip '0' category (non-medical text misclassified)
            if entity.get('entity_group') == '0':
                self.stats['entities_filtered'] += 1
                continue
            
            # Skip if it looks like a full sentence
            text = entity.get('word', '')
            if len(text) > 150 or text.count('.') > 1:
                self.stats['entities_filtered'] += 1
                continue
            
            # Keep valid medical entities
            filtered.append(entity)
        
        return filtered
    
    def score_entity(self, entity: Dict) -> float:
        score = super().score_entity(entity)
        
        # Boost score for known medical patterns
        text = entity.get('word', '').lower()
        medical_keywords = ['disease', 'syndrome', 'cancer', 'therapy', 'drug', 
                          'treatment', 'clinical', 'trial', 'fda', 'phase']
        
        if any(keyword in text for keyword in medical_keywords):
            score *= 1.2
        
        # Check for medical acronyms (3-5 uppercase letters)
        if re.match(r'^[A-Z]{3,5}$', entity.get('word', '')):
            score *= 1.1
        
        return min(score, 1.0)


class BERTHandler(ModelHandler):
    """Handler for BERT-base NER"""
    
    def filter_entities(self, entities: List[Dict]) -> List[Dict]:
        filtered = []
        for entity in entities:
            # Filter out MISC entities that are too vague
            if entity.get('entity_group') == 'MISC':
                text = entity.get('word', '')
                # Keep only if it's a clear entity (capitalized, reasonable length)
                if len(text) < 3 or len(text) > 50 or not text[0].isupper():
                    self.stats['entities_filtered'] += 1
                    continue
            
            filtered.append(entity)
        
        return filtered
    
    def score_entity(self, entity: Dict) -> float:
        score = super().score_entity(entity)
        
        # Boost score for clear entity types
        entity_type = entity.get('entity_group', '')
        if entity_type in ['PER', 'ORG', 'LOC']:
            score *= 1.1
        elif entity_type == 'MISC':
            score *= 0.9
        
        return min(score, 1.0)


class RoBERTaHandler(ModelHandler):
    """Handler for RoBERTa with high precision focus"""
    
    def filter_entities(self, entities: List[Dict]) -> List[Dict]:
        # RoBERTa is generally high precision, minimal filtering needed
        return [e for e in entities if len(e.get('word', '')) >= 2]
    
    def score_entity(self, entity: Dict) -> float:
        # RoBERTa typically has higher confidence, boost scores slightly
        score = super().score_entity(entity) * 1.05
        return min(score, 1.0)


class EntityExtractionPipeline:
    """Enhanced pipeline with quality control and model-specific handlers"""
    
    def __init__(self, db_config: Dict, model_config: Dict = None):
        self.db_config = db_config
        self.model_config = model_config or self._get_default_model_config()
        self.model_handlers = {}
        self.pipeline_stats = {
            'documents_processed': 0,
            'total_entities_extracted': 0,
            'entities_filtered': 0,
            'processing_time_total': 0
        }
        
        print(f"🔧 Enhanced EntityExtractionPipeline initialized with quality control")
    
    def _get_default_model_config(self) -> Dict:
        """Model configuration with quality thresholds"""
        return {
            'models': {
                'biobert': {
                    'model_name': 'alvaroalon2/biobert_diseases_ner',
                    'confidence_threshold': 0.6,  # Raised threshold
                    'handler_class': BioBERTHandler,
                    'description': 'Medical entities (diseases, treatments)'
                },
                'bert_base': {
                    'model_name': 'dslim/bert-base-NER',
                    'confidence_threshold': 0.5,
                    'handler_class': BERTHandler,
                    'description': 'General entities (persons, organizations)'
                },
                'roberta': {
                    'model_name': 'Jean-Baptiste/roberta-large-ner-english',
                    'confidence_threshold': 0.6,
                    'handler_class': RoBERTaHandler,
                    'description': 'High-precision general entities'
                }
            },
            'quality_thresholds': {
                'min_entity_length': 2,
                'max_entity_length': 100,
                'min_consensus_models': 2,  # Require 2+ models to agree
                'min_quality_score': 0.5
            },
            'entity_type_mapping': {
                # BioBERT mappings
                'Disease': 'MEDICAL_CONDITION',
                'Chemical': 'MEDICATION',
                'DISEASE': 'MEDICAL_CONDITION',
                'DRUG': 'MEDICATION',
                
                # BERT mappings
                'PER': 'PERSON',
                'ORG': 'ORGANIZATION',
                'LOC': 'LOCATION',
                'MISC': 'MISCELLANEOUS'
            }
        }
    
    def load_models(self) -> Dict:
        """Load NER models with specific handlers"""
        print(f"📦 Loading {len(self.model_config['models'])} NER models with handlers...")
        
        for model_key, model_info in self.model_config['models'].items():
            try:
                model_name = model_info['model_name']
                print(f"   🧠 Loading {model_key}: {model_name}")
                
                # Skip FinBERT - it's for sentiment, not NER
                if 'finbert' in model_key.lower():
                    print(f"      ⏭️ Skipping FinBERT (sentiment model, not NER)")
                    continue
                
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                model = AutoModelForTokenClassification.from_pretrained(model_name)
                ner_pipeline = pipeline(
                    "ner", 
                    model=model, 
                    tokenizer=tokenizer,
                    aggregation_strategy="average",
                    device=0 if torch.cuda.is_available() else -1
                )
                
                # Create model-specific handler
                handler_class = model_info.get('handler_class', ModelHandler)
                handler = handler_class(model_key, ner_pipeline)
                handler.confidence_threshold = model_info['confidence_threshold']
                
                self.model_handlers[model_key] = handler
                
                print(f"      ✓ {model_key} loaded with {handler_class.__name__}")
                
            except Exception as e:
                print(f"      ❌ {model_key} failed: {e}")
        
        print(f"   ✅ Successfully loaded: {len(self.model_handlers)} models")
        return self.model_handlers
    
    def extract_entities_from_text(self, text: str, model_name: str) -> List[Dict]:
        """Extract entities using specific model handler"""
        if model_name not in self.model_handlers or not text.strip():
            return []
        
        try:
            handler = self.model_handlers[model_name]
            start_time = time.time()
            
            # Process in chunks
            max_length = 512
            overlap = 50
            stride = max_length - overlap
            
            all_entities = []
            for chunk_start in range(0, len(text), stride):
                chunk_end = min(chunk_start + max_length, len(text))
                chunk = text[chunk_start:chunk_end]
                
                # Get raw entities from chunk
                chunk_entities = handler.process_chunk(chunk, chunk_start)
                
                # Apply model-specific filtering
                filtered_entities = handler.filter_entities(chunk_entities)
                
                all_entities.extend(filtered_entities)
                
                if chunk_end >= len(text):
                    break
            
            # Deduplicate entities
            unique_entities = self._deduplicate_entities(all_entities)
            
            # Filter by confidence and quality
            final_entities = []
            for entity in unique_entities:
                if entity['score'] >= handler.confidence_threshold:
                    # Calculate quality score
                    quality_score = handler.score_entity(entity)
                    
                    if quality_score >= self.model_config['quality_thresholds']['min_quality_score']:
                        processed = {
                            'extraction_id': str(uuid.uuid4()),
                            'entity_text': entity['word'].strip(),
                            'entity_type': self.model_config['entity_type_mapping'].get(
                                entity['entity_group'], entity['entity_group']
                            ),
                            'confidence_score': float(entity['score']),
                            'quality_score': quality_score,
                            'char_start': entity['start'],
                            'char_end': entity['end'],
                            'model_source': model_name,
                            'extraction_timestamp': datetime.now().isoformat()
                        }
                        final_entities.append(processed)
            
            # Update stats
            handler.stats['texts_processed'] += 1
            handler.stats['entities_found'] += len(final_entities)
            handler.stats['processing_time'] += time.time() - start_time
            
            return final_entities
            
        except Exception as e:
            print(f"   ❌ {model_name} extraction failed: {e}")
            return []
    
    def _deduplicate_entities(self, entities: List[Dict]) -> List[Dict]:
        """Remove duplicate entities at same positions"""
        seen = set()
        unique = []
        
        for entity in entities:
            key = (entity.get('start'), entity.get('end'), entity.get('word'))
            if key not in seen:
                seen.add(key)
                unique.append(entity)
        
        return unique
    
    def process_sec_filing_sections(self, filing_result: Dict) -> List[Dict]:
        """Process SEC sections with multi-model consensus"""
        if filing_result.get('processing_status') != 'success':
            return []
        
        sections = filing_result['sections']
        print(f"🔍 Processing {len(sections)} sections with quality control")
        
        all_entities = []
        
        for section_name, section_text in sections.items():
            if not section_text.strip():
                continue
            
            print(f"   📑 Processing '{section_name}'")
            
            # Extract with all applicable models
            section_entities_by_model = {}
            for model_name in self.model_handlers:
                entities = self.extract_entities_from_text(section_text, model_name)
                if entities:
                    section_entities_by_model[model_name] = entities
                    print(f"      • {model_name}: {len(entities)} quality entities")
            
            # Apply consensus merging
            consensus_entities = self._apply_consensus_merge(
                section_entities_by_model,
                section_name,
                filing_result
            )
            
            all_entities.extend(consensus_entities)
        
        # Final quality filtering
        high_quality_entities = self._final_quality_filter(all_entities)
        
        print(f"   ✅ Extracted {len(high_quality_entities)} high-quality entities")
        print(f"   📊 Filtered out {len(all_entities) - len(high_quality_entities)} low-quality entities")
        
        self.pipeline_stats['total_entities_extracted'] += len(high_quality_entities)
        self.pipeline_stats['entities_filtered'] += len(all_entities) - len(high_quality_entities)
        
        return high_quality_entities
    
    def _apply_consensus_merge(self, entities_by_model: Dict[str, List[Dict]], 
                              section_name: str, filing_result: Dict) -> List[Dict]:
        """Merge entities with consensus scoring"""
        if not entities_by_model:
            return []
        
        # Group entities by position
        position_groups = {}
        
        for model_name, entities in entities_by_model.items():
            for entity in entities:
                # Create position key with tolerance
                start = entity['char_start']
                end = entity['char_end']
                
                # Find matching position group (within 5 char tolerance)
                matched = False
                for pos_key in list(position_groups.keys()):
                    key_start, key_end = map(int, pos_key.split('_'))
                    if abs(start - key_start) <= 5 and abs(end - key_end) <= 5:
                        position_groups[pos_key].append((model_name, entity))
                        matched = True
                        break
                
                if not matched:
                    pos_key = f"{start}_{end}"
                    position_groups[pos_key] = [(model_name, entity)]
        
        # Create consensus entities
        consensus_entities = []
        min_models = self.model_config['quality_thresholds']['min_consensus_models']
        
        for pos_key, model_entities in position_groups.items():
            # Single model detection - apply stricter quality threshold
            if len(model_entities) == 1:
                model_name, entity = model_entities[0]
                if entity.get('quality_score', 0) >= 0.7:  # Higher threshold for single model
                    entity['extraction_id'] = entity.get('extraction_id', str(uuid.uuid4()))
                    entity['consensus_count'] = 1
                    entity['detecting_models'] = [model_name]
                    entity['section_name'] = section_name
                    entity['filing_id'] = filing_result['filing_id']
                    entity['company_domain'] = filing_result['company_domain']
                    entity['sec_filing_ref'] = f"SEC_{filing_result['filing_id']}"
                    consensus_entities.append(entity)
            
            # Multi-model consensus - more confident
            else:
                # Take highest quality entity as base
                best_entity = max(model_entities, key=lambda x: x[1].get('quality_score', 0))[1].copy()
                
                # Add consensus metadata
                best_entity['extraction_id'] = str(uuid.uuid4())  # New ID for merged entity
                best_entity['consensus_count'] = len(model_entities)
                best_entity['detecting_models'] = [m[0] for m in model_entities]
                best_entity['consensus_score'] = np.mean([e[1].get('quality_score', 0) 
                                                         for e in model_entities])
                
                # Boost quality score for consensus
                best_entity['quality_score'] = min(best_entity['quality_score'] * 1.2, 1.0)
                
                # Add filing metadata
                best_entity['section_name'] = section_name
                best_entity['filing_id'] = filing_result['filing_id']
                best_entity['company_domain'] = filing_result['company_domain']
                best_entity['sec_filing_ref'] = f"SEC_{filing_result['filing_id']}"
                
                consensus_entities.append(best_entity)
        
        return consensus_entities
    
    def _final_quality_filter(self, entities: List[Dict]) -> List[Dict]:
        """Apply final quality filters"""
        filtered = []
        
        # Track entity text frequency (common boilerplate detection)
        text_frequency = {}
        for entity in entities:
            text = entity['entity_text'].lower()
            text_frequency[text] = text_frequency.get(text, 0) + 1
        
        for entity in entities:
            text = entity['entity_text']
            
            # Skip if too short or too long
            if len(text) < self.model_config['quality_thresholds']['min_entity_length']:
                continue
            if len(text) > self.model_config['quality_thresholds']['max_entity_length']:
                continue
            
            # Skip common boilerplate (appears too frequently)
            if text_frequency.get(text.lower(), 0) > 10:
                continue
            
            # Skip if looks like a sentence
            if text.count(' ') > 5 or text.count('.') > 1:
                continue
            
            # Skip common non-entities
            skip_patterns = [
                r'^\d+$',  # Just numbers
                r'^[A-Z]$',  # Single letters
                r'^(the|and|or|of|in|to|for|with|on|at|from|by|about)$',  # Common words
                r'^\W+$',  # Just punctuation
            ]
            
            if any(re.match(pattern, text, re.IGNORECASE) for pattern in skip_patterns):
                continue
            
            filtered.append(entity)
        
        return filtered
    
    def get_pipeline_statistics(self) -> Dict:
        """Get comprehensive statistics"""
        model_stats = {}
        for name, handler in self.model_handlers.items():
            model_stats[name] = handler.stats.copy()
        
        return {
            'pipeline_stats': self.pipeline_stats.copy(),
            'model_stats': model_stats,
            'loaded_models': list(self.model_handlers.keys()),
            'quality_settings': self.model_config['quality_thresholds']
        }


# Initialize the enhanced pipeline
entity_pipeline = EntityExtractionPipeline(NEON_CONFIG)

# Load models with handlers
loaded_models = entity_pipeline.load_models()

print(f"\n✅ Enhanced EntityExtractionPipeline ready!")
print(f"   🎯 Loaded models: {list(loaded_models.keys())}")
print(f"   🔍 Quality control: Enabled")
print(f"   📊 Consensus requirement: {entity_pipeline.model_config['quality_thresholds']['min_consensus_models']} models")
print(f"   💻 Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

tokenizer_config.json:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/431M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/431M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/829 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/255 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/849 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

In [5]:
# Cell 4: Relationship Extractor with Local Llama 3.1-8B

import psycopg2
from psycopg2.extras import execute_values
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
import re
from typing import List, Dict, Tuple, Optional
import json
import uuid
import time
import os
from datetime import datetime

print("🚀 Loading In-Memory Pipeline with Storage and Local Llama 3.1-8B...")

# ================================================================================
# RELATIONSHIP EXTRACTOR WITH LOCAL LLAMA 3.1-8B
# ================================================================================

class RelationshipExtractor:
    """Extract company-centric relationships using local Llama 3.1-8B"""
    
    def __init__(self, llama_config: Dict = None):
        """Initialize with local Llama 3.1-8B model"""
        self.config = llama_config or CONFIG.get('llama', {})
        self.model = None
        self.tokenizer = None
        self.stats = {
            'entities_processed': 0,
            'relationships_found': 0,
            'llama_calls': 0,
            'processing_time': 0
        }
        
                
        # Verify CONFIG is available
        if not CONFIG.get('llama', {}).get('enabled', False):
            print("   ⚠️ Llama configuration disabled in CONFIG")
            return
        
        try:
            # Auto-login to HuggingFace using Kaggle secret
            print("   🔐 Logging in to HuggingFace...")
            user_secrets = UserSecretsClient()
            hf_token = user_secrets.get_secret('HUGGINGFACE_TOKEN')
            
            if hf_token:
                login(token=hf_token, add_to_git_credential=False)
                print("   ✅ Logged in to HuggingFace")
            else:
                print("   ⚠️ No HUGGINGFACE_TOKEN found in Kaggle secrets")
                return
            
            # Configure 4-bit quantization for memory efficiency
            print("   ⚙️ Configuring 4-bit quantization...")
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
            
            # Load Llama 3.1-8B model
            print("   📥 Loading Llama 3.1-8B-Instruct (this may take a minute)...")
            model_name = CONFIG["llama"]["model_name"]
            
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True
            )
            
            print("   ✅ Llama 3.1-8B loaded successfully (4-bit quantized)")
            
            # Test the model
            test_messages = [
                {"role": "user", "content": "What is a partnership? Answer in one sentence."}
            ]
            test_input = self.tokenizer.apply_chat_template(test_messages, return_tensors="pt", tokenize=True)
            
            with torch.no_grad():
                outputs = self.model.generate(test_input, max_new_tokens=CONFIG["llama"]["test_max_tokens"], temperature=CONFIG["llama"]["temperature"])
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            print(f"   🧪 Test response: {response[:100]}...")
            
        except Exception as e:
            print(f"   ❌ Failed to load Llama 3.1-8B: {e}")
            print("   ⚠️ Relationship extraction will be disabled")
            self.model = None
            self.tokenizer = None
    
    def _filter_entities(self, entities: List[Dict]) -> List[Dict]:
        """Filter out non-business entities like SEC section headers and document structure"""
        filtered = []
        
        # Patterns to exclude (SEC filing structure elements)
        exclude_patterns = [
            r'^Item \d+[A-Za-z]?\.?\s*',  # Item 1, Item 2A, etc.
            r'^Part [IVX]+\.?\s*',  # Part I, Part II, etc.
            r'^Section \d+\.?\s*',  # Section 1, Section 2, etc.
            r'^Exhibit \d+\.?\s*',  # Exhibit 1, etc.
            r'^Schedule [A-Z]\.?\s*',  # Schedule A, etc.
            r'^Note \d+\.?\s*',  # Note 1, Note 2, etc.
            r'^Table of Contents',
            r'^Index to',
            r'^Forward.?Looking Statements',
            r'^Risk Factors$',
            r'^Legal Proceedings$',
            r'^Management.s Discussion',
            r'^Quantitative and Qualitative Disclosures',
            r'^Controls and Procedures',
            r'^Financial Statements',
            r'^Signatures$',
            r'^SIGNATURES$'
        ]
        
        for entity in entities:
            entity_text = entity.get('entity_text', '').strip()
            entity_type = entity.get('entity_category', entity.get('entity_type', ''))
            
            # Skip empty or very short entities
            if not entity_text or len(entity_text) < 3:
                continue
            
            # Skip entities that are just numbers
            if entity_text.replace('.', '').replace(',', '').isdigit():
                continue
            
            # Skip if entity matches SEC structure patterns
            is_structure = False
            for pattern in exclude_patterns:
                if re.match(pattern, entity_text, re.IGNORECASE):
                    is_structure = True
                    break
            
            if is_structure:
                continue
            
            # Skip entities with category '0' that look like SEC headers
            if entity.get('entity_category') == '0' or entity_type == '0':
                # Check if it starts with common SEC filing patterns
                if any(entity_text.startswith(prefix) for prefix in ['Item ', 'Part ', 'Section ', 'Note ']):
                    continue
            
            # Additional filtering for low-quality entities
            if entity_type in ['MISCELLANEOUS', '0']:
                # For miscellaneous entities, apply stricter filtering
                confidence = entity.get('confidence_score', entity.get('confidence', 0))
                if confidence < CONFIG["llama"]["min_confidence_filter"]:
                    continue
                
                # Check if it looks like a real company/organization name
                words = entity_text.split()
                if len(words) > 0:
                    has_capital = any(word[0].isupper() for word in words if word)
                    all_caps = entity_text.isupper()
                    all_lower = entity_text.islower()
                    
                    if not has_capital or all_caps or all_lower:
                        continue
            
            # Entity passed all filters
            filtered.append(entity)
        
        print(f"      🎯 Filtered: {len(entities)} → {len(filtered)} entities (removed {len(entities) - len(filtered)} non-business entities)")
        return filtered
    
    def extract_company_relationships(self, 
                                     entities: List[Dict], 
                                     sections: Dict[str, str],
                                     company_domain: str) -> List[Dict]:
        """Extract relationships between company and all found entities"""
        if not self.model or not self.tokenizer or not entities:
            return []
        
        print(f"   🔍 Analyzing relationships for {company_domain}")
        
        # Filter out non-business entities first
        entities = self._filter_entities(entities)
        if not entities:
            print("      ⚠️ No valid business entities after filtering")
            return []
        
        relationships = []
        
        # Group entities by section for context efficiency
        entities_by_section = {}
        for entity in entities:
            section = entity.get('section_name', 'unknown')
            if section not in entities_by_section:
                entities_by_section[section] = []
            entities_by_section[section].append(entity)
        
        # Process each section's entities
        # Progress tracking initialization
        total_entities_to_process = sum(len(ents) for ents in entities_by_section.values())
        entities_processed_count = 0
        start_time = time.time()
        print(f"      🎯 Total entities to analyze: {total_entities_to_process}")

        for section_name, section_entities in entities_by_section.items():
            if section_name not in sections:
                continue
                
            section_text = sections[section_name]
            section_relationships_count = 0  # Initialize counter for this section
            print(f"      📑 Processing {len(section_entities)} entities in '{section_name}'")
            
            # Analyze each entity's relationship to the company
            for entity_idx, entity in enumerate(section_entities, 1):
                # Skip self-references
                entities_processed_count += 1
                
                # Show progress every 10 entities
                if entities_processed_count % 10 == 0:
                    progress_pct = (entities_processed_count * 100) // total_entities_to_process
                    print(f"         ⏳ Progress: {entities_processed_count}/{total_entities_to_process} entities ({progress_pct}%)")
                
                # Show time estimate every 25 entities
                if entities_processed_count % 25 == 0:
                    elapsed = time.time() - start_time
                    if entities_processed_count > 0:
                        rate = entities_processed_count / elapsed
                        remaining = (total_entities_to_process - entities_processed_count) / rate
                        print(f"         ⏱️  Estimated time remaining: {remaining/60:.1f} minutes")
                
                # Show current entity being processed (every 5th)
                if entity_idx % 5 == 1:
                    print(f"         🔬 Analyzing: {entity['entity_text'][:50]}..." if len(entity.get('entity_text', '')) > 50 else f"         🔬 Analyzing: {entity.get('entity_text', 'Unknown')}")

                company_name = company_domain.replace('.com', '').replace('tx', '')
                if entity['entity_text'].lower() == company_name.lower():
                    continue
                
                # Get context around entity
                context = self._get_entity_context(entity, section_text)
                
                # Build and send prompt to Llama
                relationship = self._analyze_relationship(
                    company_domain, entity, context, section_name
                )
                
                if relationship:
                    relationships.append(relationship)
                    section_relationships_count += 1
                    self.stats['relationships_found'] += 1
                
                self.stats['entities_processed'] += 1
        
        print(f"   ✅ Found {len(relationships)} relationships from {len(entities)} entities")
        return relationships
    
    def _get_entity_context(self, entity: Dict, section_text: str, window: int = None) -> str:
        """Get context around an entity"""
        if window is None:
            window = CONFIG["llama"]["entity_context_window"]
        start = max(0, entity.get('character_start', entity.get('char_start', 0)) - window)
        end = min(len(section_text), entity.get('character_end', entity.get('char_end', 0)) + window)
        return section_text[start:end]
    
    def _analyze_relationship(self, 
                            company_domain: str,
                            entity: Dict,
                            context: str,
                            section_name: str) -> Optional[Dict]:
        """Analyze a single entity's relationship to the company using local Llama"""
        if not self.model or not self.tokenizer:
            return None
        
        try:
            # Build prompt
            prompt = f"""Analyze the business relationship between the filing company and the mentioned entity.

Filing Company: {company_domain}
Entity Found: {entity['entity_text']} (Type: {entity.get('entity_category', entity.get('entity_type', 'UNKNOWN'))})
Section: {section_name}

Context:
{context[:CONFIG["llama"]["context_window"]]}

Based on the context, determine:
1. Relationship type (choose ONE from: PARTNERSHIP, COMPETITOR, REGULATORY, CLINICAL_TRIAL, SUPPLIER, CUSTOMER, INVESTOR, ACQUISITION, LICENSING, RESEARCH, NONE)
2. Relationship direction (company_to_entity or entity_to_company)
3. Business impact (positive, negative, or neutral)
4. Confidence level (high, medium, or low)
5. Brief summary (one sentence)

Format your response EXACTLY as:
TYPE: <relationship_type>
DIRECTION: <direction>
IMPACT: <impact>
CONFIDENCE: <confidence>
SUMMARY: <one_sentence_summary>"""
            
            # Create messages for chat format
            messages = [
                {"role": "system", "content": "You are an expert at analyzing business relationships from SEC filings. Always respond in the exact format requested."},
                {"role": "user", "content": prompt}
            ]
            
            # Apply chat template
            inputs = self.tokenizer.apply_chat_template(
                messages,
                return_tensors="pt",
                tokenize=True
            )
            
            # Generate response with local model
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs,
                    max_new_tokens=CONFIG["llama"]["max_new_tokens"],
                    temperature=CONFIG["llama"]["temperature"],
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Decode response
            llama_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract just the assistant's response
            if "assistant" in llama_response:
                llama_response = llama_response.split("assistant")[-1].strip()
            
            self.stats['llama_calls'] += 1
            
            # Parse response
            parsed = self._parse_llama_response(llama_response)
            
            if parsed and parsed['type'] != 'NONE':
                return {
                    'company_domain': company_domain,
                    'entity_text': entity['entity_text'],
                    'entity_type': entity.get('entity_category', entity.get('entity_type', 'UNKNOWN')),
                    'entity_id': entity.get('extraction_id', str(uuid.uuid4())),
                    'relationship_type': parsed['type'],
                    'relationship_direction': parsed['direction'],
                    'business_impact': parsed['impact'],
                    'confidence_level': parsed['confidence'],
                    'summary': parsed['summary'],
                    'section_name': section_name,
                    'context_used': context[:CONFIG["llama"]["context_window"]],  # Store first 1000 chars
                    'llama_response': llama_response[:500],  # Store trimmed response
                    'extraction_timestamp': datetime.now().isoformat()
                }
            
            return None
            
        except Exception as e:
            print(f"         ⚠️ Llama analysis failed for {entity['entity_text']}: {e}")
            return None
    
    def _parse_llama_response(self, response: str) -> Optional[Dict]:
        """Parse structured response from Llama"""
        try:
            lines = response.strip().split('\n')
            parsed = {}
            
            for line in lines:
                if ':' in line:
                    key, value = line.split(':', 1)
                    key = key.strip().upper()
                    value = value.strip()
                    
                    if key == 'TYPE':
                        parsed['type'] = value
                    elif key == 'DIRECTION':
                        parsed['direction'] = value
                    elif key == 'IMPACT':
                        parsed['impact'] = value.lower()
                    elif key == 'CONFIDENCE':
                        parsed['confidence'] = value.lower()
                    elif key == 'SUMMARY':
                        parsed['summary'] = value
            
            # Validate required fields
            required = ['type', 'direction', 'impact', 'confidence', 'summary']
            if all(field in parsed for field in required):
                return parsed
            
            return None
            
        except Exception:
            return None


# ================================================================================
# ENHANCED STORAGE WITH ATOMIC TRANSACTIONS
# ================================================================================

class PipelineEntityStorage:
    """Enhanced storage system with atomic entity+relationship storage"""
    
    def __init__(self, db_config: Dict):
        self.db_config = db_config
        self.storage_stats = {
            'total_entities_stored': 0,
            'total_relationships_stored': 0,
            'filings_processed': 0,
            'transactions_completed': 0,
            'transactions_failed': 0,
            'merged_entities': 0,
            'single_model_entities': 0,
            'failed_inserts': 0
        }
        
        # Ensure table structures
        self._ensure_table_structures()
    
    def _ensure_table_structures(self):
        """Ensure both entity and relationship tables exist with proper schema"""
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            # Ensure entity table columns (same as before)
            self._ensure_entity_table(cursor)
            
            # Create relationship table if not exists
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS system_uno.entity_relationships (
                    relationship_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
                    company_domain TEXT NOT NULL,
                    entity_id UUID REFERENCES system_uno.sec_entities_raw(extraction_id),
                    entity_text TEXT NOT NULL,
                    entity_type TEXT,
                    relationship_type TEXT NOT NULL,
                    relationship_direction TEXT,
                    business_impact TEXT,
                    confidence_level TEXT,
                    summary TEXT,
                    section_name TEXT,
                    context_used TEXT,
                    llama_response TEXT,
                    filing_ref TEXT,
                    extraction_timestamp TIMESTAMP,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            """)
            
            # Create indexes for relationships
            cursor.execute("""
                CREATE INDEX IF NOT EXISTS idx_relationships_company 
                ON system_uno.entity_relationships(company_domain)
            """)
            cursor.execute("""
                CREATE INDEX IF NOT EXISTS idx_relationships_type 
                ON system_uno.entity_relationships(relationship_type)
            """)
            cursor.execute("""
                CREATE INDEX IF NOT EXISTS idx_relationships_impact 
                ON system_uno.entity_relationships(business_impact)
            """)
            
            conn.commit()
            cursor.close()
            conn.close()
            print("   ✅ Database tables verified/created")
            
        except Exception as e:
            print(f"   ⚠️ Table structure setup failed: {e}")
    
    def _ensure_entity_table(self, cursor):
        """Ensure entity table has all required columns"""
        # Check existing columns
        cursor.execute("""
            SELECT column_name 
            FROM information_schema.columns 
            WHERE table_schema = 'system_uno' 
            AND table_name = 'sec_entities_raw'
        """)
        
        existing_columns = {row[0] for row in cursor.fetchall()}
        
        # Required columns for pipeline
        required_columns = {
            'models_detected': 'TEXT[]',
            'all_confidences': 'JSONB',
            'primary_model': 'TEXT',
            'entity_variations': 'JSONB',
            'is_merged': 'BOOLEAN DEFAULT FALSE',
            'section_name': 'TEXT',
            'data_source': 'TEXT DEFAULT \'sec_filings\'',
            'extraction_timestamp': 'TIMESTAMP DEFAULT CURRENT_TIMESTAMP',
            'original_label': 'TEXT',
            'quality_score': 'REAL DEFAULT 0.0',
            'consensus_count': 'INTEGER DEFAULT 1',
            'detecting_models': 'JSONB',
            'consensus_score': 'REAL DEFAULT 0.0',
        }
        
        # Add missing columns
        for col_name, col_type in required_columns.items():
            if col_name not in existing_columns:
                cursor.execute(f'ALTER TABLE system_uno.sec_entities_raw ADD COLUMN {col_name} {col_type}')
                print(f"      ✓ Added column: {col_name}")
    
    def store_entities(self, entities: List[Dict], filing_ref: str) -> bool:
        """Store only entities (without relationships) - used before Llama processing"""
        if not entities:
            return True
        
        conn = None
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            # Start transaction
            conn.autocommit = False
            
            print(f"   💾 Storing {len(entities)} entities to database...")
            
            # Count merged vs single-model entities BEFORE storage
            merged_count = sum(1 for e in entities if e.get('is_merged', False))
            single_model_count = len(entities) - merged_count
            
            # Store entities
            entity_data = []
            for entity in entities:
                # Prepare entity data
                models_detected = entity.get('models_detected', [entity.get('model_source', 'unknown')])
                if not isinstance(models_detected, list):
                    models_detected = [str(models_detected)]
                
                all_confidences = entity.get('all_confidences', 
                    {entity.get('model_source', 'unknown'): entity.get('confidence_score', 0.0)})
                entity_variations = entity.get('entity_variations', 
                    {entity.get('model_source', 'unknown'): entity.get('entity_text', '')})
                
                entity_tuple = (
                    entity.get('extraction_id', str(uuid.uuid4())),
                    entity.get('company_domain', ''),
                    entity.get('entity_text', '').strip()[:1000],
                    entity.get('entity_type', 'UNKNOWN'),
                    float(entity.get('confidence_score', 0.0)),
                    int(entity.get('char_start', 0)),
                    int(entity.get('char_end', 0)),
                    entity.get('surrounding_text', '')[:2000] if entity.get('surrounding_text') else '',
                    entity.get('sec_filing_ref', filing_ref),
                    models_detected,
                    json.dumps(all_confidences),
                    entity.get('primary_model', entity.get('model_source', 'unknown')),
                    json.dumps(entity_variations),
                    entity.get('is_merged', False),
                    entity.get('section_name', ''),
                    entity.get('data_source', 'sec_filings'),
                    entity.get('extraction_timestamp'),
                    entity.get('original_label', '')
                )
                entity_data.append(entity_tuple)
            
            # Batch insert entities
            entity_query = """
                INSERT INTO system_uno.sec_entities_raw 
                (extraction_id, company_domain, entity_text, entity_category, 
                 confidence_score, character_start, character_end, surrounding_text, 
                 sec_filing_ref, models_detected, all_confidences, primary_model,
                 entity_variations, is_merged, section_name, data_source,
                 extraction_timestamp, original_label)
                VALUES %s
                ON CONFLICT (extraction_id) DO UPDATE SET
                    models_detected = EXCLUDED.models_detected,
                    all_confidences = EXCLUDED.all_confidences,
                    is_merged = EXCLUDED.is_merged
            """
            
            execute_values(cursor, entity_query, entity_data, page_size=100)
            entities_stored = cursor.rowcount
            
            # Commit transaction
            conn.commit()
            
            # Update statistics
            self.storage_stats['total_entities_stored'] += entities_stored
            self.storage_stats['transactions_completed'] += 1
            self.storage_stats['merged_entities'] += merged_count
            self.storage_stats['single_model_entities'] += single_model_count
            
            print(f"      ✅ Entities stored: {entities_stored} entities")
            print(f"         • Merged entities: {merged_count}, Single-model: {single_model_count}")
            
            cursor.close()
            conn.close()
            return True
            
        except Exception as e:
            print(f"      ❌ Entity storage failed: {e}")
            if conn:
                conn.rollback()
                conn.close()
            self.storage_stats['transactions_failed'] += 1
            self.storage_stats['failed_inserts'] += len(entities)
            return False


    def store_relationships(self, relationships: List[Dict], filing_ref: str) -> bool:
        """Store relationship data using correct db_config pattern"""
        if not relationships:
            log_warning("RelationshipStorage", "No relationships to store")
            return True
        
        conn = None
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            conn.autocommit = False
            
            # Prepare batch data for relationships
            relationship_data = []
            for rel in relationships:
                relationship_data.append((
                    rel.get('company_domain', ''),
                    rel.get('entity_id', str(uuid.uuid4())),
                    rel.get('entity_text', ''),
                    rel.get('entity_type', 'UNKNOWN'),
                    rel.get('relationship_type', 'UNKNOWN'),
                    rel.get('relationship_direction', ''),
                    rel.get('business_impact', ''),
                    rel.get('confidence_level', ''),
                    rel.get('summary', ''),
                    rel.get('section_name', ''),
                    rel.get('context_used', '')[:2000],  # Limit context size
                    rel.get('llama_response', '')[:1000],  # Limit response size
                    filing_ref,  # Add filing_ref parameter
                    rel.get('extraction_timestamp', datetime.now().isoformat())
                ))
            
            # Use execute_values for efficient batch insert
            execute_values(
                cursor,
                """INSERT INTO system_uno.entity_relationships 
                   (company_domain, entity_id, entity_text, entity_type, relationship_type,
                    relationship_direction, business_impact, confidence_level, summary, 
                    section_name, context_used, llama_response, filing_ref, extraction_timestamp)
                   VALUES %s""",
                relationship_data,
                template=None,
                page_size=CONFIG['processing']['max_insert_batch']
            )
            
            conn.commit()
            cursor.close()
            
            # Update statistics
            self.storage_stats['total_relationships_stored'] += len(relationships)
            self.storage_stats['transactions_completed'] += 1
            
            log_info("RelationshipStorage", f"Stored {len(relationships)} relationships for {filing_ref}")
            return True
            
        except Exception as e:
            if conn:
                conn.rollback()
            self.storage_stats['transactions_failed'] += 1
            log_error("RelationshipStorage", f"Failed to store {len(relationships)} relationships", e)
            return False
        finally:
            if conn:
                conn.close()

    def get_storage_verification(self, filing_ref: str) -> Dict:
        """Verify stored entities and relationships for a filing"""
        try:
            conn = psycopg2.connect(**self.db_config)
            cursor = conn.cursor()
            
            # Get entity statistics
            cursor.execute("""
                SELECT 
                    COUNT(*) as total_entities,
                    COUNT(DISTINCT entity_category) as unique_categories,
                    COUNT(DISTINCT section_name) as sections_processed,
                    AVG(confidence_score) as avg_confidence
                FROM system_uno.sec_entities_raw
                WHERE sec_filing_ref = %s
            """, (filing_ref,))
            
            entity_stats = cursor.fetchone()
            
            # Get relationship statistics
            cursor.execute("""
                SELECT 
                    COUNT(*) as total_relationships,
                    COUNT(DISTINCT relationship_type) as relationship_types,
                    COUNT(CASE WHEN business_impact = 'positive' THEN 1 END) as positive_impact,
                    COUNT(CASE WHEN business_impact = 'negative' THEN 1 END) as negative_impact,
                    COUNT(CASE WHEN business_impact = 'neutral' THEN 1 END) as neutral_impact
                FROM system_uno.entity_relationships
                WHERE filing_ref = %s
            """, (filing_ref,))
            
            rel_stats = cursor.fetchone()
            
            cursor.close()
            conn.close()
            
            return {
                'filing_ref': filing_ref,
                'entities': {
                    'total': entity_stats[0] if entity_stats else 0,
                    'categories': entity_stats[1] if entity_stats else 0,
                    'sections': entity_stats[2] if entity_stats else 0,
                    'avg_confidence': float(entity_stats[3]) if entity_stats and entity_stats[3] else 0
                },
                'relationships': {
                    'total': rel_stats[0] if rel_stats else 0,
                    'types': rel_stats[1] if rel_stats else 0,
                    'positive': rel_stats[2] if rel_stats else 0,
                    'negative': rel_stats[3] if rel_stats else 0,
                    'neutral': rel_stats[4] if rel_stats else 0
                }
            }
            
        except Exception as e:
            print(f"   ❌ Verification failed: {e}")
            return {}

# ================================================================================
# REFACTORED MAIN PIPELINE WITH IN-MEMORY PROCESSING
# ================================================================================

# Initialize components
pipeline_storage = PipelineEntityStorage(NEON_CONFIG)
relationship_extractor = RelationshipExtractor()

def process_filing_with_pipeline(filing_data: Dict) -> Dict:
    """Process filing with in-memory entity and relationship extraction"""
    try:
        start_time = time.time()
        
        # Step 1: Extract sections (Cell 2 function)
        print(f"\n📄 Processing {filing_data['filing_type']} for {filing_data['company_domain']}")
        section_result = process_sec_filing_with_sections(filing_data)
        
        if section_result['processing_status'] != 'success':
            return {
                'success': False,
                'filing_id': filing_data.get('id'),
                'error': section_result.get('error', 'Section extraction failed'),
                'processing_time': time.time() - start_time
            }
        
        # Keep sections in memory for context retrieval
        sections_dict = section_result['sections']
        
        # Step 2: Extract entities (Cell 3 function) - keep in memory
        entities = entity_pipeline.process_sec_filing_sections(section_result)
        
        if not entities:
            return {
                'success': False,
                'filing_id': filing_data.get('id'),
                'error': 'No entities extracted',
                'processing_time': time.time() - start_time
            }
        
        print(f"   🔍 Extracted {len(entities)} entities")

        # Debug step removed
        # debug_entities call removed - function not defined

        # Step 3: Store entities IMMEDIATELY before Llama processing
        filing_ref = f"SEC_{filing_data.get('id')}"
        entity_storage_success = pipeline_storage.store_entities(entities, filing_ref)

        if not entity_storage_success:
            print(f"   ⚠️ Failed to store entities, but continuing with relationship extraction...")

        # Step 4: Extract relationships using in-memory entities and sections (LONG PROCESS)
        print(f"   🤖 Starting Llama 3.1 relationship extraction (this may take several minutes)...")
        
        # (relationships extraction continues below...)
        relationships = relationship_extractor.extract_company_relationships(
            entities, 
            sections_dict,
            filing_data['company_domain']
        )
        
        # Step 5: Store relationships separately
        filing_ref = f"SEC_{filing_data.get('id')}"
        if relationships:
            relationship_storage_success = pipeline_storage.store_relationships(relationships, filing_ref)
        else:
            relationship_storage_success = True
            print(f"   ℹ️ No relationships found to store")
        
        # Step 6: Verify storage
        verification = pipeline_storage.get_storage_verification(filing_ref)
        
        processing_time = time.time() - start_time
        
        # Calculate overall storage success
        storage_success = entity_storage_success and relationship_storage_success
        
        return {
            'success': storage_success,
            'filing_id': filing_data.get('id'),
            'company_domain': filing_data.get('company_domain'),
            'filing_type': filing_data.get('filing_type'),
            'sections_processed': len(sections_dict),
            'entities_extracted': len(entities),
            'relationships_found': len(relationships),
            'entities_stored': verification.get('entities', {}).get('total', 0),
            'relationships_stored': verification.get('relationships', {}).get('total', 0),
            'processing_time': round(processing_time, 2),
            'verification': verification,
            'sample_entities': entities[:3],
            'sample_relationships': relationships[:3]
        }
        
    except Exception as e:
        return {
            'success': False,
            'filing_id': filing_data.get('id'),
            'error': str(e),
            'processing_time': time.time() - start_time
        }

def process_filings_batch(limit: int = 3) -> Dict:
    """Process multiple filings with complete in-memory pipeline"""
    print(f"\n🚀 Processing batch of {limit} SEC filings with in-memory pipeline...")
    
    batch_start = time.time()
    
    # Get unprocessed filings
    filings = get_unprocessed_filings(limit)
    
    if not filings:
        return {'success': False, 'message': 'No filings to process'}
    
    print(f"📊 Found {len(filings)} filings to process")
    
    # Process each filing
    results = []
    successful = 0
    total_entities = 0
    total_relationships = 0
    
    for i, filing in enumerate(filings, 1):
        print(f"\n[{i}/{len(filings)}] Processing {filing['filing_type']} for {filing['company_domain']}")
        
        result = process_filing_with_pipeline(filing)
        results.append(result)
        
        if result['success']:
            successful += 1
            total_entities += result.get('entities_extracted', 0)
            total_relationships += result.get('relationships_found', 0)
            
            print(f"   ✅ Success: {result['entities_extracted']} entities, {result['relationships_found']} relationships")
            print(f"   ⏱️ Processing time: {result['processing_time']}s")
            
            # Show sample relationships
            for rel in result.get('sample_relationships', [])[:2]:
                print(f"      • {rel['entity_text']} → {rel['relationship_type']} ({rel['business_impact']})")
        else:
            print(f"   ❌ Failed: {result.get('error', 'Unknown error')}")
        
        # Brief pause between filings
        if i < len(filings):
            time.sleep(1)
    
    batch_time = time.time() - batch_start
    
    # Update pipeline statistics
    entity_pipeline.pipeline_stats['documents_processed'] += successful
    entity_pipeline.pipeline_stats['total_entities_extracted'] += total_entities
    pipeline_storage.storage_stats['filings_processed'] += successful
    
    return {
        'success': successful > 0,
        'filings_processed': len(filings),
        'successful_filings': successful,
        'failed_filings': len(filings) - successful,
        'total_entities_extracted': total_entities,
        'total_relationships_found': total_relationships,
        'batch_processing_time': round(batch_time, 2),
        'avg_time_per_filing': round(batch_time / len(filings), 2) if filings else 0,
        'results': results
    }

# ================================================================================
# QUICK ACCESS FUNCTIONS
# ================================================================================

def test_pipeline(company_domain: str = None):
    """Test the pipeline with a single filing"""
    print("\n🧪 Testing in-memory pipeline...")
    
    # Get one filing
    if company_domain:
        # Modify get_unprocessed_filings to accept company filter
        # For now, just get any filing
        filings = get_unprocessed_filings(limit=1)
    else:
        filings = get_unprocessed_filings(limit=1)
    
    if not filings:
        print("❌ No test filings available")
        return None
    
    result = process_filing_with_pipeline(filings[0])
    
    if result['success']:
        print(f"\n✅ Pipeline test successful!")
        print(f"   📊 Sections: {result['sections_processed']}")
        print(f"   🔍 Entities: {result['entities_extracted']}")
        print(f"   🔗 Relationships: {result['relationships_found']}")
        print(f"   💾 Stored: {result['entities_stored']} entities, {result['relationships_stored']} relationships")
        print(f"   ⏱️ Time: {result['processing_time']}s")
    else:
        print(f"\n❌ Pipeline test failed: {result.get('error')}")
    
    return result

print("\n✅ In-Memory Pipeline Components Ready!")
print("   🎯 RelationshipExtractor with Llama 3.1")
print("   💾 Atomic storage for entities + relationships")
print("   🚀 In-memory processing (no DB round-trips)")
print("   📊 Usage: batch_results = process_filings_batch(limit=5)")
print("   🧪 Test: test_result = test_pipeline()")

def generate_pipeline_analytics_report() -> None:
    """Generate comprehensive analytics report for the pipeline"""
    print("\n" + "="*80)
    print("📊 ENTITYEXTRACTIONPIPELINE ANALYTICS DASHBOARD")
    print("="*80)
    
    # Database overview with enhanced metrics
    try:
        conn = psycopg2.connect(**NEON_CONFIG)
        cursor = conn.cursor()
        
        # Enhanced database statistics
        cursor.execute("""
            SELECT 
                COUNT(*) as total_entities,
                COUNT(DISTINCT company_domain) as companies,
                COUNT(DISTINCT sec_filing_ref) as filings,
                COUNT(DISTINCT entity_category) as entity_types,
                AVG(confidence_score) as avg_confidence,
                COUNT(*) FILTER (WHERE is_merged = true) as merged_entities,
                COUNT(DISTINCT primary_model) as active_models,
                COUNT(DISTINCT section_name) as sections_processed,
                COUNT(*) FILTER (WHERE section_name IS NOT NULL AND section_name != '') as entities_with_sections,
                MAX(extraction_timestamp) as last_extraction
            FROM system_uno.sec_entities_raw
            WHERE data_source = 'sec_filings'
        """)
        
        db_overview = cursor.fetchone()
        
        if db_overview and db_overview[0] > 0:
            total_entities = db_overview[0]
            entities_with_sections = db_overview[8]
            section_success_rate = (entities_with_sections / total_entities * 100) if total_entities > 0 else 0
            
            print(f"\n📈 DATABASE OVERVIEW:")
            print(f"   Total Entities Extracted: {db_overview[0]:,}")
            print(f"   Companies Processed: {db_overview[1]:,}")
            print(f"   SEC Filings Analyzed: {db_overview[2]:,}")
            print(f"   Entity Categories Found: {db_overview[3]:,}")
            print(f"   Average Confidence Score: {db_overview[4]:.3f}")
            print(f"   Multi-Model Entities: {db_overview[5]:,} ({db_overview[5]/db_overview[0]*100:.1f}%)")
            print(f"   Active Models: {db_overview[6]:,}")
            print(f"   Unique Sections Found: {db_overview[7]:,}")
            print(f"   🎯 SECTION SUCCESS RATE: {entities_with_sections:,}/{total_entities:,} ({section_success_rate:.1f}%)")
            print(f"   Last Extraction: {db_overview[9] or 'Never'}")
            
            # Alert if section success rate is low
            if section_success_rate < 90 and total_entities > 10:
                print(f"   🚨 WARNING: Section success rate is {section_success_rate:.1f}% - Pipeline routing issue!")
            elif section_success_rate >= 90:
                print(f"   ✅ EXCELLENT: Section success rate is {section_success_rate:.1f}% - Pipeline working correctly!")
        else:
            print(f"\n📈 DATABASE OVERVIEW: No entities found - database is clean for testing")
        
        cursor.close()
        conn.close()
                
    except Exception as e:
        print(f"   ❌ Could not retrieve analytics: {e}")
    
    # Pipeline statistics
    try:
        pipeline_stats = entity_pipeline.get_pipeline_statistics()
        
        print(f"\n🔧 PIPELINE STATISTICS:")
        print(f"   Documents Processed: {pipeline_stats['pipeline_stats']['documents_processed']:,}")
        print(f"   Total Entities Found: {pipeline_stats['pipeline_stats']['total_entities_extracted']:,}")
        print(f"   Processing Time: {pipeline_stats['pipeline_stats']['processing_time_total']:.2f}s")
        print(f"   Device: {pipeline_stats['device']}")
        print(f"   Loaded Models: {', '.join(pipeline_stats['loaded_models'])}")
        print(f"   Supported Sources: {', '.join(pipeline_stats['supported_data_sources'])}")
        
        # Individual model statistics
        print(f"\n📊 INDIVIDUAL MODEL PERFORMANCE:")
        for model_name, stats in pipeline_stats['model_stats'].items():
            if stats['texts_processed'] > 0:
                entities_per_text = stats['entities_found'] / stats['texts_processed']
                avg_time = stats['processing_time'] / stats['texts_processed']
                print(f"   {model_name:>12}: {stats['texts_processed']:>4} texts | {stats['entities_found']:>5} entities | {entities_per_text:>4.1f} avg/text | {avg_time:>4.2f}s avg")
        
        # Storage statistics
        storage_stats = pipeline_storage.storage_stats
        if storage_stats['total_entities_stored'] > 0:
            print(f"\n💾 STORAGE STATISTICS:")
            print(f"   Entities Stored: {storage_stats['total_entities_stored']:,}")
            print(f"   Filings Processed: {storage_stats['filings_processed']:,}")
            print(f"   Merged Entities: {storage_stats['merged_entities']:,}")
            print(f"   Single-Model Entities: {storage_stats['single_model_entities']:,}")
            print(f"   Failed Inserts: {storage_stats['failed_inserts']:,}")
            
            merge_rate = (storage_stats['merged_entities'] / storage_stats['total_entities_stored'] * 100) if storage_stats['total_entities_stored'] > 0 else 0
            print(f"   Multi-Model Detection Rate: {merge_rate:.1f}%")
    
    except Exception as e:
        print(f"\n🔧 Pipeline statistics unavailable: {e}")
    
    print("\n" + "="*80)
    print("✅ EntityExtractionPipeline Analytics Complete!")
    print("="*80)


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

In [None]:
# Cell 5: Main Processing Pipeline with Relationship Extraction

print("="*80)
print("🚀 STARTING SEC FILING PROCESSING PIPELINE")
print("="*80)

# Configure processing parameters
# CONFIG["processing"]["filing_batch_size"] = 3  # REMOVED: Use CONFIG["processing"]["filing_batch_size"] instead  # Number of filings to process in this run
# CONFIG["processing"]["enable_relationships"] = True  # REMOVED: Use CONFIG["processing"]["enable_relationships"] instead  # Set to False to skip relationship extraction

# Llama 3.1-8B is loaded locally, no API key needed
print("📝 Relationship extraction enabled with local Llama 3.1-8B")

# # Check if we have a Groq API key for relationship extraction
print("   ℹ️ Using local Llama 3.1-8B for relationship extraction")
#     print("   To enable relationships, add GROQ_API_KEY to Kaggle secrets")

# Check for available unprocessed filings
print("\n📊 Checking for unprocessed filings...")
available_filings = get_unprocessed_filings(limit=CONFIG["processing"]["filing_query_limit"])
print(f"   Found {len(available_filings)} unprocessed filings")

if available_filings:
    print("\n📋 Available filings to process:")
    for i, filing in enumerate(available_filings[:5], 1):
        print(f"   {i}. {filing['company_domain']} - {filing['filing_type']} ({filing['filing_date']})")
    
    # Process the batch
    print(f"\n🔄 Processing {min(CONFIG["processing"]["filing_batch_size"], len(available_filings))} filings...")
    print("-"*60)
    
    # Run the pipeline
    batch_results = process_filings_batch(limit=CONFIG["processing"]["filing_batch_size"])
    
    # Display results summary
    print("\n" + "="*80)
    print("📊 PROCESSING SUMMARY")
    print("="*80)
    
    if batch_results['success']:
        print(f"✅ Successfully processed {batch_results['successful_filings']}/{batch_results['filings_processed']} filings")
        print(f"   • Total entities extracted: {batch_results['total_entities_extracted']:,}")
        print(f"   • Total relationships found: {batch_results['total_relationships_found']:,}")
        print(f"   • Total processing time: {batch_results['batch_processing_time']:.1f}s")
        print(f"   • Average time per filing: {batch_results['avg_time_per_filing']:.1f}s")
        
        # Show detailed results for each filing
        print(f"\n📈 Detailed Results:")
        for i, result in enumerate(batch_results['results'], 1):
            if result['success']:
                print(f"\n   Filing {i}: {result['company_domain']} - {result['filing_type']}")
                print(f"      ✓ Sections: {result['sections_processed']}")
                print(f"      ✓ Entities: {result['entities_extracted']}")
                print(f"      ✓ Relationships: {result['relationships_found']}")
                print(f"      ✓ Time: {result['processing_time']:.1f}s")
            else:
                print(f"\n   Filing {i}: FAILED - {result.get('error', 'Unknown error')}")
        
        # Show pipeline statistics
        print(f"\n📊 Pipeline Statistics:")
        print(f"   • Documents processed (total): {entity_pipeline.pipeline_stats['documents_processed']}")
        print(f"   • Entities extracted (total): {entity_pipeline.pipeline_stats['total_entities_extracted']}")
        print(f"   • Storage transactions: {pipeline_storage.storage_stats['transactions_completed']} successful, {pipeline_storage.storage_stats['transactions_failed']} failed")
        print(f"   • Merged entities: {pipeline_storage.storage_stats['merged_entities']}")
        print(f"   • Single-model entities: {pipeline_storage.storage_stats['single_model_entities']}")
        
    else:
        print(f"❌ Processing failed: {batch_results.get('message', 'Unknown error')}")
    
    # Generate analytics report
    print("\n" + "="*80)
    generate_pipeline_analytics_report()
    
else:
    print("\n⚠️ No unprocessed filings found in raw_data.sec_filings")
    print("   All available filings have already been processed")
    print("\n💡 To add new filings:")
    print("   1. Insert new records into raw_data.sec_filings with accession_number")
    print("   2. Make sure the accession_number is valid (20 characters)")
    print("   3. Run this cell again to process them")

print("\n✅ Pipeline execution complete!")

In [None]:
# Cell 6: Execute Pipeline

print(f"\n🎯 PRODUCTION COMMANDS:")
print(f"   • Process new filings: batch_results = process_filings_batch(limit=5)")
print(f"   • Check results:       generate_pipeline_analytics_report()")
print(f"   • View statistics:     context_retriever.get_retrieval_statistics()")

print(f"\n✅ EntityExtractionPipeline Production Interface Ready!")
print(f"🔧 SINGLE ENTRY POINT: process_filings_batch() - ensures all extractions use section-based pipeline")
print(f"📊 Database cleared - ready for fresh testing with guaranteed section extraction!")

In [None]:
# Cell 6: TEST ENTITY STORAGE - Quick Verification Without Full Pipeline
# This cell tests ONLY the storage functionality with mock entities

print("="*80)
print("🧪 ENTITY STORAGE TEST - Verify Storage Works Without 90-Minute Run")
print("="*80)

import uuid
from datetime import datetime

# Create mock entities that match the expected structure
print("\n📝 Creating mock entities for storage test...")

# Create test entities with ALL required fields
test_entities = [
    {
        'extraction_id': str(uuid.uuid4()),
        'company_domain': 'test.com',
        'entity_text': 'Test Company Inc',
        'entity_type': 'ORGANIZATION',
        'entity_category': 'ORGANIZATION',  # Some code uses entity_category
        'confidence_score': 0.95,
        'char_start': 100,
        'char_end': 117,
        'character_start': 100,  # Some code uses character_start
        'character_end': 117,
        'surrounding_text': 'This is Test Company Inc in the context',
        'models_detected': ['bert_base', 'roberta'],
        'all_confidences': {'bert_base': 0.94, 'roberta': 0.96},
        'primary_model': 'roberta',
        'entity_variations': {'bert_base': 'Test Company Inc', 'roberta': 'Test Company Inc.'},
        'is_merged': True,
        'section_name': 'test_section',
        'data_source': 'sec_filings',
        'extraction_timestamp': datetime.now().isoformat(),
        'original_label': 'ORG',
        'model_source': 'roberta',  # Fallback field
        'quality_score': 0.92,
        'consensus_count': 2,
        'detecting_models': ['bert_base', 'roberta'],
        'consensus_score': 0.95,
        'filing_id': 999999,
        'sec_filing_ref': 'SEC_999999'
    },
    {
        'extraction_id': str(uuid.uuid4()),
        'company_domain': 'test.com',
        'entity_text': 'John Smith',
        'entity_type': 'PERSON',
        'entity_category': 'PERSON',
        'confidence_score': 0.88,
        'char_start': 200,
        'char_end': 210,
        'character_start': 200,
        'character_end': 210,
        'surrounding_text': 'CEO John Smith announced',
        'models_detected': ['bert_base'],
        'all_confidences': {'bert_base': 0.88},
        'primary_model': 'bert_base',
        'entity_variations': {'bert_base': 'John Smith'},
        'is_merged': False,
        'section_name': 'test_section',
        'data_source': 'sec_filings',
        'extraction_timestamp': datetime.now().isoformat(),
        'original_label': 'PER',
        'model_source': 'bert_base',
        'quality_score': 0.85,
        'consensus_count': 1,
        'detecting_models': ['bert_base'],
        'consensus_score': 0.88,
        'filing_id': 999999,
        'sec_filing_ref': 'SEC_999999'
    },
    {
        'extraction_id': str(uuid.uuid4()),
        'company_domain': 'test.com',
        'entity_text': 'FDA',
        'entity_type': 'ORGANIZATION',
        'entity_category': 'ORGANIZATION',
        'confidence_score': 0.99,
        'char_start': 300,
        'char_end': 303,
        'character_start': 300,
        'character_end': 303,
        'surrounding_text': 'approved by the FDA for clinical',
        'models_detected': ['bert_base', 'roberta', 'biobert'],
        'all_confidences': {'bert_base': 0.98, 'roberta': 0.99, 'biobert': 1.0},
        'primary_model': 'biobert',
        'entity_variations': {'bert_base': 'FDA', 'roberta': 'FDA', 'biobert': 'FDA'},
        'is_merged': True,
        'section_name': 'test_section',
        'data_source': 'sec_filings',
        'extraction_timestamp': datetime.now().isoformat(),
        'original_label': 'ORG',
        'model_source': 'biobert',
        'quality_score': 0.99,
        'consensus_count': 3,
        'detecting_models': ['bert_base', 'roberta', 'biobert'],
        'consensus_score': 0.99,
        'filing_id': 999999,
        'sec_filing_ref': 'SEC_999999'
    }
]

print(f"✅ Created {len(test_entities)} test entities")
print("\n📊 Test entity details:")
for i, entity in enumerate(test_entities, 1):
    print(f"   {i}. {entity['entity_text']} ({entity['entity_type']}) - confidence: {entity['confidence_score']:.2f}")

# Test storage
print("\n💾 Testing entity storage...")
print("   Using PipelineEntityStorage from Cell 4...")

try:
    # Initialize storage (assuming NEON_CONFIG is available from Cell 0)
    test_storage = PipelineEntityStorage(NEON_CONFIG)
    print("   ✅ Storage initialized successfully")
    
    # Attempt to store test entities
    filing_ref = "SEC_TEST_999999"
    print(f"\n   📤 Attempting to store {len(test_entities)} entities...")
    
    success = test_storage.store_entities(test_entities, filing_ref)
    
    if success:
        print("   ✅ STORAGE SUCCESSFUL! Entities stored to database")
        
        # Verify by querying the database
        print("\n   🔍 Verifying stored entities...")
        import psycopg2
        
        conn = psycopg2.connect(**NEON_CONFIG)
        cursor = conn.cursor()
        
        cursor.execute("""
            SELECT COUNT(*), 
                   COUNT(DISTINCT entity_text),
                   AVG(confidence_score)::numeric(4,3),
                   AVG(quality_score)::numeric(4,3),
                   AVG(consensus_count)
            FROM system_uno.sec_entities_raw
            WHERE sec_filing_ref = %s
        """, (filing_ref,))
        
        result = cursor.fetchone()
        if result and result[0] > 0:
            print(f"   ✅ VERIFIED: {result[0]} entities found in database")
            print(f"      • Unique entities: {result[1]}")
            print(f"      • Avg confidence: {result[2]}")
            print(f"      • Avg quality: {result[3]}")
            print(f"      • Avg consensus: {result[4]}")
        else:
            print("   ❌ WARNING: Entities not found in database after storage")
        
        # Clean up test data
        print("\n   🧹 Cleaning up test data...")
        cursor.execute("DELETE FROM system_uno.sec_entities_raw WHERE sec_filing_ref = %s", (filing_ref,))
        conn.commit()
        deleted = cursor.rowcount
        print(f"   ✅ Cleaned up {deleted} test entities")
        
        cursor.close()
        conn.close()
        
    else:
        print("   ❌ STORAGE FAILED! Check error messages above")
        print("   ⚠️  The entity storage is NOT working correctly")
        
except Exception as e:
    print(f"   ❌ ERROR during storage test: {e}")
    print("   ⚠️  This error needs to be fixed before running the full pipeline")
    import traceback
    traceback.print_exc()

print("\n" + "="*80)
print("🏁 STORAGE TEST COMPLETE")
print("="*80)
print("\n📝 Summary:")
print("   • If you see '✅ STORAGE SUCCESSFUL', the storage is working")
print("   • If you see '❌ STORAGE FAILED', check the error and fix before running full pipeline")
print("   • This test takes <10 seconds vs 90 minutes for full pipeline")
