In [None]:
"""
PHASE 1:
Extract extremes ONLY on days where at least one fire occurred anywhere.
Then save the highest and lowest pollutant pixel for each fire day.
"""

import xarray as xr
import numpy as np
import pandas as pd
from tqdm import tqdm
import os

# -----------------------------
# INPUT FILES
# -----------------------------
pollutant_files = {
    'CO':    r"D:\IPMA\Results\co_fire_meteo_Greece.nc",
    'NO':    r"D:\IPMA\Results\no_fire_meteo_Greece.nc",
    'NO2':   r"D:\IPMA\Results\no2_fire_meteo_Greece.nc",
    'PM2.5': r"D:\IPMA\Results\pm2p5_fire_meteo_Greece.nc",
    'PM10':  r"D:\IPMA\Results\pm10_fire_meteo_Greece.nc"
}


# -------------------------------------------------
# FUNCTION TO EXTRACT EXTREMES ONLY ON FIRE DAYS
# -------------------------------------------------
def extract_and_save_extremes(pollutant, file_path):
    print(f"\n=== Processing {pollutant} ===")

    ds = xr.open_dataset(file_path)

    # Convert CO mg/m¬≥ ‚Üí ¬µg/m¬≥
    if pollutant == "CO" and "Mean" in ds:
        ds["Mean"] = ds["Mean"] * 1000.0

    if "fire_binary_Greece" not in ds:
        raise ValueError(f"Dataset {file_path} has no variable 'fire_binary_Greece'!")

    records = []

    # Loop through days
    for t in tqdm(ds.time.values, desc=f"Scanning fire days for {pollutant}", unit="day"):

        day = ds.sel(time=t)

        # -------------------------------------------------
        # Skip day if NO fire occurs anywhere
        # -------------------------------------------------
        has_fire = (day["fire_binary_Greece"].max().values == 1)

        if not has_fire:
            continue

        # Skip day if pollutant is entirely NaN
        if np.isnan(day["Mean"].values).all():
            continue

        # ---------------- Max pixel ----------------
        max_val = float(day["Mean"].max().values)
        max_loc = day.where(day["Mean"] == max_val, drop=True)

        max_lat = float(max_loc["latitude"].values[0])
        max_lon = float(max_loc["longitude"].values[0])

        records.append({
            "time": t,
            "latitude": max_lat,
            "longitude": max_lon,
            "Mean": max_val,
            "extreme": "high"
        })

        # ---------------- Min pixel ----------------
        min_val = float(day["Mean"].min().values)
        min_loc = day.where(day["Mean"] == min_val, drop=True)

        min_lat = float(min_loc["latitude"].values[0])
        min_lon = float(min_loc["longitude"].values[0])

        records.append({
            "time": t,
            "latitude": min_lat,
            "longitude": min_lon,
            "Mean": min_val,
            "extreme": "low"
        })

    # -------------------------------------------------
    # SAVE RESULTS
    # -------------------------------------------------
    df = pd.DataFrame(records)

    ds_out = xr.Dataset(
        {
            "Mean": ("record", df["Mean"].values),
            "latitude": ("record", df["latitude"].values),
            "longitude": ("record", df["longitude"].values),
            "extreme": ("record", df["extreme"].astype(str).values),
        },
        coords={"time": ("record", df["time"].values)},
    )

    base = os.path.splitext(file_path)[0]
    save_path = f"{base}_extremes_fires.nc"

    ds_out.to_netcdf(save_path)

    print(f"‚úî Saved {pollutant} extremes: {save_path}")

    ds.close()
    return save_path


# -----------------------------
# RUN FOR ALL POLLUTANTS
# -----------------------------
if __name__ == "__main__":
    saved_files = {}
    for pol, path in pollutant_files.items():
        saved_files[pol] = extract_and_save_extremes(pol, path)

    print("\n=== PHASE 1 COMPLETE ===")
    for pol, file in saved_files.items():
        print(f"{pol} -> {file}")


