In [0]:
from pyspark.sql import SparkSession
import time
from pyspark.sql.functions import (
    col, count, avg, sum as spark_sum, max as spark_max, min as spark_min,
    desc, when, year, month, dayofmonth, round, lit, datediff, current_date,
    countDistinct
)

In [0]:

# loading data

print("Loading Daily Market Prices of Commodity in India from 2022 to 2025...")

df = spark.table("commodities_india")

print(f"Data loaded successfully")
print(f"   • Total rows: {df.count():,}")
print(f"   • Columns: {len(df.columns)}")
print()

# show schema
print("Schema:")
df.printSchema()
print()

# preview data
print("First 10 rows:")
df.show(10)

# show column names and types
print("Column Information:")
for i, (col_name, dtype) in enumerate(df.dtypes, 1):
    print(f"   {i:2d}. {col_name:20s} → {dtype}")
print()


# show missing values
print("Missing values:")
null_counts = df.select([
    spark_sum(when(col(c).isNull(), 1).otherwise(0)).alias(c) 
    for c in df.columns
])
null_counts.show()
print()

# show any duplicates
total_rows = df.count()
unique_rows = df.dropDuplicates().count()
duplicate_rows = total_rows - unique_rows

print(f"Duplicate Check:")
print(f"   • Total rows: {total_rows:,}")
print(f"   • Unique rows: {unique_rows:,}")
print(f"   • Duplicate rows: {duplicate_rows:,}")
print()


Loading Daily Market Prices of Commodity in India from 2022 to 2025...
Data loaded successfully
   • Total rows: 20,090,620
   • Columns: 11

Schema:
root
 |-- State: string (nullable = true)
 |-- District: string (nullable = true)
 |-- Market: string (nullable = true)
 |-- Commodity: string (nullable = true)
 |-- Variety: string (nullable = true)
 |-- Grade: string (nullable = true)
 |-- Arrival_Date: date (nullable = true)
 |-- Min_Price: double (nullable = true)
 |-- Max_Price: double (nullable = true)
 |-- Modal_Price: double (nullable = true)
 |-- Commodity_Code: long (nullable = true)


First 10 rows:
+--------------+------------+--------------------+--------------+-------------+-----+------------+---------+---------+-----------+--------------+
|         State|    District|              Market|     Commodity|      Variety|Grade|Arrival_Date|Min_Price|Max_Price|Modal_Price|Commodity_Code|
+--------------+------------+--------------------+--------------+-------------+-----+--------

In [0]:
# filter 1

print("FILTER 1: REMOVING INVALID PRICES")

# Filter out rows where prices are null or negative
df_filtered = df.filter(
    (col("Min_Price").isNotNull()) &
    (col("Max_Price").isNotNull()) &
    (col("Modal_Price").isNotNull()) &
    (col("Min_Price") > 0) &
    (col("Max_Price") > 0) &
    (col("Modal_Price") > 0)
)

rows_after_filter1 = df_filtered.count()
rows_removed = df.count() - rows_after_filter1

print(f"Filter applied: Remove invalid prices")
print(f"   • Rows removed: {rows_removed:,}")
print(f"   • Rows remaining: {rows_after_filter1:,}")
print()

# filter 2

print("FILTER 2: REMOVING DUPLICATE ROWS")

# checking first
total_before = df_filtered.count()
unique_before = df_filtered.dropDuplicates().count()
duplicates = total_before - unique_before

print(f"Before removing duplicates:")
print(f"   • Total rows: {total_before:,}")
print(f"   • Unique rows: {unique_before:,}")
print(f"   • Duplicate rows: {duplicates:,}")
print()

df_filtered = df_filtered.dropDuplicates()

rows_after_filter2 = df_filtered.count()

print(f"Filter applied: Remove duplicate rows")
print(f"   • Rows removed: {duplicates:,}")
print(f"   • Rows remaining: {rows_after_filter2:,}")
print()


FILTER 1: REMOVING INVALID PRICES
Filter applied: Remove invalid prices
   • Rows removed: 30,869
   • Rows remaining: 20,059,751

