In [1]:

# ## Overview
# This notebook implements a dimensional Data Lakehouse based on the AdventureWorks dataset to analyze sales transactions.
# 
# ## Architecture
# - Data Sources: MySQL (AdventureWorks), CSV files, JSON streaming files, MongoDB Atlas
# - Integration Pattern: ELT (Extract, Load, Transform)
# - Lakehouse Architecture: Databricks Bronze, Silver, Gold layers
# 
# ## Dimensional Model
# - Fact Table: Sales (from SalesOrderDetail)
# - Dimension Tables: 
# - Date
# - Customer
# - Product
# - Employee (Sales Person)
# - Territory

# Install necessary packages
%pip install pymysql pymongo pandas numpy

import pymysql
import pandas as pd
import numpy as np
import json
import os
from datetime import datetime, timedelta
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date, date_format, year, month, dayofmonth, weekofyear, dayofweek, from_json, explode
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, DateType, TimestampType
import mysql.connector
import pandas as pd
import pymongo
from sqlalchemy import create_engine, text
import certifi
import json
import os
from sqlalchemy.exc import OperationalError
import numpy as np
import findspark
findspark.init()
print(findspark.find())

import os
import sys
import json
import time
import pymongo
import certifi
import shutil
import time
import pandas as pd

from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window as W

Note: you may need to restart the kernel to use updated packages.
C:\spark-3.5.4-bin-hadoop3


In [2]:
# --------------------------------------------------------------------------------
# Specify Directory Structure for Source Data
# --------------------------------------------------------------------------------
base_dir = os.path.join(os.getcwd(), 'lab_data')
data_dir = os.path.join(base_dir, 'adventureworks')
batch_dir = os.path.join(data_dir, 'batch')
stream_dir = os.path.join(data_dir, 'streaming')

lakehouse_dir = os.path.join(base_dir, 'adventureworks_lakehouse')
os.makedirs(lakehouse_dir, exist_ok=True)

bronze_dir = os.path.join(lakehouse_dir, 'bronze')
silver_dir = os.path.join(lakehouse_dir, 'silver')
gold_dir = os.path.join(lakehouse_dir, 'gold')

for dir_path in [bronze_dir, silver_dir, gold_dir]:
    os.makedirs(dir_path, exist_ok=True)

In [3]:

mongodb_args = {
    "user_name": "banerjeeethan",
    "password": "VUHXrKOGS58xzQyE",
    "cluster_name": "Cluster0",
    "cluster_subnet": "38mdy",
    "cluster_location": "atlas",  # "local"
    "db_name": "adventureworks"
}

def get_mongo_client(**args):
    if args["cluster_location"] == "atlas":
        connect_str = f"mongodb+srv://{args['user_name']}:{args['password']}@{args['cluster_name']}.{args['cluster_subnet']}.mongodb.net"
        client = pymongo.MongoClient(connect_str, tlsCAFile=certifi.where())
    else:
        client = pymongo.MongoClient("mongodb://localhost:27017/")
    return client

# Function to fetch MongoDB data
def get_mongo_dataframe(mongo_client, db_name, collection, query={}):
    db = mongo_client[db_name]
    dframe = pd.DataFrame(list(db[collection].find(query)))
    if '_id' in dframe and '_id' in dframe.columns:
        dframe.drop(['_id'], axis=1, inplace=True)
    return dframe

In [4]:
def extract_from_mysql(query):
    connection = pymysql.connect(
        host='localhost',
        user='root',
        password='password',  
        database='adventureworks'
    )
    
    try:
        df = pd.read_sql(query, connection)
        return df
    finally:
        connection.close()
        
def get_file_info(path: str):
    file_sizes = []
    modification_times = []

    '''Fetch each item in the directory, and filter out any directories.'''
    items = os.listdir(path)
    files = sorted([item for item in items if os.path.isfile(os.path.join(path, item))])

    '''Populate lists with the Size and Last Modification DateTime for each file in the directory.'''
    for file in files:
        file_sizes.append(os.path.getsize(os.path.join(path, file)))
        modification_times.append(pd.to_datetime(os.path.getmtime(os.path.join(path, file)), unit='s'))

    data = list(zip(files, file_sizes, modification_times))
    column_names = ['name','size','modification_time']
    
    return pd.DataFrame(data=data, columns=column_names)

In [5]:
# ------------------------------------------------------------------------------
# Step 4: Initialize Spark Session
# ------------------------------------------------------------------------------
def get_spark_conf_args(spark_jars=None, **args):
    """Generate Spark configuration arguments."""
    import os
    from multiprocessing import cpu_count as get_cpu_count
    
    # Ensure spark_jars is a list
    spark_jars = spark_jars or []
    jars = ", ".join(spark_jars) if spark_jars else ""
    
    cpu_count = get_cpu_count() or 2
    
    sparkConf_args = {
        "app_name": "PySpark AdventureWorks Data Lakehouse",
        "shuffle_partitions": str(2 if cpu_count < 2 else cpu_count), 
        "mongo_uri": args.get(
            "mongo_uri",
            f"mongodb+srv://{args['user_name']}:{args['password']}@{args['cluster_name']}.{args['cluster_subnet']}.mongodb.net/{args['db_name']}"
        ),
        "spark_jars": jars,
        "database_dir": os.path.abspath('spark-warehouse')
    }
    return sparkConf_args


