Dear tutor, there is something wrong with my IDE. I can change the file locally but cannot save my solution.
I also tried to upload my code to my personal Github repository but still cannot save the local change.
Although I could save the result of my spark code.

So I will put my code in this file. If you copy them to the original "hw07.ipynb" file and run, it should be fine.
All of the solution could get the right result.
Thanks so much. I am so sorry for this problem.



In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

# when run locally, spark has one (master) node with its own jvm and no cluster manager is created
spark = SparkSession.builder.master("local").appName("Homework 07").getOrCreate()

spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/30 17:56:37 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/05/30 17:56:38 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


## Define Helper Function

In [2]:
# load table from TPCxx-BB dataset. Returning a dataframe read from parquet format
get_table = lambda table: spark.read.option("header", "true").parquet(
    f"TPCx-BB-dataset/{table}.ptxt"
)

## 1. Query 07
List top 10 states in descending order with at least 10 customers who during a given month bought products with the price tag at least 20% higher than the average price of products in the same category.

In [3]:
# implementation

item = get_table("item")

# Define high price ratio
q07_HIGHER_PRICE_RATIO = 1.2

# Calculate the average price for each category and calculate the high price threshold
category_avg_price = item.groupBy("i_category").agg(
    (avg("i_current_price") * q07_HIGHER_PRICE_RATIO).alias("avg_price")
)

# Rename columns to avoid column name conflicts when joining
item = item.withColumnRenamed("i_category", "item_category")
item = item.withColumnRenamed("i_current_price", "item_price")

# Connect DataFrame and filter out high-priced items
high_price_items = (item.join(category_avg_price, item.item_category == category_avg_price.i_category)
    .filter(item.item_price > category_avg_price.avg_price)
    .select(item.i_item_sk)
)

customer_address = get_table("customer_address")
customer = get_table("customer")
store_sales = get_table("store_sales")
date_dim = get_table("date_dim")

q07_YEAR = 2004
q07_MONTH = 7
q07_HAVING_COUNT_GE = 10
q07_LIMIT = 10

# Extract all dates of a specific month from date_dim
selected_dates = date_dim.filter(
    (date_dim.d_year == q07_YEAR) & (date_dim.d_moy == q07_MONTH)
).select("d_date_sk").rdd.map(lambda row: row[0]).collect()

query7_solution = (customer_address
    .join(customer, customer_address.ca_address_sk == customer.c_current_addr_sk)
    .join(store_sales, customer.c_customer_sk == store_sales.ss_customer_sk)
    .join(high_price_items, store_sales.ss_item_sk == high_price_items.i_item_sk)
    .filter(store_sales.ss_sold_date_sk.isin(selected_dates))
    .filter(customer_address.ca_state.isNotNull())
    .groupBy(customer_address.ca_state)
    .agg(count("*").alias("cnt"))
    .filter(col("cnt") >= q07_HAVING_COUNT_GE)
    .orderBy(col("cnt"), ascending=False)
    .limit(q07_LIMIT)
)

query7_solution.show()


                                                                                

+--------+---+
|ca_state|cnt|
+--------+---+
|      TX|396|
|      GA|247|
|      VA|233|
|      IL|205|
|      KY|176|
|      KS|170|
|      NC|164|
|      IA|163|
|      MO|156|
|      AL|139|
+--------+---+


In [4]:
# check the result
!cat queries/q07/results/q07-result

TX,396
GA,247
VA,233
IL,205
KY,176
KS,170
NC,164
IA,163
MO,156
AL,139


## 2. Modify the dataframe resulted from query 07:

### 2(a) 
Add a column named ”hw07_id” and insert your own student id number (written in your USI badge e.g. 12-345-678) only in the record row where the state is Georgia. Print your result.

In [5]:
modified_a = query7_solution.withColumn(
    "hw07_id", 
    when(col("ca_state") == "GA", "22-981-591").otherwise(None)
)
modified_a.show()

+--------+---+----------+
|ca_state|cnt|   hw07_id|
+--------+---+----------+
|      TX|396|      NULL|
|      GA|247|22-981-591|
|      VA|233|      NULL|
|      IL|205|      NULL|
|      KY|176|      NULL|
|      KS|170|      NULL|
|      NC|164|      NULL|
|      IA|163|      NULL|
|      MO|156|      NULL|
|      AL|139|      NULL|
+--------+---+----------+


### 2(b)
Add another column named ”hw07_surname” and insert your surname only where the count aggregate field attribute is greater than 200. Print your result.

In [7]:
modified_b = modified_a.withColumn(
    "hw07_surname",
    when(col("cnt") > 200, "Feng").otherwise(None)
)
modified_b.show()

+--------+---+----------+------------+
|ca_state|cnt|   hw07_id|hw07_surname|
+--------+---+----------+------------+
|      TX|396|      NULL|        Feng|
|      GA|247|22-981-591|        Feng|
|      VA|233|      NULL|        Feng|
|      IL|205|      NULL|        Feng|
|      KY|176|      NULL|        NULL|
|      KS|170|      NULL|        NULL|
|      NC|164|      NULL|        NULL|
|      IA|163|      NULL|        NULL|
|      MO|156|      NULL|        NULL|
|      AL|139|      NULL|        NULL|
+--------+---+----------+------------+


