# 🌍 Air Quality Index (AQI) Forecasting and Early Warning System for Indian Cities

**Author:** Keerthi Amulya
**Date:** January 31, 2026  
**Challenge:** Databricks 14-Days AI Challenge - Capstone Project  
**Sponsors:** Databricks | Codebasics | Indian Data Club

---

## 📋 Table of Contents

1. Executive Summary & Problem Statement
2. Data Acquisition & Setup
3. Data Cleaning (bronze to silver)
4. Feature Engineering & Gold Layer
6. Machine Learning Models

%md
---
# 1. Executive Summary & Problem Statement

## 1.1 The Challenge: Air Pollution Crisis in India

Air pollution has become one of India's most pressing public health emergencies. According to the World Health Organization, **13 of the world's 20 most polluted cities are in India**. Every year, air pollution contributes to over 1.67 million premature deaths in the country, making it a silent killer that affects millions of citizens daily.

### Current Limitations:
- **Reactive Monitoring:** Current AQI systems only report historical data, not predictions
- **No Advance Warning:** Citizens cannot plan outdoor activities safely
- **Healthcare Unpreparedness:** Hospitals face sudden surges in respiratory cases
- **Policy Delays:** Government interventions happen after pollution events occur

## 1.2 Our Solution: AI-Powered Predictive System

We propose an **intelligent AQI forecasting and early warning system** that:

1. **Predicts AQI values 3-7 days in advance** using historical patterns and machine learning
2. **Identifies pollution trends** across different cities and seasons
3. **Generates early warnings** when AQI is predicted to exceed safe thresholds
4. **Provides actionable insights** for multiple stakeholders

## 1.3 Target Stakeholders

| Stakeholder | Use Case | Expected Benefit |
|-------------|----------|------------------|
| **Citizens** | Plan outdoor activities based on 7-day forecasts | Reduced health risks, better quality of life |
| **Healthcare** | Prepare for respiratory case surges | Optimal resource allocation, lives saved |
| **Policymakers** | Implement preventive traffic/industrial restrictions | Proactive pollution control, reduced severity |
| **Researchers** | Analyze long-term pollution patterns | Better understanding of pollution drivers |

## 1.4 Success Metrics

- **Forecast Accuracy:** RMSE < 20 for 3-day predictions
- **Classification Performance:** >85% accuracy for high pollution alerts
- **Early Warning Precision:** Catch 90%+ of severe pollution events with <15% false alarms
- **Business Impact:** Enable advance planning for 50+ million people across major cities

## 1.5 Dataset Overview

**Source:** Central Pollution Control Board (CPCB) Daily AQI Bulletins  
**Repository:** https://github.com/urbanemissionsinfo/AQI_bulletins  
**Time Period:** 2015-2023 (9 years of daily measurements)  
**Coverage:** 278+ cities across India  
**Records:** ~300,000 daily observations  

**Features:**
- `date`: Date of measurement (YYYY-MM-DD)
- `city`: Name of city
- `aqi`: Average Air Quality Index value
- `aqi_category`: Government-defined category (Good/Satisfactory/Moderate/Poor/Very Poor/Severe)
- `station_count`: Number of active monitoring stations
- `prominent_pollutant`: Primary pollutant driving AQI (PM2.5, PM10, NO2, CO, SO2, O3)

---
# 1. Setup and Imports

In [0]:
# Import required libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
import pandas as pd
import numpy as np

print("Libraries imported successfully!")
print(f"PySpark version: {spark.version}")

Libraries imported successfully!
PySpark version: 4.0.0


%md
## 1.1 Unity Catalog Architecture Setup

I implement the **Medallion Architecture** (Bronze → Silver → Gold) using Unity Catalog for proper data governance and organization.

### Architecture Overview:
```
aqi_india (Catalog)
├── bronze (Raw Data Layer)
│   └── aqi_bulletins (Raw ingestion)
├── silver (Cleaned Data Layer)
│   └── aqi_cleaned (Validated & transformed)
└── gold (Analytics Layer)
    ├── aqi_ml_features (ML-ready dataset)
    ├── city_summary (Aggregated city stats)
    └── monthly_trends (Time-based aggregates)
```

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.types import *
import re

# Widgets (optional)
try:
    dbutils.widgets.text("raw_dir", "/Volumes/aqi_india/bronze/raw_data/")
    dbutils.widgets.text("raw_file_hint", "AllIndiaBulletins")
except Exception as e:
    print("Widgets not available:", e)

RAW_DIR = None
HINT = None
try:
    RAW_DIR = dbutils.widgets.get("raw_dir")
    HINT = dbutils.widgets.get("raw_file_hint")
except Exception:
    RAW_DIR = "/Volumes/aqi_india/bronze/raw_data/"
    HINT = "AllIndiaBulletins"

CATALOG = "aqi_india"

# Create UC objects if supported (safe to re-run)
try:
    spark.sql(f"CREATE CATALOG IF NOT EXISTS {CATALOG}")
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.bronze")
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.silver")
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.gold")
    spark.sql(f"CREATE VOLUME IF NOT EXISTS {CATALOG}.bronze.raw_data")
except Exception as e:
    print("Unity Catalog/Volume create skipped (not supported here):", e)

T_BRONZE = f"{CATALOG}.bronze.aqi_bulletins"
T_SILVER = f"{CATALOG}.silver.aqi_cleaned"
T_GOLD   = f"{CATALOG}.gold.aqi_ml_features"

# Auto-pick a CSV from RAW_DIR
raw_files = []
try:
    raw_files = [f.path for f in dbutils.fs.ls(RAW_DIR) if f.path.lower().endswith(".csv")]
except Exception as e:
    print("Could not list RAW_DIR:", RAW_DIR, e)

if not raw_files:
    raise Exception(f"No CSV found in {RAW_DIR}. Upload AllIndiaBulletins CSV to this folder first.")