def get_spark_conf(**args):
    sparkConf = SparkConf().setAppName(args['app_name']) \
        .setMaster(args['worker_threads']) \
        .set('spark.driver.memory', '4g') \
        .set('spark.executor.memory', '4g') \
        .set('spark.sql.warehouse.dir', args['database_dir']) \
        .set('spark.sql.adaptive.enabled', 'false') \
        .set('spark.sql.debug.maxToStringFields', '35') \
        .set('spark.sql.shuffle.partitions', str(args['shuffle_partitions'])) \
        .set('spark.sql.streaming.forceDeleteTempCheckpointLocation', 'true') \
        .set('spark.sql.streaming.schemaInference', 'true') \
        .set('spark.streaming.stopGracefullyOnShutdown', 'true')
    
    if args['spark_jars']:
        sparkConf = sparkConf.set('spark.jars', args['spark_jars'])
        
    if 'mongo_uri' in args:
        sparkConf = sparkConf \
            .set('spark.jars.packages', 'org.mongodb.spark:mongo-spark-connector_2.12:3.0.1') \
            .set('spark.mongodb.input.uri', args['mongo_uri']) \
            .set('spark.mongodb.output.uri', args['mongo_uri'])
            
    return sparkConf

In [6]:
# ------------------------------------------------------------------------------
# Step 4: Upload Employee Data to MongoDB
# ------------------------------------------------------------------------------
def upload_employee_data_to_mongodb():
    """Extract employee data from MySQL and upload to MongoDB."""
    print("Extracting employee data from MySQL and uploading to MongoDB...")
    employee_query = """
    SELECT 
        e.EmployeeID,
        e.NationalIDNumber,
        e.LoginID,
        e.ManagerID,
        e.Title,
        e.BirthDate,
        e.MaritalStatus,
        e.Gender,
        e.HireDate,
        e.SalariedFlag,
        e.VacationHours,
        e.SickLeaveHours,
        edh.DepartmentID,
        d.Name as DepartmentName,
        edh.ShiftID,
        s.Name as ShiftName,
        edh.StartDate,
        edh.EndDate
    FROM employee e
    JOIN employeedepartmenthistory edh ON e.EmployeeID = edh.EmployeeID
    JOIN department d ON edh.DepartmentID = d.DepartmentID
    JOIN shift s ON edh.ShiftID = s.ShiftID
    """
    try:
        employee_df = extract_from_mysql(employee_query)
        print(f"Extracted {len(employee_df)} employee records with department and shift")
    except Exception as e:
        print(f"Error extracting employee data with department and shift: {str(e)}")
        employee_query = """
        SELECT 
            e.EmployeeID,
            e.NationalIDNumber,
            e.LoginID,
            e.ManagerID,
            e.Title,
            e.BirthDate,
            e.MaritalStatus,
            e.Gender,
            e.HireDate,
            e.SalariedFlag,
            e.VacationHours,
            e.SickLeaveHours,
            edh.DepartmentID,
            edh.ShiftID,
            edh.StartDate,
            edh.EndDate
        FROM employee e
        JOIN employeedepartmenthistory edh ON e.EmployeeID = edh.EmployeeID
        """
        employee_df = extract_from_mysql(employee_query)
        print(f"Extracted {len(employee_df)} employee records (without department/shift names)")
    
    # Upload to MongoDB
    
    for col in employee_df.select_dtypes(include=['datetime64[ns]']).columns:
        employee_df[col] = employee_df[col].dt.strftime('%Y-%m-%d')
    
    
    mongo_client = get_mongo_client(**mongodb_args)
    db_name = mongodb_args['db_name']
    collection_name = 'employees_from_mysql'
    
    db = mongo_client[db_name]
    collection = db[collection_name]
    collection.drop()
    
    employee_df = employee_df.where(pd.notnull(employee_df), None)

    
    records = employee_df.to_dict('records')
    
    # Insert into MongoDB
    collection.insert_many(records)
    print(f"Stored {len(records)} employee records in MongoDB collection: {collection_name}")
    mongo_client.close()
    
    return len(records)

In [7]:

# ------------------------------------------------------------------------------
# Step 5: Extract Data from MySQL
# ------------------------------------------------------------------------------
def extract_and_save_dimension_data():
    print("Extracting dimension data from MySQL and MongoDB...")
    
    # Extract Customer data
    customer_query = """
    SELECT 
        c.CustomerID,
        c.AccountNumber,
        c.CustomerType,
        a.AddressLine1,
        a.AddressLine2,
        a.City,
        sp.Name AS StateProvinceName,
        a.PostalCode,
        cr.Name AS CountryRegionName
    FROM customer c
    LEFT JOIN customeraddress ca ON c.CustomerID = ca.CustomerID
    LEFT JOIN address a ON ca.AddressID = a.AddressID
    LEFT JOIN stateprovince sp ON a.StateProvinceID = sp.StateProvinceID
    LEFT JOIN countryregion cr ON sp.CountryRegionCode = cr.CountryRegionCode
    """
    customer_df = extract_from_mysql(customer_query)
    print(f"Extracted {len(customer_df)} customer records")
    customer_df.to_csv(os.path.join(batch_dir, 'dim_customer.csv'), index=False)

    # Extract Product data
    product_query = """
    SELECT
        p.ProductID,
        p.Name AS ProductName,
        p.ProductNumber,
        p.Color,
        p.StandardCost,
        p.ListPrice,
        p.Size,
        p.Weight,
        p.ProductModelID,
        pm.Name AS ProductModelName,
        pc.Name AS ProductCategoryName,
        psc.Name AS ProductSubcategoryName
    FROM Product p
    LEFT JOIN ProductModel pm ON p.ProductModelID = pm.ProductModelID
    LEFT JOIN ProductSubcategory psc ON p.ProductSubcategoryID = psc.ProductSubcategoryID
    LEFT JOIN ProductCategory pc ON psc.ProductCategoryID = pc.ProductCategoryID
    WHERE p.FinishedGoodsFlag = 1
    """
    product_df = extract_from_mysql(product_query)
    print(f"Extracted {len(product_df)} product records")
    product_df.to_csv(os.path.join(batch_dir, 'dim_product.csv'), index=False)
    
    mongo_client = get_mongo_client(**mongodb_args)
    db_name = mongodb_args['db_name']
    collection_name = 'employees_from_mysql'
    
    db = mongo_client[db_name]
    collection = db[collection_name]
    
    employee_cursor = collection.find({})
    employee_df = pd.DataFrame(list(employee_cursor))
    
    if '_id' in employee_df.columns:
        employee_df = employee_df.drop('_id', axis=1)
        
    print(f"Extracted {len(employee_df)} employee records from MongoDB")
    employee_df.to_csv(os.path.join(batch_dir, 'dim_employee.csv'), index=False)
    mongo_client.close()
    
    dim_date_query = """
    SELECT 
        date_key,
        full_date,
        day_name_of_week,
        month_name,
        calendar_year,
        calendar_quarter
    FROM dim_date
    """
    dim_date_df = extract_from_mysql(dim_date_query)
    print(f"Extracted {len(dim_date_df)} date dimension records")
    dim_date_df.to_csv(os.path.join(batch_dir, 'dim_date.csv'), index=False)