FILTER 2: REMOVING DUPLICATE ROWS
Before removing duplicates:
   • Total rows: 20,059,751
   • Unique rows: 20,059,718
   • Duplicate rows: 33

Filter applied: Remove duplicate rows
   • Rows removed: 33
   • Rows remaining: 20,059,718



In [0]:
# join operation using regional data

print("JOIN OPERATION: Mapping States to Regions")

# create a state region lookup table
# group states in India by geographic region
state_region_data = [
    # South
    ("Andhra Pradesh", "South", "Southern Region"),
    ("Telangana", "South", "Southern Region"),
    ("Karnataka", "South", "Southern Region"),
    ("Kerala", "South", "Southern Region"),
    ("Tamil Nadu", "South", "Southern Region"),
    ("Puducherry", "South", "Southern Region"),
    ("Pondicherry", "South", "Southern Region"),
    ("Andaman and Nicobar", "South", "Southern Region"),
    
    # West
    ("Maharashtra", "West", "Western Region"),
    ("Gujarat", "West", "Western Region"),
    ("Goa", "West", "Western Region"),
    ("Rajasthan", "West", "Western Region"),
    ("Daman and Diu", "West", "Western Region"),
    ("Dadra and Nagar Haveli", "West", "Western Region"),
    
    # North
    ("Uttar Pradesh", "North", "Northern Region"),
    ("Punjab", "North", "Northern Region"),
    ("Haryana", "North", "Northern Region"),
    ("Delhi", "North", "Northern Region"),
    ("NCT of Delhi", "North", "Northern Region"), 
    ("Himachal Pradesh", "North", "Northern Region"),
    ("Uttarakhand", "North", "Northern Region"),
    ("Uttrakhand", "North", "Northern Region"),
    ("Jammu and Kashmir", "North", "Northern Region"),
    ("Chandigarh", "North", "Northern Region"),
    
    # East
    ("West Bengal", "East", "Eastern Region"),
    ("Bihar", "East", "Eastern Region"),
    ("Jharkhand", "East", "Eastern Region"),
    ("Odisha", "East", "Eastern Region"),
    
    # Northeast
    ("Assam", "Northeast", "Northeastern Region"),
    ("Manipur", "Northeast", "Northeastern Region"),
    ("Meghalaya", "Northeast", "Northeastern Region"),
    ("Tripura", "Northeast", "Northeastern Region"),
    ("Nagaland", "Northeast", "Northeastern Region"),
    ("Mizoram", "Northeast", "Northeastern Region"),
    ("Arunachal Pradesh", "Northeast", "Northeastern Region"),
    ("Sikkim", "Northeast", "Northeastern Region"),
    
    # Central
    ("Madhya Pradesh", "Central", "Central Region"),
    ("Chhattisgarh", "Central", "Central Region"),
    ("Chattisgarh", "Central", "Central Region")
]

# creating dataframe
state_region_df = spark.createDataFrame(
    state_region_data, 
    ["State", "Region", "Region_Full_Name"]
)


print("Regions in lookup table:")
state_region_df.select("Region", "Region_Full_Name").distinct().orderBy("Region").show(truncate=False)
print()

# perform join operation
print("LEFT JOIN")

print("Joining main data with region lookup...")
df_with_region = df_filtered.join(
    state_region_df,
    on="State",
    how="left" 
)

print(f"Join complete")
print(f"   • Records before join: {df_filtered.count():,}")
print(f"   • Records after join: {df_with_region.count():,}")
print()

# checking for unmatched states just in case 
unmatched = df_with_region.filter(col("Region").isNull())
unmatched_count = unmatched.count()

if unmatched_count > 0:
    print(f"Warning: {unmatched_count:,} records didn't match (Region is NULL)")
    print("   States without region mapping:")
    unmatched.select("State").distinct().show(10, truncate=False)
    print()
else:
    print("All records matched")
    print()