candidates = [p for p in raw_files if re.search(HINT, p, re.I)]
RAW_PATH = candidates[0] if candidates else raw_files[0]
print("✅ Using RAW_PATH:", RAW_PATH)


✅ Using RAW_PATH: dbfs:/Volumes/aqi_india/bronze/raw_data/AllIndiaBulletins_Master.csv


---
# 2. Data Cleaning - Bronze to Silver

In [0]:
print("=" * 80)
print("DATA CLEANING PROCESS")
print("=" * 80)

correct_schema = StructType([
    StructField("date", StringType(), True),
    StructField("city", StringType(), True),
    StructField("station_count", IntegerType(), True),
    StructField("air_quality", StringType(), True),
    StructField("aqi", DoubleType(), True),
    StructField("prominent_pollutant", StringType(), True)
])

raw_df = spark.read.csv(
    RAW_PATH,
    header=True,
    schema=correct_schema
)

print(f"Raw records loaded: {raw_df.count():,}")
print("\nSchema (note: city is STRING, not double):")
raw_df.printSchema()

bronze_df = raw_df.withColumn("ingested_at", current_timestamp()).withColumn("raw_file", lit(RAW_PATH))
(bronze_df.write
    .format("delta")
    .mode("overwrite")
    .option("overwriteSchema","true")
    .saveAsTable(T_BRONZE)
)
print(f"Bronze table created: {T_BRONZE} (rows={bronze_df.count():,})")


DATA CLEANING PROCESS
Raw records loaded: 299,976

Schema (note: city is STRING, not double):
root
 |-- date: string (nullable = true)
 |-- city: string (nullable = true)
 |-- station_count: integer (nullable = true)
 |-- air_quality: string (nullable = true)
 |-- aqi: double (nullable = true)
 |-- prominent_pollutant: string (nullable = true)

Bronze table created: aqi_india.bronze.aqi_bulletins (rows=299,976)


In [0]:
print("Sample cities from raw data:")
raw_df.select("city").distinct().show(10, truncate=False)

Sample cities from raw data:
+---------+
|city     |
+---------+
|Lucknow  |
|NULL     |
|Agra     |
|Chennai  |
|Haldia   |
|Panchkula|
|Rohtak   |
|Jaipur   |
|Gaya     |
|Faridabad|
+---------+
only showing top 10 rows


## Step 1: Remove Corrupted Rows

In [0]:
print("STEP 1: Removing corrupted rows")
print("-" * 60)

starting_count = raw_df.count()

# Remove rows where critical fields are NULL or date is malformed
silver_df = raw_df.filter(
    col("city").isNotNull() & 
    col("aqi").isNotNull() & 
    col("date").rlike("^\\d{4}-\\d{2}-\\d{2}$")
)

removed = starting_count - silver_df.count()
print(f"Clean records: {silver_df.count():,}")
print(f"Removed: {removed:,} corrupted rows ({removed/starting_count*100:.2f}%)")

STEP 1: Removing corrupted rows
------------------------------------------------------------
Clean records: 299,972
Removed: 4 corrupted rows (0.00%)


## Step 2: Parse and Validate Dates

In [0]:
print("STEP 2: Converting and validating dates")
print("-" * 60)

# Convert string date to proper date type
silver_df = silver_df.withColumn("date_parsed", to_date(col("date"), "yyyy-MM-dd"))

# Filter only valid dates (2015-2023)
silver_df = silver_df.filter(
    (year(col("date_parsed")) >= 2015) & 
    (year(col("date_parsed")) <= 2023)
)

date_range = silver_df.agg(
    min("date_parsed").alias("min_date"),
    max("date_parsed").alias("max_date")
).collect()[0]

print(f"Valid dates: {silver_df.count():,}")
print(f"Date range: {date_range['min_date']} to {date_range['max_date']}")

STEP 2: Converting and validating dates
------------------------------------------------------------
Valid dates: 299,972
Date range: 2015-05-01 to 2023-12-31


## Step 3: Filter Invalid AQI Values

In [0]:
print("STEP 3: Filtering invalid AQI values")
print("-" * 60)

before_count = silver_df.count()

# Valid AQI range: 0-999
silver_df = silver_df.filter(
    (col("aqi") >= 0) & 
    (col("aqi") <= 999)
)

removed = before_count - silver_df.count()

aqi_stats = silver_df.agg(
    min("aqi").alias("min_aqi"),
    max("aqi").alias("max_aqi"),
    avg("aqi").alias("avg_aqi")
).collect()[0]

print(f"Valid AQI records: {silver_df.count():,}")
print(f"Removed: {removed:,} invalid AQI values")
print(f"AQI range: {aqi_stats['min_aqi']:.1f} to {aqi_stats['max_aqi']:.1f}")
print(f"Average AQI: {aqi_stats['avg_aqi']:.2f}")

STEP 3: Filtering invalid AQI values
------------------------------------------------------------
Valid AQI records: 299,972
Removed: 0 invalid AQI values
AQI range: 3.0 to 500.0
Average AQI: 124.69


## Step 4: Validate Station Counts

In [0]:
print("STEP 4: Handling station counts")
print("-" * 60)

# Fill NULL station counts with 1 (minimum reasonable value)
silver_df = silver_df.withColumn(
    "station_count",
    when(col("station_count").isNull(), 1)
    .when(col("station_count") <= 0, 1)
    .otherwise(col("station_count"))
)

station_stats = silver_df.agg(
    min("station_count").alias("min_stations"),
    max("station_count").alias("max_stations"),
    avg("station_count").alias("avg_stations")
).collect()[0]

print(f"Station range: {station_stats['min_stations']} to {station_stats['max_stations']}")
print(f"Average stations: {station_stats['avg_stations']:.2f}")

STEP 4: Handling station counts
------------------------------------------------------------
Station range: 1 to 39
Average stations: 1.80


## Step 5: Standardize City Names

In [0]:
print("STEP 5: Standardizing city names")
print("-" * 60)

