<img src="./images/logo.png" alt="Drawing" style="width: 500px;"/>

# **Exercise 3:** Exploring Retail Data with Apache Spark

This exercise will introduce **Apache Spark on HPE AI Essentials**. We'll leverage Spark's powerful distributed processing capabilities to analyze and fix the sales information.

In this exercise, you will:

- Set up a Spark session for interacting with data.
- Generate sample sales data for different countries and currencies.
- Explore techniques for data loading, transformation, and analysis using Spark SQL and DataFrames.
- Create Delta Tables and perform version control.

Feel free to modify and extend the code examples to suit your specific data analysis needs.

Let's get started!

### **Prerequisites:**

As instructed in the [Introductory notebook](./00.introduction.ipynb), ensure that you have run `pip install -r requirements.txt` in a Terminal window, located in the same working directory, prior to running this notebook. 

<div class="alert alert-block alert-danger">
    <b>Important:</b> Make sure you selected <b>PySpark</b> for your notebook kernel - check the top right corner!
</div>

## **1. Create Spark Session**

Think about the most recent Excel spreadsheet you edited. It probably had tens or even hundreds of rows across tens of columns. When you run an Excel command, such as a *SUM()* or a *VLOOKUP()*, you may have noticed that it took a far bit of time to process. Maybe, even the fans of your laptop sped up a bit as your computer worked to crunch the numbers. 

Now, scale that same command out to a spreadsheet with tens of **millions** of rows across **thousands** of columns. That is the Big Data that companies must work with on a daily basis, and no single PC is going to run any *VLOOKUP* command on data of that size.

Instead of spreadsheets, the enterprise world is largely built upon **tables** in a variety of formats. To query these tables to retrieve certain data takes a **mammoth** amount of compute. It makes no sense to have a single **compute server** executing these queries - it would be far faster to parallelize queries across several computers. Enter **Apache Spark**.

### Introduction to Apache Spark on HPE AI Essentials

Apache Spark is a popular open-source big data framework that **distributes the computations** required to perform queries on large sets of data. This distribution, along with working with data in-memory rather than directly from storage disks, drastically brings down the time usually taken to query and index data. The combination of speed, versatility, and ease of use made Spark the go-to framework when working with big data. 

Apache Spark comes pre-installed with **HPE Ezmeral AI Essentials** and can leverage as much or as little of the compute available in a AIE cluster as a user desired. The core components of an Apache Spark deployment include:

<img src="./images/exercise1/spark_archi.PNG" alt="Drawing" style="width: 60%;"/>

**Driver:** The driver program coordinates the execution of Spark jobs. It submits tasks to executors, schedules operations, and manages communication between various components.

**Workers:** These are machines in the Spark cluster that manage executors. Each worker runs one or more executors. When running Spark on a HPE AI Essentials deployment, Spark Workers are Kubernetes pods distributed among worker nodes of the AIE cluster, allowing them to scale across multiple machines as required. 

**Executors:** Executors reside on worker nodes and carry out the actual computations assigned by the driver program. They partition and distribute the workload across machines in the cluster.

**JVM:**  Spark utilizes the Java Virtual Machine (JVM) on each worker node to execute executors.

On **HPE AI Essentials**, you will use Apache Spark to analyze large datasets at high speed with a unified platform for batch processing, streaming, and machine learning.

### Create a Spark Interactive Session

Let's begin using Spark! Here, you use HPE AI Essentials' native integration of **Apache Livy** to create and manage an interactive Spark session. Livy is an open-source REST service that enables remote and interactive analytics on Apache Spark clusters. It provides a way to interact with Spark clusters programmatically using a REST API, allowing you to submit Spark jobs, run interactive queries, and manage Sparksessions from web applications without the need for a specific Spark client. As a result, multiple AIE users can interact with your Spark cluster concurrently and reliably!

First, let's connect to the Livy endpoint and create a new Spark interactive session. The Spark interactive
session is particularly useful for exploratory data analysis, prototyping, and iterative development. It allows you to
interactively work with large datasets, perform transformations, apply analytical operations, and build ML models using
Spark's distributed computing capabilities. 

To communicate with Livy and manage your sessions you use Sparkmagic, an open-source tool that provides a Jupyter kernel
extension. Sparkmagic integrates with Livy, to provide the underlying communication layer between the Jupyter kernel and
the Spark cluster.