In [8]:
def setup_lakehouse():
    """Setup the Spark session and database for the data lakehouse."""
    import os
    import shutil  
    from multiprocessing import cpu_count as get_cpu_count
    from pyspark.sql import SparkSession
    
    cpu_count = get_cpu_count() or 2
    jars = []
    
    mongodb_args = {
        "user_name": "mongodb_user",
        "password": "mongodb_password",
        "cluster_name": "cluster0",
        "cluster_subnet": "abc123",
        "db_name": "adventureworks"
    }

    mysql_jar_path = []
    mysql_spark_jar = os.path.join(os.getcwd(), "mysql-connector-j-9.1.0", "mysql-connector-j-9.1.0.jar")
    mssql_spark_jar = os.path.join(os.getcwd(), "sqljdbc_12.8", "enu", "jars", "mssql-jdbc-12.8.1.jre11.jar")

    mysql_jar_path.append(mysql_spark_jar)
    #jars.append(mssql_spark_jar)

    sparkConf_args = get_spark_conf_args(jars, **mongodb_args)
    
    # Create Spark session
    spark = SparkSession.builder \
        .appName(sparkConf_args["app_name"]) \
        .master(f"local[{str(2 if cpu_count < 2 else cpu_count)}]") \
        .config("spark.sql.shuffle.partitions", sparkConf_args["shuffle_partitions"]) \
        .config("spark.jars", sparkConf_args["spark_jars"]) \
        .getOrCreate()
    
    dest_database = "adventure_works_dw"
    
    print(f"Checking if database {dest_database} exists...")
    
    databases = [db.name for db in spark.catalog.listDatabases()]
    
    warehouse_dir = "C:/Users/Baner/Downloads/spark-warehouse"
    db_location = os.path.join(warehouse_dir, f"{dest_database}.db")
    
    if dest_database in databases:
        print(f"Database {dest_database} already exists. Dropping it from Spark...")
        spark.sql(f"DROP DATABASE IF EXISTS {dest_database} CASCADE")
    
    if os.path.exists(db_location):
        try:
            print(f"Removing directory: {db_location}")
            shutil.rmtree(db_location)
            print(f"Successfully removed {db_location}")
        except Exception as e:
            print(f"Warning: Could not remove directory {db_location}: {str(e)}")
            print("You may need to manually delete this directory.")
    
    sql_create_db = f"""
        CREATE DATABASE IF NOT EXISTS {dest_database}
        COMMENT 'DS-2002 Capstone'
    """
    spark.sql(sql_create_db)
    spark.sql(f"USE {dest_database}")
    
    get_file_info(batch_dir) 
    
    # Populate the Customer Dimension
    customer_csv = os.path.join(batch_dir, 'dim_customer.csv')
    print(f"Loading customer data from: {customer_csv}")
    
    df_dim_customer = spark.read.format('csv').options(header='true', inferSchema='true').load(customer_csv)
    print(f"Loaded {df_dim_customer.count()} customer records")
    
    df_dim_customer.createOrReplaceTempView("customers")
    sql_customers = """
        SELECT *, ROW_NUMBER() OVER (ORDER BY CustomerID) AS customer_key
        FROM customers;
    """
    df_dim_customer = spark.sql(sql_customers)
    
    df_dim_customer.write.saveAsTable(f"{dest_database}.dim_customer", mode="overwrite")
    print(f"Saved customer data to {dest_database}.dim_customer")

    # Populate the Product Dimension
    product_csv = os.path.join(batch_dir, 'dim_product.csv')
    print(f"Loading product data from: {product_csv}")
    
    df_dim_product = spark.read.format('csv').options(header='true', inferSchema='true').load(product_csv)
    print(f"Loaded {df_dim_product.count()} product records")
    
    df_dim_product.createOrReplaceTempView("products")
    sql_products = """
        SELECT *, ROW_NUMBER() OVER (ORDER BY ProductID) AS product_key
        FROM products;
    """
    df_dim_product = spark.sql(sql_products)
    
    df_dim_product.write.saveAsTable(f"{dest_database}.dim_product", mode="overwrite")
    print(f"Saved product data to {dest_database}.dim_product")
    
    # Populate the Employee Dimension
    employee_csv = os.path.join(batch_dir, 'dim_employee.csv')
    print(f"Loading employee data from: {employee_csv}")
    
    df_dim_employee = spark.read.format('csv').options(header='true', inferSchema='true').load(employee_csv)
    print(f"Loaded {df_dim_employee.count()} employee records")
    
    df_dim_employee.createOrReplaceTempView("employees")
    sql_employees = """
        SELECT *, ROW_NUMBER() OVER (ORDER BY EmployeeID) AS employee_key
        FROM employees;
    """
    df_dim_employee = spark.sql(sql_employees)
    
    df_dim_employee.write.saveAsTable(f"{dest_database}.dim_employee", mode="overwrite")
    print(f"Saved employee data to {dest_database}.dim_employee")
    
    # Populate the Date Dimension
    date_csv = os.path.join(batch_dir, 'dim_date.csv')
    print(f"Loading date data from: {date_csv}")
    
    df_dim_date = spark.read.format('csv').options(header='true', inferSchema='true').load(date_csv)
    print(f"Loaded {df_dim_date.count()} date records")
    
    df_dim_date.write.saveAsTable(f"{dest_database}.dim_date", mode="overwrite")
    print(f"Saved date data to {dest_database}.dim_date")
    
    # Populate the Territory Dimension 
    territory_csv = os.path.join(batch_dir, 'dim_territory.csv')
    if os.path.exists(territory_csv):
        print(f"Loading territory data from: {territory_csv}")
        
        df_dim_territory = spark.read.format('csv').options(header='true', inferSchema='true').load(territory_csv)
        print(f"Loaded {df_dim_territory.count()} territory records")
        
        df_dim_territory.createOrReplaceTempView("territories")
        sql_territories = """
            SELECT *, ROW_NUMBER() OVER (ORDER BY TerritoryID) AS territory_key
            FROM territories;
        """
        df_dim_territory = spark.sql(sql_territories)
        
        df_dim_territory.write.saveAsTable(f"{dest_database}.dim_territory", mode="overwrite")
        print(f"Saved territory data to {dest_database}.dim_territory")
    
    print("\nSample data from dimensions:")
    
    try:
        print("\nCustomer dimension:")
        spark.sql(f"SELECT * FROM {dest_database}.dim_customer LIMIT 2").show()
    except Exception as e:
        print(f"Error displaying customer dimension: {str(e)}")
    
    try:
        print("\nProduct dimension:")
        spark.sql(f"SELECT * FROM {dest_database}.dim_product LIMIT 2").show()
    except Exception as e:
        print(f"Error displaying product dimension: {str(e)}")
    
    try:
        print("\nEmployee dimension:")
        spark.sql(f"SELECT * FROM {dest_database}.dim_employee LIMIT 2").show()
    except Exception as e:
        print(f"Error displaying employee dimension: {str(e)}")
    
    try:
        print("\nDate dimension:")
        spark.sql(f"SELECT * FROM {dest_database}.dim_date LIMIT 2").show()
    except Exception as e:
        print(f"Error displaying date dimension: {str(e)}")
    
    try:
        print("\nTerritory dimension:")
        spark.sql(f"SELECT * FROM {dest_database}.dim_territory LIMIT 2").show()
    except Exception as e:
        print(f"Error displaying territory dimension: {str(e)}")
    
    return spark, dest_database