before_cities = silver_df.select("city").distinct().count()

# Standardize major city name variations
silver_df = silver_df.withColumn("city_clean",
    when(col("city").like("%Delhi%"), "Delhi")
    .when(col("city").like("%Mumbai%"), "Mumbai")
    .when(col("city").like("%Bengaluru%"), "Bengaluru")
    .when(col("city").like("%Bangalore%"), "Bengaluru")
    .when(col("city").like("%Hyderabad%"), "Hyderabad")
    .when(col("city").like("%Chennai%"), "Chennai")
    .when(col("city").like("%Kolkata%"), "Kolkata")
    .when(col("city").like("%Pune%"), "Pune")
    .when(col("city").like("%Ahmedabad%"), "Ahmedabad")
    .when(col("city").like("%Jaipur%"), "Jaipur")
    .when(col("city").like("%Lucknow%"), "Lucknow")
    .when(col("city").like("%Kanpur%"), "Kanpur")
    .when(col("city").like("%Nagpur%"), "Nagpur")
    .when(col("city").like("%Indore%"), "Indore")
    .when(col("city").like("%Thane%"), "Thane")
    .when(col("city").like("%Bhopal%"), "Bhopal")
    .when(col("city").like("%Pimpri%"), "Pimpri-Chinchwad")
    .when(col("city").like("%Patna%"), "Patna")
    .when(col("city").like("%Vadodara%"), "Vadodara")
    .when(col("city").like("%Ghaziabad%"), "Ghaziabad")
    .when(col("city").like("%Agra%"), "Agra")
    .when(col("city").like("%Nashik%"), "Nashik")
    .when(col("city").like("%Faridabad%"), "Faridabad")
    .otherwise(trim(col("city")))
)

after_cities = silver_df.select("city_clean").distinct().count()

print(f"Cities before: {before_cities}")
print(f"Cities after: {after_cities}")
print(f"Merged {before_cities - after_cities} duplicate city name variations")

STEP 5: Standardizing city names
------------------------------------------------------------
Cities before: 277
Cities after: 275
Merged 2 duplicate city name variations


## Step 6: Remove Duplicate Records

In [0]:
print("STEP 6: Removing duplicate date-city combinations")
print("-" * 60)

before_count = silver_df.count()

# Keep only unique date-city combinations (keep first occurrence)
silver_df = silver_df.dropDuplicates(["date_parsed", "city_clean"])

removed = before_count - silver_df.count()

print(f"Unique records: {silver_df.count():,}")
print(f"Removed: {removed:,} duplicate records ({removed/before_count*100:.2f}%)")

STEP 6: Removing duplicate date-city combinations
------------------------------------------------------------
Unique records: 297,209
Removed: 2,763 duplicate records (0.92%)


## Step 7: Add Temporal Features

In [0]:
print("STEP 7: Adding temporal features")
print("-" * 60)

# Extract temporal components
silver_df = silver_df \
    .withColumn("year", year("date_parsed")) \
    .withColumn("month", month("date_parsed")) \
    .withColumn("day_of_week", dayofweek("date_parsed")) \
    .withColumn("day_of_month", dayofmonth("date_parsed")) \
    .withColumn("quarter", quarter("date_parsed")) \
    .withColumn("week_of_year", weekofyear("date_parsed")) \
    .withColumn("is_weekend", 
        when(col("day_of_week").isin([1, 7]), True).otherwise(False)
    ) \
    .withColumn("season",
        when(col("month").isin([12, 1, 2]), "Winter")
        .when(col("month").isin([3, 4, 5]), "Spring")
        .when(col("month").isin([6, 7, 8, 9]), "Monsoon")
        .otherwise("Autumn")
    )

print("Added temporal features: year, month, day_of_week, day_of_month")
print("Added: quarter, week_of_year, is_weekend, season")

STEP 7: Adding temporal features
------------------------------------------------------------
Added temporal features: year, month, day_of_week, day_of_month
Added: quarter, week_of_year, is_weekend, season


## Step 8: Create AQI Risk Categories

In [0]:
print("STEP 8: Creating AQI risk categorization")
print("-" * 60)

# Standard AQI risk levels per CPCB guidelines
silver_df = silver_df.withColumn("aqi_risk_level",
    when(col("aqi") <= 50, "Good")
    .when(col("aqi") <= 100, "Satisfactory")
    .when(col("aqi") <= 200, "Moderate")
    .when(col("aqi") <= 300, "Poor")
    .when(col("aqi") <= 400, "Very Poor")
    .otherwise("Severe")
)

# Numeric category for ML
silver_df = silver_df.withColumn("aqi_numeric_category",
    when(col("aqi") <= 50, 1)
    .when(col("aqi") <= 100, 2)
    .when(col("aqi") <= 200, 3)
    .when(col("aqi") <= 300, 4)
    .when(col("aqi") <= 400, 5)
    .otherwise(6)
)

# Binary risk flags
silver_df = silver_df.withColumn("is_high_pollution", col("aqi") > 200)
silver_df = silver_df.withColumn("is_severe_pollution", col("aqi") > 400)

print("Created risk levels: Good, Satisfactory, Moderate, Poor, Very Poor, Severe")
print("Created numeric categories 1-6")
print("Created binary flags: is_high_pollution, is_severe_pollution")

# Show distribution
silver_df.groupBy("aqi_risk_level").count().orderBy("count", ascending=False).show()

STEP 8: Creating AQI risk categorization
------------------------------------------------------------
Created risk levels: Good, Satisfactory, Moderate, Poor, Very Poor, Severe
Created numeric categories 1-6
Created binary flags: is_high_pollution, is_severe_pollution
+--------------+------+
|aqi_risk_level| count|
+--------------+------+
|  Satisfactory|104064|
|      Moderate| 95004|
|          Good| 47287|
|          Poor| 34378|
|     Very Poor| 13962|
|        Severe|  2514|
+--------------+------+



## Step 9: Handle Missing Values

