# Part 2: Streaming application using Spark Structured Streaming

### 1.	Write code to create a SparkSession, which 1) uses four cores with a proper application name; 2) use the Melbourne timezone; 3) ensure a checkpoint location has been set.


In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import (
    col, from_json, from_unixtime, window, current_timestamp, 
    expr, year, month, dayofmonth, hour, minute, when, explode, 
    current_date, sum as spark_sum
)
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, 
    FloatType, DoubleType, TimestampType, BooleanType, ArrayType
)
from pyspark.sql.window import Window
from pyspark.ml import PipelineModel
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
import os


In [3]:
os.chdir('..')

In [3]:
# Initialize the Spark session with the required configurations
# Running in local mode with 4 cores, allocating 12GB memory for processing
spark = SparkSession.builder \
    .appName("KafkaSparkStreaming") \
    .master("local[6]") \
    .config("spark.executor.memory", "12g") \
    .config("spark.sql.session.timeZone", "Australia/Melbourne") \
    .config("spark.sql.streaming.checkpointLocation", "/tmp/spark-checkpoint") \
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0") \
    .getOrCreate()

# Verify if the Spark session is configured correctly
print("Spark Session Configuration:")
print(f"App Name: {spark.sparkContext.appName}")  # Display the application name
print(f"Master: {spark.sparkContext.master}")  # Verify where the Spark application is running
print(f"Timezone: {spark.conf.get('spark.sql.session.timeZone')}")  # Check if timezone is set correctly
print(f"Checkpoint Location: {spark.conf.get('spark.sql.streaming.checkpointLocation')}")  # Confirm the checkpoint path


Spark Session Configuration:
App Name: KafkaSparkStreaming
Master: local[6]
Timezone: Australia/Melbourne
Checkpoint Location: /tmp/spark-checkpoint


### 2.	Similar to assignment 2A, write code to define the data schema for the data files, following the data types suggested in the metadata file. Load the static datasets (e.g. customer, product, category) into data frames. (You can use your code from 2A.)



In [4]:
# Defining the schema for Customer.csv
customer_schema = StructType([
    StructField("customer_id", IntegerType(), nullable=False),  # ID for each customer (must be unique)
    StructField("first_name", StringType(), nullable=True),  # Customer's first name (optional)
    StructField("last_name", StringType(), nullable=True),  # Customer's last name (optional)
    StructField("username", StringType(), nullable=True),  # Username chosen by the customer
    StructField("email", StringType(), nullable=True),  # Customer's email (could be useful for contact)
    StructField("gender", StringType(), nullable=True),  # Gender of the customer
    StructField("birthdate", StringType(), nullable=True),  # Customer's birthdate (for age calculations)
    StructField("first_join_date", StringType(), nullable=True)  # The date when the customer first joined
])

# Defining the schema for Category.csv
category_schema = StructType([
    StructField("category_id", IntegerType(), nullable=False),  # Unique ID representing each category
    StructField("cat_level1", StringType(), nullable=True),  # Top-level category (e.g., Electronics)
    StructField("cat_level2", StringType(), nullable=True),  # Sub-category (e.g., Mobile Phones)
    StructField("cat_level3", StringType(), nullable=True)  # Further breakdown (e.g., Accessories)
])

# Defining the schema for Browsing_behaviour.csv
browsing_behaviour_schema = StructType([
    StructField("session_id", StringType(), nullable=False),  # Unique session ID per browsing session
    StructField("event_type", StringType(), nullable=True),  # The action user took (e.g., view, click)
    StructField("event_time", TimestampType(), nullable=True),  # When the event occurred
    StructField("traffic_source", StringType(), nullable=True),  # How user found the website (e.g., Google)
    StructField("device_type", StringType(), nullable=True)  # Type of device used (e.g., mobile, desktop)
])

