In [2]:
"""
Supply Chain Intelligence Hub - ETL Pipeline
============================================
Production-grade ETL pipeline with comprehensive error handling,
data quality validation, and logging.
"""

import logging
import sys
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
from dataclasses import dataclass, field
import warnings

warnings.filterwarnings('ignore')


"""
#Connecting with SQLAlchemy
# SQLAlchemy connection string (inside Docker network)
DB_USER = "analytics_user"
DB_PASS = "analyticspass123"
DB_HOST = "mysql"          # service name from docker-compose
DB_PORT = "3306"
DB_NAME = "supply_chain_db"

connection_string = f"mysql+pymysql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
engine = create_engine(connection_string)

# Quick test: list tables
with engine.connect() as conn:
    tables = pd.read_sql("SHOW TABLES;", conn)
tables
"""

@dataclass
class ETLConfig:
    """Configuration for ETL pipeline"""
    db_host: str = "mysql"
    db_port: int = 3306
    db_user: str = "analytics_user"
    db_password: str = "analyticspass123"
    db_name: str = "supply_chain_db"
    batch_size: int = 1000
    max_retries: int = 3
    retry_delay: int = 5
    null_threshold: float = 0.05
    duplicate_threshold: float = 0.01
    log_level: str = "INFO"
    log_file: str = "etl_pipeline.log"


@dataclass
class DataQualityReport:
    """Data quality validation results"""
    table_name: str
    total_rows: int
    null_count: Dict[str, int] = field(default_factory=dict)
    duplicate_count: int = 0
    missing_foreign_keys: Dict[str, int] = field(default_factory=dict)
    validation_passed: bool = True
    issues: List[str] = field(default_factory=list)
    timestamp: datetime = field(default_factory=datetime.now)
    
    def add_issue(self, issue: str):
        self.issues.append(issue)
        self.validation_passed = False


def setup_logging(config: ETLConfig) -> logging.Logger:
    """Configure logging with file and console handlers"""
    logger = logging.getLogger('ETL_Pipeline')
    logger.setLevel(getattr(logging, config.log_level))
    logger.handlers = []
    
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)
    console_format = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    console_handler.setFormatter(console_format)
    
    file_handler = logging.FileHandler(config.log_file)
    file_handler.setLevel(logging.DEBUG)
    file_format = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s'
    )
    file_handler.setFormatter(file_format)
    
    logger.addHandler(console_handler)
    logger.addHandler(file_handler)
    
    return logger


class DatabaseConnection:
    """Context manager for database connections with error handling"""
    
    def __init__(self, config: ETLConfig, logger: logging.Logger):
        self.config = config
        self.logger = logger
        self.engine = None
        self.connection = None
        
    def __enter__(self):
        try:
            connection_string = (
                f"mysql+pymysql://{self.config.db_user}:{self.config.db_password}"
                f"@{self.config.db_host}:{self.config.db_port}/{self.config.db_name}"
            )
            self.engine = create_engine(
                connection_string,
                pool_pre_ping=True,
                pool_recycle=3600,
                echo=False
            )
            self.connection = self.engine.connect()
            self.logger.info(f"✓ Connected to database: {self.config.db_name}")
            return self
        except SQLAlchemyError as e:
            self.logger.error(f"✗ Database connection failed: {str(e)}")
            raise
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.connection:
            self.connection.close()
            self.logger.info("✓ Database connection closed")
        if exc_type:
            self.logger.error(f"✗ Error during database operation: {exc_val}")
        return False