In [0]:
print("STEP 9: Handling missing values")
print("-" * 60)

# Fill missing pollutant with 'Unknown'
silver_df = silver_df.withColumn("prominent_pollutant",
    when(col("prominent_pollutant").isNull(), "Unknown")
    .otherwise(trim(col("prominent_pollutant")))
)

# Fill missing air_quality with derived risk level
silver_df = silver_df.withColumn("air_quality",
    when(col("air_quality").isNull(), col("aqi_risk_level"))
    .otherwise(trim(col("air_quality")))
)

# Verify no NULLs in critical columns
print("NULL counts in critical columns:")
silver_df.select(
    count(when(col("date_parsed").isNull(), 1)).alias("date_nulls"),
    count(when(col("city_clean").isNull(), 1)).alias("city_nulls"),
    count(when(col("aqi").isNull(), 1)).alias("aqi_nulls"),
    count(when(col("prominent_pollutant").isNull(), 1)).alias("pollutant_nulls")
).show()

STEP 9: Handling missing values
------------------------------------------------------------
NULL counts in critical columns:
+----------+----------+---------+---------------+
|date_nulls|city_nulls|aqi_nulls|pollutant_nulls|
+----------+----------+---------+---------------+
|         0|         0|        0|              0|
+----------+----------+---------+---------------+



## Step 10: Create Final Silver Schema

In [0]:
print("STEP 10: Creating final Silver layer schema")
print("-" * 60)

silver_final = silver_df.select(
    # Core fields
    col("date_parsed").alias("date"),
    col("city_clean").alias("city"),
    col("aqi"),
    col("station_count"),
    
    # Categories
    col("air_quality"),
    col("aqi_risk_level"),
    col("aqi_numeric_category").alias("aqi_category"),
    col("prominent_pollutant"),
    
    # Temporal features
    col("year"),
    col("month"),
    col("day_of_week"),
    col("day_of_month"),
    col("quarter"),
    col("week_of_year"),
    col("season"),
    col("is_weekend"),
    
    # Risk flags
    col("is_high_pollution"),
    col("is_severe_pollution"),
    
    # Metadata
    current_timestamp().alias("processed_timestamp"),
    lit("silver_v1").alias("data_quality_flag")
)

print(f"Final Silver records: {silver_final.count():,}")
print(f"Final Silver columns: {len(silver_final.columns)}")
print("\nSchema:")
silver_final.printSchema()

STEP 10: Creating final Silver layer schema
------------------------------------------------------------
Final Silver records: 297,209
Final Silver columns: 20

Schema:
root
 |-- date: date (nullable = true)
 |-- city: string (nullable = true)
 |-- aqi: double (nullable = true)
 |-- station_count: integer (nullable = true)
 |-- air_quality: string (nullable = true)
 |-- aqi_risk_level: string (nullable = false)
 |-- aqi_category: integer (nullable = false)
 |-- prominent_pollutant: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- day_of_month: integer (nullable = true)
 |-- quarter: integer (nullable = true)
 |-- week_of_year: integer (nullable = true)
 |-- season: string (nullable = false)
 |-- is_weekend: boolean (nullable = false)
 |-- is_high_pollution: boolean (nullable = true)
 |-- is_severe_pollution: boolean (nullable = true)
 |-- processed_timestamp: timestamp (nullable = false)

In [0]:
# Verify city column has values
print("Sample data from Silver layer:")
silver_final.select("date", "city", "aqi", "aqi_risk_level", "season").show(10)

Sample data from Silver layer:
+----------+---------+-----+--------------+-------+
|      date|     city|  aqi|aqi_risk_level| season|
+----------+---------+-----+--------------+-------+
|2015-05-07|    Delhi|254.0|          Poor| Spring|
|2015-07-06|Ahmedabad| 89.0|  Satisfactory|Monsoon|
|2015-10-01| Varanasi|164.0|      Moderate| Autumn|
|2015-11-13|   Mumbai|113.0|      Moderate| Autumn|
|2016-01-27|    Delhi|379.0|     Very Poor| Winter|
|2016-01-28|  Lucknow|374.0|     Very Poor| Winter|
|2016-03-16|     Gaya| 52.0|  Satisfactory| Spring|
|2016-07-05|  Lucknow| 89.0|  Satisfactory|Monsoon|
|2016-07-17|Panchkula| 39.0|          Good|Monsoon|
|2016-08-17|   Mumbai| 55.0|  Satisfactory|Monsoon|
+----------+---------+-----+--------------+-------+
only showing top 10 rows


## Save Silver Table

In [0]:
print("Saving Silver layer to Delta Lake...")
print("-" * 60)

# Write to Silver table with partitioning
silver_final.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .partitionBy("year", "month") \
    .saveAsTable(T_SILVER)

print("Silver table created: aqi_india.silver.aqi_cleaned")

Saving Silver layer to Delta Lake...
------------------------------------------------------------
Silver table created: aqi_india.silver.aqi_cleaned


In [0]:
# Optimize Silver table
print("Optimizing Silver table...")
try:
    spark.sql(f"OPTIMIZE {T_SILVER} ZORDER BY (city, date)")
except Exception as e:
    print("OPTIMIZE skipped:", e)
print("Table optimized with ZORDER (city, date)")

Optimizing Silver table...
Table optimized with ZORDER (city, date)


---
# 3. Feature Engineering - Gold Layer Creation

Creating ML-ready features:
1. Lag features (historical AQI)
2. Rolling statistics
3. Rate of change
4. City baselines
5. Cyclical encoding
6. Target variables

In [0]:
print("=" * 80)
print("FEATURE ENGINEERING - GOLD LAYER CREATION")
print("=" * 80)

# Load Silver data
df_silver = spark.table(T_SILVER)

print(f"\nStarting with: {df_silver.count():,} records from Silver layer")
print("Creating ML features...\n")

FEATURE ENGINEERING - GOLD LAYER CREATION