In [9]:
def load_bronze_layer(spark, dest_database, bronze_dir, batch_dir, stream_dir):
    import os
    import time
    import traceback
    from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType

    print("Loading Bronze layer...")

    spark.conf.set("spark.sql.shuffle.partitions", "10")

    if not dest_database:
        raise ValueError("Destination database name is required")

    bronze_paths = {
        "customers": os.path.join(bronze_dir, "dim_customer"),
        "products": os.path.join(bronze_dir, "dim_product"),
        "employees": os.path.join(bronze_dir, "dim_employee"),
        "dates": os.path.join(bronze_dir, "dim_date"),
        "territories": os.path.join(bronze_dir, "dim_territory"),
        "sales": os.path.join(bronze_dir, "fact_sales"),
    }

    for path in bronze_paths.values():
        os.makedirs(path, exist_ok=True)

    print(f"Creating database if not exists: {dest_database}")
    spark.sql(f"CREATE DATABASE IF NOT EXISTS {dest_database}")

    dim_tables = [
        ("customers", 'dim_customer.csv'),
        ("products", 'dim_product.csv'),
        ("employees", 'dim_employee.csv'),
        ("dates", 'dim_date.csv'),
        ("territories", 'dim_territory.csv'),
    ]

    for name, filename in dim_tables:
        source_path = os.path.join(batch_dir, filename)
        dest_path = bronze_paths[name]

        if os.path.exists(source_path):
            for attempt in range(3):
                try:
                    df = spark.read.csv(source_path, header=True, inferSchema=True)
                    df.write.mode("overwrite").parquet(dest_path)
                    print(f"Loaded {df.count()} records to Bronze: {name}")
                    break
                except Exception as e:
                    print(f"Error loading {name}, attempt {attempt+1}/3: {e}")
                    time.sleep(2)
            else:
                print(f"Failed to load {name} after 3 attempts")
        else:
            print(f"Warning: Source file not found: {source_path}")

    fact_schema = StructType([
        StructField("SalesOrderID", IntegerType(), True),
        StructField("SalesOrderDetailID", IntegerType(), True),
        StructField("CustomerID", IntegerType(), True),
        StructField("ProductID", IntegerType(), True),
        StructField("EmployeeID", IntegerType(), True),
        StructField("TerritoryID", IntegerType(), True),
        StructField("OrderDate", StringType(), True),
        StructField("OrderQty", IntegerType(), True),
        StructField("UnitPrice", FloatType(), True),
        StructField("LineTotal", FloatType(), True),
    ])

    try:
        streaming_source = os.path.join(stream_dir, "fact_sales")
        checkpoint = os.path.join(bronze_paths["sales"], "_checkpoint")
        os.makedirs(checkpoint, exist_ok=True)

        stream_df = spark.readStream.schema(fact_schema).option("maxFilesPerTrigger", 1).json(streaming_source)
        query = stream_df.writeStream.format("parquet").option("path", bronze_paths["sales"]).option("checkpointLocation", checkpoint).trigger(once=True).start()
        query.awaitTermination()
        print("Completed streaming to Bronze layer")
    except Exception as e:
        print(f"Streaming error: {e}")
        traceback.print_exc()

    # Register external tables
    print("Registering bronze tables...")
    spark.sql(f"USE {dest_database}")

    for name, path in bronze_paths.items():
        table_name = f"bronze_{name}"

        try:
            files = [f for f in os.listdir(path) if f.endswith(".parquet") or f.endswith(".snappy.parquet")]
            if not files:
                print(f"Skipping {table_name}: no Parquet files found in {path}")
                continue

            df = spark.read.parquet(path)
            df.limit(1).collect() 

            spark.sql(f"DROP TABLE IF EXISTS {dest_database}.{table_name}")

            spark.sql(f"""
                CREATE TABLE {dest_database}.{table_name}
                USING PARQUET
                LOCATION '{path.replace("\\", "/")}'
            """)
            print(f"Registered external table: {table_name}")

            time.sleep(1)

        except Exception as e:
            print(f"Error registering table {table_name}: {e}")
            traceback.print_exc()

    print("Bronze layer processing completed.")



