Install required libraries

In [None]:
!pip install kagglehub kaggle pandas psycopg2-binary sqlalchemy

Import Kaggle and Authenticate to Kaggle

In [None]:
import kagglehub
import kaggle
print(dir(kagglehub))
print("Kaggle module is successfully installed!")
kaggle.api.authenticate()
print("Kaggle API authentication successful!")

Pull Data from Kaggle

Import the Retail Fashion Dataset from Kaggle 
The dataset consists of 6 csv file
- transactiions.csv - 773 MB in size and has 6.41 million records
- customers.csv - 183 MB in size and has 1.64 million records
- products.csv - 4.77 MB in size and has 17940 records
- discounts.csv - 18 KB in size and has 3801 records
- employees.csv - 15.2 KB in size and has 404 records
- stores.csv - 3 KB in size and has 35 records


In [None]:
import kagglehub
from kagglehub import KaggleDatasetAdapter
import os
import pandas as pd


# Define dataset folder (custom cache location)
dataset_path = "/home/megin_mathew/fashion_dataset"

# Set KaggleHub cache override
os.environ["KAGGLEHUB_CACHE"] = dataset_path

# File names in the dataset
file_names = [
    "transactions.csv", "customers.csv", "discounts.csv",
    "employees.csv", "products.csv", "stores.csv"
]

# Ensure cache directory exists
os.makedirs(dataset_path, exist_ok=True)

# Download each file individually
for file_name in file_names:
    df = kagglehub.dataset_load(
        KaggleDatasetAdapter.PANDAS,
        "ricgomes/global-fashion-retail-stores-dataset",
        file_name
    )
    
    # Save to custom cache location
    df.to_csv(os.path.join(dataset_path, file_name), index=False)
    print(f"Downloaded: {file_name} → Stored in: {dataset_path}")


Set Kaggle Cache

In [None]:
import os

# Set custom cache location
os.environ["KAGGLEHUB_CACHE"] = "/home/megin_mathew/fashion_dataset"

Optimized And Cleansed and uses Hadoop along with Spark

Cleaning /Preprocessing and Normaliztion

Cleansing Preprocessing Normalizing

In [None]:
import os
import logging
import sys
from typing import Optional, List, Dict, Any, Tuple

import pandas as pd
from dotenv import load_dotenv
from datetime import datetime
from sqlalchemy import create_engine, Column, Integer, String, Float, Date, DateTime, Boolean, ForeignKey, Text, text
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import col, when, to_date, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, DateType, BooleanType

# --- Configuration and Setup ---

# Configure logging
LOG_DIRECTORY = '/home/megin_mathew/logs/'
LOG_FILE_PATH = os.path.join(LOG_DIRECTORY, 'data_loader.log')

# Ensure log directory exists
os.makedirs(LOG_DIRECTORY, exist_ok=True)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(LOG_FILE_PATH, encoding='utf-8'),
        logging.StreamHandler(sys.stdout)  # Use sys.stdout for stream handler
    ]
)
logger = logging.getLogger(__name__)

# Load environment variables
DOTENV_PATH = "/home/megin_mathew/airflow/notebooks/.env"
load_dotenv(DOTENV_PATH)


def create_spark_session() -> Optional[SparkSession]:
    """Create and configure Spark session."""
    try:
        # Set Hadoop home directory if on Windows
        if os.name == 'nt':
            hadoop_home = 'C:\\hadoop'
            if not os.path.exists(hadoop_home):
                 logger.warning(f"HADOOP_HOME directory not found: {hadoop_home}. Spark might fail.")
            os.environ['HADOOP_HOME'] = hadoop_home
            os.environ['PATH'] = os.environ.get('PATH', '') + ';' + os.environ['HADOOP_HOME'] + '\\bin'

            # Set Python path explicitly
            python_path = sys.executable
            os.environ['PYSPARK_PYTHON'] = python_path
            os.environ['PYSPARK_DRIVER_PYTHON'] = python_path

        # Create temp directory if it doesn't exist
        temp_dir = "/home/megin_mathew/spark_temp"
        os.makedirs(temp_dir, exist_ok=True) # Use exist_ok=True

        spark_session = SparkSession.builder \
            .appName("GlobalFashionRetailDataLoader") \
            .config("spark.local.dir", temp_dir) \
            .config("spark.sql.legacy.timeParserPolicy", "LEGACY") \
            .config("spark.executor.memory", "8g") \
            .config("spark.driver.memory", "8g") \
            .config("spark.memory.offHeap.enabled", "false") \
            .config("spark.sql.shuffle.partitions", "200") \
            .config("spark.default.parallelism", "200") \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.network.timeout", "600s") \
            .config("spark.executor.heartbeatInterval", "60s") \
            .config("spark.python.profile", "false") \
            .config("spark.executor.instances", "4") \
            .config("spark.executor.cores", "2") \
            .config("spark.driver.maxResultSize", "4g") \
            .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
            .config("spark.sql.execution.arrow.pyspark.fallback.enabled", "true") \
            .getOrCreate()

        logger.info("Spark session created successfully.")
        return spark_session

    except Exception as e:
        logger.error(f"Failed to create Spark session: {str(e)}", exc_info=True)
        return None

# --- Database Schema Definition ---

Base = declarative_base()
SCHEMA_NAME = "GFRetail"

class Stores(Base):
    __tablename__ = 'stores'
    __table_args__ = {'schema': SCHEMA_NAME}
    store_id = Column(Integer, primary_key=True)
    country = Column(String(100), nullable=True)
    city = Column(String(100), nullable=True)
    store_name = Column(String(100))
    number_of_employees = Column(Integer)
    zip_code = Column(String(20))
    latitude = Column(Float, nullable=True)
    longitude = Column(Float, nullable=True)

class Employees(Base):
    __tablename__ = 'employees'
    __table_args__ = {'schema': SCHEMA_NAME}
    employee_id = Column(Integer, primary_key=True)
    store_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.stores.store_id'))
    name = Column(String(255), nullable=True)
    position = Column(String(100), nullable=True)

class Customers(Base):
    __tablename__ = 'customers'
    __table_args__ = {'schema': SCHEMA_NAME}
    customer_id = Column(Integer, primary_key=True)
    name = Column(String(255), nullable=True)
    email = Column(String(255), nullable=True)
    join_date = Column(Date, nullable=True)
    telephone = Column(String(50), nullable=True)
    city = Column(String(100), nullable=True)
    country = Column(String(100), nullable=True)
    gender = Column(String(20), nullable=True)
    date_of_birth = Column(Date, nullable=True)
    job_title = Column(String(100), nullable=True)