**Execute the cell below**, then:

1. Select the `Add Endpoint` tab.
1. Select `Single Sign-on` and ensure there is a Livy address in the `Address` field. 
1. Click `Add Endpoint`.
1. Select the `Create Session` tab.
1. Provide a name (e.g. `retail-demo`).
1. Select `python` under the Language field.
1. Click `Create Session` (right side).

The session will take a few minutes for your session to initialize. 

Once ready, the Manage Sessions pane will activate, displaying
your session ID. When the session state turns to idle, you're all set!

In [1]:
%manage_spark

Tab(children=(ManageSessionWidget(children=(HTML(value='<br/>'), HTML(value='No sessions yet.'))), CreateSessi…

Now, let's check the status of the session.

1. Navigate back to the AIE dashboard.
1. In the sidebar navigation menu, select `Spark Interactive Sessions`.

![image.png](./images/exercise1/menu.PNG)

3. Here, you can check the status of your session. It will take 2-3 minutes to start. When the `State` says `Idle`, the session is ready. 

![image.png](./images/exercise1/session.PNG)

4. Scroll back up to the Notebook cell of the session (%manage_spark command). Confirm under the `Manage Sessions` tab that the session should now be visible as `Idle` too. 

![image.png](./images/exercise1/session2.PNG)

### Configure Spark Interactive Session

1. Run the `%config_spark` magic command.
2. Leave the settings as they are. Click `Submit`.

<div class="alert alert-block alert-danger">
    <b>Important:</b> Ignore the resulting message and <b>do not</b> restart the kernel.
</div>

In [None]:
%config_spark

Next, let's import the required libraries for working with Spark in this notebook.

In [2]:
import os
os.makedirs("file:///mounts/shared-volume/shared/retail-data/delta-tables/vince", exist_ok=True)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [34]:
import random
from datetime import datetime, timedelta
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit
from delta.tables import DeltaTable

# Initialize Spark session
spark = SparkSession.builder \
    .appName("RetailDataPipeline") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .getOrCreate()

# Configuration
delta_path = "file:///mounts/shared-volume/shared/retail-data/delta-tables/vince/"

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [36]:
# List of tables to extract
tables = [
    "source_catalog",
    "source_customers",
    "source_orders",
    "source_order_products",
    "source_stock"
]

def extract_and_save_table(table_name):
    """Extract a single table from Presto and save to Delta"""
    try:
        print(f"Processing table: {table_name}")
        
        # Presto connection configuration
        uri = f"jdbc:presto://ezpresto.{DOMAIN}:443/{catalog}/{schema}"
        query = f"SELECT * FROM {catalog}.{schema}.{table_name}"
        
        # Read from Presto
        df = spark.read.format("jdbc") \
            .option("driver", "com.facebook.presto.jdbc.PrestoDriver") \
            .option("url", uri) \
            .option("user", user) \
            .option("SSL", "true") \
            .option("IgnoreSSLChecks", "true") \
            .option("query", query) \
            .load()
        
        # Write to Delta format
        df.write.format("delta") \
            .mode("overwrite") \
            .save(f"{delta_path}{table_name}")
            
        print(f"Successfully saved {table_name} to Delta format")
        return True
        
    except Exception as e:
        print(f"Error processing table {table_name}: {str(e)}")
        return False

# Process all tables
for table in tables:
    success = extract_and_save_table(table)
    if not success:
        print(f"Failed to process table {table}")
        # Continue with next table or break based on your requirements

print("All tables processed")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Processing table: source_catalog
Successfully saved source_catalog to Delta format
Processing table: source_customers
Successfully saved source_customers to Delta format
Processing table: source_orders
Successfully saved source_orders to Delta format
Processing table: source_order_products
Successfully saved source_order_products to Delta format
Processing table: source_stock
Successfully saved source_stock to Delta format
All tables processed

In [47]:
from pyspark.sql.functions import col, trim, when, lit

# Clean product names and categories
cleaned_catalog = spark.read.format("delta").load(f"{delta_path}source_catalog") \
    .withColumn("product_name", trim(col("product_name"))) \
    .withColumn("product_category", 
        when(col("product_category") == "Toyz", "Toys")
        .when(col("product_category") == "Clothng", "Clothing")
        .when(col("product_category") == "Eletronics", "Electronics")
        .otherwise(col("product_category"))) \
    .filter(col("product_id").isNotNull()) \
    .filter(col("price_cents") > 0)  # Remove negative prices

cleaned_catalog.write.format("delta") \
    .mode("overwrite") \
    .save(f"{delta_path}source_catalog")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [48]:
# Clean customer data
cleaned_customers = spark.read.format("delta").load(f"{delta_path}source_customers") \
    .withColumn("customer_name", trim(col("customer_name"))) \
    .withColumn("customer_surname", trim(col("customer_surname"))) \
    .withColumn("customer_email",
        when(
            (col("customer_email").contains("@")) & 
            (col("customer_email").contains(".")),
            col("customer_email")
        ).otherwise(lit(None))) \
    .filter(col("customer_id").isNotNull())

cleaned_customers.write.format("delta") \
    .mode("overwrite") \
    .save(f"{delta_path}source_customers")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [49]:
# Clean stock data
cleaned_stock = spark.read.format("delta").load(f"{delta_path}source_stock") \
    .filter(col("product_quantity") > 0) \
    .filter(col("entry_date") <= lit(current_date())) \
    .filter(col("product_id").isNotNull()) \
    .filter(col("purchase_price_cents") > 0)

cleaned_stock.write.format("delta") \
    .mode("overwrite") \
    .save(f"{delta_path}source_stock")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [50]:
from pyspark.sql.functions import current_date

# Clean orders data
cleaned_orders = spark.read.format("delta").load(f"{delta_path}source_orders") \
    .filter(col("order_date") <= current_date()) \
    .filter(col("customer_id").isNotNull()) \
    .withColumn("order_status",
        when(col("order_status").isin(["completed", "pending", "cancelled", "shipped"]),
            col("order_status")
        ).otherwise(lit("pending")))

cleaned_orders.write.format("delta") \
    .mode("overwrite") \
    .save(f"{delta_path}source_orders")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

An error was encountered:
Invalid status code '400' from http://livy-0.livy-svc.spark.svc.cluster.local:8998/sessions/28/statements/44 with error payload: {"msg":"requirement failed: Session isn't active."}


In [51]:
# Clean order products data
cleaned_order_products = spark.read.format("delta").load(f"{delta_path}source_order_products") \
    .filter(col("product_quantity") > 0) \
    .filter(col("order_id").isNotNull()) \
    .filter(col("product_id").isNotNull())

cleaned_order_products.write.format("delta") \
    .mode("overwrite") \
    .save(f"{delta_path}source_order_products")

An error was encountered:
Session 28 unexpectedly reached final status 'dead'. See logs:
stdout: 

stderr: 
	at org.apache.logging.log4j.LogManager.getContext(LogManager.java:157)
	at org.apache.spark.internal.Logging$.islog4j2DefaultConfigured(Logging.scala:258)
	at org.apache.spark.internal.Logging.initializeLogging(Logging.scala:133)
	at org.apache.spark.internal.Logging.initializeLogIfNecessary(Logging.scala:114)
	at org.apache.spark.internal.Logging.initializeLogIfNecessary$(Logging.scala:108)
	at org.apache.spark.deploy.SparkSubmit.initializeLogIfNecessary(SparkSubmit.scala:76)
	at org.apache.spark.deploy.SparkSubmit.doSubmit(SparkSubmit.scala:84)
	at org.apache.spark.deploy.SparkSubmit$$anon$2.doSubmit(SparkSubmit.scala:1137)
	at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:1146)
	at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)
