In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

# ============================================================
# Choose mode: "FRP" or "POLLUTANTS"
# ============================================================
mode = "POLLUTANTS"  

# ============================================================
# File paths for each pollutant
# ============================================================
files = {
    'CO':     r"D:\IPMA\Results\co_fire_meteo_Greece.nc",
    'NO':     r"D:\IPMA\Results\no_fire_meteo_Greece.nc",
    'NO₂':    r"D:\IPMA\Results\no2_fire_meteo_Greece.nc",
    'PM₂.₅':  r"D:\IPMA\Results\pm2p5_fire_meteo_Greece.nc",
    'PM₁₀':   r"D:\IPMA\Results\pm10_fire_meteo_Greece.nc"
}

# ============================================================
# Units for pollutants
# ============================================================
units = {
    'CO': 'µg/m³',
    'NO': 'µg/m³',
    'NO₂': 'µg/m³',
    'PM₂.₅': 'µg/m³',
    'PM₁₀': 'µg/m³'
}

# ============================================================
# Selected meteorological variables (Label, Statistic)
# ============================================================
meteo_vars = {
    'precip_Total_Precipitation': ('Total Precipitation (m)', 'mean'),
    'temp_Max': ('Max Temperature (°C)', 'mean'),
    'wind_Max': ('Max Wind Speed (m/s)', 'mean')
}

# ============================================================
# Load sample dataset
# ============================================================
sample_file = list(files.values())[0]
ds_sample = xr.open_dataset(sample_file)

# ============================================================
# ✅ Select a specific year to plot
# ============================================================
selected_year = 2017  # <-- Change this to the year you want
time = pd.to_datetime(ds_sample.time.values)
year_mask = time.year == selected_year
ds_sample = ds_sample.sel(time=ds_sample.time[year_mask])

tmin, tmax = pd.to_datetime(time.min()), pd.to_datetime(time.max())

# ============================================================
# Colors
# ============================================================
pollutant_colors = {
    'CO': 'darkgreen',
    'NO': 'tab:red',
    'NO₂': 'tab:purple',
    'PM₂.₅': 'tab:orange',
    'PM₁₀': 'mediumvioletred'
}
frp_color = 'crimson'

# ============================================================
# Main loop
# ============================================================
for var_name, (label, stat) in meteo_vars.items():
    print(f"\nPlotting correlations for: {label} ({stat})")

    # Get meteorological variable (spatial statistic)
    da = ds_sample[var_name]
    if stat == 'mean':
        meteo = da.mean(dim=['latitude', 'longitude'])
    elif stat == 'sum':
        meteo = da.sum(dim=['latitude', 'longitude'])
    elif stat == 'max':
        meteo = da.max(dim=['latitude', 'longitude'])
    elif stat == 'min':
        meteo = da.min(dim=['latitude', 'longitude'])
    elif stat == 'median':
        meteo = da.median(dim=['latitude', 'longitude'])
    else:
        raise ValueError(f"Unsupported stat '{stat}' for variable '{var_name}'")

    meteo_values = meteo.values
    meteo_valid = np.isfinite(meteo_values)

    # ========================================================
    # MODE 1: FRP scatter comparison
    # ========================================================
    if mode == "FRP":
        frp = ds_sample['frp_sum_Greece'].mean(dim=['latitude', 'longitude']).values
        valid = meteo_valid & np.isfinite(frp)
        x = meteo_values[valid].reshape(-1, 1)
        y = frp[valid]

        # Scatter
        fig, ax = plt.subplots(figsize=(7, 5))
        ax.scatter(x, y, color=frp_color, alpha=0.7, s=30, label='FRP (MW)')

        # Labels and formatting
        ax.set_xlabel(label)
        ax.set_ylabel('FRP (MW)')
        ax.set_ylim(0, 10000)  # Adjust as needed
        ax.legend()
        ax.grid(False)
        ax.set_title(f'{label} vs FRP (Scatter Plot) - {selected_year}')
        plt.tight_layout()
        plt.show()

    # ========================================================
    # MODE 2: Pollutant scatter comparison
    # ========================================================
    elif mode == "POLLUTANTS":
        n_pollutants = len(files)
        fig, axs = plt.subplots(n_pollutants, 1, figsize=(7, 3 * n_pollutants), sharex=False)
        fig.suptitle(f'{label} vs Pollutants ({selected_year})', fontsize=14)

        for ax, (pollutant_name, file_path) in zip(axs, files.items()):
            ds = xr.open_dataset(file_path)

            # ✅ Apply same year filter to pollutant dataset
            time_pollutant = pd.to_datetime(ds.time.values)
            year_mask_pollutant = time_pollutant.year == selected_year
            ds = ds.sel(time=ds.time[year_mask_pollutant])

            pollutant = ds['Mean'].mean(dim=['latitude', 'longitude']).values
            if pollutant_name == 'CO':
                pollutant *= 1000  # mg/m³ → µg/m³

            valid = meteo_valid & np.isfinite(pollutant)
            x = meteo_values[valid].reshape(-1, 1)
            y = pollutant[valid]

            color = pollutant_colors.get(pollutant_name, 'tab:blue')
            ax.scatter(x, y, color=color, alpha=0.6, s=25, label=pollutant_name)
            ax.set_xlabel(f"{label}")
            ax.set_ylabel(f"{pollutant_name} ({units[pollutant_name]})")
            ax.legend()
            ax.grid(False)

            # ✅ Individual Y-axis limits per pollutant
            if pollutant_name == 'CO':
                ax.set_ylim(0, 4000)
            elif pollutant_name == 'NO':
                ax.set_ylim(0, 200)
            elif pollutant_name == 'NO₂':
                ax.set_ylim(0, 100)
            elif pollutant_name == 'PM₂.₅':
                ax.set_ylim(0, 500)
            elif pollutant_name == 'PM₁₀':
                ax.set_ylim(0, 700)

        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()
