In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, lag, year, month, dayofmonth
from pyspark.sql.window import Window
from pyspark.sql.types import *

spark = SparkSession.builder.appName("RetailTransactions").getOrCreate()

Basics

1. Load retail_data.csv into a PySpark DataFrame and display schema.

In [0]:
df = spark.read.option("header", True).option("inferSchema", True).csv("file:/Workspace/Shared/retail_data.csv")
df.printSchema()

root
 |-- TransactionID: string (nullable = true)
 |-- Customer: string (nullable = true)
 |-- City: string (nullable = true)
 |-- Product: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- UnitPrice: integer (nullable = true)
 |-- TotalPrice: integer (nullable = true)
 |-- TransactionDa: date (nullable = true)



2. Infer schema as False — then manually cast columns.

In [0]:
df_manual = spark.read.option("header", True).option("inferSchema", False).csv("file:/Workspace/Shared/retail_data.csv")

df_casted = df_manual.select(
    col("TransactionID"),
    col("Customer"),
    col("City"),
    col("Product"),
    col("Category"),
    col("Quantity").cast("int"),
    col("UnitPrice").cast("double"),
    col("TotalPrice").cast("double"),
    col("TransactionDate").cast("date"),
    col("PaymentMode")
    )

Data Exploration & Filtering

3. Filter transactions where TotalPrice > 40000 .

In [0]:
df_casted.filter(col("TotalPrice") > 40000).show()

+-------------+--------+---------+-------+-----------+--------+---------+----------+-------------+
|TransactionID|Customer|     City|Product|   Category|Quantity|UnitPrice|TotalPrice|TransactionDa|
+-------------+--------+---------+-------+-----------+--------+---------+----------+-------------+
|        T1001|     Ali|   Mumbai| Laptop|Electronics|       1|  70000.0|   70000.0|   2024-01-15|
|        T1002|    Neha|Bangalore| Tablet|Electronics|       2|  30000.0|   60000.0|   2024-01-20|
|        T1005|   Karan|   Mumbai|  Phone|Electronics|       1|  50000.0|   50000.0|   2024-02-15|
+-------------+--------+---------+-------+-----------+--------+---------+----------+-------------+



4. Get unique cities from the dataset.

In [0]:
df_casted.select("City").distinct().show()

+---------+
|     City|
+---------+
|Bangalore|
|   Mumbai|
|    Delhi|
|Hyderabad|
+---------+



5. Find all transactions from "Delhi" using .filter() and .where()

In [0]:
df_casted.filter(col("City") == "Delhi").show()
df_casted.where(col("City") == "Delhi").show()

+-------------+--------+-----+-------+-----------+--------+---------+----------+-------------+
|TransactionID|Customer| City|Product|   Category|Quantity|UnitPrice|TotalPrice|TransactionDa|
+-------------+--------+-----+-------+-----------+--------+---------+----------+-------------+
|        T1004|    Zoya|Delhi|  Chair|  Furniture|       4|   5000.0|   20000.0|   2024-02-12|
|        T1006|   Farah|Delhi|  Mouse|Electronics|       3|   1000.0|    3000.0|   2024-02-18|
+-------------+--------+-----+-------+-----------+--------+---------+----------+-------------+

+-------------+--------+-----+-------+-----------+--------+---------+----------+-------------+
|TransactionID|Customer| City|Product|   Category|Quantity|UnitPrice|TotalPrice|TransactionDa|
+-------------+--------+-----+-------+-----------+--------+---------+----------+-------------+
|        T1004|    Zoya|Delhi|  Chair|  Furniture|       4|   5000.0|   20000.0|   2024-02-12|
|        T1006|   Farah|Delhi|  Mouse|Electronics

Data Manipulation

6. Add a column DiscountedPrice = TotalPrice - 10%.

In [0]:
df_discounted = df_casted.withColumn("DiscountedPrice", col("TotalPrice") * 0.9)

7. Rename TransactionDate to TxnDate .

In [0]:
df_renamed = df_discounted.withColumnRenamed("TransactionDate", "TxnDate")

8. Drop the column UnitPrice .