ERROR StatusConsoleListener Could not create plugin of type class org.apache.logging.log4j.core.async.AsyncLoggerConfig for

# Now add fancy text and go back to presto to add the delta table there

In [6]:
%%sql 
SELECT * FROM retailvince.public.source_catalog

Hey, The code failed because of a fatal error:
	Error sending http request and maximum retry encountered..

Some things to try:
a) Make sure Spark has enough available resources for Jupyter to create a Spark context.
b) Contact your Jupyter administrator to make sure the Spark magics library is configured correctly.
c) Restart the kernel.


In [32]:
DOMAIN = "hpepcai.ezmeral.demo.local"
user = "admin-901d042c"
catalog = "retailvince"
schema = "public"
table = "source_catalog"
uri = f"jdbc:presto://ezpresto.{DOMAIN}:443/{catalog}/{table}"
query = f"select * from {catalog}.{schema}.{table}"
df = spark.read.format("jdbc"). \
      option("driver", "com.facebook.presto.jdbc.PrestoDriver"). \
      option("url", uri). \
      option("user", user). \
      option("SSL", "true"). \
      option("IgnoreSSLChecks", "true"). \
      option("query", query). \
      load().show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------------+----------------+-----------+
|product_id|product_name|product_category|price_cents|
+----------+------------+----------------+-----------+
|         1|Electronic 1|           Books|       3980|
|         2|    Clothn 2|      Home Decor|       4471|
|         3|       Toy 3|        Clothing|       5279|
|         4| Eletronic 4|            Toyz|       3426|
|         5|        NULL|         Clothng|       8863|
|         6|   Clothin 6|            Toyz|       3410|
|         7| Home Deco 7|        Clothing|       1195|
|         8|    Clothn 8|            NULL|       3458|
|         9|       Toy 9|        Clothing|       7435|
|        10|  Clothin 10|            Toys|       6627|
|        11|Eletronic 11|            Toys|       2758|
|        13|        NULL|         Clothng|       8983|
|        14|      Toy 14|            NULL|       7845|
|        15|Eletronic 15|      Home Decor|       8619|
|        16|      Toy 16|         Clothng|       9183|
|        1