# sample of joined data
print("Sample data with region:")
df_with_region.select(
    "State", "Region", "District", "Commodity", 
    "Modal_Price", "Arrival_Date"
).show(15, truncate=False)
print()


# look at join results

print("JOIN VERIFICATION")

# counting records by region
print("Records by region:")
region_counts = df_with_region \
    .filter(col("Region").isNotNull()) \
    .groupBy("Region", "Region_Full_Name") \
    .agg(
        count("*").alias("total_records"),
        countDistinct("State").alias("num_states")
    ) \
    .orderBy(desc("total_records"))

region_counts.show(truncate=False)
print()

# aggregate operation: Group by region

print("AGGREGATION AFTER JOIN: REGIONAL ANALYSIS")

regional_stats = df_with_region \
    .filter(col("Region").isNotNull()) \
    .groupBy("Region", "Region_Full_Name") \
    .agg(
        count("*").alias("total_records"),
        countDistinct("State").alias("num_states"),           
        countDistinct("Commodity").alias("unique_commodities"), 
        countDistinct("Market").alias("unique_markets"), 
        round(avg("Modal_Price"), 2).alias("avg_modal_price"),
        round(spark_min("Modal_Price"), 2).alias("min_price"),
        round(spark_max("Modal_Price"), 2).alias("max_price"),
        round(avg("Min_Price"), 2).alias("avg_min_price"),
        round(avg("Max_Price"), 2).alias("avg_max_price")
    ) \
    .orderBy(desc("avg_modal_price"))

print("Regional price comparison:")
regional_stats.show(truncate=False)
print()

JOIN OPERATION: Mapping States to Regions
Regions in lookup table:
+---------+-------------------+
|Region   |Region_Full_Name   |
+---------+-------------------+
|Central  |Central Region     |
|East     |Eastern Region     |
|North    |Northern Region    |
|Northeast|Northeastern Region|
|South    |Southern Region    |
|West     |Western Region     |
+---------+-------------------+


LEFT JOIN
Joining main data with region lookup...
Join complete
   • Records before join: 20,059,718
   • Records after join: 20,059,718

All records matched

Sample data with region:
+--------------+-------+------------+----------------------+-----------+------------+
|State         |Region |District    |Commodity             |Modal_Price|Arrival_Date|
+--------------+-------+------------+----------------------+-----------+------------+
|Maharashtra   |West   |Satara      |Tomato                |2000.0     |2023-08-27  |
|Rajasthan     |West   |Hanumangarh |Onion                 |2400.0     |2023-08-27 

In [0]:
# column transformations using withColumn

df_transformed = df_with_region

df_transformed = df_transformed \
    .withColumn("Avg_Price", (col("Min_Price") + col("Max_Price")) / 2)

df_transformed = df_transformed \
    .withColumn("Price_Volatility_Pct", 
                round(((col("Max_Price") - col("Min_Price")) / col("Modal_Price")) * 100, 2))

# Show sample
df_transformed.select(
    "Commodity", "Min_Price", "Max_Price", "Modal_Price", 
    "Avg_Price", "Price_Volatility_Pct"
).show(10)