# Defining the schema for Product.csv
product_schema = StructType([
    StructField("id", IntegerType(), nullable=False),  # ID for each product (primary key)
    StructField("gender", StringType(), nullable=True),  # Intended gender for the product (if any)
    StructField("baseColour", StringType(), nullable=True),  # Main color of the product
    StructField("season", StringType(), nullable=True),  # Product’s season (e.g., Winter collection)
    StructField("year", IntegerType(), nullable=True),  # Year the product was released or made
    StructField("usage", StringType(), nullable=True),  # Primary use case of the product (optional)
    StructField("productDisplayName", StringType(), nullable=True),  # Display name shown to users
    StructField("category_id", IntegerType(), nullable=True)  # Reference to product's category
])

# Defining the schema for Transaction.csv
transaction_schema = StructType([
    StructField("created_at", TimestampType(), nullable=False),  # Timestamp of the transaction
    StructField("customer_id", IntegerType(), nullable=False),  # Customer ID involved in the transaction
    StructField("transaction_id", StringType(), nullable=False),  # Unique ID for each transaction
    StructField("session_id", StringType(), nullable=True),  # Session linked to the transaction
    StructField("product_metadata", StringType(), nullable=True),  # Metadata for products bought
    StructField("payment_method", StringType(), nullable=True),  # Payment method used (e.g., card, PayPal)
    StructField("payment_status", StringType(), nullable=True),  # Status (e.g., successful, failed)
    StructField("promo_amount", FloatType(), nullable=True),  # Discount applied (if any)
    StructField("promo_code", StringType(), nullable=True),  # Code used for promotion
    StructField("shipment_fee", FloatType(), nullable=True),  # Fee for shipping
    StructField("shipment_location_lat", DoubleType(), nullable=True),  # Latitude of the shipping address
    StructField("shipment_location_long", DoubleType(), nullable=True),  # Longitude of the shipping address
    StructField("total_amount", FloatType(), nullable=True),  # Total value of the transaction
    StructField("clear_payment", BooleanType(), nullable=True)  # Whether the payment was completed successfully
])

# Defining the schema for Customer_session.csv
customer_session_schema = StructType([
    StructField("session_id", StringType(), nullable=False),  # ID representing the browsing session
    StructField("customer_id", IntegerType(), nullable=False)  # Which customer was in the session
])

# Defining the schema for Fraud_transaction.csv
fraud_transaction_schema = StructType([
    StructField("transaction_id", StringType(), nullable=False),  # Transaction ID involved in fraud
    StructField("is_fraud", BooleanType(), nullable=False)  # Whether the transaction was fraudulent or not
])


In [5]:
print(os.getcwd())
curr_dir = os.getcwd() + '/A2/dataset/dataset'

# Reading in all CSV files with their respective schemas
cust_data = spark.read.csv(curr_dir + "/customer.csv", header=True, schema=customer_schema)
cat_data = spark.read.csv(curr_dir + "/category.csv", header=True, schema=category_schema)
browse_data = spark.read.csv(curr_dir + "/browsing_behaviour.csv", header=True, schema=browsing_behaviour_schema)
prod_data = spark.read.csv(curr_dir + "/product.csv", header=True, schema=product_schema)
trans_data = spark.read.csv(curr_dir + "/transactions.csv", header=True, schema=transaction_schema)
session_data = spark.read.csv(curr_dir + "/customer_session.csv", header=True, schema=customer_session_schema)
fraud_data = spark.read.csv(curr_dir + "/fraud_transaction.csv", header=True, schema=fraud_transaction_schema)

# Printing the schema for each DataFrame to make sure they loaded correctly
print("Customer dataset schema:")
cust_data.printSchema()

print("\nCategory dataset schema:")
cat_data.printSchema()

print("\nBrowsing Behaviour dataset schema:")
browse_data.printSchema()

print("\nProduct dataset schema:")
prod_data.printSchema()

print("\nTransaction dataset schema:")
trans_data.printSchema()

print("\nCustomer Session dataset schema:")
session_data.printSchema()