In [15]:
dfReader = spark.read.format("EzPresto")
dfReader.option("presto_url", "https://ezpresto-sts-mst-0.admin-901d042c.svc.cluster.local:8081")
dfReader.option("dal_url", "https://ezpresto-sts-mst-0.admin-901d042c.svc.cluster.local:9090")
dfReader.option("ignore_ssl_check", "true")
dfReader.option("query", "SELECT * FROM retailvince.public.source_catalog")

df = dfReader.load()
df.show()


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

An error was encountered:
An error occurred while calling o155.load.
: com.ezsql.sparkconnector.utils.dal.DalException: DalClient failed to extract schema
	at com.ezsql.sparkconnector.utils.dal.DalClient.extractSchema(DalClient.java:41)
	at com.ezsql.sparkconnector.PrestoTable.<init>(PrestoTable.java:30)
	at com.ezsql.sparkconnector.Presto.getTable(Presto.java:41)
	at org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils$.getTableFromProvider(DataSourceV2Utils.scala:92)
	at org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils$.loadV2Source(DataSourceV2Utils.scala:140)
	at org.apache.spark.sql.DataFrameReader.$anonfun$load$1(DataFrameReader.scala:210)
	at scala.Option.flatMap(Option.scala:271)
	at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:208)
	at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:172)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodA

In [9]:
!cat /opt/conda/envs/sparkmagic/lib/python3.11/site-packages/sparkmagic/magics/remotesparkmagics.py

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

An error was encountered:
invalid syntax (<stdin>, line 1)
  File "<stdin>", line 1
    !cat /opt/conda/envs/sparkmagic/lib/python3.11/site-packages/sparkmagic/magics/remotesparkmagics.py
    ^
SyntaxError: invalid syntax



In [None]:
presto_df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:presto://ezpresto.hpepcai.ezmeral.demo.local:/retailvince/public") \
    .option("dbtable", "source_catalog") \
    .option("driver", "com.facebook.presto.jdbc.PrestoDriver") \
    .load()


In [None]:
"""Extract data from SQL tables using %sql commands"""
print("Extracting data from SQL tables...")

# Extract each table and cache the DataFrames
catalog_df = spark.sql("SELECT * FROM retailvince.public.source_catalog")
customers_df = spark.sql("SELECT * FROM retailvince.public.source_customers")
orders_df = spark.sql("SELECT * FROM retailvince.public.source_orders")
order_products_df = spark.sql("SELECT * FROM retailvince.public.source_order_products")
stock_df = spark.sql("SELECT * FROM retailvince.public.source_stock")

return {
    "catalog": catalog_df,
    "customers": customers_df,
    "orders": orders_df,
    "order_products": order_products_df,
    "stock": stock_df
}