In [20]:
def load_silver_layer(spark, silver_dir, silver_database, bronze_database):
    import os
    import time

    print("Step 5: Transforming and loading Silver layer...")

    if not silver_dir:
        raise ValueError("silver_dir must be provided and cannot be empty")

    if not silver_database:
        raise ValueError("silver_database must be provided and cannot be empty")

    if not bronze_database:
        raise ValueError("bronze_database must be provided and cannot be empty")

    # Define Silver paths
    silver_paths = {
        "customers": os.path.join(silver_dir, "dim_customer"),
        "products": os.path.join(silver_dir, "dim_product"),
        "employees": os.path.join(silver_dir, "dim_employee"),
        "dates": os.path.join(silver_dir, "dim_date"),
        "territories": os.path.join(silver_dir, "dim_territory"),
        "sales": os.path.join(silver_dir, "fact_sales"),
    }
    
    for path in silver_paths.values():
        if os.path.exists(path):
            shutil.rmtree(path)
        os.makedirs(path, exist_ok=True)


    for path in silver_paths.values():
        os.makedirs(path, exist_ok=True)

    # Create Silver database if not exists
    spark.sql(f"CREATE DATABASE IF NOT EXISTS {silver_database}")
    spark.sql(f"USE {silver_database}")

    def write_silver_table(view_name, table_name, path):
        spark.sql(f"DROP TABLE IF EXISTS {silver_database}.{table_name}")
        spark.sql(f"""
            CREATE TABLE {silver_database}.{table_name}
            USING parquet
            LOCATION '{path.replace("\\", "/")}'
            AS SELECT * FROM {view_name}
        """)
        print(f" Created silver table: {silver_database}.{table_name}")

    # ---------------------------
    # Transform Customer Dimension
    # ---------------------------
    print("Transforming customer dimension...")
    spark.sql(f"""
        CREATE OR REPLACE TEMPORARY VIEW silver_customers AS
        SELECT
            CustomerID,
            AccountNumber,
            CustomerType,
            COALESCE(AddressLine1, 'Unknown') AS AddressLine1,
            COALESCE(AddressLine2, '') AS AddressLine2,
            COALESCE(City, 'Unknown') AS City,
            COALESCE(StateProvinceName, 'Unknown') AS StateProvinceName,
            COALESCE(PostalCode, 'Unknown') AS PostalCode,
            COALESCE(CountryRegionName, 'Unknown') AS CountryRegionName,
            current_timestamp() AS LastUpdated
        FROM {bronze_database}.bronze_customers
    """)
    write_silver_table("silver_customers", "silver_customers", silver_paths["customers"])

    # ---------------------------
    # Transform Product Dimension
    # ---------------------------
    print("Transforming product dimension...")
    spark.sql(f"""
        CREATE OR REPLACE TEMPORARY VIEW silver_products AS
        SELECT
            ProductID,
            COALESCE(ProductName, 'Unknown') AS ProductName,
            ProductNumber,
            COALESCE(Color, 'N/A') AS Color,
            StandardCost,
            ListPrice,
            COALESCE(Size, 'N/A') AS Size,
            Weight,
            ProductModelID,
            COALESCE(ProductModelName, 'Unknown') AS ProductModelName,
            COALESCE(ProductCategoryName, 'Unknown') AS ProductCategoryName,
            COALESCE(ProductSubcategoryName, 'Unknown') AS ProductSubcategoryName,
            current_timestamp() AS LastUpdated
        FROM {bronze_database}.bronze_products
    """)
    write_silver_table("silver_products", "silver_products", silver_paths["products"])

    # ---------------------------
    # Transform Employee Dimension
    # ---------------------------
    print("Transforming employee dimension...")
    spark.sql(f"""
        CREATE OR REPLACE TEMPORARY VIEW silver_employees AS
        SELECT
            EmployeeID,
            NationalIDNumber,
            LoginID,
            ManagerID,
            COALESCE(Title, 'Unknown') AS Title,
            BirthDate,
            MaritalStatus,
            Gender,
            HireDate,
            SalariedFlag,
            VacationHours,
            SickLeaveHours,
            DepartmentID,
            ShiftID,
            StartDate,
            EndDate,
            current_timestamp() AS LastUpdated
        FROM {bronze_database}.bronze_employees
    """)
    write_silver_table("silver_employees", "silver_employees", silver_paths["employees"])

    # ---------------------------
    # Transform Date Dimension
    # ---------------------------
    print("Transforming date dimension...")
    spark.sql(f"""
        CREATE OR REPLACE TEMPORARY VIEW silver_dates AS
        SELECT
            date_key,
            full_date,
            day_name_of_week,
            month_name,
            calendar_year,
            calendar_quarter,
            current_timestamp() AS LastUpdated
        FROM {bronze_database}.bronze_dates
    """)
    write_silver_table("silver_dates", "silver_dates", silver_paths["dates"])

    # ---------------------------
    # Transform Territory Dimension
    # ---------------------------
    print("Transforming territory dimension...")
    spark.sql(f"""
        CREATE OR REPLACE TEMPORARY VIEW silver_territories AS
        SELECT
            TerritoryID,
            COALESCE(TerritoryName, 'Unknown') AS TerritoryName,
            COALESCE(CountryRegionCode, 'Unknown') AS CountryRegionCode,
            COALESCE(TerritoryGroup, 'Unknown') AS TerritoryGroup,
            current_timestamp() AS LastUpdated
        FROM {bronze_database}.bronze_territories
    """)
    write_silver_table("silver_territories", "silver_territories", silver_paths["territories"])

    # ---------------------------
    # Transform Sales Fact Table
    # ---------------------------
    print("Transforming sales fact data...")
    spark.sql(f"""
        CREATE OR REPLACE TEMPORARY VIEW silver_sales AS
        SELECT
            SalesOrderID,
            SalesOrderDetailID,
            CustomerID,
            ProductID,
            COALESCE(EmployeeID, -1) AS EmployeeID,
            COALESCE(TerritoryID, -1) AS TerritoryID,
            to_date(OrderDate) AS OrderDate,
            year(to_date(OrderDate)) AS OrderYear,
            month(to_date(OrderDate)) AS OrderMonth,
            dayofmonth(to_date(OrderDate)) AS OrderDay,
            quarter(to_date(OrderDate)) AS OrderQuarter,
            OrderQty,
            UnitPrice,
            LineTotal,
            current_timestamp() AS LastUpdated
        FROM {bronze_database}.bronze_sales
    """)
    write_silver_table("silver_sales", "silver_sales", silver_paths["sales"])

    print("Silver layer transformation complete.")


