In [1]:
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql import functions as F
from pyspark.sql.functions import expr
from pyspark.sql import functions as F
from pyspark.ml.evaluation import RegressionEvaluator
from prophet import Prophet
import pandas as pd

In [2]:
spark = SparkSession.builder.appName('AirQualityAnalysisIndia').getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/11/07 11:04:17 WARN Utils: Your hostname, Karthikeya, resolves to a loopback address: 127.0.1.1; using 172.17.54.10 instead (on interface wlp1s0)
25/11/07 11:04:17 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/07 11:04:17 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df = spark.read.parquet('hdfs://localhost:9000/processed/enriched_air_quality.parquet')

In [4]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from prophet import Prophet
from pyspark.sql.functions import pandas_udf, PandasUDFType

# Define Pandas UDF with plotting
@pandas_udf("City string, Pollutant string, ds timestamp, yhat double, yhat_lower double, yhat_upper double", PandasUDFType.GROUPED_MAP)
def forecast_city_pollutant(pdf: pd.DataFrame) -> pd.DataFrame:
    city = pdf["City"].iloc[0]
    pollutant = pdf["Pollutant"].iloc[0]

    pdf = pdf.rename(columns={"Date": "ds", "Value": "y"})
    pdf["ds"] = pd.to_datetime(pdf["ds"])

    model = Prophet()
    model.fit(pdf[["ds", "y"]])

    future = model.make_future_dataframe(periods=365)
    forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]]

    fig = model.plot(forecast)
    plt.title(f"{city} - {pollutant}", fontsize=14)

    folder_path = f"Graphs/Prophet/{city}"
    os.makedirs(folder_path, exist_ok=True)
    plt.savefig(f"{folder_path}/{pollutant}.jpg", bbox_inches="tight")
    plt.close(fig)

    forecast["City"] = city
    forecast["Pollutant"] = pollutant
    return forecast[["City", "Pollutant", "ds", "yhat", "yhat_lower", "yhat_upper"]]


In [5]:
pollutants = ["PM25", "PM10", "NO", "NO2", "NOx", "NH3", "CO", "SO2", "O3", "Benzene", "Toluene"]

stack_expr = "stack({}, {}) as (Pollutant, Value)".format(
    len(pollutants),
    ", ".join([f"'{p}', {p}" for p in pollutants])
)

## Train, Test Split

In [6]:
df = df.withColumn("date", F.to_timestamp("date"))
df = df.withColumn("year", F.year("date"))

latest_year = df.agg(F.max("year")).collect()[0][0]

df_train = df.filter(F.col("year") < latest_year)
df_test = df.filter(F.col("year") == latest_year)

print("Train years:", df_train.select("year").distinct().orderBy("year").collect())
print("Test year:", df_test.select("year").distinct().collect())

df_long = df.selectExpr("City", "Date", stack_expr)
df_long_train = df_train.selectExpr("City", "Date", stack_expr)
df_long_test = df_test.selectExpr("City", "Date", stack_expr)

Train years: [Row(year=2015), Row(year=2016), Row(year=2017), Row(year=2018), Row(year=2019)]
Test year: [Row(year=2020)]


## Training

In [7]:
forecast_df_train = df_long_train.groupBy("City", "Pollutant").apply(forecast_city_pollutant)



## Results

In [8]:
df_results = df_long_test.withColumnRenamed("Date", "ds").join(forecast_df_train, on=["City", "Pollutant", "ds"], how="inner")

In [9]:
from pyspark.sql import functions as F

df_eval = df_results.withColumn("error", F.col("Value") - F.col("yhat"))

agg_df = (
    df_eval.groupBy("City", "Pollutant")
    .agg(
        F.mean(F.abs(F.col("error"))).alias("MAE"),
        F.mean(F.pow(F.col("error"), 2)).alias("MSE"),
        F.mean(F.col("Value")).alias("mean_actual")
    )
)

df_joined = df_eval.join(agg_df, on=["City", "Pollutant"], how="inner")

r2_df = (
    df_joined.groupBy("City", "Pollutant", "MAE", "MSE", "mean_actual")
    .agg(
        F.sum(F.pow(F.col("Value") - F.col("yhat"), 2)).alias("ss_res"),
        F.sum(F.pow(F.col("Value") - F.col("mean_actual"), 2)).alias("ss_tot")
    )
    .withColumn("RMSE", F.sqrt(F.col("MSE")))
    .withColumn("R2", 1 - (F.col("ss_res") / F.col("ss_tot")))
    .select("City", "Pollutant", "MAE", "RMSE", "R2")
)

r2_df.show(truncate=False)