In [None]:
"""
PHASE 2: Rank extreme pollution days, compute meteorology composites, and plot
‚Ä¢ Works for all pollutants
‚Ä¢ Automatically detects extremes file dimension (e.g., 'record')
‚Ä¢ Uses 'time' variable to select from full dataset
‚Ä¢ Ignores NaNs and plots only coherent grid points
‚Ä¢ Zooms to area with valid data + padding
‚Ä¢ Side-by-side HIGH vs LOW meteorology composites, closer together
‚Ä¢ Only country borders, horizontal colorbar below
‚Ä¢ Human-readable meteorology variable names in titles
‚Ä¢ Custom colormaps per meteorological variable
"""

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import os

# ------------------------------
# USER INPUTS
# ------------------------------

pollutant_full_files = {
    'CO':    r"D:\IPMA\Results\co_fire_meteo_Greece.nc",
    'NO':    r"D:\IPMA\Results\no_fire_meteo_Greece.nc",
    'NO2':   r"D:\IPMA\Results\no2_fire_meteo_Greece.nc",
    'PM2.5': r"D:\IPMA\Results\pm2p5_fire_meteo_Greece.nc",
    'PM10':  r"D:\IPMA\Results\pm10_fire_meteo_Greece.nc"
}

pollutant_extreme_files = {
    'CO':    r"D:\IPMA\Results\co_fire_meteo_Greece_extremes_fires.nc",
    'NO':    r"D:\IPMA\Results\no_fire_meteo_Greece_extremes_fires.nc",
    'NO2':   r"D:\IPMA\Results\no2_fire_meteo_Greece_extremes_fires.nc",
    'PM2.5': r"D:\IPMA\Results\pm2p5_fire_meteo_Greece_extremes_fires.nc",
    'PM10':  r"D:\IPMA\Results\pm10_fire_meteo_Greece_extremes_fires.nc"
}

N_days = 25  # top high/low days

# Meteorological variables: variable -> (unit, colormap)
meteo_vars = {
    "precip_Total_Precipitation": ("m", "PuBu"),
    "temp_Max": ("¬∞C", "coolwarm"),
    "wind_Max": ("m/s", "Oranges")
}

# Human-readable titles
meteo_var_names = {
    "precip_Total_Precipitation": "Total Precipitation",
    "temp_Max": "Max Temperature",
    "wind_Max": "Max Wind Speed"
}