In [None]:
def create_delta_tables(dataframes):
    """Create Delta tables from the extracted data"""
    print("Creating Delta tables...")
    
    # Write each DataFrame to Delta format
    dataframes["catalog"].write.format("delta").mode("overwrite").save(f"{delta_path}source_catalog")
    dataframes["customers"].write.format("delta").mode("overwrite").save(f"{delta_path}source_customers")
    dataframes["orders"].write.format("delta").mode("overwrite").save(f"{delta_path}source_orders")
    dataframes["order_products"].write.format("delta").mode("overwrite").save(f"{delta_path}source_order_products")
    dataframes["stock"].write.format("delta").mode("overwrite").save(f"{delta_path}source_stock")
    
    print("Delta tables created successfully")

In [None]:
def transform_data():
    """Transform the data in Delta tables"""
    print("Transforming data...")
    
    # Example transformation: Clean product names in catalog
    catalog_delta = DeltaTable.forPath(spark, f"{delta_path}source_catalog")
    catalog_df = catalog_delta.toDF()
    
    # Apply transformations
    cleaned_catalog = catalog_df.withColumn(
        "product_name_clean", 
        trim(col("product_name"))
    
    # Write transformed data back
    cleaned_catalog.write.format("delta").mode("overwrite").save(f"{delta_path}source_catalog")
    
    print("Data transformation completed")

In [None]:
def validate_data():
    """Validate the data quality"""
    print("Validating data...")
    
    # Example validation: Check for null product names
    null_names = spark.sql(f"""
    SELECT COUNT(*) AS null_names_count 
    FROM delta.`{delta_path}source_catalog` 
    WHERE product_name IS NULL OR TRIM(product_name) = ''
    """).collect()[0]["null_names_count"]
    
    print(f"Found {null_names} products with null or empty names")
    
    # Add more validations as needed

In [None]:
def main():
    """Main pipeline execution"""
    try:
        # Step 1: Extract data from SQL tables
        dataframes = extract_sql_tables()
        
        # Step 2: Create Delta tables
        create_delta_tables(dataframes)
        
        # Step 3: Transform data
        transform_data()
        
        # Step 4: Validate data
        validate_data()
        
        print("Pipeline executed successfully")
    except Exception as e:
        print(f"Pipeline failed: {str(e)}")
        raise

if __name__ == "__main__":
    main()

In [None]:
def create_delta_tables():
    """Create empty Delta tables with proper schema"""
    # Create products table
    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS delta.`{delta_path}source_catalog` (
        product_id INT,
        product_name STRING,
        product_category STRING,
        price_cents INT
    ) USING DELTA
    """)
    
    # Create customers table
    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS delta.`{delta_path}source_customers` (
        customer_id INT,
        customer_name STRING,
        customer_surname STRING,
        customer_email STRING
    ) USING DELTA
    """)
    
    # Create stock table
    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS delta.`{delta_path}source_stock` (
        entry_id INT,
        product_id INT,
        product_quantity INT,
        purchase_price_cents INT,
        entry_date DATE
    ) USING DELTA
    """)
    
    # Create orders table
    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS delta.`{delta_path}source_orders` (
        order_id INT,
        customer_id INT,
        order_status STRING,
        order_date DATE
    ) USING DELTA
    """)
    
    # Create order products table
    spark.sql(f"""
    CREATE TABLE IF NOT EXISTS delta.`{delta_path}source_order_products` (
        transaction_id INT,
        order_id INT,
        product_id INT,
        product_quantity INT
    ) USING DELTA
    """)
    
    print("Delta tables created successfully")

def add_constraints():
    """Add constraints and metadata after table creation"""
    # For Delta Lake, we can add constraints as table properties or enforce them through application logic
    # This is a placeholder for where you would implement your constraint logic
    print("Constraints would be enforced here through application logic")

    # Example: You could create a function that validates data before writing
    # or use Delta Lake's data validation features

In [None]:
def load_raw_data_to_delta():
    """Load generated sample data to Delta tables"""
    data = generate_sample_data()
    
    # Convert each dataset to Spark DataFrame and write to Delta
    spark.createDataFrame(data["products"]).write \
        .format("delta") \
        .mode("append") \
        .save(f"{delta_path}source_catalog")
    
    spark.createDataFrame(data["customers"]).write \
        .format("delta") \
        .mode("append") \
        .save(f"{delta_path}source_customers")
    
    # Similarly for other tables...
    print("Sample data loaded to Delta tables")