Starting with: 297,209 records from Silver layer
Creating ML features...



## 3.1 Lag Features

In [0]:
print("FEATURE SET 1: Lag features (historical AQI)")
print("-" * 60)

# Window partitioned by city, ordered by date
window_spec = Window.partitionBy("city").orderBy("date")

# Create lag features
gold_df = df_silver \
    .withColumn("aqi_lag_1", lag("aqi", 1).over(window_spec)) \
    .withColumn("aqi_lag_3", lag("aqi", 3).over(window_spec)) \
    .withColumn("aqi_lag_7", lag("aqi", 7).over(window_spec)) \
    .withColumn("aqi_lag_14", lag("aqi", 14).over(window_spec)) \
    .withColumn("aqi_lag_30", lag("aqi", 30).over(window_spec))

print("Created: aqi_lag_1, aqi_lag_3, aqi_lag_7, aqi_lag_14, aqi_lag_30")

FEATURE SET 1: Lag features (historical AQI)
------------------------------------------------------------
Created: aqi_lag_1, aqi_lag_3, aqi_lag_7, aqi_lag_14, aqi_lag_30


## 3.2 Rolling Statistics

In [0]:
print("FEATURE SET 2: Rolling window statistics")
print("-" * 60)

# Rolling windows
rolling_7 = Window.partitionBy("city").orderBy("date").rowsBetween(-6, 0)
rolling_14 = Window.partitionBy("city").orderBy("date").rowsBetween(-13, 0)
rolling_30 = Window.partitionBy("city").orderBy("date").rowsBetween(-29, 0)

# Rolling averages
gold_df = gold_df \
    .withColumn("aqi_rolling_avg_7", avg("aqi").over(rolling_7)) \
    .withColumn("aqi_rolling_avg_14", avg("aqi").over(rolling_14)) \
    .withColumn("aqi_rolling_avg_30", avg("aqi").over(rolling_30))

# Rolling std dev (volatility)
gold_df = gold_df \
    .withColumn("aqi_rolling_std_7", stddev("aqi").over(rolling_7)) \
    .withColumn("aqi_rolling_std_14", stddev("aqi").over(rolling_14))

# Rolling min/max
gold_df = gold_df \
    .withColumn("aqi_rolling_max_7", max("aqi").over(rolling_7)) \
    .withColumn("aqi_rolling_min_7", min("aqi").over(rolling_7))

print("Created rolling averages: 7, 14, 30 days")
print("Created rolling std dev: 7, 14 days")
print("Created rolling min/max: 7 days")

FEATURE SET 2: Rolling window statistics
------------------------------------------------------------
Created rolling averages: 7, 14, 30 days
Created rolling std dev: 7, 14 days
Created rolling min/max: 7 days


## 3.3 Rate of Change Features

In [0]:
print("FEATURE SET 3: Rate of change features")
print("-" * 60)

# Absolute changes
gold_df = gold_df \
    .withColumn("aqi_change_1d",
        when(col("aqi_lag_1").isNotNull(),
             col("aqi") - col("aqi_lag_1")).otherwise(None)
    ) \
    .withColumn("aqi_change_7d",
        when(col("aqi_lag_7").isNotNull(),
             col("aqi") - col("aqi_lag_7")).otherwise(None)
    )

# Percentage changes
gold_df = gold_df \
    .withColumn("aqi_pct_change_1d",
        when((col("aqi_lag_1").isNotNull()) & (col("aqi_lag_1") != 0),
             ((col("aqi") - col("aqi_lag_1")) / col("aqi_lag_1")) * 100).otherwise(None)
    ) \
    .withColumn("aqi_pct_change_7d",
        when((col("aqi_lag_7").isNotNull()) & (col("aqi_lag_7") != 0),
             ((col("aqi") - col("aqi_lag_7")) / col("aqi_lag_7")) * 100).otherwise(None)
    )

print("Created: aqi_change_1d, aqi_change_7d")
print("Created: aqi_pct_change_1d, aqi_pct_change_7d")

FEATURE SET 3: Rate of change features
------------------------------------------------------------
Created: aqi_change_1d, aqi_change_7d
Created: aqi_pct_change_1d, aqi_pct_change_7d


## 3.4 City-Level Baseline Features

In [0]:
print("FEATURE SET 4: City-level baseline statistics")
print("-" * 60)

# Calculate city-wide statistics
city_stats = df_silver.groupBy("city") \
    .agg(
        avg("aqi").alias("city_avg_aqi"),
        stddev("aqi").alias("city_std_aqi"),
        min("aqi").alias("city_min_aqi"),
        max("aqi").alias("city_max_aqi"),
        count("*").alias("city_record_count")
    )

# Join back to main dataset
gold_df = gold_df.join(city_stats, on="city", how="left")

# Calculate deviation from baseline
gold_df = gold_df \
    .withColumn("aqi_deviation_from_city_avg",
        col("aqi") - col("city_avg_aqi")
    ) \
    .withColumn("aqi_z_score",
        when(col("city_std_aqi") != 0,
             (col("aqi") - col("city_avg_aqi")) / col("city_std_aqi")
        ).otherwise(0)
    ) \
    .withColumn("aqi_percentile_in_city",
        when(col("city_max_aqi") != col("city_min_aqi"),
             (col("aqi") - col("city_min_aqi")) / (col("city_max_aqi") - col("city_min_aqi")) * 100
        ).otherwise(50)
    )

print("Created: city_avg_aqi, city_std_aqi")
print("Created: aqi_deviation_from_city_avg, aqi_z_score, aqi_percentile_in_city")

FEATURE SET 4: City-level baseline statistics
------------------------------------------------------------
Created: city_avg_aqi, city_std_aqi
Created: aqi_deviation_from_city_avg, aqi_z_score, aqi_percentile_in_city


## 3.5 Cyclical Temporal Encoding

In [0]:
print("FEATURE SET 5: Cyclical temporal encodings")
print("-" * 60)