print("\nFraud Transaction dataset schema:")
fraud_data.printSchema()


/home/student/BIGDATA/LABORATORY/ASSIGNMENTS
Customer dataset schema:
root
 |-- customer_id: integer (nullable = true)
 |-- first_name: string (nullable = true)
 |-- last_name: string (nullable = true)
 |-- username: string (nullable = true)
 |-- email: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- birthdate: string (nullable = true)
 |-- first_join_date: string (nullable = true)


Category dataset schema:
root
 |-- category_id: integer (nullable = true)
 |-- cat_level1: string (nullable = true)
 |-- cat_level2: string (nullable = true)
 |-- cat_level3: string (nullable = true)


Browsing Behaviour dataset schema:
root
 |-- session_id: string (nullable = true)
 |-- event_type: string (nullable = true)
 |-- event_time: timestamp (nullable = true)
 |-- traffic_source: string (nullable = true)
 |-- device_type: string (nullable = true)


Product dataset schema:
root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- baseColour: string (null

### 3. Using the Kafka topics from the producer in Task 1, ingest the streaming data into Spark Streaming, assuming all data comes in the String format. Except for the 'ts' column, you shall receive it as an Int type.


In [6]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, from_json
from pyspark.sql.types import StructField, StructType, StringType, IntegerType, FloatType

# Set Kafka broker configurations
hostip = "kafka"  # Replace this with the Kafka broker's IP or hostname if needed
browsing_topic = 'BrowsingBehaviour-Topic'  # Topic name for browsing events
transaction_topic = 'Transaction-Topic'  # Topic name for transaction events

# Define the schema for browsing data
browsing_schema = StructType([
    StructField("session_id", StringType(), True),  # Unique ID for each session
    StructField("event_type", StringType(), True),  # Type of event (click, view, etc.)
    StructField("event_time", StringType(), True),  # Time when the event occurred (as string)
    StructField("traffic_source", StringType(), True),  # Source of the traffic (e.g., Google)
    StructField("device_type", StringType(), True),  # Type of device used (e.g., mobile, desktop)
    StructField("customer_id", StringType(), True),  # ID of the customer, if available
    StructField("ts", IntegerType(), True)  # Unix timestamp for event time
])

# Define the schema for transaction data
transaction_schema = StructType([
    StructField("created_at", StringType(), True),  # Transaction creation time (as string)
    StructField("customer_id", StringType(), True),  # ID of the customer making the purchase
    StructField("transaction_id", StringType(), True),  # Unique ID for each transaction
    StructField("session_id", StringType(), True),  # Session linked to the transaction
    StructField("product_metadata", StringType(), True),  # Product details (as string)
    StructField("payment_method", StringType(), True),  # Payment method used
    StructField("payment_status", StringType(), True),  # Status of the payment
    StructField("promo_amount", StringType(), True),  # Amount discounted via promo code
    StructField("promo_code", StringType(), True),  # Promo code used (if any)
    StructField("shipment_fee", StringType(), True),  # Shipping fee for the order
    StructField("shipment_location_lat", StringType(), True),  # Latitude of delivery address
    StructField("shipment_location_long", StringType(), True),  # Longitude of delivery address
    StructField("total_amount", StringType(), True),  # Total transaction amount
    StructField("clear_payment", StringType(), True),  # Whether payment was cleared
    StructField("ts", IntegerType(), True)  # Unix timestamp for the transaction
])

# Read browsing data from Kafka
browsing_stream_df = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", f'{hostip}:9092') \
    .option("subscribe", browsing_topic) \
    .option("startingOffsets", "earliest") \
    .load()

# Read transaction data from Kafka
transaction_stream_df = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", f'{hostip}:9092') \
    .option("subscribe", transaction_topic) \
    .option("startingOffsets", "earliest") \
    .load()

# Parse browsing data with the schema
browsing_stream_parsed_df = browsing_stream_df \
    .selectExpr("CAST(value AS STRING)") \
    .select(from_json(col("value"), browsing_schema).alias("browsing_data")) \
    .select("browsing_data.*")

# Parse transaction data with the schema
transaction_stream_parsed_df = transaction_stream_df \
    .selectExpr("CAST(value AS STRING)") \
    .select(from_json(col("value"), transaction_schema).alias("transaction_data")) \
    .select("transaction_data.*")

# Print a success message to confirm dataframes creation
print("Successfully created DataFrames for Browsing and Transaction streams!")


Successfully created DataFrames for Browsing and Transaction streams!


In [7]:
# Stream browsing data to console
browsing_query = browsing_stream_parsed_df.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", False) \
    .start()

# Stream transaction data to console
transaction_query = transaction_stream_parsed_df.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", False) \
    .start()

print("Streaming browsing and transaction data to console...")


Streaming browsing and transaction data to console...


In [8]:

# To stop manually, you can call:
browsing_query.stop()
transaction_query.stop()

print("Browsing and Transaction streaming processes stopped.")

Browsing and Transaction streaming processes stopped.


### 4.	Then, the streaming data format should be transformed into the proper formats following the metadata file schema, similar to assignment 2A. Perform the following tasks:  
a)	For the 'ts' column, convert it to the timestamp format, we will use it as event_ts.  
b)	If the data is late for more than 2 minutes, discard it.  


