In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window

# Initialize Spark session
spark = SparkSession.builder.appName("Retail_Transactions").getOrCreate()
spark 

<pyspark.sql.connect.session.SparkSession at 0x76d8e8d2ac90>

In [0]:
# 1. Load retail_data.csv and display schema
retail_df = spark.read.csv("/FileStore/tables/retail_data.csv", header=True)
retail_df.printSchema()

# 2. Infer schema as False, then manually cast columns
retail_df_no_infer = spark.read.csv("/FileStore/tables/retail_data.csv", header=True, inferSchema=False)
retail_df_manual = retail_df_no_infer.withColumn("Quantity", col("Quantity").cast(IntegerType())) \
    .withColumn("UnitPrice", col("UnitPrice").cast(IntegerType())) \
    .withColumn("TotalPrice", col("TotalPrice").cast(IntegerType())) \
    .withColumn("TransactionDate", col("TransactionDate").cast(DateType()))
retail_df_manual.printSchema()

root
 |-- TransactionID: string (nullable = true)
 |-- Customer: string (nullable = true)
 |-- City: string (nullable = true)
 |-- Product: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Quantity: string (nullable = true)
 |-- UnitPrice: string (nullable = true)
 |-- TotalPrice: string (nullable = true)
 |-- TransactionDate: string (nullable = true)
 |-- PaymentMode: string (nullable = true)

