Install required libraries

In [None]:
!pip install kagglehub kaggle pandas psycopg2-binary sqlalchemy

Import Kaggle and Authenticate to Kaggle

In [None]:
import kagglehub
import kaggle
print(dir(kagglehub))
print("Kaggle module is successfully installed!")
kaggle.api.authenticate()
print("Kaggle API authentication successful!")

Pull Data from Kaggle

Import the Retail Fashion Dataset from Kaggle 
The dataset consists of 6 csv file
- transactiions.csv - 773 MB in size and has 6.41 million records
- customers.csv - 183 MB in size and has 1.64 million records
- products.csv - 4.77 MB in size and has 17940 records
- discounts.csv - 18 KB in size and has 3801 records
- employees.csv - 15.2 KB in size and has 404 records
- stores.csv - 3 KB in size and has 35 records


In [None]:
import kagglehub
from kagglehub import KaggleDatasetAdapter
import os
import pandas as pd


# Define dataset folder (custom cache location)
dataset_path = "/home/megin_mathew/fashion_dataset"

# Set KaggleHub cache override
os.environ["KAGGLEHUB_CACHE"] = dataset_path

# File names in the dataset
file_names = [
    "transactions.csv", "customers.csv", "discounts.csv",
    "employees.csv", "products.csv", "stores.csv"
]

# Ensure cache directory exists
os.makedirs(dataset_path, exist_ok=True)

# Download each file individually
for file_name in file_names:
    df = kagglehub.dataset_load(
        KaggleDatasetAdapter.PANDAS,
        "ricgomes/global-fashion-retail-stores-dataset",
        file_name
    )
    
    # Save to custom cache location
    df.to_csv(os.path.join(dataset_path, file_name), index=False)
    print(f"Downloaded: {file_name} → Stored in: {dataset_path}")


Set Kaggle Cache

In [None]:
import os

# Set custom cache location
os.environ["KAGGLEHUB_CACHE"] = "/home/megin_mathew/fashion_dataset"

# Global Fashion Retail Data Pipeline

## Initial Setup and Configuration