In [7]:
# Convert 'ts' (Unix timestamp) to 'event_ts' and apply watermark for browsing data
browsing_stream_parsed_df = browsing_stream_parsed_df \
    .withColumn("event_ts", from_unixtime(col("ts")).cast("timestamp")) \
    .withWatermark("event_ts", "10 minutes")  # Watermark to allow streaming joins

# Convert 'ts' to 'created_ts' and apply watermark for transaction data
transaction_stream_parsed_df = transaction_stream_parsed_df \
    .withColumn("event_ts", from_unixtime(col("ts")).cast("timestamp")) \
    .withWatermark("event_ts", "10 minutes")  # Watermark to allow streaming joins


In [8]:
# Task (a): Parse 'ts' to 'event_ts' and 'created_ts' to make them timestamp format
browsing_stream_parsed_df = browsing_stream_parsed_df \
    .withColumn("event_ts", from_unixtime(col("ts")).cast("timestamp"))  # Convert 'ts' to 'event_ts'

transaction_stream_parsed_df = transaction_stream_parsed_df \
    .withColumn("created_ts", from_unixtime(col("ts")).cast("timestamp"))  # Convert 'ts' to 'created_ts'


In [9]:
# Task (b):# Filter out late data (more than 2 minutes delay) for both streams
browsing_stream_with_watermark = browsing_stream_parsed_df \
    .withColumn("current_ts", current_timestamp()) \
    .filter(expr("current_ts - event_ts <= interval 2 minutes"))

transaction_stream_with_watermark = transaction_stream_parsed_df \
    .withColumn("current_ts", current_timestamp()) \
    .filter(expr("current_ts - created_ts <= interval 2 minutes"))


In [11]:
# Apply watermarks again to both streams (redundant but kept for clarity)
browsing_stream_with_watermark = browsing_stream_with_watermark \
    .withColumn("event_ts", from_unixtime(col("ts")).cast("timestamp")) \
    .withWatermark("event_ts", "10 minutes")  # Watermark event_ts

transaction_stream_with_watermark = transaction_stream_with_watermark \
    .withColumn("created_ts", from_unixtime(col("ts")).cast("timestamp")) \
    .withWatermark("created_ts", "10 minutes")  # Watermark created_ts

In [52]:
# Output the browsing stream to the console
browsing_final_stream_df_query = browsing_stream_with_watermark.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", False) \
    .start()

# Output the transaction stream to the console
transaction_stream_parsed_df_ts_parsed_query = transaction_stream_with_watermark.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", False) \
    .start()