class Products(Base):
    __tablename__ = 'products'
    __table_args__ = {'schema': SCHEMA_NAME}
    product_id = Column(Integer, primary_key=True)
    category = Column(String(100), nullable=True)
    sub_category = Column(String(100))
    description_pt = Column(Text)
    description_de = Column(Text)
    description_fr = Column(Text)
    description_es = Column(Text)
    description_en = Column(Text)
    description_zh = Column(Text)
    color = Column(String(50), nullable=True)
    sizes = Column(String(100), nullable=True)
    production_cost = Column(Float)

class Discounts(Base):
    __tablename__ = 'discounts'
    __table_args__ = {'schema': SCHEMA_NAME}
    id = Column(Integer, primary_key=True, autoincrement=True)
    start_date = Column(Date)
    end_date = Column(Date)
    discount_rate = Column(Float)
    description = Column(String(255), nullable=True)
    category = Column(String(100), nullable=True)
    sub_category = Column(String(100))

class Transactions(Base):
    __tablename__ = 'transactions'
    __table_args__ = {'schema': SCHEMA_NAME}
    id = Column(Integer, primary_key=True, autoincrement=True)
    invoice_id = Column(String(50), nullable=True)
    line_number = Column(Integer)
    customer_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.customers.customer_id'))
    product_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.products.product_id'))
    size = Column(String(10))
    color = Column(String(50), nullable=True)
    unit_price = Column(Float)
    quantity = Column(Integer, nullable=True)
    transaction_date = Column(DateTime)
    discount = Column(Float, nullable=True)
    line_total = Column(Float)
    store_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.stores.store_id'))
    employee_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.employees.employee_id'))
    currency = Column(String(10), nullable=True)
    currency_symbol = Column(String(5))
    sku = Column(String(100))
    transaction_type = Column(String(20))
    payment_method = Column(String(50))
    invoice_total = Column(Float)
    is_return = Column(Boolean)

# --- Extract ---

def extract_data(spark: SparkSession, file_path: str, table_name: str) -> Optional[DataFrame]:
    """
    Extracts data from a CSV file into a Spark DataFrame.
    Attempts schema inference, using sampling for potentially large files.
    """
    if not os.path.exists(file_path):
        logger.warning(f"File not found: {file_path}. Skipping extraction for {table_name}.")
        return None

    logger.info(f"Extracting data from {file_path}...")
    try:
        # For large files like customers and transactions, use sampling for schema inference
        if table_name in ['customers', 'transactions']:
            sample_ratio = 0.1
            logger.info(f"Using sampling ratio {sample_ratio} for schema inference on {file_name}.")
            sample_df = spark.read.csv(file_path, header=True, inferSchema=True, samplingRatio=sample_ratio)
            schema = sample_df.schema
            spark_df = spark.read.csv(file_path, header=True, schema=schema)
            logger.info(f"Extracted {table_name} with schema from sample.")
        else:
            # For smaller files, infer schema directly
            spark_df = spark.read.csv(file_path, header=True, inferSchema=True)
            logger.info(f"Extracted {table_name} with inferred schema.")

        logger.info(f"Extracted DataFrame schema for {table_name}:")
        spark_df.printSchema()
        return spark_df

    except Exception as e:
        logger.error(f"Error extracting data from {file_path}: {str(e)}", exc_info=True)
        return None

# --- Transform ---

def process_stores(spark_df: DataFrame) -> DataFrame:
    """Transforms the stores DataFrame."""
    logger.info("Transforming stores data...")
    transformed_df = spark_df.select(
        col("Store ID").alias("store_id").cast(IntegerType()),
        when(col("Country").isNull(), lit("Unknown")).otherwise(col("Country")).alias("country"),
        when(col("City").isNull(), lit("Unknown")).otherwise(col("City")).alias("city"),
        col("Store Name").alias("store_name").cast(StringType()),
        col("Number of Employees").alias("number_of_employees").cast(IntegerType()),
        col("ZIP Code").alias("zip_code").cast(StringType()),
        col("Latitude").alias("latitude").cast(FloatType()),
        col("Longitude").alias("longitude").cast(FloatType())
    )
    logger.info("Stores data transformation complete.")
    return transformed_df

def process_employees(spark_df: DataFrame) -> DataFrame:
    """Transforms the employees DataFrame."""
    logger.info("Transforming employees data...")
    transformed_df = spark_df.select(
        col("Employee ID").alias("employee_id").cast(IntegerType()),
        col("Store ID").alias("store_id").cast(IntegerType()),
        when(col("Name").isNull(), lit("Unknown")).otherwise(col("Name")).alias("name"),
        when(col("Position").isNull(), lit("Unknown")).otherwise(col("Position")).alias("position")
    )
    logger.info("Employees data transformation complete.")
    return transformed_df

def process_customers(spark_df: DataFrame) -> DataFrame:
    """Transforms the customers DataFrame."""
    logger.info("Transforming customers data...")
    transformed_df = spark_df.select(
        col("Customer ID").alias("customer_id").cast(IntegerType()),
        when(col("Name").isNull(), lit("Unknown")).otherwise(col("Name")).alias("name"),
        when(col("Email").isNull(), lit("unknown@example.com")).otherwise(col("Email")).alias("email"),
        to_date(col("Join Date"), "yyyy-MM-dd").alias("join_date"), # Assuming 'Join Date' in input
        col("Telephone").alias("telephone"),
        col("City").alias("city"),
        col("Country").alias("country"),
        col("Gender").alias("gender"),
        to_date(col("Date Of Birth"), "yyyy-MM-dd").alias("date_of_birth"),
        col("Job Title").alias("job_title")
    )
    logger.info("Customers data transformation complete.")
    return transformed_df


def process_products(spark_df: DataFrame) -> DataFrame:
    """Transforms the products DataFrame."""
    logger.info("Transforming products data...")
    transformed_df = spark_df.select(
        col("Product ID").alias("product_id").cast(IntegerType()),
        when(col("Category").isNull(), lit("Unknown")).otherwise(col("Category")).alias("category"),
        col("Sub Category").alias("sub_category"),
        col("Description PT").alias("description_pt").cast(StringType()),
        col("Description DE").alias("description_de").cast(StringType()),
        col("Description FR").alias("description_fr").cast(StringType()),
        col("Description ES").alias("description_es").cast(StringType()),
        col("Description EN").alias("description_en").cast(StringType()),
        col("Description ZH").alias("description_zh").cast(StringType()),
        when(col("Color").isNull(), lit("Unknown")).otherwise(col("Color")).alias("color"),
        when(col("Sizes").isNull(), lit("Unknown")).otherwise(col("Sizes")).alias("sizes"),
        col("Production Cost").alias("production_cost").cast(FloatType())
    )
    logger.info("Products data transformation complete.")
    return transformed_df