import math
PI = 3.14159265359

# Month cyclical (handles Dec->Jan transition)
gold_df = gold_df \
    .withColumn("month_sin", sin(col("month") * 2 * PI / 12)) \
    .withColumn("month_cos", cos(col("month") * 2 * PI / 12))

# Day of week cyclical
gold_df = gold_df \
    .withColumn("day_of_week_sin", sin(col("day_of_week") * 2 * PI / 7)) \
    .withColumn("day_of_week_cos", cos(col("day_of_week") * 2 * PI / 7))

# Day of month cyclical
gold_df = gold_df \
    .withColumn("day_of_month_sin", sin(col("day_of_month") * 2 * PI / 31)) \
    .withColumn("day_of_month_cos", cos(col("day_of_month") * 2 * PI / 31))

print("Created: month_sin, month_cos")
print("Created: day_of_week_sin, day_of_week_cos")
print("Created: day_of_month_sin, day_of_month_cos")

FEATURE SET 5: Cyclical temporal encodings
------------------------------------------------------------
Created: month_sin, month_cos
Created: day_of_week_sin, day_of_week_cos
Created: day_of_month_sin, day_of_month_cos


## 3.6 Target Variables for Forecasting

In [0]:
print("FEATURE SET 6: Target variables for prediction")
print("-" * 60)

# Future AQI values (prediction targets)
gold_df = gold_df \
    .withColumn("aqi_next_1d", lead("aqi", 1).over(window_spec)) \
    .withColumn("aqi_next_3d", lead("aqi", 3).over(window_spec)) \
    .withColumn("aqi_next_7d", lead("aqi", 7).over(window_spec))

# Binary classification targets
gold_df = gold_df \
    .withColumn("target_high_aqi_tomorrow",
        when(col("aqi_next_1d") > 200, 1).otherwise(0)
    ) \
    .withColumn("target_severe_aqi_tomorrow",
        when(col("aqi_next_1d") > 400, 1).otherwise(0)
    )

print("Created targets: aqi_next_1d, aqi_next_3d, aqi_next_7d")
print("Created binary: target_high_aqi_tomorrow, target_severe_aqi_tomorrow")

FEATURE SET 6: Target variables for prediction
------------------------------------------------------------
Created targets: aqi_next_1d, aqi_next_3d, aqi_next_7d
Created binary: target_high_aqi_tomorrow, target_severe_aqi_tomorrow


## 3.7 Interaction Features

In [0]:
print("FEATURE SET 7: Interaction features")
print("-" * 60)

# Winter high pollution cities flag
gold_df = gold_df \
    .withColumn("is_winter_high_pollution_city",
        when((col("season") == "Winter") & 
             col("city").isin(["Delhi", "Ghaziabad", "Noida", "Gurugram", "Faridabad"]), 1).otherwise(0)
    )

# Weekend effect
gold_df = gold_df \
    .withColumn("weekend_pollution_delta",
        when(col("is_weekend") == True,
             col("aqi") - col("city_avg_aqi")).otherwise(None)
    )

# Station reliability indicator
gold_df = gold_df \
    .withColumn("is_reliable_measurement",
        when(col("station_count") >= 3, True).otherwise(False)
    )

print("Created: is_winter_high_pollution_city")
print("Created: weekend_pollution_delta")
print("Created: is_reliable_measurement")

FEATURE SET 7: Interaction features
------------------------------------------------------------
Created: is_winter_high_pollution_city
Created: weekend_pollution_delta
Created: is_reliable_measurement


## 3.8 Final Gold Schema

In [0]:
print("FINAL: Creating Gold layer schema")
print("-" * 60)

# Select final feature set
gold_final = gold_df.select(
    # Identifiers
    "date", "city", "year", "month", "day_of_week", "quarter", "season",

    # Current values
    "aqi", "station_count", "aqi_category", "aqi_risk_level",
    "prominent_pollutant", "is_weekend",

    # Risk flags (from Silver layer - needed for aggregations)
    "is_high_pollution",
    "is_severe_pollution",

    # Lag features (5)
    "aqi_lag_1", "aqi_lag_3", "aqi_lag_7", "aqi_lag_14", "aqi_lag_30",

    # Rolling statistics (7)
    "aqi_rolling_avg_7", "aqi_rolling_avg_14", "aqi_rolling_avg_30",
    "aqi_rolling_std_7", "aqi_rolling_std_14",
    "aqi_rolling_max_7", "aqi_rolling_min_7",

    # Change features (4)
    "aqi_change_1d", "aqi_change_7d",
    "aqi_pct_change_1d", "aqi_pct_change_7d",

    # City baselines (5)
    "city_avg_aqi", "city_std_aqi",
    "aqi_deviation_from_city_avg", "aqi_z_score", "aqi_percentile_in_city",

    # Cyclical encodings (6)
    "month_sin", "month_cos",
    "day_of_week_sin", "day_of_week_cos",
    "day_of_month_sin", "day_of_month_cos",

    # Interaction features (3)
    "is_winter_high_pollution_city",
    "weekend_pollution_delta",
    "is_reliable_measurement",

    # Target variables (5)
    "aqi_next_1d", "aqi_next_3d", "aqi_next_7d",
    "target_high_aqi_tomorrow", "target_severe_aqi_tomorrow",

    # Metadata
    current_timestamp().alias("feature_timestamp"),
    lit("gold_v1").alias("feature_version")
)

print(f"Gold layer records: {gold_final.count():,}")
print(f"Gold layer features: {len(gold_final.columns)}")
print("\nSchema:")
gold_final.printSchema()


FINAL: Creating Gold layer schema
------------------------------------------------------------
Gold layer records: 297,209
Gold layer features: 52