class DataExtractor:
    """Extract data from various sources"""
    
    def __init__(self, connection: DatabaseConnection, logger: logging.Logger):
        self.connection = connection
        self.logger = logger
    
    def extract_table(self, table_name: str, 
                     date_column: Optional[str] = None,
                     start_date: Optional[datetime] = None,
                     end_date: Optional[datetime] = None) -> pd.DataFrame:
        try:
            self.logger.info(f"Extracting data from table: {table_name}")
            
            query = f"SELECT * FROM {table_name}"
            
            if date_column and start_date and end_date:
                query += f" WHERE {date_column} BETWEEN :start_date AND :end_date"
                params = {'start_date': start_date, 'end_date': end_date}
                df = pd.read_sql_query(text(query), self.connection.connection, params=params)
            else:
                df = pd.read_sql_query(query, self.connection.connection)
            
            self.logger.info(f"✓ Extracted {len(df):,} rows from {table_name}")
            return df
        except SQLAlchemyError as e:
            self.logger.error(f"✗ Failed to extract from {table_name}: {str(e)}")
            raise
    
    def extract_with_joins(self, query: str, params: Optional[Dict] = None) -> pd.DataFrame:
        try:
            self.logger.info("Executing custom extraction query")
            if params:
                df = pd.read_sql_query(text(query), self.connection.connection, params=params)
            else:
                df = pd.read_sql_query(query, self.connection.connection)
            self.logger.info(f"✓ Extracted {len(df):,} rows from custom query")
            return df
        except SQLAlchemyError as e:
            self.logger.error(f"✗ Custom query failed: {str(e)}")
            raise


class DataTransformer:
    """Transform and clean extracted data"""
    
    def __init__(self, logger: logging.Logger):
        self.logger = logger
    
    def clean_nulls(self, df: pd.DataFrame, strategy: str = 'drop') -> pd.DataFrame:
        null_count = df.isnull().sum().sum()
        self.logger.info(f"Handling {null_count} null values using strategy: {strategy}")
        
        if strategy == 'drop':
            return df.dropna()
        elif strategy == 'fill_mean':
            return df.fillna(df.mean(numeric_only=True))
        elif strategy == 'fill_median':
            return df.fillna(df.median(numeric_only=True))
        elif strategy == 'fill_forward':
            return df.fillna(method='ffill')
        else:
            self.logger.warning(f"Unknown strategy '{strategy}', returning original DataFrame")
            return df
    
    def remove_duplicates(self, df: pd.DataFrame, 
                         subset: Optional[List[str]] = None) -> pd.DataFrame:
        initial_rows = len(df)
        df_clean = df.drop_duplicates(subset=subset, keep='first')
        removed = initial_rows - len(df_clean)
        
        if removed > 0:
            self.logger.warning(f"Removed {removed} duplicate rows")
        else:
            self.logger.info("✓ No duplicates found")
        
        return df_clean
    
    def standardize_dates(self, df: pd.DataFrame, 
                         date_columns: List[str],
                         date_format: str = '%Y-%m-%d') -> pd.DataFrame:
        for col in date_columns:
            if col in df.columns:
                try:
                    df[col] = pd.to_datetime(df[col], errors='coerce')
                    self.logger.info(f"✓ Standardized date column: {col}")
                except Exception as e:
                    self.logger.error(f"✗ Failed to standardize {col}: {str(e)}")
        return df
    
    def add_derived_columns(self, df: pd.DataFrame, table_name: str) -> pd.DataFrame:
        self.logger.info(f"Adding derived columns for {table_name}")
        
        if table_name == 'inventory':
            if 'quantity_on_hand' in df.columns and 'quantity_reserved' in df.columns:
                df['quantity_available'] = df['quantity_on_hand'] - df['quantity_reserved']
                self.logger.info("✓ Added 'quantity_available' column")
        
        elif table_name == 'orders':
            if 'expected_delivery_date' in df.columns and 'actual_delivery_date' in df.columns:
                df['delivery_delay_days'] = (
                    pd.to_datetime(df['actual_delivery_date']) - 
                    pd.to_datetime(df['expected_delivery_date'])
                ).dt.days
                df['is_late'] = df['delivery_delay_days'] > 0
                self.logger.info("✓ Added delivery metrics columns")
        
        elif table_name == 'sales':
            if 'revenue' in df.columns and 'quantity_sold' in df.columns:
                df['unit_price'] = df['revenue'] / df['quantity_sold']
                self.logger.info("✓ Added 'unit_price' column")
        
        return df
    
    def apply_business_rules(self, df: pd.DataFrame, table_name: str) -> pd.DataFrame:
        self.logger.info(f"Applying business rules for {table_name}")
        initial_rows = len(df)
        
        if table_name == 'products':
            df = df[df['unit_cost'] > 0]
            df = df[df['reorder_level'] >= 0]
        elif table_name == 'inventory':
            df = df[df['quantity_on_hand'] >= 0]
            df = df[df['quantity_reserved'] >= 0]
        elif table_name == 'orders':
            df = df[df['order_quantity'] > 0]
            df = df[df['order_cost'] >= 0]
        elif table_name == 'sales':
            df = df[df['quantity_sold'] > 0]
            df = df[df['revenue'] >= 0]
        
        removed = initial_rows - len(df)
        if removed > 0:
            self.logger.warning(f"Removed {removed} rows violating business rules")
        
        return df

