In [1]:
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, count, avg, round, countDistinct
from pyspark.sql.types import IntegerType, FloatType, TimestampType

In [2]:
# Task 1: Establish PySpark Connection
def init_spark():
    """
    Initialize a Spark session for PySpark processing.
    Output: SparkSession object
    """
    try:
        spark = SparkSession.builder.appName("DMartSalesAnalysis").getOrCreate()
        print("Spark session initialized successfully.")
        return spark
    except Exception as e:
        print(f"Error initializing Spark session: {str(e)}")
        return None

In [3]:
# Task 2: Load Data into PySpark DataFrames
def load_data(spark, products_file, sales_file, customers_file):
    """
    Load CSV files into PySpark DataFrames.
    Input: SparkSession, paths to products.csv, sales.csv, customers.csv
    Output: Tuple of DataFrames (products_df, sales_df, customers_df)
    """
    try:
        products_df = spark.read.option("header", "true").csv(products_file)
        print(f"Loaded products.csv with {products_df.count()} rows.")
        print("Products DataFrame schema:")
        products_df.printSchema()
        sales_df = spark.read.option("header", "true").csv(sales_file)
        print(f"Loaded sales.csv with {sales_df.count()} rows.")
        print("Sales DataFrame schema:")
        sales_df.printSchema()
        customers_df = spark.read.option("header", "true").csv(customers_file)
        print(f"Loaded customers.csv with {customers_df.count()} rows.")
        print("Customers DataFrame schema:")
        customers_df.printSchema()
        return products_df, sales_df, customers_df
    except Exception as e:
        print(f"Error loading data: {str(e)}")
        return None, None, None

In [4]:
def transform_data(products_df, sales_df, customers_df):
    """
    Clean and transform DataFrames, then join them.
    Input: Products, sales, and customers DataFrames
    Output: Integrated DataFrame after cleaning and joining
    """
    try:
        # Rename columns for consistency
        products_df = products_df.withColumnRenamed("Product ID", "product_id") \
                                 .withColumnRenamed("Category", "category") \
                                 .withColumnRenamed("Sub-Category", "sub_category") \
                                 .withColumnRenamed("Product Name", "product_name")
        
        sales_df = sales_df.withColumnRenamed("Order Line", "order_line") \
                           .withColumnRenamed("Order ID", "order_id") \
                           .withColumnRenamed("Order Date", "order_date") \
                           .withColumnRenamed("Ship Date", "ship_date") \
                           .withColumnRenamed("Ship Mode", "ship_mode") \
                           .withColumnRenamed("Customer ID", "customer_id") \
                           .withColumnRenamed("Product ID", "product_id") \
                           .withColumnRenamed("Sales", "sales") \
                           .withColumnRenamed("Quantity", "quantity") \
                           .withColumnRenamed("Discount", "discount") \
                           .withColumnRenamed("Profit", "profit")
        
        customers_df = customers_df.withColumnRenamed("Customer ID", "customer_id") \
                                   .withColumnRenamed("Customer Name", "customer_name") \
                                   .withColumnRenamed("Segment", "customer_segment") \
                                   .withColumnRenamed("Age", "age") \
                                   .withColumnRenamed("Country", "country") \
                                   .withColumnRenamed("City", "city") \
                                   .withColumnRenamed("State", "state") \
                                   .withColumnRenamed("Postal Code", "postal_code") \
                                   .withColumnRenamed("Region", "region")

        # Handle missing values
        products_df = products_df.na.fill({"product_name": "Unknown"})
        sales_df = sales_df.na.fill({"quantity": 0, "discount": 0.0, "sales": 0.0, "profit": 0.0})
        customers_df = customers_df.na.fill({"age": 0, "customer_name": "Unknown"})

        # Ensure correct data types
        products_df = products_df.withColumn("product_id", col("product_id").cast(IntegerType()))
        sales_df = sales_df.withColumn("order_line", col("order_line").cast(IntegerType())) \
                           .withColumn("product_id", col("product_id").cast(IntegerType())) \
                           .withColumn("customer_id", col("customer_id").cast(IntegerType())) \
                           .withColumn("quantity", col("quantity").cast(IntegerType())) \
                           .withColumn("discount", col("discount").cast(FloatType())) \
                           .withColumn("sales", col("sales").cast(FloatType())) \
                           .withColumn("profit", col("profit").cast(FloatType())) \
                           .withColumn("order_date", col("order_date").cast(TimestampType())) \
                           .withColumn("ship_date", col("ship_date").cast(TimestampType()))
        customers_df = customers_df.withColumn("customer_id", col("customer_id").cast(IntegerType())) \
                                   .withColumn("age", col("age").cast(IntegerType())) \
                                   .withColumn("postal_code", col("postal_code").cast(IntegerType()))

        # Join DataFrames
        integrated_df = sales_df.join(products_df.select("product_id", "category", "sub_category", "product_name"), 
                                     "product_id", "left") \
                               .join(customers_df.select("customer_id", "customer_name", "customer_segment", 
                                                        "age", "country", "city", "state", "postal_code", "region"), 
                                     "customer_id", "left")
        print(f"Task 3: Integrated DataFrame created with {integrated_df.count()} rows.")
        return integrated_df
    except Exception as e:
        print(f"Task 3: Error transforming data: {str(e)}")
        return None

