In [None]:
import os
import gc
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, mean, stddev, when, concat_ws
from pyspark.sql.window import Window

# Initialize Spark session
spark = SparkSession.builder.appName("YieldClimateCorrelation").getOrCreate()

# User sets the crop here (one at a time)
target_crop = "maize"  # Change to: "maize", "wheat", "soybean" as needed

# Define paths
data_folder = "csv_data"
climate_folder = "climate_data"
crop_file = os.path.join(data_folder, f"{target_crop}.csv")

enso_file = os.path.join(climate_folder, "ENSO.csv")
temperature_file = os.path.join(climate_folder, "temperature.csv")
drought_file = os.path.join(climate_folder, "drought.csv")
plot_folder = "plots"
os.makedirs(plot_folder, exist_ok=True)

# Load climate datasets
enso_df = spark.read.csv(enso_file, header=True, inferSchema=True) \
    .select(col("Year"), col("`Nino 3.4 SST Anomalies`").alias("ENSO_Index")) \
    .filter((col("Year") >= 1982) & (col("Year") <= 2016)) \
    .dropna()

temp_df = spark.read.csv(temperature_file, header=True, inferSchema=True) \
    .groupBy("Year").agg(mean("Temp_Anomaly").alias("Temp_Anomaly"))

drought_df = spark.read.csv(drought_file, header=True, inferSchema=True) \
    .groupBy("Year").agg(mean("Drought_Index").alias("Drought_Index"))

# Function to compute 5-year moving average and yield anomalies
def compute_anomalies(df):
    windowSpec = Window.partitionBy("Grid_ID").orderBy("Year").rowsBetween(-4, 0)
    df = df.withColumn("MovingAvg", mean("Yield").over(windowSpec))
    df = df.withColumn("Anomaly", col("Yield") - col("MovingAvg"))
    df = df.withColumn("StdDev", stddev("Anomaly").over(windowSpec))
    df = df.withColumn("ExtremeLow", when(col("Anomaly") < -2 * col("StdDev"), 1).otherwise(0))
    df = df.withColumn("ExtremeHigh", when(col("Anomaly") > 2 * col("StdDev"), 1).otherwise(0))
    return df

# Load and process target crop
print(f"\n=== Processing {target_crop.upper()} ===")
df = spark.read.csv(crop_file, header=True, inferSchema=True)
df = df.filter((col("year") >= 1982) & (col("year") <= 2016)) \
       .withColumn("Grid_ID", concat_ws("_", col("lat").cast("string"), col("lon").cast("string"))) \
       .select(col("year").alias("Year"), "Grid_ID", col("var").alias("Yield"))

df_anomalies = compute_anomalies(df)

# Join with global-averaged climate data
df_joined = df_anomalies.join(enso_df, on="Year", how="left") \
                         .join(temp_df, on="Year", how="left") \
                         .join(drought_df, on="Year", how="left")

df_clean = df_joined.dropna(subset=["Anomaly", "ENSO_Index", "Temp_Anomaly", "Drought_Index"])
print(f"{target_crop.upper()} - Valid data rows: {df_clean.count()}")

# Compute and display correlations
correlations = {
    "ENSO": df_clean.stat.corr("Anomaly", "ENSO_Index"),
    "Temp": df_clean.stat.corr("Anomaly", "Temp_Anomaly"),
    "Drought": df_clean.stat.corr("Anomaly", "Drought_Index"),
}
print("--- Climate correlations ---")
for name, val in correlations.items():
    print(f"{name} vs Yield Anomaly: {val:.4f}")

# Plot and save
df_plot = df_clean.select("Year", "Anomaly", "ENSO_Index").orderBy("Year").toPandas()
if not df_plot.empty:
    plt.figure(figsize=(12, 5))
    plt.plot(df_plot["Year"], df_plot["Anomaly"], label=f"{target_crop.capitalize()} Yield Anomaly")
    plt.plot(df_plot["Year"], df_plot["ENSO_Index"], label="ENSO Index", linestyle="--")
    plt.title(f"{target_crop.capitalize()} Yield Anomaly vs ENSO Index")
    plt.xlabel("Year")
    plt.ylabel("Anomaly / ENSO Index")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plot_path = os.path.join(plot_folder, f"{target_crop}_enso_plot.png")
    plt.savefig(plot_path)
    plt.close()
    print(f" Plot saved: {plot_path}")
else:
    print(f" No data to plot for {target_crop.upper()}")

# Free memory
del df, df_anomalies, df_joined, df_clean, df_plot
gc.collect()



=== Processing MAIZE ===
MAIZE - Valid data rows: 6139284
--- Climate correlations ---
ENSO vs Yield Anomaly: 0.0300
Temp vs Yield Anomaly: 0.0445
Drought vs Yield Anomaly: -0.0287