In [None]:
import os
import sys
import logging
import pandas as pd
from datetime import datetime
from dotenv import load_dotenv

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('data_pipeline.log', encoding='utf-8'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv("/home/megin_mathew/airflow/notebooks/.env")

## Database Models

In [None]:
from sqlalchemy import (
    create_engine, Column, Integer, String, Float, Date, DateTime,
    Boolean, ForeignKey, Text, text, PrimaryKeyConstraint, inspect
)
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.exc import SQLAlchemyError

Base = declarative_base()
SCHEMA_NAME = "GFRetail"

class Stores(Base):
    __tablename__ = 'stores'
    __table_args__ = {'schema': SCHEMA_NAME}
    store_id = Column(Integer, primary_key=True)
    country = Column(String(100), nullable=True)
    city = Column(String(100), nullable=True)
    store_name = Column(String(100))
    number_of_employees = Column(Integer)
    zip_code = Column(String(20))
    latitude = Column(Float, nullable=True)
    longitude = Column(Float, nullable=True)

class Employees(Base):
    __tablename__ = 'employees'
    __table_args__ = {'schema': SCHEMA_NAME}
    employee_id = Column(Integer, primary_key=True)
    store_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.stores.store_id'))
    name = Column(String(255), nullable=True)
    position = Column(String(100), nullable=True)

class Customers(Base):
    __tablename__ = 'customers'
    __table_args__ = {'schema': SCHEMA_NAME}
    customer_id = Column(Integer, primary_key=True)
    name = Column(String(255), nullable=True)
    email = Column(String(255), nullable=True)
    join_date = Column(Date, nullable=True)
    telephone = Column(String(50), nullable=True)
    city = Column(String(100), nullable=True)
    country = Column(String(100), nullable=True)
    gender = Column(String(20), nullable=True)
    date_of_birth = Column(Date, nullable=True)
    job_title = Column(String(100), nullable=True)
    days_since_birth = Column(Integer, nullable=True)

class Products(Base):
    __tablename__ = 'products'
    __table_args__ = {'schema': SCHEMA_NAME}
    product_id = Column(Integer, primary_key=True)
    category = Column(String(100), nullable=True)
    sub_category = Column(String(100))
    description_pt = Column(Text)
    description_de = Column(Text)
    description_fr = Column(Text)
    description_es = Column(Text)
    description_en = Column(Text)
    description_zh = Column(Text)
    color = Column(String(50), nullable=True)
    sizes = Column(String(100), nullable=True)
    production_cost = Column(Float)

class Discounts(Base):
    __tablename__ = 'discounts'
    __table_args__ = {'schema': SCHEMA_NAME}
    id = Column(Integer, primary_key=True, autoincrement=True)
    start_date = Column(Date)
    end_date = Column(Date)
    discount_rate = Column(Float)
    description = Column(String(255), nullable=True)
    category = Column(String(100), nullable=True)
    sub_category = Column(String(100))

class Transactions(Base):
    __tablename__ = 'transactions'
    __table_args__ = {'schema': SCHEMA_NAME}
    id = Column(Integer, primary_key=True)
    invoice_id = Column(String(50), nullable=True)
    line_number = Column(Integer)
    customer_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.customers.customer_id'))
    product_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.products.product_id'))
    size = Column(String(10))
    color = Column(String(50), nullable=True)
    unit_price = Column(Float)
    quantity = Column(Integer, nullable=True)
    transaction_date = Column(DateTime)
    discount = Column(Float, nullable=True)
    line_total = Column(Float)
    store_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.stores.store_id'))
    employee_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.employees.employee_id'))
    currency = Column(String(10), nullable=True)
    currency_symbol = Column(String(5))
    sku = Column(String(100))
    transaction_type = Column(String(20))
    payment_method = Column(String(50))
    invoice_total = Column(Float)
    is_return = Column(Boolean)
    net_line_total = Column(Float, nullable=True)
    transaction_day_of_week = Column(Integer, nullable=True)

class DailyStoreSales(Base):
    __tablename__ = 'daily_store_sales'
    __table_args__ = (
        PrimaryKeyConstraint('store_id', 'sale_date'),
        {'schema': SCHEMA_NAME}
    )
    store_id = Column(Integer, ForeignKey(f'{SCHEMA_NAME}.stores.store_id'), nullable=False)
    sale_date = Column(Date, nullable=False)
    transaction_count = Column(Integer, nullable=True)
    total_sales = Column(Float, nullable=True)
    average_unit_price = Column(Float, nullable=True)
    total_quantity_sold = Column(Integer, nullable=True)
    total_discount_given = Column(Float, nullable=True)
    total_net_sales = Column(Float, nullable=True)


## Spark Session Management

In [None]:
from pyspark.sql import SparkSession

def create_spark_session():
    """Create and configure Spark session."""
    try:
        if os.name == 'nt':
            hadoop_home = os.getenv('HADOOP_HOME_WIN', 'C:\\hadoop')
            os.environ['HADOOP_HOME'] = hadoop_home
            os.environ['PATH'] = os.environ['PATH'] + ';' + os.path.join(hadoop_home, 'bin')
            os.environ['PYSPARK_PYTHON'] = sys.executable
            os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

        temp_dir = os.getenv('SPARK_TEMP_DIR',  "/home/megin_mathew/spark_temp")
        if not os.path.exists(temp_dir):
            os.makedirs(temp_dir)

        spark = SparkSession.builder \
            .appName("FashionRetailETL") \
            .config("spark.local.dir", temp_dir) \
            .config("spark.executor.memory", "8g") \
            .config("spark.driver.memory", "8g") \
            .config("spark.sql.shuffle.partitions", "200") \
            .getOrCreate()

        logger.info("Spark session created successfully")
        return spark
    except Exception as e:
        logger.error(f"Failed to create Spark session: {str(e)}")
        raise

## Extraction Module

In [None]:
def extract_csv(spark, file_path):
    """Extract data from CSV file."""
    try:
        logger.info(f"Extracting data from {file_path}")
        if 'transactions' in file_path.lower() or 'customers' in file_path.lower():
            # For large files, use sampling to infer schema
            sample_df = spark.read.csv(file_path, header=True, inferSchema=True, samplingRatio=0.1)
            return spark.read.csv(file_path, header=True, schema=sample_df.schema)
        return spark.read.csv(file_path, header=True, inferSchema=True)
    except Exception as e:
        logger.error(f"Error extracting {file_path}: {str(e)}")
        raise

def extract_postgres(spark, table_name):
    """Extract data from PostgreSQL."""
    try:
        logger.info(f"Extracting {table_name} from PostgreSQL")
        jdbc_url = f"jdbc:postgresql://{os.getenv('DB_HOST')}:{os.getenv('DB_PORT', '5432')}/{os.getenv('DB_NAME')}"
        properties = {
            "user": os.getenv('DB_USER'),
            "password": os.getenv('DB_PASSWORD'),
            "driver": "org.postgresql.Driver",
            "fetchsize": "10000"
        }
        return spark.read.jdbc(
            url=jdbc_url,
            table=f'"{SCHEMA_NAME}"."{table_name}"',
            properties=properties
        )
    except Exception as e:
        logger.error(f"Error extracting {table_name}: {str(e)}")
        raise


## Transformation Module

In [None]:
from pyspark.sql.functions import (
    col, when, lit, to_date, datediff, current_date,
    dayofweek, to_timestamp, count, sum, avg ,  floor, months_between
)
from pyspark.sql.types import FloatType

def transform_stores(df):
    """Transform stores data."""
    return df.withColumn("country", when(col("Country").isNull(), lit("Unknown")).otherwise(col("Country"))) \
            .withColumn("city", when(col("City").isNull(), lit("Unknown")).otherwise(col("City"))) \
            .select(
                col("Store ID").alias("store_id"),
                col("country"),
                col("city"),
                col("Store Name").alias("store_name"),
                col("Number of Employees").alias("number_of_employees"),
                col("ZIP Code").alias("zip_code"),
                col("Latitude").cast(FloatType()).alias("latitude"),
                col("Longitude").cast(FloatType()).alias("longitude")
            )

def transform_employees(df):
    """Transform employees data."""
    return df.withColumn("name", when(col("Name").isNull(), lit("Unknown")).otherwise(col("Name"))) \
            .withColumn("position", when(col("Position").isNull(), lit("Unknown")).otherwise(col("Position"))) \
            .select(
                col("Employee ID").alias("employee_id"),
                col("Store ID").alias("store_id"),
                col("name"),
                col("position")
            )


def transform_customers(df):
    """Transform customers data with feature engineering."""
    # First check if required columns exist
    for col_name in ["Join Date", "Date Of Birth", "Email", "Name"]:
        if col_name not in df.columns:
            df = df.withColumn(col_name, lit(None))
    
    return (
        df
        # Handle NULL names → "Unknown"
        .withColumn(
            "name",
            when(col("Name").isNull(), lit("Unknown"))  
            .otherwise(col("Name"))
        )
        # Handle NULL emails → "unknown@example.com"
        .withColumn(
            "email",
            when(col("Email").isNull(), lit("unknown@example.com"))
            .otherwise(col("Email"))
        )
        # Convert Join Date to date format
        .withColumn(
            "join_date",
            to_date(col("Join Date"), "yyyy-MM-dd")
        )
        # Convert Date Of Birth to date (default to 1900-01-01 if NULL)
        .withColumn(
            "date_of_birth",
            when(
                col("Date Of Birth").isNotNull(),
                to_date(col("Date Of Birth"), "yyyy-MM-dd")
            )
            .otherwise(to_date(lit("1900-01-01"), "yyyy-MM-dd"))
        )
        # Calculate age in years
        .withColumn(
            "age",
            floor(months_between(current_date(), col("date_of_birth")) / 12).cast("integer")
        )
        # Select and rename columns
        .select(
            col("Customer ID").alias("customer_id"),
            col("name"),
            col("email"),
            col("join_date"),
            col("Telephone").alias("telephone"),
            col("City").alias("city"),
            col("Country").alias("country"),
            col("Gender").alias("gender"),
            col("date_of_birth"),
            col("Job Title").alias("job_title"),
            col("age")
        )
    )
            
def transform_products(df):
    """Transform products data."""
    return df.withColumn("category", when(col("Category").isNull(), lit("Unknown")).otherwise(col("Category"))) \
            .withColumn("color", when(col("Color").isNull(), lit("Unknown")).otherwise(col("Color"))) \
            .select(
                col("Product ID").alias("product_id"),
                col("category"),
                col("Sub Category").alias("sub_category"),
                col("Description PT").alias("description_pt"),
                col("Description DE").alias("description_de"),
                col("Description FR").alias("description_fr"),
                col("Description ES").alias("description_es"),
                col("Description EN").alias("description_en"),
                col("Description ZH").alias("description_zh"),
                col("color"),
                col("Sizes").alias("sizes"),
                col("Production Cost").alias("production_cost")
            )

def transform_discounts(df):
    """Transform discounts data."""
    return df.withColumn("description", when(col("Description").isNull(), lit("No description")).otherwise(col("Description"))) \
            .select(
                to_date(col("Start"), "yyyy-MM-dd").alias("start_date"),
                to_date(col("End"), "yyyy-MM-dd").alias("end_date"),
                col("Discont").alias("discount_rate"),
                col("description"),
                col("Category").alias("category"),
                col("Sub Category").alias("sub_category")
            )

def transform_transactions(df):
    """Transform transactions data with feature engineering."""
    return df.withColumn("invoice_id", when(col("Invoice ID").isNull(), lit("UNKNOWN")).otherwise(col("Invoice ID"))) \
            .withColumn("quantity", when(col("Quantity").isNull(), lit(1)).otherwise(col("Quantity"))) \
            .withColumn("discount", when(col("Discount").isNull(), lit(0.0)).otherwise(col("Discount"))) \
            .withColumn("is_return", col("Transaction Type") == "Return") \
            .withColumn("transaction_date", to_timestamp(col("Date"), "yyyy-MM-dd HH:mm:ss")) \
            .withColumn("net_line_total", (col("Unit Price") * col("quantity")) - col("discount")) \
            .withColumn("transaction_day_of_week", dayofweek(to_timestamp(col("Date"), "yyyy-MM-dd HH:mm:ss"))) \
            .select(
                col("invoice_id"),
                col("Line").alias("line_number"),
                col("Customer ID").alias("customer_id"),
                col("Product ID").alias("product_id"),
                col("Size").alias("size"),
                col("Color").alias("color"),
                col("Unit Price").alias("unit_price"),
                col("quantity"),
                col("transaction_date"),
                col("discount"),
                col("Line Total").alias("line_total"),
                col("Store ID").alias("store_id"),
                col("Employee ID").alias("employee_id"),
                col("Currency").alias("currency"),
                col("Currency Symbol").alias("currency_symbol"),
                col("SKU").alias("sku"),
                col("Transaction Type").alias("transaction_type"),
                col("Payment Method").alias("payment_method"),
                col("Invoice Total").alias("invoice_total"),
                col("is_return"),
                col("net_line_total"),
                col("transaction_day_of_week")
            )

def aggregate_daily_sales(df):
    """Aggregate transactions to daily store sales."""
    return df.withColumn("sale_date", to_date(col("transaction_date"))) \
            .groupBy("store_id", "sale_date") \
            .agg(
                count("*").alias("transaction_count"),
                sum("line_total").alias("total_sales"),
                avg("unit_price").alias("average_unit_price"),
                sum("quantity").alias("total_quantity_sold"),
                sum("discount").alias("total_discount_given"),
                sum("net_line_total").alias("total_net_sales")
            )

## Database Module(Load)

In [None]:
def setup_database(engine):
    """Initialize database schema and tables."""
    try:
        with engine.begin() as conn:
            # Create schema if not exists
            if not conn.execute(
                text("SELECT 1 FROM information_schema.schemata WHERE schema_name = :schema"),
                {'schema': SCHEMA_NAME}
            ).scalar():
                logger.info(f"Creating schema {SCHEMA_NAME}")
                conn.execute(text(f"CREATE SCHEMA {SCHEMA_NAME}"))
                grant_user = os.getenv('DB_USER', 'postgres')
                conn.execute(text(f"""
                    GRANT ALL ON SCHEMA {SCHEMA_NAME} TO "{grant_user}";
                    ALTER DEFAULT PRIVILEGES IN SCHEMA {SCHEMA_NAME}
                    GRANT ALL ON TABLES TO "{grant_user}";
                """))

            # Create tables and add missing columns
            Base.metadata.create_all(engine)
            inspector = inspect(engine)
            
            for table_name in Base.metadata.tables.keys():
                if not inspector.has_table(table_name, schema=SCHEMA_NAME):
                    continue
                    
                table = Base.metadata.tables[f"{SCHEMA_NAME}.{table_name}"]
                columns = inspector.get_columns(table_name, schema=SCHEMA_NAME)
                existing_columns = {c['name'] for c in columns}
                
                for column in table.columns:
                    if column.name not in existing_columns:
                        logger.info(f"Adding column {column.name} to {table_name}")
                        column_type = column.type.compile(engine.dialect)
                        conn.execute(text(
                            f'ALTER TABLE "{SCHEMA_NAME}"."{table_name}" '
                            f'ADD COLUMN "{column.name}" {column_type}'
                        ))
            
        logger.info("Database setup completed")
        return True
    except Exception as e:
        logger.error(f"Database setup failed: {str(e)}")
        return False

def create_db_engine():
    """Create SQLAlchemy engine."""
    return create_engine(
        f'postgresql+psycopg2://{os.getenv("DB_USER")}:{os.getenv("DB_PASSWORD")}@'
        f'{os.getenv("DB_HOST")}:{os.getenv("DB_PORT", "5432")}/{os.getenv("DB_NAME")}',
        pool_size=20,
        max_overflow=30,
        pool_pre_ping=True,
        pool_recycle=3600,
        connect_args={'connect_timeout': 10}
    )

def load_data_to_postgres(spark_df, table_name, engine):
    """Load Spark DataFrame to PostgreSQL with proper batch processing."""
    try:
        table_class = globals().get(table_name.capitalize())
        if not table_class:
            logger.error(f"No model class found for {table_name}")
            return

        primary_keys = {
            'stores': ['store_id'],
            'employees': ['employee_id'],
            'customers': ['customer_id'],
            'products': ['product_id'],
            'discounts': ['id'],
            'transactions': ['id'],
            'daily_store_sales': ['store_id', 'sale_date']
        }.get(table_name)

        if not primary_keys:
            logger.error(f"No primary key defined for {table_name}")
            return

        logger.info(f"Starting to load data to {table_name}...")
        
        # Convert to Pandas DataFrame in chunks
        pandas_df = spark_df.toPandas()
        
        # Process in batches
        batch_size = 1000
        total_records = len(pandas_df)
        num_batches = (total_records // batch_size) + 1
        
        Session = sessionmaker(bind=engine)
        session = Session()
        
        try:
            for i in range(0, total_records, batch_size):
                batch = pandas_df.iloc[i:i + batch_size]
                records = batch.to_dict('records')
                
                try:
                    # First try bulk insert for performance
                    session.bulk_insert_mappings(table_class, records)
                    session.commit()
                    logger.info(f"Processed batch {i//batch_size + 1}/{num_batches}")
                except Exception as bulk_error:
                    session.rollback()
                    logger.warning(f"Bulk insert failed, switching to individual upserts: {str(bulk_error)}")
                    
                    for record in records:
                        try:
                            existing = session.query(table_class).filter_by(**{
                                pk: record[pk] for pk in primary_keys
                            }).first()
                            
                            if existing:
                                for key, value in record.items():
                                    setattr(existing, key, value)
                            else:
                                session.add(table_class(**record))
                            session.commit()
                        except Exception as rec_error:
                            session.rollback()
                            logger.error(f"Error processing record: {str(rec_error)}")
                            continue
        
            # Verify count for non-large tables
            if table_name not in ['transactions', 'daily_store_sales']:
                count = session.query(table_class).count()
                logger.info(f"Verified {count} records in {table_name}")
                
        finally:
            session.close()
            
        logger.info(f"Successfully loaded data to {table_name}")
    except Exception as e:
        logger.error(f"Error loading to {table_name}: {str(e)}", exc_info=True)
        raise

## Main Pipeline

In [None]:
def run_pipeline():
    """Main ETL pipeline execution."""
    spark = None
    engine = None
    
    try:
        # Initialize
        spark = create_spark_session()
        engine = create_db_engine()
        
        if not setup_database(engine):
            raise RuntimeError("Database setup failed")
        
        # Process each data file
        data_files = [
            ('stores.csv', transform_stores, 'stores'),
            ('employees.csv', transform_employees, 'employees'),
            ('products.csv', transform_products, 'products'),
            ('discounts.csv', transform_discounts, 'discounts'),
            ('customers.csv', transform_customers, 'customers'),
            ('transactions.csv', transform_transactions, 'transactions')
        ]
        
        for file_name, transform_func, table_name in data_files:
            file_path = os.path.join(os.getenv('DATASET_PATH'), file_name)
            if not os.path.exists(file_path):
                logger.warning(f"File not found: {file_path}")
                continue
                
            logger.info(f"Processing {file_name}...")
            start_time = datetime.now()
            
            try:
                df = extract_csv(spark, file_path)
                transformed_df = transform_func(df)
                load_data_to_postgres(transformed_df, table_name, engine)
                logger.info(f"Completed {file_name} in {datetime.now() - start_time}")
            except Exception as e:
                logger.error(f"Error processing {file_name}: {str(e)}")
                if file_name == 'products.csv':
                    logger.info("Trying alternative approach for products")
                    try:
                        products_pd = pd.read_csv(file_path)
                        products_spark = spark.createDataFrame(products_pd)
                        transformed_df = transform_func(products_spark)
                        load_data_to_postgres(transformed_df, table_name, engine)
                    except Exception as alt_e:
                        logger.error(f"Alternative approach failed: {str(alt_e)}")
                        raise
                else:
                    raise
        
        # Generate and load daily sales aggregations
        logger.info("Generating daily sales aggregations...")
        start_time = datetime.now()
        
        transactions_df = extract_postgres(spark, 'transactions')
        daily_sales_df = aggregate_daily_sales(transactions_df)
        load_data_to_postgres(daily_sales_df, 'daily_store_sales', engine)
        
        logger.info(f"Daily sales completed in {datetime.now() - start_time}")
        logger.info("Pipeline completed successfully!")
        
    except Exception as e:
        logger.error(f"Pipeline failed: {str(e)}", exc_info=True)
        raise
    finally:
        if engine:
            engine.dispose()
        if spark:
            spark.stop()

## Execute Pipeline

In [None]:
if __name__ == "__main__":
    logger.info("Starting pipeline execution...")
    start_time = datetime.now()
    
    try:
        run_pipeline()
        logger.info(f"Pipeline completed in {datetime.now() - start_time}")
        sys.exit(0)
    except Exception as e:
        logger.error(f"Pipeline failed: {str(e)}", exc_info=True)
        sys.exit(1)