+--------------------+---------+---------+-----------+---------+--------------------+
|           Commodity|Min_Price|Max_Price|Modal_Price|Avg_Price|Price_Volatility_Pct|
+--------------------+---------+---------+-----------+---------+--------------------+
|              Tomato|   1000.0|   2000.0|     2000.0|   1500.0|                50.0|
|               Onion|   2400.0|   2400.0|     2400.0|   2400.0|                 0.0|
|        Bitter gourd|   2590.0|   2690.0|     2640.0|   2640.0|                3.79|
|             Cabbage|    400.0|   1800.0|     1000.0|   1100.0|               140.0|
|Little gourd (Kun...|   1400.0|   1800.0|     1600.0|   1600.0|                25.0|
|           Colacasia|   1200.0|   1500.0|     1500.0|   1350.0|                20.0|
|Bhindi (Ladies Fi...|   3000.0|   3000.0|     3000.0|   3000.0|                 0.0|
|                Rice|   2730.0|   2830.0|     2780.0|   2780.0|                 3.6|
|              Potato|   1130.0|   1340.0|     1235.0|

In [0]:
# sql queries

df_transformed.createOrReplaceTempView("df_transformed")

print("SQL QUERY 1: Top 10 Most Volatile Commodities by Region")

query1 = """
SELECT 
    Region,
    Commodity,
    ROUND(AVG(Price_Volatility_Pct), 2) as Avg_Volatility,
    COUNT(*) as Record_Count,
    ROUND(AVG(Avg_Price), 2) as Avg_Price
FROM df_transformed
WHERE Price_Volatility_Pct IS NOT NULL
GROUP BY Region, Commodity
ORDER BY Avg_Volatility DESC
LIMIT 10
"""

result1 = spark.sql(query1)
result1.show(10, truncate=False)

print()


print("SQL QUERY 2: Price Comparison - Above vs Below Modal Price by State")


query2 = """
SELECT 
    State,
    COUNT(CASE WHEN Avg_Price > Modal_Price THEN 1 END) as Above_Modal_Count,
    COUNT(CASE WHEN Avg_Price <= Modal_Price THEN 1 END) as Below_Modal_Count,
    ROUND(AVG(CASE WHEN Avg_Price > Modal_Price THEN Avg_Price END), 2) as Avg_Price_Above_Modal,
    ROUND(AVG(CASE WHEN Avg_Price <= Modal_Price THEN Avg_Price END), 2) as Avg_Price_Below_Modal,
    COUNT(*) as Total_Records
FROM df_transformed
WHERE Avg_Price IS NOT NULL AND Modal_Price IS NOT NULL
GROUP BY State
HAVING COUNT(*) > 100
ORDER BY Above_Modal_Count DESC
LIMIT 15
"""

result2 = spark.sql(query2)
result2.show(15, truncate=False)

print()

SQL QUERY 1: Top 10 Most Volatile Commodities by Region
+-------+-----------------------+--------------+------------+---------+
|Region |Commodity              |Avg_Volatility|Record_Count|Avg_Price|
+-------+-----------------------+--------------+------------+---------+
|West   |Brinjal                |210.48        |71133       |3545.68  |
|North  |Mace                   |128.0         |1           |3400.0   |
|Central|Garlic                 |124.48        |34784       |6824.65  |
|West   |Dry Grapes             |124.29        |236         |14531.36 |
|West   |She Buffalo            |115.69        |461         |46369.82 |
|Central|Ashwagandha            |115.57        |425         |20076.24 |
|West   |Cow                    |112.46        |484         |43931.49 |
|North  |Custard Apple (Sharifa)|109.84        |224         |7824.35  |
|West   |Goat                   |108.26        |341         |7443.78  |
|East   |Goat                   |108.12        |600         |9672.17  |
+-------

In [0]:

# Write the transformed data to a Parquet table
print("Writing transformed data to Parquet table...")
df_transformed.write.mode("overwrite").saveAsTable("commodity_analysis_results")

print("Results written to table: commodity_analysis_results")
print()

# Verify the write
print("Verifying saved data:")
saved_df = spark.sql("SELECT * FROM commodity_analysis_results")
print(f"   • Rows saved: {saved_df.count():,}")
print()

saved_df.show(10)
print()

Writing transformed data to Parquet table...
Results written to table: commodity_analysis_results

Verifying saved data:
   • Rows saved: 20,059,718

+--------------+-------------+-------------+--------------------+-----------+------+------------+---------+---------+-----------+--------------+-------+----------------+---------+--------------------+
|         State|     District|       Market|           Commodity|    Variety| Grade|Arrival_Date|Min_Price|Max_Price|Modal_Price|Commodity_Code| Region|Region_Full_Name|Avg_Price|Price_Volatility_Pct|
+--------------+-------------+-------------+--------------------+-----------+------+------------+---------+---------+-----------+--------------+-------+----------------+---------+--------------------+
|Madhya Pradesh|      Khandwa|Khandwa (F&V)|        Sponge gourd|      Other|   FAQ|  2023-08-27|   1200.0|   3500.0|     2545.0|           311|Central|  Central Region|   2350.0|               90.37|
|   Maharashtra|         Pune|         Pune|  

In [0]:
print("Execution plan for the aggregation query:")
result1.explain(True)

Execution plan for the aggregation query:
== Parsed Logical Plan ==
'GlobalLimit 10
+- 'LocalLimit 10
   +- 'Sort ['Avg_Volatility DESC NULLS LAST], true
      +- 'Aggregate ['Region, 'Commodity], ['Region, 'Commodity, 'ROUND('AVG('Price_Volatility_Pct), 2) AS Avg_Volatility#17878, 'COUNT(1) AS Record_Count#17879, 'ROUND('AVG('Avg_Price), 2) AS Avg_Price#17880]
         +- 'Filter isnotnull('Price_Volatility_Pct)
            +- 'UnresolvedRelation [df_transformed], [], false

