In [None]:
"""
This notebook assesses the quality of the Sentinel-2 output by estimating the 
NDVI (Normalized Difference Vegetation Index) using the red and NIR bands, and 
verifying the consistency and completeness of the time series data.
"""

import sys
import os

# Set up imports
project_root = os.path.abspath("..") 
sys.path.append(project_root)  

from preprocessing.spark_session import spark  # Reuse the preconfigured SparkSession

In [None]:
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
import pyspark.sql.functions as F
import pandas as pd
import plotly.graph_objects as go

# Define NDVI function
def ndvi_calc(b04, b08):
    if b08 + b04 == 0:
        return None
    return float((b08 - b04) / (b08 + b04))

# Register as a Spark UDF
ndvi_udf = udf(ndvi_calc, FloatType())


# Helper function to show NDVI for a given location and year
def show_ndvi_for_point(df_s2_sampled, lon, lat, cdl, year, ndvi_udf):
    df_filtered = df_s2_sampled.where(
        (F.col("lon") == lon) &
        (F.col("lat") == lat) &
        (F.col("CDL") == cdl) &
        (F.col("year") == year)
    ).withColumn("NDVI", ndvi_udf(F.col("red"), F.col("nir")))

    df_filtered.show()
    return df_filtered

def plot_ndvi_timeseries(df, lon, lat, cdl, year):
    # Filter and compute NDVI
    df_ndvi = df.where(
        (F.col("lon") == lon) &
        (F.col("lat") == lat) &
        (F.col("CDL") == cdl) &
        (F.col("year") == year)
    ).withColumn("NDVI", ndvi_udf(F.col("red"), F.col("nir"))) \
     .select("scene_date", "NDVI") \
     .orderBy("scene_date")

    # Convert to Pandas DataFrame for Plotly
    pd_ndvi = df_ndvi.toPandas()
    pd_ndvi["scene_date"] = pd.to_datetime(pd_ndvi["scene_date"])

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=pd_ndvi["scene_date"],
        y=pd_ndvi["NDVI"],
        mode="lines+markers",
        name="NDVI"
    ))

    fig.update_layout(
        title=f"NDVI Time Series ({cdl}, {year})",
        xaxis_title="Date",
        yaxis_title="NDVI",
        xaxis=dict(tickformat="%Y-%m-%d"),
        height=400
    )

    fig.show()

In [None]:
# Read the Parquet dataset (adjust the path if needed)
df_s2_sampled = spark.read.parquet("../data/s2_unique_scene.parquet")
df_s2_sampled.show(5)

In [None]:
# Apply NDVI function to specific points of interest
show_ndvi_for_point(df_s2_sampled, -90.5923870410212, 35.57624322094004, "Cotton", 2019, ndvi_udf)
plot_ndvi_timeseries(df_s2_sampled, -90.5923870410212, 35.57624322094004, "Cotton", 2019)

In [None]:
show_ndvi_for_point(df_s2_sampled, -90.57506989498134, 35.575181872687295, "Rice", 2019, ndvi_udf)
plot_ndvi_timeseries(df_s2_sampled, -90.57506989498134, 35.575181872687295, "Rice", 2019,)

In [None]:
show_ndvi_for_point(df_s2_sampled, -90.55773411916678, 35.57438554009128, "Dbl Crop WinWht/Soybeans", 2019, ndvi_udf)
plot_ndvi_timeseries(df_s2_sampled, -90.55773411916678, 35.57438554009128, "Dbl Crop WinWht/Soybeans", 2019,)


In [None]:
show_ndvi_for_point(df_s2_sampled, -90.57745082006127, 35.56052757960141, "Soybeans", 2019, ndvi_udf)
plot_ndvi_timeseries(df_s2_sampled, -90.57745082006127, 35.56052757960141, "Soybeans", 2019)

In [None]:
show_ndvi_for_point(df_s2_sampled, -90.58575725883986, 35.57540310857804, "Corn", 2019, ndvi_udf)
plot_ndvi_timeseries(df_s2_sampled, -90.58575725883986, 35.57540310857804, "Corn", 2019)