In [0]:
spark

In [0]:
# Basics
from pyspark.sql.functions import *
from pyspark.sql.window import Window
# 1. Load retail_data.csv into a PySpark DataFrame and display schema.
retaildf=spark.read.csv("file:/Workspace/Shared/retail_data_17.csv",header=True,inferSchema=True)
retaildf.printSchema()
# 2. Infer schema as False — then manually cast columns.
df_raw = spark.read.option("header", True).option("inferSchema", False).csv("file:/Workspace/Shared/retail_data_17.csv")
df_casted = df_raw.withColumn("Quantity", col("Quantity").cast("int")) \
                  .withColumn("UnitPrice", col("UnitPrice").cast("double")) \
                  .withColumn("TotalPrice", col("TotalPrice").cast("double")) \
                  .withColumn("TransactionDate", to_date(col("TransactionDate")))\
                  .withColumn("PaymentMode", col("PaymentMode").cast("string"))
df_casted.printSchema()
# Data Exploration & Filtering
# 3. Filter transactions where TotalPrice > 40000 .
retaildf.filter(col("TotalPrice")>40000)
# 4. Get unique cities from the dataset.
retaildf.select("City").distinct().show()
# 5. Find all transactions from "Delhi" using .filter() and .where() .
retaildf.filter(col("City")=="Delhi").show()
retaildf.where(col("City")=="Delhi").show()
# Data Manipulation
# 6. Add a column DiscountedPrice = TotalPrice - 10%.
retaildf.withColumn("DiscountedPrice",col("TotalPrice")*0.9)
# 7. Rename TransactionDate to TxnDate .
retaildf.withColumnRenamed("TransactionDate","TxnDate")
# 8. Drop the column UnitPrice .
retaildf.drop("UnitPrice")
# Aggregations
# 9. Get total sales by city.
retaildf.groupBy("City").sum("TotalPrice").show()
# 10. Get average unit price by category.
retaildf.groupBy("Category").avg("UnitPrice").show()
# 11. Count of transactions grouped by PaymentMode.
retaildf.groupBy("PaymentMode").count().show()
# Window Functions
# 12. Use a window partitioned by City to rank transactions by TotalPrice .
w=Window.partitionBy("City").orderBy(col("TotalPrice").desc())
retaildf.withColumn("rank",rank().over(w)).show()
# 13. Use lag function to get previous transaction amount per city.
w1 = Window.partitionBy("City").orderBy("TransactionDate")
retaildf.withColumn("PrevTxnAmt", lag("TotalPrice").over(w1)).show()
# Joins
# 14. Create a second DataFrame city_region :
# City,Region
# Mumbai,West
# Delhi,North
# Bangalore,South
# Hyderabad,South
city_region = spark.createDataFrame([
    ("Mumbai", "West"),
    ("Delhi", "North"),
    ("Bangalore", "South"),
    ("Hyderabad", "South")
], ["City", "Region"])

joined = retaildf.join(city_region, "City", "left")
joined.show()
# 15. Join with main DataFrame and group total sales by Region.
joined.groupBy("Region").sum("TotalPrice").show()
# Nulls and Data Cleaning
# 16. Introduce some nulls and replace them with default values.
null = retaildf.withColumn("PaymentMode", when(col("PaymentMode") == "Cash", None).otherwise(col("PaymentMode")))
filled = null.fillna({"PaymentMode": "Unknown"})
# 17. Drop rows where Quantity is null.
filled.dropna(subset=["Quantity"]).show()
# 18. Fill null PaymentMode with "Unknown".
retaildf.fillna({"PaymentMode": "Unknown"}).show()
# Custom Functions
# 19. Write a UDF to label orders:
# def label_order(amount):
# if amount > 50000: return "High"
# elif amount >= 30000: return "Medium"
# else: return "Low"
# Apply this to classify TotalPrice .
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

def label_order(amount):
    if amount > 50000:
        return "High"
    elif amount >= 30000:
        return "Medium"
    else:
        return "Low"

label_udf = udf(label_order, StringType())
retaildf = retaildf.withColumn("OrderLabel", label_udf(col("TotalPrice")))
retaildf.select("TransactionID", "TotalPrice", "OrderLabel").show()

# Date & Time
# 20. Extract year, month, and day from TxnDate .
retaildf = retaildf.withColumn("Year", year("TransactionDate")) \
       .withColumn("Month", month("TransactionDate")) \
       .withColumn("Day", dayofmonth("TransactionDate"))

# 21. Filter transactions that happened in February.
retaildf.filter(col("Month") == 2).show()
# Union & Duplicate Handling
# 22. Duplicate the DataFrame using union() and remove duplicates.
retaildup=retaildf.union(retaildf)
retaildup.dropDuplicates()

root
 |-- TransactionID: string (nullable = true)
 |-- Customer: string (nullable = true)
 |-- City: string (nullable = true)
 |-- Product: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- UnitPrice: integer (nullable = true)
 |-- TotalPrice: integer (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- PaymentMode: string (nullable = true)

root
 |-- TransactionID: string (nullable = true)
 |-- Customer: string (nullable = true)
 |-- City: string (nullable = true)
 |-- Product: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- TotalPrice: double (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- PaymentMode: string (nullable = true)

+---------+
|     City|
+---------+
|Bangalore|
|   Mumbai|
|    Delhi|
|Hyderabad|
+---------+

+-------------+--------+-----+-------+-----------+--------+------

DataFrame[TransactionID: string, Customer: string, City: string, Product: string, Category: string, Quantity: int, UnitPrice: int, TotalPrice: int, TransactionDate: date, PaymentMode: string, OrderLabel: string, Year: int, Month: int, Day: int]