In [11]:
browsing_final_stream_df_query.stop()
transaction_stream_parsed_df_ts_parsed_query.stop()

### 5.	Aggregate the streaming data frames and create features you used in your assignment 2A model.  
(note: customer ID has already been included in the stream.) Then, join the static data frames with the streaming data frame as our final data for prediction. Perform data type/column conversion according to your ML model and print out the Schema. (Again, you can reuse code from A2A).

In [13]:
browsing_stream_parsed_df_ts_parsed.printSchema()

root
 |-- browsing_session_id: string (nullable = true)
 |-- event_type: string (nullable = true)
 |-- event_time: string (nullable = true)
 |-- traffic_source: string (nullable = true)
 |-- device_type: string (nullable = true)
 |-- customer_id: string (nullable = true)
 |-- ts: integer (nullable = true)
 |-- event_ts: timestamp (nullable = true)
 |-- browsing_current_ts: timestamp (nullable = false)



In [14]:
transaction_stream_with_watermark.printSchema()

root
 |-- created_at: string (nullable = true)
 |-- customer_id: string (nullable = true)
 |-- transaction_id: string (nullable = true)
 |-- session_id: string (nullable = true)
 |-- product_metadata: string (nullable = true)
 |-- payment_method: string (nullable = true)
 |-- payment_status: string (nullable = true)
 |-- promo_amount: string (nullable = true)
 |-- promo_code: string (nullable = true)
 |-- shipment_fee: string (nullable = true)
 |-- shipment_location_lat: string (nullable = true)
 |-- shipment_location_long: string (nullable = true)
 |-- total_amount: string (nullable = true)
 |-- clear_payment: string (nullable = true)
 |-- ts: integer (nullable = true)
 |-- event_ts: timestamp (nullable = true)
 |-- created_ts: timestamp (nullable = true)
 |-- current_ts: timestamp (nullable = false)



In [12]:
# Rename columns in transaction stream to avoid naming conflicts during join
transaction_stream_parsed_df_ts_parsed = transaction_stream_with_watermark \
    .withColumnRenamed("session_id", "transaction_session_id") \
    .withColumnRenamed("current_ts", "transaction_current_ts") \
    .withColumnRenamed("event_ts", "transaction_event_ts") \
    .withColumnRenamed("ts", "transaction_ts")  # Rename Unix timestamp for clarity

# Rename columns in browsing stream to avoid ambiguity
browsing_stream_parsed_df_ts_parsed = browsing_stream_with_watermark \
    .withColumnRenamed("session_id", "browsing_session_id") \
    .withColumnRenamed("current_ts", "browsing_current_ts")  # Differentiate current timestamps

# Create aliases to make joins clearer and easier to read
transaction_alias = transaction_stream_parsed_df_ts_parsed.alias("transaction")  
browsing_alias = browsing_stream_parsed_df_ts_parsed.alias("browsing")  

# Perform an inner join on customer_id and matching timestamps within a 5-minute window
joined_stream_df = transaction_alias.join(
    browsing_alias,
    expr("""
        transaction.customer_id = browsing.customer_id AND
        browsing.event_ts BETWEEN transaction.created_ts - INTERVAL 5 MINUTES
        AND transaction.created_ts + INTERVAL 5 MINUTES
    """),
    how="inner"  # Use inner join to keep only matching records
)


In [12]:
# Output the streaming data to the console
final_stream_df_query = joined_stream_df.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", False) \
    .start()



In [13]:
final_stream_df_query.stop()

In [13]:
customer_final_stream = joined_stream_df.join(
    cust_data,
    on="customer_id",
    how="inner"
)
customer_final_stream.printSchema()