In [21]:
def load_gold_layer(spark, gold_dir, dest_database):
    import os

    print("Loading Gold layer...")

    if not gold_dir:
        raise ValueError("gold_dir must be provided and cannot be empty")

    if not dest_database:
        raise ValueError("dest_database must be provided and cannot be empty")

    # Define Gold paths
    gold_paths = {
        "sales_by_product": os.path.join(gold_dir, "sales_by_product"),
        "sales_by_territory": os.path.join(gold_dir, "sales_by_territory"),
        "sales_by_customer": os.path.join(gold_dir, "sales_by_customer"),
        "sales_by_date": os.path.join(gold_dir, "sales_by_date"),
        "top_products": os.path.join(gold_dir, "top_products"),
    }

    # Create directories
    for path in gold_paths.values():
        os.makedirs(path, exist_ok=True)

    # Use the destination database
    spark.sql(f"CREATE DATABASE IF NOT EXISTS {dest_database}")
    spark.sql(f"USE {dest_database}")

    def write_gold_table(view_name, table_name, path):
        spark.sql(f"DROP TABLE IF EXISTS {dest_database}.{table_name}")
        spark.sql(f"""
            CREATE TABLE {dest_database}.{table_name}
            USING parquet
            LOCATION '{path.replace("\\", "/")}'
            AS SELECT * FROM {view_name}
        """)
        print(f" Created gold table: {dest_database}.{table_name}")

    # ------------------------------
    # Sales by Product
    # ------------------------------
    print("Creating sales by product analytics view...")
    spark.sql("""
        CREATE OR REPLACE TEMPORARY VIEW gold_sales_by_product AS
        SELECT
            p.ProductID,
            p.ProductName,
            p.ProductCategoryName,
            p.ProductSubcategoryName,
            SUM(s.OrderQty) AS TotalQuantity,
            SUM(s.LineTotal) AS TotalSales,
            COUNT(DISTINCT s.SalesOrderID) AS NumberOfOrders,
            AVG(s.UnitPrice) AS AverageUnitPrice
        FROM silver_sales s
        JOIN silver_products p ON s.ProductID = p.ProductID
        GROUP BY p.ProductID, p.ProductName, p.ProductCategoryName, p.ProductSubcategoryName
    """)
    write_gold_table("gold_sales_by_product", "gold_sales_by_product", gold_paths["sales_by_product"])

    # ------------------------------
    # Sales by Territory
    # ------------------------------
    print("Creating sales by territory analytics view...")
    spark.sql("""
        CREATE OR REPLACE TEMPORARY VIEW gold_sales_by_territory AS
        SELECT
            t.TerritoryID,
            t.TerritoryName,
            t.CountryRegionCode,
            t.TerritoryGroup,
            SUM(s.LineTotal) AS TotalSales,
            COUNT(DISTINCT s.SalesOrderID) AS NumberOfOrders,
            COUNT(DISTINCT s.CustomerID) AS NumberOfCustomers,
            AVG(s.LineTotal) AS AverageOrderValue
        FROM silver_sales s
        JOIN silver_territories t ON s.TerritoryID = t.TerritoryID
        GROUP BY t.TerritoryID, t.TerritoryName, t.CountryRegionCode, t.TerritoryGroup
    """)
    write_gold_table("gold_sales_by_territory", "gold_sales_by_territory", gold_paths["sales_by_territory"])

    # ------------------------------
    # Sales by Customer
    # ------------------------------
    print("Creating sales by customer analytics view...")
    spark.sql("""
        CREATE OR REPLACE TEMPORARY VIEW gold_sales_by_customer AS
        SELECT
            c.CustomerID,
            c.AccountNumber,
            c.CustomerType,
            c.CountryRegionName,
            c.StateProvinceName,
            c.City,
            SUM(s.LineTotal) AS TotalSpend,
            COUNT(DISTINCT s.SalesOrderID) AS NumberOfOrders,
            COUNT(DISTINCT s.ProductID) AS NumberOfUniqueProducts,
            MIN(s.OrderDate) AS FirstPurchaseDate,
            MAX(s.OrderDate) AS LastPurchaseDate,
            AVG(s.LineTotal) AS AverageOrderValue
        FROM silver_sales s
        JOIN silver_customers c ON s.CustomerID = c.CustomerID
        GROUP BY c.CustomerID, c.AccountNumber, c.CustomerType, c.CountryRegionName, c.StateProvinceName, c.City
    """)
    write_gold_table("gold_sales_by_customer", "gold_sales_by_customer", gold_paths["sales_by_customer"])

    # ------------------------------
    # Sales by Date
    # ------------------------------
    print("Creating sales by date analytics view...")
    spark.sql("""
        CREATE OR REPLACE TEMPORARY VIEW gold_sales_by_date AS
        SELECT
            s.OrderDate,
            s.OrderYear,
            s.OrderMonth,
            s.OrderQuarter,
            SUM(s.LineTotal) AS TotalSales,
            COUNT(DISTINCT s.SalesOrderID) AS NumberOfOrders,
            COUNT(DISTINCT s.CustomerID) AS NumberOfCustomers,
            AVG(s.LineTotal) AS AverageOrderValue,
            COUNT(DISTINCT s.ProductID) AS NumberOfUniqueProducts
        FROM silver_sales s
        GROUP BY s.OrderDate, s.OrderYear, s.OrderMonth, s.OrderQuarter
    """)
    write_gold_table("gold_sales_by_date", "gold_sales_by_date", gold_paths["sales_by_date"])

    # ------------------------------
    # Top Products
    # ------------------------------
    print("Creating top products analytics view...")
    spark.sql("""
        CREATE OR REPLACE TEMPORARY VIEW gold_top_products AS
        SELECT
            p.ProductID,
            p.ProductName,
            p.ProductCategoryName,
            p.ProductSubcategoryName,
            SUM(s.LineTotal) AS TotalSales,
            SUM(s.OrderQty) AS TotalQuantity,
            ROW_NUMBER() OVER (ORDER BY SUM(s.LineTotal) DESC) AS SalesRank,
            ROW_NUMBER() OVER (ORDER BY SUM(s.OrderQty) DESC) AS QuantityRank
        FROM silver_sales s
        JOIN silver_products p ON s.ProductID = p.ProductID
        GROUP BY p.ProductID, p.ProductName, p.ProductCategoryName, p.ProductSubcategoryName
    """)
    write_gold_table("gold_top_products", "gold_top_products", gold_paths["top_products"])

    print(" Gold layer analytics tables created successfully.")