11:04:36 - cmdstanpy - INFO - Chain [1] start processing            (0 + 3) / 3]
11:04:36 - cmdstanpy - INFO - Chain [1] done processing
11:04:36 - cmdstanpy - INFO - Chain [1] start processing
11:04:36 - cmdstanpy - INFO - Chain [1] start processing
11:04:36 - cmdstanpy - INFO - Chain [1] start processing
11:04:36 - cmdstanpy - INFO - Chain [1] start processing
11:04:36 - cmdstanpy - INFO - Chain [1] done processing
11:04:37 - cmdstanpy - INFO - Chain [1] done processing
11:04:37 - cmdstanpy - INFO - Chain [1] done processing
11:04:37 - cmdstanpy - INFO - Chain [1] done processing
11:04:37 - cmdstanpy - INFO - Chain [1] start processing
11:04:37 - cmdstanpy - INFO - Chain [1] done processing
11:04:38 - cmdstanpy - INFO - Chain [1] start processing
11:04:38 - cmdstanpy - INFO - Chain [1] start processing
11:04:38 - cmdstanpy - INFO - Chain [1] start processing
11:04:38 - cmdstanpy - INFO - Chain [1] done processing
11:04:38 - cmdstanpy - INFO - Chain [1] done processing
11:04:38 - cmds

+------------+---------+------------------+------------------+---------------------+
|City        |Pollutant|MAE               |RMSE              |R2                   |
+------------+---------+------------------+------------------+---------------------+
|Lucknow     |PM10     |0.0               |0.0               |1.0                  |
|Patna       |O3       |30.26969780233083 |31.40394454702936 |-7.032434206655509   |
|Ahmedabad   |Benzene  |2.166435715426622 |3.0996701224231678|-0.3843911324629523  |
|Brajrajnagar|Benzene  |60.8719755228628  |61.90499275600006 |-82.50548462996738   |
|Talcher     |O3       |13.822347901458267|17.830295632128223|0.2248649211222975   |
|Bengaluru   |PM25     |7.793336942823292 |9.73220624504468  |0.36634392596409937  |
|Hyderabad   |SO2      |3.29054429660866  |4.235782446049504 |-1.6781237313372075  |
|Shillong    |SO2      |7.674860575146957 |8.590273187111125 |-14.636059420399608  |
|Gurugram    |NH3      |10.333387978142074|15.732836341367769|-0.

In [10]:
pairs = [
    ("Delhi", "PM25"),
    ("Mumbai", "PM10"),
    ("Chennai", "NO2"),
    ('Bengaluru', 'O2')
]

condition = None
for city, pollutant in pairs:
    expr = (F.col("City") == city) & (F.col("Pollutant") == pollutant)
    condition = expr if condition is None else (condition | expr)

filtered_df = r2_df.filter(condition)
filtered_df.show(truncate=False)

11:06:01 - cmdstanpy - INFO - Chain [1] start processing
11:06:01 - cmdstanpy - INFO - Chain [1] start processing
11:06:01 - cmdstanpy - INFO - Chain [1] start processing
11:06:01 - cmdstanpy - INFO - Chain [1] start processing
11:06:01 - cmdstanpy - INFO - Chain [1] start processing
11:06:01 - cmdstanpy - INFO - Chain [1] start processing
11:06:02 - cmdstanpy - INFO - Chain [1] done processing
11:06:02 - cmdstanpy - INFO - Chain [1] done processing             (0 + 3) / 3]
11:06:02 - cmdstanpy - INFO - Chain [1] done processing
11:06:02 - cmdstanpy - INFO - Chain [1] done processing
11:06:02 - cmdstanpy - INFO - Chain [1] done processing
11:06:02 - cmdstanpy - INFO - Chain [1] done processing
11:06:02 - cmdstanpy - INFO - Chain [1] start processing
11:06:03 - cmdstanpy - INFO - Chain [1] done processing
11:06:03 - cmdstanpy - INFO - Chain [1] start processing
11:06:03 - cmdstanpy - INFO - Chain [1] start processing
11:06:03 - cmdstanpy - INFO - Chain [1] start processing
11:06:03 - cm

+-------+---------+-----------------+-----------------+--------------------+
|City   |Pollutant|MAE              |RMSE             |R2                  |
+-------+---------+-----------------+-----------------+--------------------+
|Mumbai |PM10     |39.95737052618425|55.16396239123597|-0.02349660934513076|
|Delhi  |PM25     |33.13171882674591|41.75688187561895|0.43894843331204236 |
|Chennai|NO2      |5.420187690296768|6.839270443068862|-1.5356558653558592 |
+-------+---------+-----------------+-----------------+--------------------+



                                                                                