root
 |-- TransactionID: string (nullable = true)
 |-- Customer: string (nullable = true)
 |-- City: string (nullable = true)
 |-- Product: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- UnitPrice: integer (nullable = true)
 |-- TotalPrice: integer (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- PaymentMode: string (nullable = true)



In [0]:
# 3. Filter transactions where TotalPrice > 40000
high_value = retail_df.filter(col("TotalPrice") > 40000)
high_value.show()

# 4. Get unique cities from the dataset
unique_cities = retail_df.select("City").distinct()
unique_cities.show()

# 5. Find all transactions from "Delhi" using .filter() and .where()
delhi_filter = retail_df.filter(col("City") == "Delhi")
delhi_where = retail_df.where(col("City") == "Delhi")
delhi_filter.show()
delhi_where.show()

+-------------+--------+---------+-------+-----------+--------+---------+----------+---------------+-----------+
|TransactionID|Customer|     City|Product|   Category|Quantity|UnitPrice|TotalPrice|TransactionDate|PaymentMode|
+-------------+--------+---------+-------+-----------+--------+---------+----------+---------------+-----------+
|        T1002|    Neha|Bangalore| Tablet|Electronics|       2|    30000|     60000|     2024-01-20|        UPI|
|        T1005|   Karan|   Mumbai|  Phone|Electronics|       1|    50000|     50000|     2024-02-15|       Card|
|        T1001|     Ali|   Mumbai| Laptop|Electronics|       1|    70000|     70000|     2024-01-15|       Card|
+-------------+--------+---------+-------+-----------+--------+---------+----------+---------------+-----------+

+---------+
|     City|
+---------+
|Hyderabad|
|Bangalore|
|   Mumbai|
|    Delhi|
+---------+

+-------------+--------+-----+-------+-----------+--------+---------+----------+---------------+-----------+
|T

In [0]:
# 6. Add a column DiscountedPrice = TotalPrice - 10%
retail_df = retail_df.withColumn("DiscountedPrice", col("TotalPrice") * 0.9)
retail_df.show()

# 7. Rename TransactionDate to TxnDate
retail_df = retail_df.withColumnRenamed("TransactionDate", "TxnDate")
retail_df.show()

# 8. Drop the column UnitPrice
retail_df = retail_df.drop("UnitPrice")
retail_df.show()

+-------------+--------+---------+-------+-----------+--------+----------+----------+-----------+---------------+
|TransactionID|Customer|     City|Product|   Category|Quantity|TotalPrice|   TxnDate|PaymentMode|DiscountedPrice|
+-------------+--------+---------+-------+-----------+--------+----------+----------+-----------+---------------+
|        T1003|    Ravi|Hyderabad|   Desk|  Furniture|       1|     15000|2024-02-10|Net Banking|        13500.0|
|        T1002|    Neha|Bangalore| Tablet|Electronics|       2|     60000|2024-01-20|        UPI|        54000.0|
|        T1005|   Karan|   Mumbai|  Phone|Electronics|       1|     50000|2024-02-15|       Card|        45000.0|
|        T1001|     Ali|   Mumbai| Laptop|Electronics|       1|     70000|2024-01-15|       Card|        63000.0|
|        T1006|   Farah|    Delhi|  Mouse|Electronics|       3|      3000|2024-02-18|       Cash|         2700.0|
|        T1004|    Zoya|    Delhi|  Chair|  Furniture|       4|     20000|2024-02-12|   

In [0]:
# 9. Get total sales by city
total_sales_by_city = retail_df.groupBy("City").agg(sum("TotalPrice").alias("TotalSales"))
total_sales_by_city.show()

# 10. Get average unit price by category
avg_price_by_category = retail_df.groupBy("Category").agg(
    (sum("TotalPrice")/sum("Quantity")).alias("AvgUnitPrice"))
avg_price_by_category.show()

# 11. Count of transactions grouped by PaymentMode
txn_count_by_payment = retail_df.groupBy("PaymentMode").count()
txn_count_by_payment.show()

+---------+----------+
|     City|TotalSales|
+---------+----------+
|Hyderabad|   15000.0|
|Bangalore|   60000.0|
|   Mumbai|  120000.0|
|    Delhi|   23000.0|
+---------+----------+

+-----------+-----------------+
|   Category|     AvgUnitPrice|
+-----------+-----------------+
|  Furniture|           7000.0|
|Electronics|26142.85714285714|
+-----------+-----------------+

+-----------+-----+
|PaymentMode|count|
+-----------+-----+
|Net Banking|    1|
|        UPI|    1|
|       Card|    3|
|       Cash|    1|
+-----------+-----+



In [0]:
# 12. Use a window partitioned by City to rank transactions by TotalPrice
window_spec = Window.partitionBy("City").orderBy(col("TotalPrice").desc())
ranked_df = retail_df.withColumn("rank", rank().over(window_spec))
ranked_df.show()

# 13. Use lag function to get previous transaction amount per city
window_spec_lag = Window.partitionBy("City").orderBy("TxnDate")
lag_df = retail_df.withColumn("prev_amount", lag("TotalPrice").over(window_spec_lag))
lag_df.show()

+-------------+--------+---------+-------+-----------+--------+----------+----------+-----------+---------------+----+
|TransactionID|Customer|     City|Product|   Category|Quantity|TotalPrice|   TxnDate|PaymentMode|DiscountedPrice|rank|
+-------------+--------+---------+-------+-----------+--------+----------+----------+-----------+---------------+----+
|        T1002|    Neha|Bangalore| Tablet|Electronics|       2|     60000|2024-01-20|        UPI|        54000.0|   1|
|        T1006|   Farah|    Delhi|  Mouse|Electronics|       3|      3000|2024-02-18|       Cash|         2700.0|   1|
|        T1004|    Zoya|    Delhi|  Chair|  Furniture|       4|     20000|2024-02-12|       Card|        18000.0|   2|
|        T1003|    Ravi|Hyderabad|   Desk|  Furniture|       1|     15000|2024-02-10|Net Banking|        13500.0|   1|
|        T1001|     Ali|   Mumbai| Laptop|Electronics|       1|     70000|2024-01-15|       Card|        63000.0|   1|
|        T1005|   Karan|   Mumbai|  Phone|Electr

In [0]:
# 14. Create city_region DataFrame
city_region_data = [
    ("Mumbai", "West"),
    ("Delhi", "North"),
    ("Bangalore", "South"),
    ("Hyderabad", "South")
]
city_region_df = spark.createDataFrame(city_region_data, ["City", "Region"])

# 15. Join with main DataFrame and group total sales by Region
joined_df = retail_df.join(city_region_df, "City", "left")
sales_by_region = joined_df.groupBy("Region").agg(sum("TotalPrice").alias("TotalSales"))
sales_by_region.show()

+------+----------+
|Region|TotalSales|
+------+----------+
| South|   75000.0|
|  West|  120000.0|
| North|   23000.0|
+------+----------+



In [0]:
# 16. Introduce some nulls and replace them with default values
# First create a copy with nulls
from pyspark.sql.functions import when, rand

retail_with_nulls = retail_df.withColumn("Quantity", 
    when(rand() > 0.7, None).otherwise(col("Quantity"))) \
    .withColumn("PaymentMode", 
    when(rand() > 0.7, None).otherwise(col("PaymentMode")))

# Replace nulls with defaults
retail_filled = retail_with_nulls.fillna({
    "Quantity": 1,
    "PaymentMode": "Unknown"
})
retail_filled.show()

# 17. Drop rows where Quantity is null
retail_no_null_qty = retail_with_nulls.na.drop(subset=["Quantity"])
retail_no_null_qty.show()

# 18. Fill null PaymentMode with "Unknown"
retail_filled_payment = retail_with_nulls.fillna("Unknown", subset=["PaymentMode"])
retail_filled_payment.show()

+-------------+--------+---------+-------+-----------+--------+----------+----------+-----------+---------------+
|TransactionID|Customer|     City|Product|   Category|Quantity|TotalPrice|   TxnDate|PaymentMode|DiscountedPrice|
+-------------+--------+---------+-------+-----------+--------+----------+----------+-----------+---------------+
|        T1003|    Ravi|Hyderabad|   Desk|  Furniture|       1|     15000|2024-02-10|Net Banking|        13500.0|
|        T1002|    Neha|Bangalore| Tablet|Electronics|       2|     60000|2024-01-20|    Unknown|        54000.0|
|        T1005|   Karan|   Mumbai|  Phone|Electronics|       1|     50000|2024-02-15|    Unknown|        45000.0|
|        T1001|     Ali|   Mumbai| Laptop|Electronics|       1|     70000|2024-01-15|       Card|        63000.0|
|        T1006|   Farah|    Delhi|  Mouse|Electronics|       3|      3000|2024-02-18|       Cash|         2700.0|
|        T1004|    Zoya|    Delhi|  Chair|  Furniture|       4|     20000|2024-02-12|   

In [0]:
# 19. Write a UDF to label orders
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# 1. First ensure TotalPrice is numeric
retail_df = retail_df.withColumn("TotalPrice", col("TotalPrice").cast("integer"))

# 2. Define the UDF
def label_order(amount):
    if amount > 50000: 
        return "High"
    elif amount >= 30000: 
        return "Medium"
    else: 
        return "Low"

label_order_udf = udf(label_order, StringType())

# 3. Apply the UDF
retail_df = retail_df.withColumn("OrderLabel", label_order_udf(col("TotalPrice")))

# 4. Show results
retail_df.select("TransactionID", "TotalPrice", "OrderLabel").show()

+-------------+----------+----------+
|TransactionID|TotalPrice|OrderLabel|
+-------------+----------+----------+
|        T1003|     15000|       Low|
|        T1002|     60000|      High|
|        T1005|     50000|    Medium|
|        T1001|     70000|      High|
|        T1006|      3000|       Low|
|        T1004|     20000|       Low|
+-------------+----------+----------+



In [0]:
# 20. Extract year, month, and day from TxnDate
retail_df = retail_df.withColumn("Year", year(col("TxnDate"))) \
    .withColumn("Month", month(col("TxnDate"))) \
    .withColumn("Day", dayofmonth(col("TxnDate")))
retail_df.show()

# 21. Filter transactions that happened in February
feb_txns = retail_df.filter(month(col("TxnDate")) == 2)
feb_txns.show()

+-------------+--------+---------+-------+-----------+--------+----------+----------+-----------+---------------+----------+----+-----+---+
|TransactionID|Customer|     City|Product|   Category|Quantity|TotalPrice|   TxnDate|PaymentMode|DiscountedPrice|OrderLabel|Year|Month|Day|
+-------------+--------+---------+-------+-----------+--------+----------+----------+-----------+---------------+----------+----+-----+---+
|        T1003|    Ravi|Hyderabad|   Desk|  Furniture|       1|     15000|2024-02-10|Net Banking|        13500.0|       Low|2024|    2| 10|
|        T1002|    Neha|Bangalore| Tablet|Electronics|       2|     60000|2024-01-20|        UPI|        54000.0|      High|2024|    1| 20|
|        T1005|   Karan|   Mumbai|  Phone|Electronics|       1|     50000|2024-02-15|       Card|        45000.0|    Medium|2024|    2| 15|
|        T1001|     Ali|   Mumbai| Laptop|Electronics|       1|     70000|2024-01-15|       Card|        63000.0|      High|2024|    1| 15|
|        T1006|   Fa

In [0]:
# 22. Duplicate the DataFrame using union() and remove duplicates
duplicated_df = retail_df.union(retail_df)
deduped_df = duplicated_df.dropDuplicates()
print(f"Original count: {retail_df.count()}, Duplicated count: {duplicated_df.count()}, Deduped count: {deduped_df.count()}")

Original count: 6, Duplicated count: 12, Deduped count: 6