In [22]:

# ------------------------------------------------------------------------------
# Step 10: Query the Gold Layer
# ------------------------------------------------------------------------------
def query_gold_layer(spark, dest_database):
    print("\nRunning sample analytics queries on Gold layer...\n")
    
    # Use the database
    spark.sql(f"USE {dest_database}")
    
    # Sample query 1: Top 5 products by total sales
    print("Top 5 products by total sales:")
    result = spark.sql("""
    SELECT ProductName, ProductCategoryName, ProductSubcategoryName, 
           TotalSales, SalesRank
    FROM gold_top_products
    WHERE SalesRank <= 5
    ORDER BY SalesRank
    """).toPandas()
    print(result)
    print("\n")
    
    # Sample query 2: Sales by territory group
    print("Sales by territory group:")
    result = spark.sql("""
    SELECT TerritoryGroup, 
           SUM(TotalSales) AS GroupSales, 
           SUM(NumberOfOrders) AS GroupOrders,
           AVG(AverageOrderValue) AS GroupAvgOrderValue
    FROM gold_sales_by_territory
    GROUP BY TerritoryGroup
    ORDER BY GroupSales DESC
    """).toPandas()
    print(result)
    print("\n")
    
    # Sample query 3: Monthly sales trend
    print("Monthly sales trend:")
    result = spark.sql("""
    SELECT OrderYear, OrderMonth, 
           SUM(TotalSales) AS MonthlySales,
           SUM(NumberOfOrders) AS MonthlyOrders
    FROM gold_sales_by_date
    GROUP BY OrderYear, OrderMonth
    ORDER BY OrderYear, OrderMonth
    """).toPandas()
    print(result)
    print("\n")
    
    # Sample query 4: Top customer types by spending
    print("Customer type analysis:")
    result = spark.sql("""
    SELECT CustomerType, 
           COUNT(DISTINCT CustomerID) AS CustomerCount,
           SUM(TotalSpend) AS TotalSpend,
           AVG(TotalSpend) AS AvgCustomerSpend,
           AVG(NumberOfOrders) AS AvgOrdersPerCustomer
    FROM gold_sales_by_customer
    GROUP BY CustomerType
    ORDER BY TotalSpend DESC
    """).toPandas()
    print(result)
    print("\n")
    
    # Sample query 5: Product category performance
    print("Product category performance:")
    result = spark.sql("""
    SELECT ProductCategoryName, 
           SUM(TotalSales) AS CategorySales,
           SUM(TotalQuantity) AS CategoryQuantity,
           COUNT(DISTINCT ProductID) AS NumberOfProducts,
           AVG(AverageUnitPrice) AS AveragePriceInCategory
    FROM gold_sales_by_product
    GROUP BY ProductCategoryName
    ORDER BY CategorySales DESC
    """).toPandas()
    print(result)
    print("\n")