== Analyzed Logical Plan ==
Region: string, Commodity: string, Avg_Volatility: double, Record_Count: bigint, Avg_Price: double
GlobalLimit 10
+- LocalLimit 10
   +- Sort [Avg_Volatility#17878 DESC NULLS LAST], true
      +- Aggregate [Region#16594, Commodity#16236], [Region#16594, Commodity#16236, round(avg(Price_Volatility_Pct#16599), 2) AS Avg_Volatility#17878, count(1) AS Record_Count#17879L, round(avg(Avg_Price#16597), 2) AS Avg_Price#17880]
         +- Filter isnotnull(Price_Volatility_Pct#16599)
            

In [0]:
# Actions vs. Transformations

import time
from pyspark.sql.functions import col

df = spark.table("commodities_india")

# Transformation: Select specific columns (lazy)
start_transformation_time = time.time()
transformation = df.select("State", "Commodity", "Market", "Min_Price", "Max_Price", "Modal_Price")
end_transformation_time = time.time()
transformation_time = end_transformation_time - start_transformation_time

# Action: Count the number of records (eager)
start_action_count_time = time.time()
record_count = transformation.count()
end_action_count_time = time.time()
action_count_time = end_action_count_time - start_action_count_time

# Action: Show the first few rows (eager)
start_action_show_time = time.time()
sample_data = transformation.show()
end_action_show_time = time.time()
action_show_time = end_action_show_time - start_action_show_time

# results
print(f"Transformation Time: {transformation_time:.4f} seconds")
print(f"Record Count: {record_count:,}")
print(f"Action (Show) Time: {action_show_time:.4f} seconds")

+--------------+----------------+--------------------+---------+---------+-----------+
|         State|       Commodity|              Market|Min_Price|Max_Price|Modal_Price|
+--------------+----------------+--------------------+---------+---------+-----------+
|Andhra Pradesh|   Gur (Jaggery)|            Chittoor|   3200.0|   3500.0|     3500.0|
|    Tamil Nadu|          Garlic|Gudalur (Uzhavar ...|  28000.0|  30000.0|    30000.0|
|    Tamil Nadu|  Ginger (Green)|Gudalur (Uzhavar ...|   7000.0|   8000.0|     8000.0|
|    Tamil Nadu|    Green Chilli|Gudalur (Uzhavar ...|   4500.0|   5000.0|     5000.0|
|    Tamil Nadu|   Mint (Pudina)|Gudalur (Uzhavar ...|   4500.0|   5000.0|     5000.0|
|    Tamil Nadu|           Onion|Gudalur (Uzhavar ...|   4200.0|   4400.0|     4400.0|
|    Tamil Nadu|     Onion Green|Gudalur (Uzhavar ...|   7000.0|   7400.0|     7400.0|
|    Tamil Nadu|          Potato|Gudalur (Uzhavar ...|   4500.0|   5000.0|     5000.0|
|    Tamil Nadu|         Pumpkin|Gudalur (U