class DataQualityValidator:
    """Comprehensive data quality checks"""
    
    def __init__(self, connection: DatabaseConnection, 
                 config: ETLConfig, logger: logging.Logger):
        self.connection = connection
        self.config = config
        self.logger = logger
    
    def validate_table(self, df: pd.DataFrame, 
                      table_name: str,
                      required_columns: Optional[List[str]] = None) -> DataQualityReport:
        self.logger.info(f"=== Validating data quality for: {table_name} ===")
        report = DataQualityReport(table_name=table_name, total_rows=len(df))
        
        if required_columns:
            missing_cols = set(required_columns) - set(df.columns)
            if missing_cols:
                report.add_issue(f"Missing required columns: {missing_cols}")
                self.logger.error(f"✗ Missing columns: {missing_cols}")
        
        null_counts = df.isnull().sum()
        for col, count in null_counts.items():
            if count > 0:
                null_pct = count / len(df)
                report.null_count[col] = count
                
                if null_pct > self.config.null_threshold:
                    report.add_issue(
                        f"Column '{col}' has {null_pct:.1%} nulls (threshold: {self.config.null_threshold:.1%})"
                    )
                    self.logger.warning(f"⚠ High null count in '{col}': {count} ({null_pct:.1%})")
        
        if len(df) > 0:
            duplicate_count = df.duplicated().sum()
            report.duplicate_count = duplicate_count
            
            if duplicate_count > 0:
                dup_pct = duplicate_count / len(df)
                if dup_pct > self.config.duplicate_threshold:
                    report.add_issue(f"Found {duplicate_count} duplicates ({dup_pct:.1%})")
                    self.logger.warning(f"⚠ Duplicates found: {duplicate_count}")
        
        fk_issues = self._validate_foreign_keys(df, table_name)
        if fk_issues:
            report.missing_foreign_keys = fk_issues
            for fk, count in fk_issues.items():
                report.add_issue(f"Missing foreign key references in '{fk}': {count} rows")
                self.logger.error(f"✗ Foreign key issue in '{fk}': {count} orphaned rows")
        
        if report.validation_passed:
            self.logger.info(f"✓ Data quality validation PASSED for {table_name}")
        else:
            self.logger.error(f"✗ Data quality validation FAILED for {table_name}")
        
        return report
    
    def _validate_foreign_keys(self, df: pd.DataFrame, table_name: str) -> Dict[str, int]:
        issues = {}
        fk_mapping = {
            'products': {'supplier_id': 'suppliers'},
            'inventory': {'product_id': 'products', 'warehouse_id': 'warehouses'},
            'orders': {'supplier_id': 'suppliers'},
            'sales': {'product_id': 'products', 'warehouse_id': 'warehouses'},
            'price_history': {'product_id': 'products', 'supplier_id': 'suppliers'}
        }
        
        if table_name not in fk_mapping:
            return issues
        
        for fk_col, parent_table in fk_mapping[table_name].items():
            if fk_col not in df.columns:
                continue
            
            try:
                parent_ids = pd.read_sql_query(
                    f"SELECT {fk_col.replace('_id', '')}_id FROM {parent_table}",
                    self.connection.connection
                )
                valid_ids = set(parent_ids.iloc[:, 0])
                orphaned = ~df[fk_col].isin(valid_ids)
                orphaned_count = orphaned.sum()
                
                if orphaned_count > 0:
                    issues[fk_col] = orphaned_count
            except Exception as e:
                self.logger.warning(f"Could not validate FK {fk_col}: {str(e)}")
        
        return issues
    
    def generate_quality_summary(self, reports: List[DataQualityReport]) -> pd.DataFrame:
        summary_data = []
        for report in reports:
            summary_data.append({
                'table_name': report.table_name,
                'total_rows': report.total_rows,
                'null_columns': len(report.null_count),
                'total_nulls': sum(report.null_count.values()),
                'duplicates': report.duplicate_count,
                'fk_issues': len(report.missing_foreign_keys),
                'validation_passed': report.validation_passed,
                'issue_count': len(report.issues),
                'timestamp': report.timestamp
            })
        return pd.DataFrame(summary_data)