### 2(c)
Remove records containing null missing value in the ”hw07_surname” column. Print your result.

In [8]:
modified_c = modified_b.filter(col("hw07_surname").isNotNull())
modified_c.show()

+--------+---+----------+------------+
|ca_state|cnt|   hw07_id|hw07_surname|
+--------+---+----------+------------+
|      TX|396|      NULL|        Feng|
|      GA|247|22-981-591|        Feng|
|      VA|233|      NULL|        Feng|
|      IL|205|      NULL|        Feng|
+--------+---+----------+------------+


### 2(d)
Replace the null missing values with the string ”-”. Print your result.

In [9]:
modified_d = modified_b.na.fill({"hw07_surname": "-"})
modified_d.show()

+--------+---+----------+------------+
|ca_state|cnt|   hw07_id|hw07_surname|
+--------+---+----------+------------+
|      TX|396|      NULL|        Feng|
|      GA|247|22-981-591|        Feng|
|      VA|233|      NULL|        Feng|
|      IL|205|      NULL|        Feng|
|      KY|176|      NULL|           -|
|      KS|170|      NULL|           -|
|      NC|164|      NULL|           -|
|      IA|163|      NULL|           -|
|      MO|156|      NULL|           -|
|      AL|139|      NULL|           -|
+--------+---+----------+------------+


## 3. Query 09
Aggregate total amount of sold items over different given types of combinations of customers based on selected groups of marital status, education status, sales price and different combinations of state and sales profit.

In [12]:
store_sales = get_table("store_sales")
date_dim = get_table("date_dim")
customer_address = get_table("customer_address")
store = get_table("store")
customer_demographics = get_table("customer_demographics")

q09_year = 2001

q09_part1_ca_country = "United States"
q09_part1_ca_state_IN = ['KY', 'GA', 'NM']
q09_part1_net_profit_min = 0
q09_part1_net_profit_max = 2000
q09_part1_education_status = "4 yr Degree"
q09_part1_marital_status = "M"
q09_part1_sales_price_min = 100
q09_part1_sales_price_max = 150

q09_part2_ca_country = "United States"
q09_part2_ca_state_IN = ['MT', 'OR', 'IN']
q09_part2_net_profit_min = 150
q09_part2_net_profit_max = 3000
q09_part2_education_status = "4 yr Degree"
q09_part2_marital_status = "M"
q09_part2_sales_price_min = 50
q09_part2_sales_price_max = 200

q09_part3_ca_country = "United States"
q09_part3_ca_state_IN = ['WI', 'MO', 'WV']
q09_part3_net_profit_min = 50
q09_part3_net_profit_max = 25000
q09_part3_education_status = "4 yr Degree"
q09_part3_marital_status = "M"
q09_part3_sales_price_min = 150
q09_part3_sales_price_max = 200

# Join conditions and filtering
query9_solution = store_sales.join(date_dim, store_sales.ss_sold_date_sk == date_dim.d_date_sk) \
    .join(customer_address, store_sales.ss_addr_sk == customer_address.ca_address_sk) \
    .join(store, store.s_store_sk == store_sales.ss_store_sk) \
    .join(customer_demographics, customer_demographics.cd_demo_sk == store_sales.ss_cdemo_sk) \
    .filter(
        (date_dim.d_year == q09_year) &
        (
            (
                (customer_demographics.cd_marital_status == q09_part1_marital_status) &
                (customer_demographics.cd_education_status == q09_part1_education_status) &
                (store_sales.ss_sales_price.between(q09_part1_sales_price_min, q09_part1_sales_price_max))
            ) |
            (
                (customer_demographics.cd_marital_status == q09_part2_marital_status) &
                (customer_demographics.cd_education_status == q09_part2_education_status) &
                (store_sales.ss_sales_price.between(q09_part2_sales_price_min, q09_part2_sales_price_max))
            ) |
            (
                (customer_demographics.cd_marital_status == q09_part3_marital_status) &
                (customer_demographics.cd_education_status == q09_part3_education_status) &
                (store_sales.ss_sales_price.between(q09_part3_sales_price_min, q09_part3_sales_price_max))
            )
        ) &
        (
            (
                (customer_address.ca_country == q09_part1_ca_country) &
                (customer_address.ca_state.isin(q09_part1_ca_state_IN)) &
                (store_sales.ss_net_profit.between(q09_part1_net_profit_min, q09_part1_net_profit_max))
            ) |
            (
                (customer_address.ca_country == q09_part2_ca_country) &
                (customer_address.ca_state.isin(q09_part2_ca_state_IN)) &
                (store_sales.ss_net_profit.between(q09_part2_net_profit_min, q09_part2_net_profit_max))
            ) |
            (
                (customer_address.ca_country == q09_part3_ca_country) &
                (customer_address.ca_state.isin(q09_part3_ca_state_IN)) &
                (store_sales.ss_net_profit.between(q09_part3_net_profit_min, q09_part3_net_profit_max))
            )
        )
    ) \
    .agg(sum("ss_quantity").alias("total_quantity"))