# ------------------------------------------------------------------------------
# Main Execution
# ------------------------------------------------------------------------------
def main():
    base_dir = os.path.join(os.getcwd(), 'lab_data')
    data_dir = os.path.join(base_dir, 'adventureworks')
    batch_dir = os.path.join(data_dir, 'batch')
    stream_dir = os.path.join(data_dir, 'streaming')

    # Set up the output directories for our lakehouse
    lakehouse_dir = os.path.join(base_dir, 'adventureworks_lakehouse')
    os.makedirs(lakehouse_dir, exist_ok=True)

    # Set up paths for bronze, silver, gold layers
    bronze_dir = os.path.join(lakehouse_dir, 'bronze')
    silver_dir = os.path.join(lakehouse_dir, 'silver')
    gold_dir = os.path.join(lakehouse_dir, 'gold')
    print("=" * 80)
    print("AdventureWorks Data Lakehouse Implementation")
    print("=" * 80)
    
    # Step 1: Create directory structure
    print("\nStep 1: Setting up directory structure...")
    os.makedirs(batch_dir, exist_ok=True)
    os.makedirs(stream_dir, exist_ok=True)
    
    upload_employee_data_to_mongodb()
    
    # Step 2: Extract data from MySQL and store in CSV and JSON files
    print("\nStep 2: Extracting data from source systems...")
    extract_and_save_dimension_data()
    
    # Step 3: Setup the Data Lakehouse
    print("\nStep 3: Setting up the Data Lakehouse...")
    spark, dest_database = setup_lakehouse()
    
    #Step 4: Populate Warehouse
    print("\nStep 3: filling up the Data Lakehouse...")
    
    
    try:
        # Step 4: Load data into Bronze layer
        print("\nStep 4: Loading Bronze layer...")
        load_bronze_layer(spark, "adventureworks_bronze", bronze_dir, batch_dir, stream_dir)

        # Step 5: Transform data into Silver layer
        print("\nStep 5: Transforming and loading Silver layer...")
        load_silver_layer(spark, silver_dir, silver_database="adventureworks_silver", bronze_database="adventureworks_bronze")

        
        # Step 6: Create Gold layer analytics views
        print("\nStep 6: Creating Gold layer analytics views...")
        load_gold_layer(spark, gold_dir, dest_database)
        
        # Step 7: Run sample queries on Gold layer
        print("\nStep 7: Running sample analytics queries...")
        query_gold_layer(spark, dest_database)
        
        print("\nData Lakehouse implementation completed successfully!")
        
    except Exception as e:
        print(f"Error during execution: {str(e)}")
    finally:
        spark.stop()
        print("\nSpark session stopped")

if __name__ == "__main__":
    main()

AdventureWorks Data Lakehouse Implementation

Step 1: Setting up directory structure...
Extracting employee data from MySQL and uploading to MongoDB...
Extracted 296 employee records with department and shift


  df = pd.read_sql(query, connection)


Stored 296 employee records in MongoDB collection: employees_from_mysql

Step 2: Extracting data from source systems...
Extracting dimension data from MySQL and MongoDB...


  df = pd.read_sql(query, connection)


Extracted 19220 customer records
Extracted 295 product records
Extracted 296 employee records from MongoDB
Extracted 4018 date dimension records

Step 3: Setting up the Data Lakehouse...


  df = pd.read_sql(query, connection)


Checking if database adventure_works_dw exists...
Removing directory: C:/Users/Baner/Downloads/spark-warehouse\adventure_works_dw.db
Successfully removed C:/Users/Baner/Downloads/spark-warehouse\adventure_works_dw.db
Loading customer data from: C:\Users\Baner\Downloads\lab_data\adventureworks\batch\dim_customer.csv
Loaded 19220 customer records
Saved customer data to adventure_works_dw.dim_customer
Loading product data from: C:\Users\Baner\Downloads\lab_data\adventureworks\batch\dim_product.csv
Loaded 295 product records
Saved product data to adventure_works_dw.dim_product
Loading employee data from: C:\Users\Baner\Downloads\lab_data\adventureworks\batch\dim_employee.csv
Loaded 296 employee records
Saved employee data to adventure_works_dw.dim_employee
Loading date data from: C:\Users\Baner\Downloads\lab_data\adventureworks\batch\dim_date.csv
Loaded 4018 date records
Saved date data to adventure_works_dw.dim_date
Loading territory data from: C:\Users\Baner\Downloads\lab_data\adventurew