# ------------------------------
# PLOTTING FUNCTION
# ------------------------------
def plot_composites(pol, comp_max, comp_min):
    for var, (unit, cmap) in meteo_vars.items():
        if var not in comp_max:
            print(f"‚ö† WARNING: {var} not found for {pol}. Skipped.")
            continue

        data_max = comp_max[var]
        data_min = comp_min[var]

        # Mask only points valid in both composites
        valid_mask = np.isfinite(data_max) & np.isfinite(data_min)
        data_max_masked = data_max.where(valid_mask)
        data_min_masked = data_min.where(valid_mask)

        # Shared color scale
        vmin_all = np.nanmin([data_min_masked.min(), data_max_masked.min()])
        vmax_all = np.nanmax([data_min_masked.max(), data_max_masked.max()])

        lon_vals = comp_max["longitude"].values
        lat_vals = comp_max["latitude"].values

        # Indices of valid data
        valid_indices = np.where(valid_mask.values)
        lat_valid_idx = valid_indices[0]
        lon_valid_idx = valid_indices[1]

        # Min/max indices + 1-pixel padding
        lat_min_idx = max(lat_valid_idx.min() - 1, 0)
        lat_max_idx = min(lat_valid_idx.max() + 1, len(lat_vals)-1)
        lon_min_idx = max(lon_valid_idx.min() - 1, 0)
        lon_max_idx = min(lon_valid_idx.max() + 1, len(lon_vals)-1)

        lat_slice = slice(lat_min_idx, lat_max_idx+1)
        lon_slice = slice(lon_min_idx, lon_max_idx+1)

        lat_plot = lat_vals[lat_slice]
        lon_plot = lon_vals[lon_slice]

        # Grid edges for full pixel visibility
        lon_edges = np.zeros(len(lon_plot)+1)
        lon_edges[1:-1] = (lon_plot[:-1] + lon_plot[1:])/2
        lon_edges[0] = lon_plot[0] - (lon_plot[1]-lon_plot[0])/2
        lon_edges[-1] = lon_plot[-1] + (lon_plot[-1]-lon_plot[-2])/2

        lat_edges = np.zeros(len(lat_plot)+1)
        lat_edges[1:-1] = (lat_plot[:-1] + lat_plot[1:])/2
        lat_edges[0] = lat_plot[0] - (lat_plot[1]-lat_plot[0])/2
        lat_edges[-1] = lat_plot[-1] + (lat_plot[-1]-lat_plot[-2])/2

        data_max_plot = data_max_masked.values[lat_slice, lon_slice]
        data_min_plot = data_min_masked.values[lat_slice, lon_slice]

        # Subplots side by side, closer
        fig, axes = plt.subplots(1, 2, figsize=(9, 5),
                                 subplot_kw={'projection': ccrs.PlateCarree()},
                                 gridspec_kw={'wspace': 0.08})

        var_title = meteo_var_names.get(var, var)
        titles = [
            f"{pol} High Composite ‚Äì {var_title}",
            f"{pol} Low Composite ‚Äì {var_title}"
        ]

        for ax, data, title in zip(axes, [data_max_plot, data_min_plot], titles):
            img = ax.pcolormesh(
                lon_edges, lat_edges, data,
                cmap=cmap, vmin=vmin_all, vmax=vmax_all,
                transform=ccrs.PlateCarree()
            )
            ax.set_title(title, fontsize=11)
            ax.coastlines()
            ax.add_feature(cfeature.BORDERS, linewidth=0.8)
            ax.set_extent([lon_edges[0], lon_edges[-1], lat_edges[0], lat_edges[-1]],
                          crs=ccrs.PlateCarree())

        # Horizontal colorbar below
        cbar_ax = fig.add_axes([0.25, 0.05, 0.5, 0.03])
        cbar = fig.colorbar(img, cax=cbar_ax, orientation='horizontal')
        cbar.set_label(f"{unit}")

        plt.tight_layout(rect=[0, 0.07, 1, 1])
        plt.show()


# ------------------------------
# PHASE 2 MAIN LOOP
# ------------------------------
for pol in pollutant_extreme_files:

    print(f"\n======================================================")
    print(f"üå´Ô∏è PROCESSING POLLUTANT: {pol}")
    print(f"======================================================")

    ext_file = pollutant_extreme_files[pol]
    full_file = pollutant_full_files[pol]

    if not (os.path.exists(ext_file) and os.path.exists(full_file)):
        print(f"‚ùå Missing data for {pol}. Skipping!")
        continue

    ds_ext = xr.open_dataset(ext_file)
    ds_full = xr.open_dataset(full_file)

    if "Mean" not in ds_ext:
        print(f"‚ö† Missing 'Mean' in extremes for {pol}. Cannot rank.")
        continue

    dim_name = list(ds_ext.dims.keys())[0]

    ds_sorted = ds_ext.sortby(ds_ext["Mean"], ascending=False)

    top_max = ds_sorted.isel({dim_name: slice(0, N_days)})
    top_min = ds_sorted.isel({dim_name: slice(-N_days, None)})

    top_max_times = top_max["time"].values
    top_min_times = top_min["time"].values

    print(f"üìå {pol} -> Number of high days selected: {len(top_max_times)}")
    print(f"üìå {pol} -> Number of low days selected: {len(top_min_times)}")

    # Compute composites (mean over selected times)
    comp_max = ds_full.sel(time=top_max_times).mean(dim="time", skipna=True)
    comp_min = ds_full.sel(time=top_min_times).mean(dim="time", skipna=True)

    # ----------------- PLOT COMPOSITES -----------------
    print(f"\nüìç Plotting meteorology composites for {pol}...")
    plot_composites(pol, comp_max, comp_min)

print("\nüéâ PHASE 2 COMPLETED SUCCESSFULLY!")