def process_discounts(spark_df: DataFrame) -> DataFrame:
    """Transforms the discounts DataFrame."""
    logger.info("Transforming discounts data...")
    transformed_df = spark_df.select(
        # id is auto-incrementing in DB, so we don't select it from source
        to_date(col("Start"), "yyyy-MM-dd").alias("start_date"), # Assuming 'Start' in input
        to_date(col("End"), "yyyy-MM-dd").alias("end_date"),     # Assuming 'End' in input
        col("Discont").alias("discount_rate").cast(FloatType()), # Assuming 'Discont' typo in input
        when(col("Description").isNull(), lit("No description")).otherwise(col("Description")).alias("description"),
        when(col("Category").isNull(), lit("General")).otherwise(col("Category")).alias("category"),
        col("Sub Category").alias("sub_category")
    )
    logger.info("Discounts data transformation complete.")
    return transformed_df

def process_transactions(spark_df: DataFrame) -> DataFrame:
    """Transforms the transactions DataFrame."""
    logger.info("Transforming transactions data...")
    transformed_df = spark_df.select(
        # id is auto-incrementing in DB, so we don't select it from source
        when(col("Invoice ID").isNull(), lit("UNKNOWN")).otherwise(col("Invoice ID")).alias("invoice_id"),
        col("Line").alias("line_number").cast(IntegerType()),
        col("Customer ID").alias("customer_id").cast(IntegerType()),
        col("Product ID").alias("product_id").cast(IntegerType()),
        col("Size").alias("size").cast(StringType()),
        when(col("Color").isNull(), lit("Unknown")).otherwise(col("Color")).alias("color"),
        col("Unit Price").alias("unit_price").cast(FloatType()),
        when(col("Quantity").isNull(), lit(1)).otherwise(col("Quantity")).alias("quantity").cast(IntegerType()),
        col("Date").alias("transaction_date").cast(DateTime()), # Assuming 'Date' in input is datetime
        when(col("Discount").isNull(), lit(0.0)).otherwise(col("Discount")).alias("discount").cast(FloatType()),
        col("Line Total").alias("line_total").cast(FloatType()),
        col("Store ID").alias("store_id").cast(IntegerType()),
        col("Employee ID").alias("employee_id").cast(IntegerType()),
        when(col("Currency").isNull(), lit("USD")).otherwise(col("Currency")).alias("currency"),
        col("Currency Symbol").alias("currency_symbol"),
        col("SKU").alias("sku"),
        col("Transaction Type").alias("transaction_type"),
        col("Payment Method").alias("payment_method"),
        col("Invoice Total").alias("invoice_total").cast(FloatType()),
        (col("Transaction Type") == "Return").alias("is_return").cast(BooleanType())
    )
    logger.info("Transactions data transformation complete.")
    return transformed_df


# Mapping table names to their processing functions and primary keys
TABLE_PROCESSING_MAP: Dict[str, Tuple[callable, str]] = {
    'stores': (process_stores, 'store_id'),
    'employees': (process_employees, 'employee_id'),
    'customers': (process_customers, 'customer_id'),
    'products': (process_products, 'product_id'),
    'discounts': (process_discounts, 'id'), # Use 'id' for discounts PK
    'transactions': (process_transactions, 'id'), # Use 'id' for transactions PK
}

# --- Load ---

def setup_database(engine: Engine) -> bool:
    """Create schema and tables if they don't exist, and set permissions."""
    try:
        with engine.begin() as conn:
            # Check if schema exists
            schema_exists = conn.execute(
                text("SELECT 1 FROM information_schema.schemata WHERE schema_name = :schema"),
                {'schema': SCHEMA_NAME}
            ).scalar()

            if not schema_exists:
                logger.info(f"Creating schema {SCHEMA_NAME}")
                conn.execute(text(f'CREATE SCHEMA "{SCHEMA_NAME}"')) # Quote schema name
                db_user = os.getenv('DB_USER', 'postgres')
                conn.execute(text(f'''
                    GRANT ALL ON SCHEMA "{SCHEMA_NAME}" TO "{db_user}";
                    ALTER DEFAULT PRIVILEGES IN SCHEMA "{SCHEMA_NAME}"
                    GRANT ALL ON TABLES TO "{db_user}";
                ''')) # Quote user name

            # Create all tables defined by the Base metadata
            # This will create tables only if they don't exist
            logger.info(f"Creating tables in schema {SCHEMA_NAME} if they don't exist...")
            Base.metadata.create_all(engine)
            logger.info("Table creation process completed.")

        return True
    except Exception as e:
        logger.error(f"Database setup failed: {str(e)}", exc_info=True)
        return False