class DataLoader:
    """Load transformed data into database"""
    
    def __init__(self, connection: DatabaseConnection, 
                 config: ETLConfig, logger: logging.Logger):
        self.connection = connection
        self.config = config
        self.logger = logger
    
    def load_data(self, df: pd.DataFrame, table_name: str,
                  if_exists: str = 'append',
                  create_backup: bool = True) -> Tuple[int, int]:
        self.logger.info(f"Loading {len(df)} rows into table: {table_name}")
        
        rows_loaded = 0
        rows_failed = 0
        
        try:
            if create_backup and if_exists == 'replace':
                self._create_backup(table_name)
            
            for i in range(0, len(df), self.config.batch_size):
                batch = df.iloc[i:i + self.config.batch_size]
                
                try:
                    batch.to_sql(
                        name=table_name,
                        con=self.connection.engine,
                        if_exists=if_exists if i == 0 else 'append',
                        index=False,
                        method='multi'
                    )
                    rows_loaded += len(batch)
                except IntegrityError as e:
                    self.logger.error(f"Integrity error in batch: {str(e)}")
                    rows_failed += len(batch)
                except SQLAlchemyError as e:
                    self.logger.error(f"Database error in batch: {str(e)}")
                    rows_failed += len(batch)
            
            self.logger.info(f"✓ Load completed: {rows_loaded} rows loaded, {rows_failed} rows failed")
            return rows_loaded, rows_failed
        except Exception as e:
            self.logger.error(f"✗ Load failed: {str(e)}")
            raise
    
    def _create_backup(self, table_name: str):
        backup_name = f"{table_name}_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        try:
            with self.connection.connection.begin():
                self.connection.connection.execute(
                    text(f"CREATE TABLE {backup_name} AS SELECT * FROM {table_name}")
                )
            self.logger.info(f"✓ Created backup table: {backup_name}")
        except SQLAlchemyError as e:
            self.logger.warning(f"Could not create backup: {str(e)}")