query9_solution.show()


+--------------+
|total_quantity|
+--------------+
|        5900.0|
+--------------+


In [13]:
# check the result
!cat queries/q09/results/q09-result

5900


## 4. Query 20
Customer segmentation for return analysis: Customers are separated along the following dimensions: return frequency, return order ratio (total number of orders partially or fully returned versus the total number of orders), return item ratio (total number of items returned versus the number of items purchased), return amount ration (total monetary amount of items returned versus the amount purchased), return order ratio. Consider the store returns during a given year for the computation.

In [14]:
store_sales = get_table("store_sales")
store_returns = get_table("store_returns")

# Subquery 1: Order statistics
orders = store_sales.groupBy("ss_customer_sk").agg(
    countDistinct("ss_ticket_number").alias("orders_count"),
    count("ss_item_sk").alias("orders_items"),
    sum("ss_net_paid").alias("orders_money")
)

# Subquery 2: Return statistics
returns = store_returns.groupBy("sr_customer_sk").agg(
    countDistinct("sr_ticket_number").alias("returns_count"),
    count("sr_item_sk").alias("returns_items"),
    sum("sr_return_amt").alias("returns_money")
)

# Join subqueries and calculate ratios
query20_solution = orders.join(returns, orders.ss_customer_sk == returns.sr_customer_sk, "left_outer") \
    .select(
        col("ss_customer_sk").alias("user_sk"),
        round(
            when(
                (col("returns_count").isNull()) | (col("orders_count").isNull()) | ((col("returns_count") / col("orders_count")).isNull()), 
                0.0
            ).otherwise(col("returns_count") / col("orders_count")), 7
        ).alias("orderRatio"),
        round(
            when(
                (col("returns_items").isNull()) | (col("orders_items").isNull()) | ((col("returns_items") / col("orders_items")).isNull()), 
                0.0
            ).otherwise(col("returns_items") / col("orders_items")), 7
        ).alias("itemsRatio"),
        round(
            when(
                (col("returns_money").isNull()) | (col("orders_money").isNull()) | ((col("returns_money") / col("orders_money")).isNull()), 
                0.0
            ).otherwise(col("returns_money") / col("orders_money")), 7
        ).alias("monetaryRatio"),
        round(
            when(col("returns_count").isNull(), 0.0).otherwise(col("returns_count")), 0
        ).alias("frequency")
    )

# Sort by user_sk
query20_solution = query20_solution.orderBy("user_sk")

query20_solution.show()

+-------+----------+----------+-------------+---------+
|user_sk|orderRatio|itemsRatio|monetaryRatio|frequency|
+-------+----------+----------+-------------+---------+
|      2|       0.1| 0.0595745|     0.051983|      3.0|
|      3|       0.0|       0.0|          0.0|      0.0|
|      4|       0.0|       0.0|          0.0|      0.0|
|      7| 0.3333333|      0.04|    0.0175497|      1.0|
|     12|       0.0|       0.0|          0.0|      0.0|
|     16|       0.0|       0.0|          0.0|      0.0|
|     17|       1.0| 0.5714286|    0.2275411|      1.0|
|     19|       0.0|       0.0|          0.0|      0.0|
|     22|       0.0|       0.0|          0.0|      0.0|
|     23|       0.5| 0.1666667|    0.1366843|      1.0|
|     30|       0.0|       0.0|          0.0|      0.0|
|     31|       1.0|      0.25|    0.4410076|      1.0|
|     32|       0.0|       0.0|          0.0|      0.0|
|     37|       0.0|       0.0|          0.0|      0.0|
|     38|       0.0|       0.0|          0.0|   

In [15]:
# check the result
!cat queries/q20/results/q20-result-queryonly

+-------+----------+----------+-------------+---------+
|user_sk|orderRatio|itemsRatio|monetaryRatio|frequency|
+-------+----------+----------+-------------+---------+
|      2|       0.1| 0.0595745|     0.051983|        3|
|      3|       0.0|       0.0|          0.0|        0|
|      4|       0.0|       0.0|          0.0|        0|
|      7| 0.3333333|      0.04|    0.0175497|        1|
|     12|       0.0|       0.0|          0.0|        0|
|     16|       0.0|       0.0|          0.0|        0|
|     17|       1.0| 0.5714286|    0.2275411|        1|
|     19|       0.0|       0.0|          0.0|        0|
|     22|       0.0|       0.0|          0.0|        0|
|     23|       0.5| 0.1666667|    0.1366843|        1|
|     30|       0.0|       0.0|          0.0|        0|
|     31|       1.0|      0.25|    0.4410076|        1|
|     32|       0.0|       0.0|          0.0|        0|
|     37|       0.0|       0.0|          0.0|        0|
|     38|       0.0|       0.0|