In [None]:
def transform_data():
    """Apply transformations to the raw data"""
    # Example transformation: Clean product names
    products_df = DeltaTable.forPath(spark, f"{delta_path}source_catalog").toDF()
    
    cleaned_products = products_df.withColumn(
        "product_name_clean",
        col("product_name").rtrim().ltrim()
    ).drop("product_name").withColumnRenamed("product_name_clean", "product_name")
    
    # Example transformation: Standardize categories
    category_mapping = {
        "Toyz": "Toys",
        "Clothng": "Clothing",
        "Eletronics": "Electronics"
    }
    
    mapping_expr = "CASE "
    for wrong, correct in category_mapping.items():
        mapping_expr += f"WHEN product_category = '{wrong}' THEN '{correct}' "
    mapping_expr += "ELSE product_category END"
    
    cleaned_products = cleaned_products.withColumn(
        "product_category_clean",
        expr(mapping_expr)
    ).drop("product_category").withColumnRenamed("product_category_clean", "product_category")
    
    # Save transformed data to new Delta table
    cleaned_products.write \
        .format("delta") \
        .mode("overwrite") \
        .save(f"{delta_path}transformed_catalog")
    
    print("Data transformations completed")

In [None]:
def validate_data():
    """Run data quality checks"""
    # Example validation: Check for null product names
    null_names = spark.sql(f"""
    SELECT COUNT(*) AS null_names_count 
    FROM delta.`{delta_path}source_catalog` 
    WHERE product_name IS NULL OR TRIM(product_name) = ''
    """).collect()[0]["null_names_count"]
    
    print(f"Found {null_names} products with null or empty names")
    
    # Add more validations as needed

In [None]:
"""Orchestrate the data pipeline"""
try:
    # Step 1: Create Delta tables
    create_delta_tables()
    
    # Step 2: Load raw data
    load_raw_data_to_delta()
    
    # Step 3: Validate raw data
    validate_data()
    
    # Step 4: Transform data
    transform_data()
    
    # Step 5: Validate transformed data
    validate_data()
    
    print("Pipeline executed successfully")
except Exception as e:
    print(f"Pipeline failed: {str(e)}")
    raise

In [None]:
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
from pyspark.sql.functions import udf, col, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
import os

We will also define the paths for where Spark will pull files from and save files to. These paths are specific to the AIE directory structure and to be left as they are.

In [None]:
file_root = "file:///mounts/shared-volume/shared/retail-data/raw-data"
delta_root = "file:///mounts/shared-volume/shared/retail-data/delta-tables/"

You can now instantiate the Spark session. We'll add delta extensions to the configuration to be able to interact with the delta tables.

In [None]:
# Set up the Spark session
spark = SparkSession.builder \
    .appName("DataCleaningWithSpark") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .config("spark.local.dir", "/mnt/shared/end2end-main-exercises/exercises") \
    .getOrCreate()

print("Pyspark session started")

## **2. Generating and Preparing Sales Data**

In this section, we are going to synthetically generate several years of sales data from our three retail stores located in three countries: Switzerland, Germany and the Czech Republic. This sales data will provide the basis for the remaining exercises, where we will learn to analyze, graph and build dashboards to gather insights between and across regions. 

**Optional:** To use `Data Sources` connected through AIE (such as MySQL, MariaDB and PostgresSQL databases), follow **this**.


### Generating Sales Data

A Python script has been provided which can generate the sales data for the three given country locations. 

The parameters for this script are:

- cu: Currency, to account for conversions between stores.
- s: Number of stores in that region.
- sy: Start Year
- ey: End Year
- csv: Resulting File Name

We'll see the first 10 rows of the newly created table. 

In [None]:
%run resources/create_csv.py -c "Germany" -cu EUR -s 5 -sy 2019 -ey 2023 -csv "germany_sales_data_2019_2023.csv" 

In [None]:
%run resources/create_csv.py -c "Czech Republic" -cu CZK -s 5 -sy 2019 -ey 2023 -csv "czech_sales_data_2019_2023.csv"

In [None]:
%run resources/create_csv.py -c "Swiss" -cu CHF -s 5 -sy 2019 -ey 2023 -csv "swiss_sales_data_2019_2023.csv"

Next, we'll ensure that our Spark Interactive session can access the data.

In [None]:
# Define the directory path
data_path = file_root

# List files in the directory
files = spark.sparkContext.wholeTextFiles(data_path)

# Display the list of files
for file_path, _ in files.collect():
    print(file_path)