In [11]:
avg_rmse = r2_df.agg(F.mean("RMSE").alias("avg_RMSE"))
avg_rmse.show()

11:07:33 - cmdstanpy - INFO - Chain [1] start processing
11:07:33 - cmdstanpy - INFO - Chain [1] start processing
11:07:33 - cmdstanpy - INFO - Chain [1] start processing
11:07:33 - cmdstanpy - INFO - Chain [1] start processing
11:07:33 - cmdstanpy - INFO - Chain [1] start processing
11:07:33 - cmdstanpy - INFO - Chain [1] start processing
11:07:33 - cmdstanpy - INFO - Chain [1] done processing
11:07:33 - cmdstanpy - INFO - Chain [1] done processing
11:07:33 - cmdstanpy - INFO - Chain [1] done processing
11:07:33 - cmdstanpy - INFO - Chain [1] done processing             (0 + 3) / 3]
11:07:33 - cmdstanpy - INFO - Chain [1] done processing
11:07:33 - cmdstanpy - INFO - Chain [1] done processing
11:07:34 - cmdstanpy - INFO - Chain [1] start processing
11:07:34 - cmdstanpy - INFO - Chain [1] start processing
11:07:34 - cmdstanpy - INFO - Chain [1] done processing
11:07:34 - cmdstanpy - INFO - Chain [1] done processing
11:07:34 - cmdstanpy - INFO - Chain [1] start processing
11:07:34 - cmd

+-----------------+
|         avg_RMSE|
+-----------------+
|25.26477026737655|
+-----------------+



                                                                                

## Prediction for 1 Year (2021-2022)

In [12]:
forecast_df = df_long.groupBy("City", "Pollutant").apply(forecast_city_pollutant)

In [13]:
forecast_df = forecast_df.withColumnRenamed("ds", "Date").withColumnRenamed("yhat", "Value")
wide_df = (forecast_df.groupBy("City", "Date").pivot("Pollutant").agg(F.first("Value")).orderBy("City", "Date"))

11:09:07 - cmdstanpy - INFO - Chain [1] start processing
11:09:07 - cmdstanpy - INFO - Chain [1] start processing
11:09:08 - cmdstanpy - INFO - Chain [1] start processing
11:09:08 - cmdstanpy - INFO - Chain [1] start processing
11:09:08 - cmdstanpy - INFO - Chain [1] done processing             (0 + 4) / 4]
11:09:08 - cmdstanpy - INFO - Chain [1] done processing
11:09:08 - cmdstanpy - INFO - Chain [1] done processing
11:09:08 - cmdstanpy - INFO - Chain [1] done processing
11:09:09 - cmdstanpy - INFO - Chain [1] start processing
11:09:09 - cmdstanpy - INFO - Chain [1] start processing
11:09:09 - cmdstanpy - INFO - Chain [1] start processing
11:09:09 - cmdstanpy - INFO - Chain [1] done processing
11:09:09 - cmdstanpy - INFO - Chain [1] done processing
11:09:09 - cmdstanpy - INFO - Chain [1] done processing
11:09:09 - cmdstanpy - INFO - Chain [1] start processing
11:09:09 - cmdstanpy - INFO - Chain [1] start processing
11:09:09 - cmdstanpy - INFO - Chain [1] done processing
11:09:09 - cmd

In [15]:
wide_df.write.mode("overwrite").parquet("hdfs://localhost:9000/output/forecast_results/")

11:10:26 - cmdstanpy - INFO - Chain [1] start processing
11:10:26 - cmdstanpy - INFO - Chain [1] start processing
11:10:26 - cmdstanpy - INFO - Chain [1] start processing
11:10:26 - cmdstanpy - INFO - Chain [1] start processing
11:10:26 - cmdstanpy - INFO - Chain [1] done processing
11:10:26 - cmdstanpy - INFO - Chain [1] done processing             (0 + 4) / 4]
11:10:26 - cmdstanpy - INFO - Chain [1] done processing
11:10:26 - cmdstanpy - INFO - Chain [1] done processing
11:10:27 - cmdstanpy - INFO - Chain [1] start processing
11:10:27 - cmdstanpy - INFO - Chain [1] start processing
11:10:27 - cmdstanpy - INFO - Chain [1] done processing
11:10:27 - cmdstanpy - INFO - Chain [1] start processing
11:10:27 - cmdstanpy - INFO - Chain [1] done processing
11:10:27 - cmdstanpy - INFO - Chain [1] start processing
11:10:27 - cmdstanpy - INFO - Chain [1] done processing
11:10:28 - cmdstanpy - INFO - Chain [1] start processing
11:10:28 - cmdstanpy - INFO - Chain [1] done processing
11:10:28 - cmd

In [16]:
spark.stop()