root
 |-- customer_id: string (nullable = true)
 |-- created_at: string (nullable = true)
 |-- transaction_id: string (nullable = true)
 |-- transaction_session_id: string (nullable = true)
 |-- product_metadata: string (nullable = true)
 |-- payment_method: string (nullable = true)
 |-- payment_status: string (nullable = true)
 |-- promo_amount: string (nullable = true)
 |-- promo_code: string (nullable = true)
 |-- shipment_fee: string (nullable = true)
 |-- shipment_location_lat: string (nullable = true)
 |-- shipment_location_long: string (nullable = true)
 |-- total_amount: string (nullable = true)
 |-- clear_payment: string (nullable = true)
 |-- transaction_ts: integer (nullable = true)
 |-- transaction_event_ts: timestamp (nullable = true)
 |-- created_ts: timestamp (nullable = true)
 |-- transaction_current_ts: timestamp (nullable = false)
 |-- browsing_session_id: string (nullable = true)
 |-- event_type: string (nullable = true)
 |-- event_time: string (nullable = true)
 |--

In [14]:
# Extract date and time components from 'created_ts'
enriched_df = customer_final_stream \
    .withColumn('created_year', year(col('created_ts'))) \
    .withColumn('created_month', month(col('created_ts'))) \
    .withColumn('created_day', dayofmonth(col('created_ts'))) \
    .withColumn('created_hour', hour(col('created_ts'))) \
    .withColumn('created_minute', minute(col('created_ts')))

# Extract components from 'birthdate' for demographic insights
enriched_df = enriched_df \
    .withColumn('birth_year', year(col('birthdate'))) \
    .withColumn('birth_month', month(col('birthdate'))) \
    .withColumn('birth_day', dayofmonth(col('birthdate'))) 

# Extract join date components from 'first_join_date'
enriched_df = enriched_df \
    .withColumn('join_year', year(col('first_join_date'))) \
    .withColumn('join_month', month(col('first_join_date'))) \
    .withColumn('join_day', dayofmonth(col('first_join_date'))) 

# Define schema for product metadata
product_schema = ArrayType(
    StructType([
        StructField("product_id", IntegerType(), True),  # Product ID
        StructField("quantity", IntegerType(), True),  # Quantity purchased
        StructField("item_price", IntegerType(), True)  # Price per item
    ])
)

# Calculate age based on birth year
enriched_df = enriched_df.withColumn('age', year(current_date()) - year(col('birthdate')))

# Extract the year from the 'first_join_date'
enriched_df = enriched_df.withColumn('first_join_year', year(col('first_join_date')))

# Mark if a promo code was used (1 if used, 0 if not)
enriched_df = enriched_df.withColumn('promo_code_used', (col('promo_code').isNotNull()).cast('int'))

# Parse product metadata and explode it into individual rows for detailed analysis
enriched_df = enriched_df \
    .withColumn("product_data", from_json(col("product_metadata"), product_schema)) \
    .withColumn("product_exploded", explode(col("product_data"))) \
    .withColumn("product_id", col("product_exploded.product_id")) \
    .withColumn("quantity", col("product_exploded.quantity")) \
    .withColumn("item_price", col("product_exploded.item_price")) \
    .drop("product_exploded", "product_data")  # Drop intermediate columns


In [25]:
# Output the streaming data to the console
enriched_df_stream = enriched_df.limit(5)
customer_final_stream_query = enriched_df_stream.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", False) \
    .start()



In [26]:

customer_final_stream_query.stop()

In [4]:
# Load the saved pipeline model from the specified path
model_path = os.getcwd() + '/A2/best_model'  # Adjust path if needed
saved_model = PipelineModel.load(model_path)

# Initialize VectorAssembler to extract training features
vector_assembler = None  
for stage in saved_model.stages:  
    if isinstance(stage, VectorAssembler):  # Check if the stage is a VectorAssembler
        vector_assembler = stage  
        break  # Exit the loop once found

# Retrieve and print the features used during training
if vector_assembler:
    training_feature_columns = vector_assembler.getInputCols()  
    print("Features used in Model Training:")
    print(training_feature_columns)  