Schema:
root
 |-- date: date (nullable = true)
 |-- city: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- quarter: integer (nullable = true)
 |-- season: string (nullable = true)
 |-- aqi: double (nullable = true)
 |-- station_count: integer (nullable = true)
 |-- aqi_category: integer (nullable = true)
 |-- aqi_risk_level: string (nullable = true)
 |-- prominent_pollutant: string (nullable = true)
 |-- is_weekend: boolean (nullable = true)
 |-- is_high_pollution: boolean (nullable = true)
 |-- is_severe_pollution: boolean (nullable = true)
 |-- aqi_lag_1: double (nullable = true)
 |-- aqi_lag_3: double (nullable = true)
 |-- aqi_lag_7: double (nullable = true)
 |-- aqi_lag_14: double (nullable = true)
 |-- aqi_lag_30: double (nullable = tru

In [0]:
# Verify data quality
print("Sample Gold data:")
gold_final.select("date", "city", "aqi", "aqi_lag_1", "aqi_rolling_avg_7", "aqi_next_1d").show(10)

Sample Gold data:
+----------+----+-----+---------+------------------+-----------+
|      date|city|  aqi|aqi_lag_1| aqi_rolling_avg_7|aqi_next_1d|
+----------+----+-----+---------+------------------+-----------+
|2015-05-01|Agra|179.0|     NULL|             179.0|      135.0|
|2015-05-02|Agra|135.0|    179.0|             157.0|       84.0|
|2015-05-03|Agra| 84.0|    135.0|132.66666666666666|      104.0|
|2015-05-04|Agra|104.0|     84.0|             125.5|       88.0|
|2015-05-05|Agra| 88.0|    104.0|             118.0|       91.0|
|2015-05-06|Agra| 91.0|     88.0|             113.5|       65.0|
|2015-05-07|Agra| 65.0|     91.0|106.57142857142857|       62.0|
|2015-05-08|Agra| 62.0|     65.0| 89.85714285714286|      294.0|
|2015-05-09|Agra|294.0|     62.0|112.57142857142857|       56.0|
|2015-05-10|Agra| 56.0|    294.0|108.57142857142857|       78.0|
+----------+----+-----+---------+------------------+-----------+
only showing top 10 rows


## Save Gold Table

In [0]:
print("Saving Gold layer to Delta Lake...")
print("-" * 60)

# Filter to only records with valid lag features (need at least 30 days history)
gold_ml_ready = gold_final.filter(col("aqi_lag_30").isNotNull())

print(f"ML-ready records (with 30+ days history): {gold_ml_ready.count():,}")

# Save Gold ML features table
gold_ml_ready.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .partitionBy("year") \
    .saveAsTable(T_GOLD)

print("Gold table created: aqi_india.gold.aqi_ml_features")

Saving Gold layer to Delta Lake...
------------------------------------------------------------
ML-ready records (with 30+ days history): 289,034
Gold table created: aqi_india.gold.aqi_ml_features


In [0]:
# Optimize Gold table
print("Optimizing Gold table...")
try:
    spark.sql(f"OPTIMIZE {T_GOLD} ZORDER BY (city, date)")
except Exception as e:
    print("OPTIMIZE skipped:", e)
print("Gold table optimized!")

Optimizing Gold table...
Gold table optimized!


## Create Additional Gold Aggregates

In [0]:
print("Creating city summary table...")

city_summary = gold_final.groupBy("city") \
    .agg(
        count("*").alias("total_records"),
        min("date").alias("first_date"),
        max("date").alias("last_date"),
        avg("aqi").alias("avg_aqi"),
        stddev("aqi").alias("std_aqi"),
        min("aqi").alias("min_aqi"),
        max("aqi").alias("max_aqi"),
        avg("station_count").alias("avg_stations"),
        sum(when(col("is_high_pollution") == True, 1).otherwise(0)).alias("high_pollution_days"),
        sum(when(col("is_severe_pollution") == True, 1).otherwise(0)).alias("severe_pollution_days")
    ) \
    .withColumn("high_pollution_pct", col("high_pollution_days") / col("total_records") * 100) \
    .orderBy("avg_aqi", ascending=False)

city_summary.write \
    .format("delta") \
    .mode("overwrite") \
    .saveAsTable(f"{CATALOG}.gold.city_summary")

print("City summary table created: aqi_india.gold.city_summary")
city_summary.show(10)

Creating city summary table...
City summary table created: aqi_india.gold.city_summary
+-----------+-------------+----------+----------+------------------+------------------+-------+-------+------------------+-------------------+---------------------+------------------+
|       city|total_records|first_date| last_date|           avg_aqi|           std_aqi|min_aqi|max_aqi|      avg_stations|high_pollution_days|severe_pollution_days|high_pollution_pct|
+-----------+-------------+----------+----------+------------------+------------------+-------+-------+------------------+-------------------+---------------------+------------------+
| Jharsuguda|            1|2018-02-04|2018-02-04|             282.0|              NULL|  282.0|  282.0|               1.0|                  1|                    0|             100.0|
|   Byrnihat|          281|2022-12-24|2023-12-31|248.67259786476868| 93.40514437888999|   49.0|  442.0|               1.0|                191|                   11| 67.971530249

In [0]:
print("Creating monthly trends table...")

monthly_trends = gold_final.groupBy("year", "month") \
    .agg(
        count("*").alias("total_records"),
        countDistinct("city").alias("cities_reporting"),
        avg("aqi").alias("avg_aqi"),
        stddev("aqi").alias("std_aqi"),
        min("aqi").alias("min_aqi"),
        max("aqi").alias("max_aqi"),
        sum(when(col("is_high_pollution") == True, 1).otherwise(0)).alias("high_pollution_count"),
        sum(when(col("is_severe_pollution") == True, 1).otherwise(0)).alias("severe_pollution_count")
    ) \
    .orderBy("year", "month")

monthly_trends.write \
    .format("delta") \
    .mode("overwrite") \
    .saveAsTable(f"{CATALOG}.gold.monthly_trends")

print("Monthly trends table created: aqi_india.gold.monthly_trends")
monthly_trends.show(12)

Creating monthly trends table...
Monthly trends table created: aqi_india.gold.monthly_trends
+----+-----+-------------+----------------+------------------+------------------+-------+-------+--------------------+----------------------+
|year|month|total_records|cities_reporting|           avg_aqi|           std_aqi|min_aqi|max_aqi|high_pollution_count|severe_pollution_count|
+----+-----+-------------+----------------+------------------+------------------+-------+-------+--------------------+----------------------+
|2015|    5|          316|              11|146.80379746835442| 80.65131367951713|   42.0|  482.0|                  78|                     3|
|2015|    6|          286|              11|113.73776223776224|  67.0569627497433|   28.0|  356.0|                  32|                     0|
|2015|    7|          332|              11| 86.96385542168674| 50.50564762369056|   15.0|  351.0|                  11|                     0|
|2015|    8|          331|              13| 86.10574018

---
# 4. Final Summary

In [0]:
print("=" * 80)
print("DATA PIPELINE SUMMARY")
print("=" * 80)

bronze_count = spark.table(T_BRONZE).count()
silver_count = spark.table(T_SILVER).count()
gold_count = spark.table(T_GOLD).count()

print(f"\nLAYER SUMMARY:")
print(f"  Bronze (Raw):      {bronze_count:,} records")
print(f"  Silver (Cleaned):  {silver_count:,} records")
print(f"  Gold (ML-Ready):   {gold_count:,} records")

print(f"\nTABLES CREATED:")
print(f"  1. aqi_india.bronze.aqi_bulletins")
print(f"  2. aqi_india.silver.aqi_cleaned")
print(f"  3. aqi_india.gold.aqi_ml_features")
print(f"  4. aqi_india.gold.city_summary")
print(f"  5. aqi_india.gold.monthly_trends")

print(f"\nFEATURE GROUPS IN GOLD LAYER:")
print(f"  - Lag features: 5 (1d, 3d, 7d, 14d, 30d)")
print(f"  - Rolling stats: 7 (avg, std, min, max)")
print(f"  - Change features: 4 (absolute & percentage)")
print(f"  - City baselines: 5 (deviation, z-score, percentile)")
print(f"  - Cyclical encoding: 6 (sin/cos for time)")
print(f"  - Interaction features: 3")
print(f"  - Target variables: 5 (1d, 3d, 7d + binary)")

print(f"\nREADY FOR ML MODELING!")
print("=" * 80)

DATA PIPELINE SUMMARY

LAYER SUMMARY:
  Bronze (Raw):      299,976 records
  Silver (Cleaned):  297,209 records
  Gold (ML-Ready):   289,034 records

TABLES CREATED:
  1. aqi_india.bronze.aqi_bulletins
  2. aqi_india.silver.aqi_cleaned
  3. aqi_india.gold.aqi_ml_features
  4. aqi_india.gold.city_summary
  5. aqi_india.gold.monthly_trends

FEATURE GROUPS IN GOLD LAYER:
  - Lag features: 5 (1d, 3d, 7d, 14d, 30d)
  - Rolling stats: 7 (avg, std, min, max)
  - Change features: 4 (absolute & percentage)
  - City baselines: 5 (deviation, z-score, percentile)
  - Cyclical encoding: 6 (sin/cos for time)
  - Interaction features: 3
  - Target variables: 5 (1d, 3d, 7d + binary)

READY FOR ML MODELING!


In [0]:
# Final verification
print("Final Gold table schema:")
spark.table(T_GOLD).printSchema()

Final Gold table schema:
root
 |-- date: date (nullable = true)
 |-- city: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- quarter: integer (nullable = true)
 |-- season: string (nullable = true)
 |-- aqi: double (nullable = true)
 |-- station_count: integer (nullable = true)
 |-- aqi_category: integer (nullable = true)
 |-- aqi_risk_level: string (nullable = true)
 |-- prominent_pollutant: string (nullable = true)
 |-- is_weekend: boolean (nullable = true)
 |-- is_high_pollution: boolean (nullable = true)
 |-- is_severe_pollution: boolean (nullable = true)
 |-- aqi_lag_1: double (nullable = true)
 |-- aqi_lag_3: double (nullable = true)
 |-- aqi_lag_7: double (nullable = true)
 |-- aqi_lag_14: double (nullable = true)
 |-- aqi_lag_30: double (nullable = true)
 |-- aqi_rolling_avg_7: double (nullable = true)
 |-- aqi_rolling_avg_14: double (nullable = true)
 |-- aqi_rolling_avg_30: doub

In [0]:
# Sample ML-ready data
print("Sample ML-ready data:")
spark.table(T_GOLD) \
    .select("date", "city", "aqi", "aqi_lag_1", "aqi_rolling_avg_7", "aqi_z_score", "aqi_next_1d", "target_high_aqi_tomorrow") \
    .filter(col("city") == "Delhi") \
    .orderBy("date", ascending=False) \
    .show(10)

Sample ML-ready data:
+----------+-----+-----+---------+------------------+------------------+-----------+------------------------+
|      date| city|  aqi|aqi_lag_1| aqi_rolling_avg_7|       aqi_z_score|aqi_next_1d|target_high_aqi_tomorrow|
+----------+-----+-----+---------+------------------+------------------+-----------+------------------------+
|2023-12-31|Delhi|382.0|    400.0| 380.2857142857143|  1.58505810692669|       NULL|                       0|
|2023-12-30|Delhi|400.0|    382.0|384.42857142857144|1.7581160289544313|      382.0|                       1|
|2023-12-29|Delhi|382.0|    358.0|391.57142857142856|  1.58505810692669|      400.0|                       1|
|2023-12-28|Delhi|358.0|    380.0|395.42857142857144| 1.354314210889702|      382.0|                       1|
|2023-12-27|Delhi|380.0|    377.0|395.85714285714283|1.5658294489236078|      358.0|                       1|
|2023-12-26|Delhi|377.0|    383.0| 382.2857142857143|1.5369864619189841|      380.0|              