In [5]:
def define_queries():
    """
    Define the 10 analytical questions as a list of query functions.
    Output: List of tuples (query_name, query_function)
    """
    print("Task 4: Formulating analytical questions.")
    
    queries = [
        ("Total sales by category", query_1),
        ("Top customer by purchases", query_2),
        ("Average discount", query_3),
        ("Unique products by region", query_4),
        ("Total profit by state", query_5),
        ("Top sub-category by sales", query_6),
        ("Average age by segment", query_7),
        ("Orders by shipping mode", query_8),
        ("Total quantity by city", query_9),
        ("Top segment by profit margin", query_10)
    ]
    return queries

# ------------------- Task 5: Execute Analytical Queries -------------------

def run_queries(integrated_df, queries):
    """
    Execute the defined analytical queries.
    Input: Integrated DataFrame, list of (query_name, query_function)
    """
    try:
        print("\n--- Analytical Queries ---")
        for query_name, query_func in queries:
            print(f"\nRunning query: {query_name}")
            query_func(integrated_df)
        print("Task 5: All queries executed successfully.")
    except Exception as e:
        print(f"Task 5: Error executing queries: {str(e)}")

# ------------------- Main Execution -------------------

def main():
    """
    Main function to execute the PySpark pipeline and analysis.
    """
    # File paths
    products_file = "file:///C:/Guvi/Project_5/Product.csv"
    sales_file = "file:///C:/Guvi/Project_5/Sales.csv"
    customers_file = "file:///C:/Guvi/Project_5/Customer.csv"

    # Verify file existence
    for file_path in [products_file.replace("file:///", ""), sales_file.replace("file:///", ""), customers_file.replace("file:///", "")]:
        if not os.path.exists(file_path):
            print(f"Error: File not found at {file_path}")
            return

    # Task 1: Initialize Spark
    spark = init_spark()
    if not spark:
        return

    # Task 2: Load data
    products_df, sales_df, customers_df = load_data(spark, products_file, sales_file, customers_file)
    if not all([products_df, sales_df, customers_df]):
        spark.stop()
        return

    # Task 3: Transform and clean data
    integrated_df = transform_data(products_df, sales_df, customers_df)
    if not integrated_df:
        spark.stop()
        return

    # Task 4: Define queries
    queries = define_queries()

    # Task 5: Run queries
    run_queries(integrated_df, queries)

    # Summary statistics
    try:
        total_sales = integrated_df.agg(round(sum("sales"), 2).alias("total_sales")).collect()[0]["total_sales"]
        total_profit = integrated_df.agg(round(sum("profit"), 2).alias("total_profit")).collect()[0]["total_profit"]
        total_rows = integrated_df.count()
        print("\nSummary:")
        print(f"Total records processed: {total_rows}")
        print(f"Total sales across all records: ${total_sales}")
        print(f"Total profit across all records: ${total_profit}")
    except Exception as e:
        print(f"Error calculating summary: {str(e)}")

    # Stop Spark session
    spark.stop()
    print("Spark session stopped.")