else:
    print("VectorAssembler not found in the pipeline model.")  # Handle case if not found


Features used in Model Training:
['promo_amount', 'shipment_fee', 'total_amount', 'median_hour', 'age', 'first_join_year', 'created_year', 'created_month', 'created_day', 'created_hour', 'created_minute', 'promo_code_used', 'birth_year', 'birth_month', 'birth_day', 'join_year', 'join_month', 'join_day', 'product_id', 'quantity', 'item_price']


In [15]:
model_path = os.getcwd() + '/A2/best_model'  
saved_model = PipelineModel.load(model_path)

In [16]:
from pyspark.sql.functions import lit

# Ensure 'new_time_of_day_index' column exists, initialize with 0 if missing
if 'new_time_of_day_index' not in enriched_df.columns:
    enriched_df = enriched_df.withColumn('new_time_of_day_index', lit(0))

# Ensure 'median_hour' column exists, initialize with 0 if missing    
if 'median_hour' not in enriched_df.columns:
    enriched_df = enriched_df.withColumn('median_hour', lit(0))


In [23]:
# Output the streaming data to the console
enriched_df_stream = enriched_df.limit(5)
customer_final_stream_query = enriched_df_stream.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", False) \
    .start()



In [24]:
customer_final_stream_query.stop()

In [17]:
required_columns = [
    'promo_amount', 'shipment_fee', 'total_amount', 'median_hour', 'age',
    'first_join_year', 'created_year', 'created_month', 'created_day',
    'created_hour', 'created_minute', 'promo_code_used', 'birth_year',
    'birth_month', 'birth_day', 'join_year', 'join_month', 'join_day', 'event_ts',
    'product_id', 'quantity', 'item_price', 'payment_status', 'payment_method',
    'new_time_of_day_index', 'median_hour'
]

# Keep only the required columns from the DataFrame to avoid unnecessary data processing
filtered_df = enriched_df.select(*required_columns)

# Ensure numeric columns are cast to 'int' for consistency
filtered_df = filtered_df.select(
    *[F.col(c).cast('int') if c in ['promo_amount', 'shipment_fee', 'total_amount'] 
      else F.col(c) for c in required_columns]
)

# Double-check the schema to make sure only the necessary columns are present
filtered_df.printSchema()


root
 |-- promo_amount: integer (nullable = true)
 |-- shipment_fee: integer (nullable = true)
 |-- total_amount: integer (nullable = true)
 |-- median_hour: integer (nullable = false)
 |-- age: integer (nullable = true)
 |-- first_join_year: integer (nullable = true)
 |-- created_year: integer (nullable = true)
 |-- created_month: integer (nullable = true)
 |-- created_day: integer (nullable = true)
 |-- created_hour: integer (nullable = true)
 |-- created_minute: integer (nullable = true)
 |-- promo_code_used: integer (nullable = false)
 |-- birth_year: integer (nullable = true)
 |-- birth_month: integer (nullable = true)
 |-- birth_day: integer (nullable = true)
 |-- join_year: integer (nullable = true)
 |-- join_month: integer (nullable = true)
 |-- join_day: integer (nullable = true)
 |-- event_ts: timestamp (nullable = true)
 |-- product_id: integer (nullable = true)
 |-- quantity: integer (nullable = true)
 |-- item_price: integer (nullable = true)
 |-- payment_status: string (n

In [18]:
filtered_df = filtered_df.limit(5)
fraud_count_query_predictions_df = filtered_df.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", "false") \
    .trigger(processingTime="10 seconds") \
    .start()


In [19]:
fraud_count_query_predictions_df.stop()

### 6.	The company is interested in the number of potential frauds as they happen and the products in customers’ shopping carts (so that they can plan their stock level ahead.) Load your ML model, and use the model to predict/process each browsing session/transaction as follows:  
a)	Every 10 seconds, show the total number of potential frauds (prediction = 1) in the last 2 minutes, and persist the raw data (see 7a).  
b)	Every 30 seconds, find the top 20 products (order by quantity descending) in the last 30 seconds, show product ID, name and total quantity. We only need the non-fraud transactions (prediction=0) by extracting customer shopping cart details (sum of all items of ADD_TO_CART(ATC) events from browsing behaviour, you can also extract it from transactions).