def upsert_data(
    session: any, # Use `sqlalchemy.orm.Session` in a real project, but 'any' for simplicity here
    table_class: type,
    records: List[Dict[str, Any]],
    primary_key: str,
    batch_size: int = 1000
):
    """Upsert data into the database with batch processing."""
    total_records = len(records)
    if total_records == 0:
        logger.info(f"No records to upsert for {table_class.__tablename__}.")
        return

    table_name = table_class.__tablename__
    logger.info(f"Starting to upsert {total_records} records into {SCHEMA_NAME}.{table_name} in batches of {batch_size}")

    for i in range(0, total_records, batch_size):
        batch = records[i:i + batch_size]
        batch_num = (i // batch_size) + 1
        total_batches = (total_records + batch_size - 1) // batch_size # Correct total batch calculation

        try:
            # Attempt bulk insert first
            session.bulk_insert_mappings(table_class, batch)
            session.commit()
            logger.debug(f"Successfully bulk inserted batch {batch_num}/{total_batches}") # Use debug for high volume
        except Exception:
            session.rollback()
            logger.info(f"Bulk insert failed for batch {batch_num}. Switching to individual upsert.")

            # Fallback to individual upsert if bulk insert fails (e.g., due to conflicts)
            for record in batch:
                try:
                    # Check if the record exists by primary key
                    existing = session.query(table_class).filter(
                        getattr(table_class, primary_key) == record.get(primary_key) # Use get for safety
                    ).first()

                    if existing:
                        # Update existing record
                        for key, value in record.items():
                            setattr(existing, key, value)
                        # logger.debug(f"Updated record with {primary_key}={record.get(primary_key)}")
                    else:
                        # Insert new record
                        new_record = table_class(**record)
                        session.add(new_record)
                        # logger.debug(f"Added new record with {primary_key}={record.get(primary_key)}")

                except Exception as rec_error:
                    logger.error(
                        f"Error processing record {record.get(primary_key, 'N/A')} "
                        f"in batch {batch_num}: {str(rec_error)}", exc_info=True
                    )
                    session.rollback() # Rollback individual record's transaction if it fails

            try:
                session.commit()
                logger.info(f"Successfully upserted (bulk or individual) batch {batch_num}/{total_batches}")
            except Exception as commit_error:
                session.rollback()
                logger.error(f"Failed to commit batch {batch_num}: {str(commit_error)}", exc_info=True)
                # Depending on severity, you might want to raise the exception here
                # raise commit_error # Uncomment to fail the whole process on commit error
                continue # Continue to the next batch if commit fails


def write_spark_df_to_postgres(spark_df: DataFrame, table_name: str, engine: Engine):
    """Writes Spark DataFrame to PostgreSQL with optimized methods."""
    try:
        # Dynamically get the SQLAlchemy model class and primary key
        table_class = globals().get(table_name.capitalize())
        primary_key = TABLE_PROCESSING_MAP.get(table_name, (None, None))[1]

        if not table_class or not primary_key:
            logger.error(f"Database model or primary key not defined for table '{table_name}'. Skipping load.")
            return

        logger.info(f"Starting load process for table: {SCHEMA_NAME}.{table_name}")

        # Use toLocalIterator for large DataFrames to avoid OOM on driver
        # Convert Spark DataFrame to an iterator of Pandas DataFrames or records
        # Using toLocalIterator with asDict() to get dictionaries
        records_iterator = (row.asDict() for row in spark_df.toLocalIterator())

        records_buffer: List[Dict[str, Any]] = []
        buffer_size = 5000 # Adjust buffer size as needed

        Session = sessionmaker(bind=engine)

        for i, record in enumerate(records_iterator, 1):
            records_buffer.append(record)

            if len(records_buffer) >= buffer_size:
                session = Session()
                try:
                    upsert_data(session, table_class, records_buffer, primary_key)
                    logger.info(f"Processed and upserted {i} records for {table_name}.")
                    records_buffer = [] # Clear buffer after successful upsert
                except Exception as e:
                    logger.error(f"Error during upsert for batch ending at record {i}: {str(e)}")
                    session.rollback() # Ensure session is rolled back on error
                    # Decide whether to continue or break on error
                    # break # Uncomment to stop processing on first upsert error
                finally:
                    session.close()

        # Upsert any remaining records in the buffer
        if records_buffer:
            session = Session()
            try:
                upsert_data(session, table_class, records_buffer, primary_key)
                logger.info(f"Processed and upserted final batch for {table_name}.")
            except Exception as e:
                 logger.error(f"Error during final upsert for {table_name}: {str(e)}")
                 session.rollback()
            finally:
                session.close()


        # Optional: Verify count for smaller tables or periodically for large ones
        # Skipping final count verification for very large tables like transactions
        if table_name not in ['transactions']:
             with engine.connect() as conn:
                 # Use quoted identifiers for the schema and table name
                 count_query = text(f'SELECT COUNT(*) FROM "{SCHEMA_NAME}"."{table_name}"')
                 count = conn.execute(count_query).scalar()
                 logger.info(f"Final record count in {SCHEMA_NAME}.{table_name}: {count}")

        logger.info(f"Load process completed for table: {SCHEMA_NAME}.{table_name}.")

    except Exception as e:
        logger.error(f"Error writing to {table_name}: {str(e)}", exc_info=True)
        raise


# --- Main Pipeline Execution ---

def run_etl_pipeline():
    """Orchestrates the ETL process for all datasets."""
    db_user = os.getenv('DB_USER')
    db_password = os.getenv('DB_PASSWORD')
    db_host = os.getenv('DB_HOST')
    db_port = os.getenv('DB_PORT', '5432') # Use default port if not set
    db_name = os.getenv('DB_NAME')
    dataset_path = os.getenv('DATASET_PATH')

    required_vars = {
        'DB_USER': db_user,
        'DB_PASSWORD': db_password,
        'DB_HOST': db_host,
        'DB_NAME': db_name,
        'DATASET_PATH': dataset_path
    }

    for var_name, var_value in required_vars.items():
        if not var_value:
            logger.error(f"Missing required environment variable: {var_name}")
            # Do not log password value
            sys.exit(f"Error: Missing required environment variable: {var_name}") # Exit if essential vars are missing

    engine: Optional[Engine] = None
    spark_session: Optional[SparkSession] = None

    try:
        # 1. Create Spark Session
        spark_session = create_spark_session()
        if not spark_session:
            raise RuntimeError("Failed to create Spark session.")

        # 2. Setup Database Connection and Schema/Tables
        db_url = f'postgresql+psycopg2://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}'
        logger.info(f"Attempting to connect to database: {db_host}:{db_port}/{db_name}")
        engine = create_engine(
            db_url,
            pool_size=20,
            max_overflow=30,
            pool_pre_ping=True,
            pool_recycle=3600,
            connect_args={'connect_timeout': 10}
        )

        # Test connection
        with engine.connect() as connection:
            connection.execute(text("SELECT 1"))
        logger.info("Database connection successful.")

        if not setup_database(engine):
            raise RuntimeError("Database setup failed.")

        # 3. ETL Process for each file
        # The order of processing might matter if there are foreign key constraints,
        # but upsert handles existing keys, so order is less critical for upsert.
        # However, loading parent tables before children is generally safer.
        # Revised order based on dependencies: Stores, Employees, Customers, Products, Discounts, Transactions
        file_configs = [
            ('stores.csv', 'stores'),
            ('employees.csv', 'employees'),
            ('customers.csv', 'customers'),
            ('products.csv', 'products'),
            ('discounts.csv', 'discounts'),
            ('transactions.csv', 'transactions'),
        ]

        for file_name, table_name in file_configs:
            file_path = os.path.join(dataset_path, file_name)
            processor, primary_key = TABLE_PROCESSING_MAP.get(table_name, (None, None))

            if processor is None:
                 logger.error(f"No processing function defined for table '{table_name}'. Skipping.")
                 continue

            logger.info(f"--- Starting processing for {file_name} (Table: {table_name}) ---")
            process_start_time = datetime.now()

            # Extract
            spark_df = extract_data(spark_session, file_path, table_name)
            if spark_df is None:
                logger.warning(f"Skipping transformation and load for {file_name} due to extraction failure.")
                continue

            # Transform
            try:
                 transformed_df = processor(spark_df)
                 logger.info(f"Transformation complete for {file_name}.")
            except Exception as e:
                 logger.error(f"Error during transformation for {file_name}: {str(e)}", exc_info=True)
                 logger.warning(f"Skipping load for {file_name} due to transformation failure.")
                 continue # Skip load if transformation fails

            # Load
            try:
                write_spark_df_to_postgres(transformed_df, table_name, engine)
                logger.info(f"Load complete for {file_name}.")
            except Exception as e:
                logger.error(f"Error during load for {file_name}: {str(e)}", exc_info=True)
                # Decide if a load error should stop the entire pipeline or just skip this file
                # raise # Uncomment to stop the pipeline on any load error

            logger.info(f"--- Finished processing for {file_name} in {datetime.now() - process_start_time} ---")

        logger.info("ETL pipeline completed successfully.")

    except Exception as e:
        logger.error(f"An error occurred during the ETL pipeline: {str(e)}", exc_info=True)
        # Consider adding a cleanup step here for partial loads if needed
        raise # Re-raise the exception to indicate failure

    finally:
        # 4. Cleanup
        if engine:
            engine.dispose()
            logger.info("Database engine disposed.")
        if spark_session:
            spark_session.stop()
            logger.info("Spark session stopped.")

# --- Script Entry Point ---

if __name__ == "__main__":
    logger.info("Starting Global Fashion Retail Data Load process...")
    start_time = datetime.now()
    try:
        run_etl_pipeline()
        end_time = datetime.now()
        logger.info(f"Global Fashion Retail Data Load process completed successfully in {end_time - start_time}.")
        sys.exit(0) # Indicate success
    except Exception:
        end_time = datetime.now()
        logger.error(f"Global Fashion Retail Data Load process failed after {end_time - start_time}.")
        sys.exit(1) # Indicate failure

Cleanse and Aggregate for Reporting

In [None]:
import os
import logging
import sys
import pandas as pd
from dotenv import load_dotenv
from datetime import datetime
from sqlalchemy import create_engine, Column, Integer, String, Float, Date, DateTime, Boolean, ForeignKey, Text, text, PrimaryKeyConstraint
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
from sqlalchemy.exc import SQLAlchemyError
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, to_date, lit, datediff, current_date, dayofweek, sum as spark_sum, avg as spark_avg, count as spark_count, to_timestamp
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, DateType, BooleanType, TimestampType

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('data_processor.log', encoding='utf-8'), # Changed log file name
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Initialize Spark with Hadoop workaround for Windows and JDBC driver config
def create_spark_session():
    """Create and configure Spark session, including JDBC driver via spark.driver.extraClassPath."""
    try:
        # Set Hadoop home directory if on Windows
        # This is often needed for Spark to function correctly on Windows, even without HDFS
        if os.name == 'nt':
            # Ensure these paths are correct for your Windows setup if needed
            # Consider adding HADOOP_HOME_WIN to your .env or set it here if consistent
            hadoop_home = os.getenv('HADOOP_HOME_WIN', 'C:\\hadoop')
            os.environ['HADOOP_HOME'] = hadoop_home
            # Add Hadoop bin to PATH if it exists
            hadoop_bin = os.path.join(hadoop_home, 'bin')
            if os.path.exists(hadoop_bin):
                 os.environ['PATH'] = os.environ['PATH'] + ';' + hadoop_bin

            # Set Python path explicitly for PySpark
            python_path = sys.executable
            os.environ['PYSPARK_PYTHON'] = python_path
            os.environ['PYSPARK_DRIVER_PYTHON'] = python_path

        # Create temp directory if it doesn't exist
        # Use environment variable with default for flexibility
        temp_dir = os.getenv('SPARK_TEMP_DIR', 'D:/spark_temp')
        if not os.path.exists(temp_dir):
            os.makedirs(temp_dir)
            logger.info(f"Created Spark temporary directory: {temp_dir}")

        spark_builder = SparkSession.builder \
            .appName("GlobalFashionRetailDataProcessor") \
            .config("spark.local.dir", temp_dir) \
            .config("spark.sql.legacy.timeParserPolicy", "LEGACY") \
            .config("spark.executor.memory", "8g") \
            .config("spark.driver.memory", "8g") \
            .config("spark.memory.offHeap.enabled", "false") \
            .config("spark.sql.shuffle.partitions", "200") \
            .config("spark.default.parallelism", "200") \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.network.timeout", "600s") \
            .config("spark.executor.heartbeatInterval", "60s") \
            .config("spark.python.profile", "false") \
            .config("spark.executor.instances", "4") \
            .config("spark.executor.cores", "2") \
            .config("spark.driver.maxResultSize", "4g") \
            .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
            .config("spark.sql.execution.arrow.pyspark.fallback.enabled", "true") \
            .config("spark.sql.debug.maxToStringFields", 100) # Increase for wider dataframes in logging

        # === Configure JDBC driver via spark.driver.extraClassPath ===
        # Manually specify the path to the downloaded JAR file using the POSTGRES_JDBC_DRIVER_PATH env var
        # This config is specifically for the driver program (where your Python script runs)
        jdbc_jar_path = os.getenv('POSTGRES_JDBC_DRIVER_PATH')

        if not jdbc_jar_path:
            logger.error("POSTGRES_JDBC_DRIVER_PATH environment variable is not set.")
            raise ValueError("POSTGRES_JDBC_DRIVER_PATH environment variable must be set to the path of the PostgreSQL JDBC driver JAR.")

        if not os.path.exists(jdbc_jar_path):
             logger.error(f"PostgreSQL JDBC driver not found at specified path: {jdbc_jar_path}")
             raise FileNotFoundError(f"PostgreSQL JDBC driver not found at {jdbc_jar_path}. Please download the JAR and set POSTGRES_JDBC_DRIVER_PATH.")

        logger.info(f"Configuring Spark driver with extra classpath: {jdbc_jar_path}")
        spark_builder = spark_builder.config("spark.driver.extraClassPath", jdbc_jar_path)
        # =====================================================

        # === Removed the spark.jars.packages configuration block ===
        # jdbc_package = "org.postgresql:postgresql:42.6.0"
        # logger.info(f"Configuring Spark to use JDBC package: {jdbc_package}")
        # spark_builder = spark_builder.config("spark.jars.packages", jdbc_package)
        # =====================================================


        logger.info("Attempting to get or create Spark session...")
        spark_session = spark_builder.getOrCreate()
        logger.info("Spark session created successfully.")
        return spark_session

    except Exception as e:
        logger.error(f"Failed to create Spark session: {str(e)}", exc_info=True)
        raise

# Initialize Spark
spark = create_spark_session()

# Load environment variables
load_dotenv()

# Database schema definition
Base = declarative_base()
SCHEMA_NAME = "GFRetail" # Ensure this matches your database schema name


class Stores(Base):
    __tablename__ = 'stores'
    __table_args__ = {'schema': SCHEMA_NAME}
    store_id = Column(Integer, primary_key=True)
    country = Column(String(100), nullable=True)
    city = Column(String(100), nullable=True)
    store_name = Column(String(100))
    number_of_employees = Column(Integer)
    zip_code = Column(String(20))
    latitude = Column(Float, nullable=True)
    longitude = Column(Float, nullable=True)

class Employees(Base):
    __tablename__ = 'employees'
    __table_args__ = {'schema': SCHEMA_NAME}
    employee_id = Column(Integer, primary_key=True)
    store_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.stores.store_id'))
    name = Column(String(255), nullable=True)
    position = Column(String(100), nullable=True)

class Customers(Base):
    __tablename__ = 'customers'
    __table_args__ = {'schema': SCHEMA_NAME}
    customer_id = Column(Integer, primary_key=True)
    name = Column(String(255), nullable=True)
    email = Column(String(255), nullable=True)
    join_date = Column(Date, nullable=True)
    telephone = Column(String(50), nullable=True)
    city = Column(String(100), nullable=True)
    country = Column(String(100), nullable=True)
    gender = Column(String(20), nullable=True)
    date_of_birth = Column(Date, nullable=True)
    job_title = Column(String(100), nullable=True)
    # New Feature Column - will be added to DB if running create_all and it's missing
    days_since_birth = Column(Integer, nullable=True)

class Products(Base):
    __tablename__ = 'products'
    __table_args__ = {'schema': SCHEMA_NAME}
    product_id = Column(Integer, primary_key=True)
    category = Column(String(100), nullable=True)
    sub_category = Column(String(100))
    description_pt = Column(Text)
    description_de = Column(Text)
    description_fr = Column(Text)
    description_es = Column(Text)
    description_en = Column(Text)
    description_zh = Column(Text)
    color = Column(String(50), nullable=True)
    sizes = Column(String(100), nullable=True)
    production_cost = Column(Float)

class Discounts(Base):
    __tablename__ = 'discounts'
    __table_args__ = {'schema': SCHEMA_NAME}
    id = Column(Integer, primary_key=True, autoincrement=True) # Assuming this is the PK in DB
    start_date = Column(Date)
    end_date = Column(Date)
    discount_rate = Column(Float)
    description = Column(String(255), nullable=True)
    category = Column(String(100), nullable=True)
    sub_category = Column(String(100))

class Transactions(Base):
    __tablename__ = 'transactions'
    __table_args__ = {'schema': SCHEMA_NAME}
    # Assuming 'id' is the primary key already in the DB
    id = Column(Integer, primary_key=True) # Assuming this is pre-existing PK
    invoice_id = Column(String(50), nullable=True)
    line_number = Column(Integer)
    customer_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.customers.customer_id'))
    product_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.products.product_id'))
    size = Column(String(10))
    color = Column(String(50), nullable=True)
    unit_price = Column(Float)
    quantity = Column(Integer, nullable=True)
    transaction_date = Column(DateTime) # Or Timestamp if that's the DB type
    discount = Column(Float, nullable=True)
    line_total = Column(Float)
    store_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.stores.store_id'))
    employee_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.employees.employee_id'))
    currency = Column(String(10), nullable=True)
    currency_symbol = Column(String(5))
    sku = Column(String(100))
    transaction_type = Column(String(20))
    payment_method = Column(String(50))
    invoice_total = Column(Float)
    is_return = Column(Boolean)
    # New Feature Columns - will be added to DB if running create_all and they are missing
    net_line_total = Column(Float, nullable=True)
    transaction_day_of_week = Column(Integer, nullable=True)

class DailyStoreSales(Base):
    __tablename__ = 'daily_store_sales'
    __table_args__ = (
        PrimaryKeyConstraint('store_id', 'sale_date'), # Using composite primary key
        {'schema': SCHEMA_NAME}
    )
    store_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.stores.store_id'), nullable=False)
    sale_date = Column(Date, nullable=False)
    transaction_count = Column(Integer, nullable=True)
    total_sales = Column(Float, nullable=True)
    average_unit_price = Column(Float, nullable=True)
    total_quantity_sold = Column(Integer, nullable=True)
    total_discount_given = Column(Float, nullable=True)
    total_net_sales = Column(Float, nullable=True)


# --- Data Processing Functions (Feature Engineering & Aggregation) ---
# These functions now assume input DataFrames have column names matching SQLAlchemy models (snake_case)

def add_customer_features(customers_df):
    """Adds feature engineered columns to the customers DataFrame."""
    logger.info("Adding customer features (days_since_birth)...")
    # Assumes 'date_of_birth' column exists and is DateType/TimestampType in DB
    processed_df = customers_df.withColumn(
        "days_since_birth",
        when(col("date_of_birth").isNotNull(), datediff(current_date(), col("date_of_birth")))
        .otherwise(lit(None)) # Handle cases where date_of_birth is null
    )
    logger.info("Customer feature engineering complete.")
    return processed_df

def add_transaction_features(transactions_df):
    """Adds feature engineered columns to the transactions DataFrame."""
    logger.info("Adding transaction features (net_line_total, day_of_week)...")
    # Ensure required columns are cast to appropriate types before calculation/extraction
    # Assuming columns are already in snake_case and appropriate types from DB read,
    # but explicit casting here adds robustness if DB schema is slightly different.
    transactions_df = transactions_df.withColumn("unit_price", col("unit_price").cast(FloatType()))\
                                     .withColumn("quantity", col("quantity").cast(IntegerType()))\
                                     .withColumn("discount", col("discount").cast(FloatType()))\
                                     .withColumn("transaction_date", col("transaction_date").cast(TimestampType())) # Ensure datetime type


    processed_df = transactions_df.withColumn(
        "net_line_total",
        (col("unit_price") * col("quantity")) - col("discount") # Calculation on cleaned columns
    ).withColumn(
        "transaction_day_of_week",
         when(col("transaction_date").isNotNull(), dayofweek(col("transaction_date")))
        .otherwise(lit(None))
    )
    logger.info("Transaction feature engineering complete.")
    return processed_df


def aggregate_daily_store_sales(transactions_df):
    """Aggregates transaction data (assuming snake_case columns) to get daily sales summary per store."""
    logger.info("Starting daily store sales aggregation...")

    # Ensure transaction_date is a proper timestamp/datetime before casting to Date
    # Assuming the input transactions_df already has 'transaction_date' as TimestampType
    transactions_df = transactions_df.withColumn("sale_date", to_date(col("transaction_date")))

    # Group by store_id and sale_date and perform aggregations
    daily_sales_df = transactions_df.groupBy("store_id", "sale_date").agg(
        spark_count("*").alias("transaction_count"),
        spark_sum("line_total").alias("total_sales"), # Using line_total as total sales per line
        spark_avg("unit_price").alias("average_unit_price"),
        spark_sum("quantity").alias("total_quantity_sold"),
        spark_sum("discount").alias("total_discount_given"),
        spark_sum("net_line_total").alias("total_net_sales") # Aggregate the new net_line_total
    )

    logger.info("Daily store sales aggregation complete.")
    return daily_sales_df.select(
        "store_id",
        "sale_date",
        "transaction_count",
        "total_sales",
        "average_unit_price",
        "total_quantity_sold",
        "total_discount_given",
        "total_net_sales"
    )

# --- Database Interaction Functions ---

def setup_database(engine):
    """Create schema and set permissions if needed, then create tables (if they don't exist)."""
    try:
        with engine.begin() as conn:
            # Check if schema exists
            schema_exists = conn.execute(
                text("SELECT 1 FROM information_schema.schemata WHERE schema_name = :schema"),
                {'schema': SCHEMA_NAME}
            ).scalar()

            if not schema_exists:
                logger.info(f"Creating schema {SCHEMA_NAME}")
                conn.execute(text(f"CREATE SCHEMA {SCHEMA_NAME}"))
                # Grant permissions - adjust user/role as needed
                grant_user = os.getenv('DB_USER', 'postgres')
                conn.execute(text(f"""
                    GRANT ALL ON SCHEMA {SCHEMA_NAME} TO "{grant_user}";
                    ALTER DEFAULT PRIVILEGES IN SCHEMA {SCHEMA_NAME}
                    GRANT ALL ON TABLES TO "{grant_user}";
                """))
                logger.info(f"Schema {SCHEMA_NAME} created and permissions granted to {grant_user}.")

            # Create all tables if they don't exist (this is safe for existing tables)
            # This will add new columns (days_since_birth, net_line_total, transaction_day_of_week)
            # to existing tables if they are missing, and create the daily_store_sales table.
            logger.info("Creating/Ensuring existence of database tables based on models...")
            Base.metadata.create_all(engine)
            logger.info("Table creation/check complete.")

        return True
    except Exception as e:
        logger.error(f"Database setup failed: {str(e)}", exc_info=True)
        return False

def load_data_from_postgres_to_spark(spark, db_host, db_port, db_name, db_user, db_password, schema_name, table_name):
    """Reads a table from PostgreSQL into a Spark DataFrame using JDBC with optional partitioning."""
    logger.info(f"Reading data from database table: {schema_name}.{table_name}")
    try:
        jdbc_url = f'jdbc:postgresql://{db_host}:{db_port}/{db_name}'
        db_properties = {
            "user": db_user,
            "password": db_password,
            "driver": "org.postgresql.Driver",
            "fetchsize": "10000" # Helps with large reads
        }

        # Construct the full table name including schema (quoted for case sensitivity/special chars)
        full_table_name = f'"{schema_name}"."{table_name}"'

        # === Add Partitioning Logic Here for Performance ===
        # This is crucial for large tables like 'transactions'.
        # You need to determine appropriate partitionColumn, lowerBound, upperBound, and numPartitions
        partition_options = {}
        if table_name == 'transactions':
            logger.info(f"Applying partitioning for {table_name}")
            logger.warning(f"Partitioning options are NOT set for {table_name}. This will be slow for large tables. Consider adding partitionColumn, lowerBound, upperBound, and numPartitions.")

        spark_df = spark.read.jdbc(
            url=jdbc_url,
            table=full_table_name,
            properties=db_properties,
            **partition_options # Pass the partition options
        )
        logger.info(f"Successfully read data from {schema_name}.{table_name}. Schema: {spark_df.printSchema()}")
        # logger.info(f"Number of partitions after reading: {spark_df.rdd.getNumPartitions()}") # Check partition count
        return spark_df
    except Exception as e:
        logger.error(f"Failed to read data from {schema_name}.{table_name} into Spark: {str(e)}", exc_info=True)
        raise

def upsert_data(session, table_class, records, primary_keys, batch_size=1000):
    """Upsert data into the database with batch processing using SQLAlchemy merge."""
    try:
        total_records = len(records)
        if total_records == 0:
            logger.info(f"No records to upsert for {table_class.__tablename__}.")
            return

        logger.info(f"Starting to upsert {total_records} records for {table_class.__tablename__} in batches of {batch_size}")

        for i in range(0, total_records, batch_size):
            batch = records[i:i + batch_size]
            batch_num = (i // batch_size) + 1
            total_batches = (total_records - 1) // batch_size + 1

            try:
                # Use merge for upsert logic, which handles existing primary keys (simple or composite)
                for record in batch:
                    # Convert Spark Row object to dictionary if needed, or ensure records are dicts
                    record_dict = record if isinstance(record, dict) else record.asDict()
                    instance = table_class(**record_dict)
                    session.merge(instance) # merge handles both insert and update based on PK

                session.commit()
                logger.info(f"Successfully processed batch {batch_num}/{total_batches} for {table_class.__tablename__}")

            except Exception as e:
                session.rollback()
                logger.error(f"Error processing batch {batch_num} for {table_class.__tablename__}: {str(e)}", exc_info=True)
                # Log affected primary key(s) if possible
                try:
                    pk_values = [
                        {pk: record_dict.get(pk) for pk in primary_keys}
                        for record_dict in (record if isinstance(record, dict) else record.asDict() for record in batch)
                    ]
                    logger.error(f"Batch primary key values: {pk_values}")
                except Exception as log_e:
                    logger.error(f"Could not log batch primary keys: {log_e}")

                # Decide whether to re-raise or continue. Re-raising stops on first error.
                raise # Stop on error


    except Exception as e:
        session.rollback()
        logger.error(f"Critical error in upsert_data for {table_class.__tablename__}: {str(e)}", exc_info=True)
        raise

def write_spark_df_to_postgres(spark_df, table_name, engine):
    """Write Spark DataFrame to PostgreSQL using SQLAlchemy upsert (merge) logic."""
    try:
        # Map table name to SQLAlchemy class name
        # Handles 'daily_store_sales' -> 'DailyStoreSales'
        table_class_name = ''.join(word.capitalize() for word in table_name.split('_'))
        table_class = globals().get(table_class_name)

        if not table_class:
             logger.error(f"SQLAlchemy class not found for table name: {table_name}")
             return

        # Define primary keys for each table
        primary_keys_map = {
            'stores': ['store_id'],
            'employees': ['employee_id'],
            'customers': ['customer_id'],
            'products': ['product_id'],
            'discounts': ['id'],
            'transactions': ['id'], # Assuming 'id' is the PK
            'daily_store_sales': ['store_id', 'sale_date'] # Composite key
        }
        primary_keys = primary_keys_map.get(table_name)

        if not primary_keys:
            logger.error(f"Primary key(s) not defined for table {table_name}")
            return

        logger.info(f"Starting write process for table {table_name}...")

        # Convert Spark DataFrame to Pandas DataFrames iteratively
        # Using toLocalIterator for larger dataframes to avoid driver memory issues
        pandas_iter = spark_df.toLocalIterator()
        records_buffer = []
        batch_size = 5000 # Batch size for converting Spark rows to dicts and upserting

        Session = sessionmaker(bind=engine)
        session = Session()
        try:
            i = 0
            for row in pandas_iter:
                records_buffer.append(row.asDict()) # Convert Spark Row to dictionary
                i += 1

                if i % batch_size == 0:
                    logger.info(f"Buffer size reached {batch_size}, processing batch {i // batch_size}")
                    upsert_data(session, table_class, records_buffer, primary_keys)
                    records_buffer = [] # Clear buffer
                    # Commit is done inside upsert_data for each batch

            # Process any remaining records in the buffer
            if records_buffer:
                logger.info(f"Processing final batch of {len(records_buffer)} records")
                upsert_data(session, table_class, records_buffer, primary_keys)

            # Verify count only for smaller tables or if specifically needed and feasible
            # Note: Counting can be slow on large tables, especially after upserts
            if table_name in ['stores', 'employees', 'discounts', 'daily_store_sales']: # Only count the tables we loaded/updated
                 try:
                    with engine.connect() as conn:
                        # Use quoted identifiers for the schema and table name
                        count = conn.execute(text(f'SELECT COUNT(*) FROM "{SCHEMA_NAME}"."{table_name}"')).scalar()
                        logger.info(f"Final count for {table_name}: {count} records")
                 except Exception as count_e:
                    logger.warning(f"Could not verify count for {table_name}: {count_e}")
            else:
                 logger.info(f"Skipped final count verification for table {table_name} (read from DB).")


        except Exception as e:
             logger.error(f"Error during data processing and upsert for {table_name}: {str(e)}", exc_info=True)
             session.rollback() # Rollback any outstanding transaction if an error occurred before a commit
             raise
        finally:
            session.close()
            logger.info(f"Session closed for table {table_name}.")


    except Exception as e:
        logger.error(f"Fatal error in write_spark_df_to_postgres for {table_name}: {str(e)}", exc_info=True)
        raise


def process_and_aggregate_from_db():
    """Main function to read data from DB, process, aggregate, and load aggregated data."""
    # Load environment variables
    DB_USER = os.getenv('DB_USER')
    DB_PASSWORD = os.getenv('DB_PASSWORD')
    DB_HOST = os.getenv('DB_HOST')
    DB_PORT = os.getenv('DB_PORT')
    DB_NAME = os.getenv('DB_NAME')

    required_vars = {
        'DB_USER': DB_USER,
        'DB_PASSWORD': DB_PASSWORD,
        'DB_HOST': DB_HOST,
        'DB_NAME': DB_NAME,
    }

    for var_name, var_value in required_vars.items():
        if not var_value:
            raise ValueError(f"Missing required environment variable: {var_name}")

    DB_PORT = DB_PORT if DB_PORT else '5432' # Set default for PORT

    engine = None # Initialize engine outside try for finally block

    try:
        # Create SQLAlchemy engine for schema setup and writing (only for aggregated data)
        engine = create_engine(
            f'postgresql+psycopg2://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}',
            pool_size=15, # Adjust pool size
            max_overflow=25, # Adjust max overflow
            pool_pre_ping=True,
            pool_recycle=3600,
            connect_args={'connect_timeout': 10}
        )

        # Setup Database: Create schema and tables if they don't exist
        # This will ensure the daily_store_sales table exists and add FE columns if missing
        if not setup_database(engine):
            raise RuntimeError("Database setup failed")

        # --- Read Data from Database into Spark ---
        logger.info("Starting to read data from database tables...")
        start_time_read_db = datetime.now()
        try:
            # Read all necessary tables from the database
            customers_df = load_data_from_postgres_to_spark(spark, DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD, SCHEMA_NAME, 'customers')
            products_df = load_data_from_postgres_to_spark(spark, DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD, SCHEMA_NAME, 'products')
            transactions_df = load_data_from_postgres_to_spark(spark, DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD, SCHEMA_NAME, 'transactions')
            stores_df = load_data_from_postgres_to_spark(spark, DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD, SCHEMA_NAME, 'stores')
            employees_df = load_data_from_postgres_to_spark(spark, DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD, SCHEMA_NAME, 'employees')
            discounts_df = load_data_from_postgres_to_spark(spark, DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD, SCHEMA_NAME, 'discounts')

            logger.info(f"Completed reading data from database in {datetime.now() - start_time_read_db}")

        except Exception as e:
             logger.error(f"Failed to read data from DB. Cannot proceed with processing. Error: {str(e)}", exc_info=True)
             # Reading core data from DB is fatal if it fails
             raise

        # --- Perform Feature Engineering on DB-read DataFrames ---
        logger.info("Starting feature engineering on DB-read dataframes...")
        start_time_fe = datetime.now()
        try:
            # Apply feature engineering functions to the relevant DataFrames
            # These functions now expect snake_case columns from the DB read
            processed_customer_df = add_customer_features(customers_df)
            processed_transaction_df = add_transaction_features(transactions_df)
            # No specific FE planned for other tables in this iteration

            logger.info(f"Feature engineering complete in {datetime.now() - start_time_fe}")

        except Exception as e:
             logger.error(f"Error during feature engineering: {str(e)}", exc_info=True)
             # Feature engineering failure is likely fatal as aggregation depends on it
             raise

        # --- Perform Aggregation and Load New Tables ---
        if processed_transaction_df is not None: # Check if transactions data was processed
            logger.info("Starting aggregation for daily store sales...")
            start_time_agg = datetime.now()
            try:
                # Aggregate using the feature-engineered transactions DataFrame
                daily_sales_df = aggregate_daily_store_sales(processed_transaction_df)
                # Load the aggregated data into the new daily_store_sales table
                write_spark_df_to_postgres(daily_sales_df, 'daily_store_sales', engine)
                logger.info(f"Completed daily store sales aggregation and loading in {datetime.now() - start_time_agg}")
            except Exception as e:
                logger.error(f"Error during daily store sales aggregation or loading: {str(e)}", exc_info=True)
                # Decide whether to stop or continue on aggregation error
                # For now, let's treat aggregation failure as fatal
                raise

        else:
            logger.warning("No processed transaction data available for aggregation.")


        logger.info("Overall data processing and aggregation completed.")

    except Exception as e:
        logger.error(f"Fatal error during overall process: {str(e)}", exc_info=True)
        # Re-raise the exception so the main block catches it for process exit code
        raise

    finally:
        # Ensure database engine and Spark session are closed/stopped
        if engine:
            engine.dispose()
            logger.info("Database engine disposed.")
        if spark:
            spark.stop()
            logger.info("Spark session stopped.")


if __name__ == "__main__":
    # This block runs when the script is executed directly
    try:
        logger.info("Starting data processing and aggregation script (reading from DB)...")
        overall_start_time = datetime.now()
        process_and_aggregate_from_db() # Call the new main function
        overall_end_time = datetime.now()
        logger.info(f"Script finished. Total elapsed time: {overall_end_time - overall_start_time}")
        sys.exit(0) # Exit successfully
    except Exception as e:
        logger.error(f"Script terminated due to fatal error: {str(e)}", exc_info=True)
        sys.exit(1) # Exit with a non-zero status code to indicate failure