# ------------------- Analytical Questions Code -------------------

"""
Analytical Questions Code
The following functions define the logic for the 10 analytical questions.
They are used in Task 4 (define_queries) and executed in Task 5 (run_queries).
"""

def query_1(df):
    print("1. Total sales for each product category")
    df.groupBy("category") \
      .agg(round(sum("sales"), 2).alias("total_sales")) \
      .orderBy("total_sales", ascending=False) \
      .show()

def query_2(df):
    print("2. Customer with the highest number of purchases")
    df.groupBy("customer_id", "customer_name") \
      .agg(count("order_line").alias("purchase_count")) \
      .orderBy("purchase_count", ascending=False) \
      .limit(1) \
      .show()

def query_3(df):
    print("3. Average discount given on sales across all products")
    df.agg(round(avg("discount"), 4).alias("avg_discount")) \
      .show()

def query_4(df):
    print("4. Unique products sold in each region")
    df.groupBy("region") \
      .agg(countDistinct("product_id").alias("unique_products")) \
      .orderBy("region") \
      .show()

def query_5(df):
    print("5. Total profit generated in each state")
    df.groupBy("state") \
      .agg(round(sum("profit"), 2).alias("total_profit")) \
      .orderBy("total_profit", ascending=False) \
      .show()

def query_6(df):
    print("6. Product sub-category with the highest sales")
    df.groupBy("sub_category") \
      .agg(round(sum("sales"), 2).alias("total_sales")) \
      .orderBy("total_sales", ascending=False) \
      .limit(1) \
      .show()

def query_7(df):
    print("7. Average age of customers in each segment")
    df.groupBy("customer_segment") \
      .agg(round(avg("age"), 2).alias("avg_age")) \
      .orderBy("customer_segment") \
      .show()

def query_8(df):
    print("8. Orders shipped in each shipping mode")
    df.groupBy("ship_mode") \
      .agg(count("order_line").alias("order_count")) \
      .orderBy("order_count", ascending=False) \
      .show()

def query_9(df):
    print("9. Total quantity of products sold in each city")
    df.groupBy("city") \
      .agg(sum("quantity").alias("total_quantity")) \
      .orderBy("total_quantity", ascending=False) \
      .show()

def query_10(df):
    print("10. Customer segment with the highest profit margin")
    df.groupBy("customer_segment") \
      .agg(round(avg(col("profit") / col("sales")), 4).alias("profit_margin")) \
      .orderBy("profit_margin", ascending=False) \
      .limit(1) \
      .show()

if __name__ == "__main__":
    main()

Spark session initialized successfully.
Loaded products.csv with 1862 rows.
Products DataFrame schema:
root
 |-- Product ID: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Sub-Category: string (nullable = true)
 |-- Product Name: string (nullable = true)

Loaded sales.csv with 9994 rows.
Sales DataFrame schema:
root
 |-- Order Line: string (nullable = true)
 |-- Order ID: string (nullable = true)
 |-- Order Date: string (nullable = true)
 |-- Ship Date: string (nullable = true)
 |-- Ship Mode: string (nullable = true)
 |-- Customer ID: string (nullable = true)
 |-- Product ID: string (nullable = true)
 |-- Sales: string (nullable = true)
 |-- Quantity: string (nullable = true)
 |-- Discount: string (nullable = true)
 |-- Profit: string (nullable = true)

Loaded customers.csv with 793 rows.
Customers DataFrame schema:
root
 |-- Customer ID: string (nullable = true)
 |-- Customer Name: string (nullable = true)
 |-- Segment: string (nullable = true)
 |-- Age: string