class ETLPipeline:
    """Main ETL pipeline orchestrator"""
    
    def __init__(self, config: Optional[ETLConfig] = None):
        self.config = config or ETLConfig()
        self.logger = setup_logging(self.config)
        self.quality_reports = []
        
        self.logger.info("=" * 70)
        self.logger.info("ETL PIPELINE INITIALIZED")
        self.logger.info("=" * 70)
    
    def run_full_pipeline(self, tables: List[str]) -> Dict[str, any]:
        results = {
            'success': False,
            'tables_processed': [],
            'tables_failed': [],
            'quality_reports': [],
            'summary': {}
        }
        
        try:
            with DatabaseConnection(self.config, self.logger) as db:
                extractor = DataExtractor(db, self.logger)
                transformer = DataTransformer(self.logger)
                validator = DataQualityValidator(db, self.config, self.logger)
                loader = DataLoader(db, self.config, self.logger)
                
                for table in tables:
                    try:
                        self.logger.info(f"\n{'='*70}")
                        self.logger.info(f"Processing table: {table}")
                        self.logger.info(f"{'='*70}")
                        
                        df = extractor.extract_table(table)
                        # df = transformer.clean_nulls(df, strategy='drop')
                        df = transformer.remove_duplicates(df)
                        df = transformer.standardize_dates(
                            df, 
                            date_columns=[col for col in df.columns if 'date' in col.lower()]
                        )
                        df = transformer.add_derived_columns(df, table)
                        df = transformer.apply_business_rules(df, table)
                        
                        report = validator.validate_table(df, table)
                        self.quality_reports.append(report)
                        
                        if not report.validation_passed:
                            self.logger.warning(f"⚠ Quality validation failed for {table}, skipping load")
                            results['tables_failed'].append(table)
                            continue
                        
                        results['tables_processed'].append(table)
                        self.logger.info(f"✓ Successfully processed {table}")
                    except Exception as e:
                        self.logger.error(f"✗ Failed to process {table}: {str(e)}")
                        results['tables_failed'].append(table)
                
                results['quality_reports'] = self.quality_reports
                results['summary'] = validator.generate_quality_summary(self.quality_reports)
                results['success'] = len(results['tables_failed']) == 0
                
                self.logger.info("\n" + "="*70)
                self.logger.info("ETL PIPELINE COMPLETED")
                self.logger.info("="*70)
        except Exception as e:
            self.logger.error(f"✗ Pipeline failed: {str(e)}")
            results['success'] = False
        
        return results


if __name__ == "__main__":
    config = ETLConfig()
    pipeline = ETLPipeline(config)
    tables_to_process = ['suppliers', 'products', 'warehouses', 'inventory', 'orders', 'sales', 'price_history']
    results = pipeline.run_full_pipeline(tables_to_process)
    
    if results['success']:
        print("\n✓ Pipeline completed successfully!")
    else:
        print("\n✗ Pipeline completed with errors")
    
    print(f"\nProcessed: {len(results['tables_processed'])} tables")
    print(f"Failed: {len(results['tables_failed'])} tables")


2026-02-01 17:53:06 - ETL_Pipeline - INFO - ETL PIPELINE INITIALIZED
2026-02-01 17:53:06 - ETL_Pipeline - INFO - ✓ Connected to database: supply_chain_db
2026-02-01 17:53:06 - ETL_Pipeline - INFO - 
2026-02-01 17:53:06 - ETL_Pipeline - INFO - Processing table: suppliers
2026-02-01 17:53:06 - ETL_Pipeline - INFO - Extracting data from table: suppliers
2026-02-01 17:53:06 - ETL_Pipeline - INFO - ✓ Extracted 30 rows from suppliers
2026-02-01 17:53:06 - ETL_Pipeline - INFO - ✓ No duplicates found
2026-02-01 17:53:06 - ETL_Pipeline - INFO - ✓ Standardized date column: created_date
2026-02-01 17:53:06 - ETL_Pipeline - INFO - Adding derived columns for suppliers
2026-02-01 17:53:06 - ETL_Pipeline - INFO - Applying business rules for suppliers
2026-02-01 17:53:06 - ETL_Pipeline - INFO - === Validating data quality for: suppliers ===
2026-02-01 17:53:06 - ETL_Pipeline - INFO - ✓ Data quality validation PASSED for suppliers
2026-02-01 17:53:06 - ETL_Pipeline - INFO - ✓ Successfully processed sup