In [0]:
df_dropped = df_renamed.drop("UnitPrice")

Aggregations

9. Get total sales by city.

In [0]:
df_dropped.groupBy("City").sum("TotalPrice").withColumnRenamed("sum(TotalPrice)", "TotalSales").show()

+---------+----------+
|     City|TotalSales|
+---------+----------+
|Bangalore|   60000.0|
|   Mumbai|  120000.0|
|    Delhi|   23000.0|
|Hyderabad|   15000.0|
+---------+----------+



10. Get average unit price by category.

In [0]:
df_casted.groupBy("Category").avg("UnitPrice").withColumnRenamed("avg(UnitPrice)", "AvgUnitPrice").show()

+-----------+------------+
|   Category|AvgUnitPrice|
+-----------+------------+
|Electronics|     37750.0|
|  Furniture|     10000.0|
+-----------+------------+



11. Count of transactions grouped by PaymentMode.

In [0]:
df_casted.groupBy("PaymentMode").count().show()

+-----------+-----+
|PaymentMode|count|
+-----------+-----+
|Net Banking|    1|
|       Card|    3|
|       Cash|    1|
|        UPI|    1|
+-----------+-----+



Window Functions

12. Use a window partitioned by City to rank transactions by TotalPrice .

In [0]:
from pyspark.sql.functions import col, rank
from pyspark.sql.window import Window

window_spec = Window.partitionBy("City").orderBy(col("TotalPrice").desc())
df_casted.withColumn("Rank", rank().over(window_spec)).show()

13. Use lag function to get previous transaction amount per city.

In [0]:
from pyspark.sql.functions import lag
from pyspark.sql.window import Window

window_spec = Window.partitionBy("City").orderBy("TotalPrice")

df_with_lag = df_casted.withColumn("PrevTransaction", lag("TotalPrice", 1).over(window_spec))

df_with_lag.select("TransactionID", "City", "TotalPrice", "PrevTransaction").show()


Joins

14. Create a second DataFrame city_region

In [0]:
data = [("Mumbai", "West"), ("Delhi", "North"), ("Bangalore", "South"), ("Hyderabad", "South")]
columns = ["City", "Region"]
city_region = spark.createDataFrame(data, columns)

15. Join with main DataFrame and group total sales by Region.

In [0]:
joined_df = df_casted.join(city_region, on="City", how="left")
joined_df.groupBy("Region").sum("TotalPrice").withColumnRenamed("sum(TotalPrice)", "TotalSales").show()

Nulls and Data Cleaning

16. Introduce some nulls and replace them with default values.

In [0]:
df_nulls = df_casted.withColumn("Quantity", when(col("TransactionID") == "T1003", None).otherwise(col("Quantity")))
df_filled = df_nulls.fillna({"Quantity": 0})
df_filled.show()

17. Drop rows where Quantity is null.

In [0]:
df_casted.dropna(subset=["Quantity"]).show()

18. Fill null PaymentMode with "Unknown".

In [0]:
df_filled_pm = df_casted.fillna({"PaymentMode": "Unknown"})

Custom Functions

19. Write a UDF to label orders

In [0]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

def label_order(amount):
    if amount > 50000:
        return "High"
    elif amount >= 30000:
        return "Medium"
    else:
        return "Low"

label_udf = udf(label_order, StringType())

df_labeled = df_casted.withColumn("OrderLabel", label_udf(col("TotalPrice")))
df_labeled.select("TransactionID", "TotalPrice", "OrderLabel").show()

Date & Time

20. Extract year, month, and day from TxnDate .

In [0]:
df_casted.withColumn("Year", year("TransactionDate")) \
         .withColumn("Month", month("TransactionDate")) \
         .withColumn("Day", dayofmonth("TransactionDate")) \
         .show()

21. Filter transactions that happened in February.

In [0]:
df_casted.filter(month("TransactionDate") == 2).show()

Union & Duplicate Handling

22. Duplicate the DataFrame using union() and remove duplicates.

In [0]:
df_union = df_casted.union(df_casted)
df_unique = df_union.dropDuplicates()
df_unique.show()