6a

In [20]:
# Use the pre-trained model to make predictions on the filtered streaming data
predictions_df = saved_model.transform(filtered_df)

# Filter out potential fraud cases (where prediction value is 1)
potential_frauds_df = predictions_df.filter(col("gbt_prediction") == 1)

# Aggregate potential frauds in the last 2 minutes, updating every 10 seconds
fraud_count_df = potential_frauds_df \
    .groupBy(window(col("event_ts"), "2 minutes", "10 seconds")) \
    .count()


In [21]:
predictions_df = predictions_df.limit(2)
fraud_count_query_predictions_df = predictions_df.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", "false") \
    .trigger(processingTime="10 seconds") \
    .start()


In [22]:
fraud_count_query_predictions_df.stop()

In [23]:
fraud_count_query = fraud_count_df.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", "false") \
    .trigger(processingTime="10 seconds") \
    .start()

In [25]:
fraud_count_query.stop()


6b

In [26]:
# Apply the saved model to the filtered data to generate predictions
predictions_df = saved_model.transform(filtered_df)

# Filter non-fraud transactions (where prediction = 0) and focus on 'ADD_TO_CART' events
non_fraud_df = predictions_df.filter((col("gbt_prediction") == 0) & (col("event_type") == "ATC"))

# Identify the top 20 products in the last 30 seconds using a tumbling window
top_products_df = non_fraud_df \
    .withWatermark("event_ts", "2 minutes") \
    .groupBy(window(col("event_ts"), "30 seconds"), "product_id") \
    .agg(spark_sum("quantity").alias("total_quantity")) \
    .limit(20)  # Limiting results to the top 20 products


In [27]:

# Step 11: Output top 20 products to the console every 30 seconds
top_products_query = top_products_df.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", "false") \
    .trigger(processingTime="30 seconds") \
    .start()



In [31]:
top_products_query.stop()

### 7.	Write a Parquet file and save the following data frames (tip: you may look at part 3 and think about what columns to save):  
a.	Persist the raw data from 6a in parquet format. Every student may have different features/columns in their data frames depending on their model, at the bare minimum, we need some IDs to identify those frauds later on (transaction_id and/or session_id). After that, read the parquet file and show a few rows to verify it is saved correctly.  
b.	Persist the data from 6b in another parquet file.  

7a

In [27]:
def save_to_parquet(df, batch_id):
    # Save the DataFrame as a Parquet file, appending each batch to the existing data
    df.write.parquet(f"{os.getcwd()}/A2_B/fraud_raw_data_1.parquet", mode="append")

# Persist the raw fraud data to Parquet every 10 seconds using foreachBatch
raw_fraud_query = potential_frauds_df.writeStream \
    .outputMode("append") \
    .foreachBatch(save_to_parquet) \
    .start()


In [34]:
raw_fraud_query.stop()


7b

In [28]:
def save_to_parquet_b(df, batch_id):
    # Save the DataFrame as a Parquet file, appending each batch to the existing data
    df.write.parquet(f"{os.getcwd()}/A2_B/top_products_1.parquet", mode="append")

# Persist the top products data to Parquet every 10 seconds using foreachBatch
top_products_persist_query = top_products_df.writeStream \
    .outputMode("append") \
    .foreachBatch(save_to_parquet_b) \
    .start()


In [33]:
top_products_persist_query.stop()


### 8.	Read the two parquet files from task 7 as data streams and send to Kafka topics with appropriate names.
(Note: You shall read the parquet files as a streaming data frame and send messages to the Kafka topic when new data appears in the parquet file.)