## **3. Create Delta Tables**

In this section, we will create Delta Tables from our CSV files that we can query using AIE. Delta Tables are a type of table that can be created in Delta Lake, which is an extension of Apache Parquet file format.

### Define an ETL Pipeline to create Delta Tables 

First, let's define some functions that will:

1. Load the data in from a CSV and return a pandas DataFrame.

In [None]:
from pyspark.sql.types import IntegerType

def load_data(spark, country, data_path):
    # Define the path to the CSV file
    csv_path = f"{data_path}/{country}_sales_data_2019_2023.csv"

    # Define the schema with specific data types
    schema = StructType([
        StructField("PRODUCTID", IntegerType(), True),
        StructField("PRODUCT", StringType(), True),
        StructField("TYPE", StringType(), True),
        StructField("UNITPRICE", DoubleType(), True),
        StructField("UNIT", StringType(), True),
        StructField("QTY", IntegerType(), True),
        StructField("TOTALSALES", DoubleType(), True),
        StructField("CURRENCY", StringType(), True),
        StructField("STORE", StringType(), True),
        StructField("COUNTRY", StringType(), True),
        StructField("YEAR", IntegerType(), True)
    ])

    # Read data from the CSV file with the specified schema
    df = spark.read \
        .format("csv") \
        .schema(schema) \
        .option("header", "true") \
        .load(csv_path)

    return df

2. Clean the data, in this case by ensuring the currency of each item is standardized in Euros.

In [None]:
def clean_data(df, spark, country):
    # Define a UDF to convert currencies to EUR
    convert_udf = udf(lambda currency, amount: amount / CZK_TO_EUR_RATE if currency == "CZK" else amount / CHF_TO_EUR_RATE if currency == "CHF" else amount, DoubleType())

    # Apply the UDFs to the DataFrame
    corrected_df = df.withColumn("totalsales", convert_udf(col("currency"), col("totalsales"))) \
                     .withColumn("currency", lit("EUR"))

    # Show the results
    corrected_df.show()

    return corrected_df

3. Save the data as parquet files (Delta Tables).

In [None]:
def write_data(df, country):
    delta_path = delta_root + country

    # Check if the directory exists, and create it if it doesn't
    if not os.path.exists(delta_path):
        os.makedirs(delta_path)
        
    df.write.format("delta").mode("overwrite").save(delta_path)

Great! We've just created functions that will **extract** the data from our generated CSV files, **transform** them into Delta Tables with the currency standardized, then **load** them into a new directory.

You guessed it! We have just created an **ETL pipeline!** 

After declaring our country list and our currency conversion rates, we can run the pipeline.

In [None]:
# Constants
COUNTRY_LIST = ["czech", "germany", "swiss"]
CZK_TO_EUR_RATE = 25
CHF_TO_EUR_RATE = 1

<div class="alert alert-block alert-warning">
<b>Hint:</b> As you can tell by the parameters to the create_csv.py functions in Section 2, we can synthetically generate data for as many stores in as many European countries as we want! Feel free to experiment, so long as the countries are declared in the cell above <b>and the countries that are already there remain.</b>
</div>

In [None]:
for country in COUNTRY_LIST:
    # Load data from the DBs
    df = load_data(spark, country, data_path)
    df.show()
    
    # Clean the data
    cleaned_df = clean_data(df, spark, country)
    cleaned_df.printSchema()
    
    # Write the cleaned data back to the Delta Table
    write_data(cleaned_df, country)

Now, we'll confirm the Delta Tables were create correctly.

In [None]:
for country in COUNTRY_LIST:
    # List files in a directory
    selected_country_path = delta_root + country
    files = os.listdir(selected_country_path)
    print("Table:", country)
    
    for file in files:
        if file.endswith(".parquet"):
            full_path = os.path.join(selected_country_path, file)
            print("Saved in:", full_path)

    print()

# **Conclusion**

In this exercise, you learned to perform the basics of data engineering - all within a single notebook! 

**HPE AI Essentials** makes this possible by natively supporting and including the most widely used open-source data tools and frameworks and making them readily available out-of-the-box, such that you spent this time performing invaluable data preperation for upcoming exercises instead of hours installing and connecting them all!

In the next exercise, you will learn how to use EzPresto on HPE AI Essentials to prepare these datasets for visualization and modelling. 

