# Common code for Climatologoy, Threshold, Mean STT ,SST Anomaly, SST Trend and MHW days correlation analysis

In [None]:
# Common code for Climatologoy, Threshold, Mean STT ,SST Anomaly, SST Trend and MHW days correlation analysis
#Run this code first before running other analysis codes

from glob import glob
from pathlib import Path
from typing import Dict, Tuple, Optional
import numpy as np, pandas as pd, xarray as xr
import matplotlib.pyplot as plt
import datetime as dt
import matplotlib.dates as mdates
import marineHeatWaves as mhw  

# convert np.NAN to np.nan to avoid dependency issues
import numpy as _np
if not hasattr(_np, "NaN"):
    _np.NaN = _np.nan

# File path
FILES_GLOB = "/home/Desktop/Noah_data_1982-2024_SST_daily_mean/sst.day.mean.*.nc"
VAR = "sst"
CLIM_YEARS: Tuple[int,int] = (1982, 2024)   
REGIONS: Dict[str, Dict[str, float]] = {   # modify the lat-lon bounds as per your region
    "Arabian Sea":    {"lon_min": 40.0, "lon_max": 78.0,  "lat_min": 0.0, "lat_max": 30.0},
    "Bay Of Bengal":   {"lon_min": 78.0, "lon_max": 110.0, "lat_min": 0.0, "lat_max": 30.0},
    "North Indian Ocean": {"lon_min": 40.0, "lon_max": 110.0, "lat_min": 0.0, "lat_max": 30.0},
}

SEASONS = {"DJF": [12, 1, 2],"MAM": [3, 4, 5],"JJA": [6, 7, 8],"SON": [9, 10, 11],}
season_order = ["DJF", "MAM", "JJA", "SON"] 
OUTDIR = Path("outputs"); OUTDIR.mkdir(parents=True, exist_ok=True)

# leap-aware reference year 
_ref = pd.date_range("2000-01-01", "2000-12-31", freq="D") 
_doy_to_month = pd.Series(_ref.month.values, index=np.arange(1, 367))
_month_to_doy0 = {m: (_doy_to_month[_doy_to_month == m].index.values - 1) for m in range(1, 13)}


def open_mfdataset(paths_glob: str, chunks={"time": 120}, engine: str = "netcdf4") -> xr.Dataset:
    paths = sorted(glob(paths_glob))
    if not paths:
        raise FileNotFoundError(f"No files match: {paths_glob}")
    print(f"[open] {len(paths)} files")
    return xr.open_mfdataset(paths, combine="by_coords", parallel=True, chunks=chunks, engine=engine)

def subset_box(ds: xr.Dataset, box: Dict[str, float]) -> xr.Dataset:
    latn = "lat" if "lat" in ds.coords else "latitude"
    lonn = "lon" if "lon" in ds.coords else "longitude"
    return ds.sel({latn: slice(box["lat_min"], box["lat_max"]),lonn: slice(box["lon_min"], box["lon_max"])})

def area_weighted_boxmean(da: xr.DataArray) -> xr.DataArray:
    latn = "lat" if "lat" in da.coords else "latitude"
    w = np.cos(np.deg2rad(da[latn]))
    return da.weighted(w).mean(dim=[latn, "lon" if "lon" in da.coords else "longitude"])


# Following script calcultes and plot the Long-term Climatology, 90th and 80th percentile Threshold of Arabian Sea, Bay Of Bengal, North Indian Ocean

In [None]:
# Long-term Climatology, 90th and 80th percentile Threshold of Arabian Sea, Bay Of Bengal, North Indian Ocean

def mhw_seas_thresh_doy(da_box: xr.DataArray, clim_years: Tuple[int,int]) -> Tuple[np.ndarray, np.ndarray]:

    t_index = da_box["time"].to_index()

    # safety-clip baseline to data span
    y0, y1 = max(clim_years[0], t_index.year.min()), min(clim_years[1], t_index.year.max())

    ords = np.array([d.toordinal() for d in t_index], dtype=int)
    temp = da_box.values.astype(float)

    # clim['seas'] and clim['thresh'] for the 90th percentile 
    res_90th, clim_90th = mhw.detect(ords, temp, climatologyPeriod=[int(y0), int(y1)], pctile=90, minDuration=5, joinAcrossGaps=True)
    seas_full   = np.asarray(clim_90th["seas"])
    thresh_90th_full = np.asarray(clim_90th["thresh"])

    # clim['seas'] and clim['thresh'] for the 80th percentile 
    res_80th, clim_80th = mhw.detect(ords, temp, climatologyPeriod=[int(y0), int(y1)], pctile=80, minDuration=5, joinAcrossGaps=True)
    thresh_80th_full = np.asarray(clim_80th["thresh"])

    # Group by day-of-year to get 366-length curves (leap-aware).
    doy = t_index.dayofyear.values
    seas_366   = np.full(366, np.nan, float)
    thresh_90th_366 = np.full(366, np.nan, float)
    thresh_80th_366 = np.full(366, np.nan, float)
    for d in range(1, 367): 
        m = (doy == d)
        if m.any():
            seas_366[d-1]   = np.nanmean(seas_full[m])
            thresh_90th_366[d-1] = np.nanmean(thresh_90th_full[m])
            thresh_80th_366[d-1] = np.nanmean(thresh_80th_full[m])
    return seas_366, thresh_90th_366, thresh_80th_366

# Load dataset
ds_annual= open_mfdataset(FILES_GLOB)
# Compute curves per region and store
curves_annual = {}  
for name, box in REGIONS.items():
    da = subset_box(ds_annual[[VAR]], box)[VAR]
    da_box = area_weighted_boxmean(da)
    seas_doy, thresh_90th_doy, thresh_80th_doy = mhw_seas_thresh_doy(da_box, CLIM_YEARS)
    curves_annual[name] = {"seas": seas_doy, "90th thresh": thresh_90th_doy, "80th thresh": thresh_80th_doy}

# Plotting three stacked panels with shared X-axis
x_dates = pd.date_range("2000-01-01", "2000-12-31", freq="D")  
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(11, 8), sharex=True, constrained_layout=True)

# use common y-limits for comparability
all_vals = np.concatenate([np.r_[v["seas"], v["90th thresh"], v["80th thresh"]] for v in curves_annual.values()])
ymin = float(np.nanmin(all_vals)) - 0.2
ymax = float(np.nanmax(all_vals)) + 0.2 

for ax, (name, ct) in zip(axes, curves_annual.items()):
    ax.plot(x_dates, ct["seas"],   label="Climatology (Annual)", color="C0", lw=1.6)
    ax.plot(x_dates, ct["90th thresh"], label="90th perc Threshold", ls="--", color= "#ff7f0e", lw=1.6)
    ax.plot(x_dates, ct["80th thresh"], label="80th perc Threshold", ls="--", color= "#24fa11", lw=1.6)
    ax.set_ylabel("Temp (°C)")
    ax.set_title(f"{name}: Long- Term Climatology, 90th and 80th perc Threshold ({CLIM_YEARS[0]}–{CLIM_YEARS[1]})")
    ax.legend(loc="upper right")
    ax.set_ylim(ymin, ymax)

year = 2000  
tick_dates = []
for m in range(1, 13):
    tick_dates.append(pd.Timestamp(year, m, 1))
    tick_dates.append(pd.Timestamp(year, m, 15))
tick_dates.append(pd.Timestamp(year, 12, 31))  

axes[-1].set_xticks(tick_dates)
axes[-1].xaxis.set_major_formatter(mdates.DateFormatter("%d %b"))
axes[-1].set_xlim(x_dates[0], x_dates[-1])
axes[-1].set_xlabel("Day of Year")
fig.autofmt_xdate()


# The Following Script Calcultes and plot the Monthly Climatology, 90th & 80th percentile thresholds for Arabian Sea, Bay Of Bengal, North Indian Ocean

In [None]:
# Monthly Climatology, 90th & 80th percentile thresholds for Arabian Sea, Bay Of Bengal, North Indian Ocean

from glob import glob
from pathlib import Path
from typing import Dict, Tuple
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import marineHeatWaves as mhw  


FILES_GLOB = "/home/Desktop/Noah_data_1982-2024_SST_daily_mean/sst.day.mean.*.nc"
VAR = "sst"
CLIM_YEARS: Tuple[int, int] = (1982, 2024)
REGIONS: Dict[str, Dict[str, float]] = {
    "Arabian Sea":      {"lon_min": 20.0, "lon_max": 78.0,  "lat_min": 0.0, "lat_max": 25.0},
    "Bay Of Bengal":    {"lon_min": 78.0, "lon_max": 100.0, "lat_min": 0.0, "lat_max": 25.0},
    "North Indian Ocean": {"lon_min": 20.0, "lon_max": 100.0, "lat_min": 0.0, "lat_max": 25.0},
}
OUTDIR = Path("outputs"); OUTDIR.mkdir(parents=True, exist_ok=True)

# leap-aware reference year 
_ref = pd.date_range("2000-01-01", "2000-12-31", freq="D") 
_doy_to_month = pd.Series(_ref.month.values, index=np.arange(1, 367))

_month_to_doy0 = {m: (_doy_to_month[_doy_to_month == m].index.values - 1) for m in range(1, 13)}


def open_mfdataset(paths_glob: str, chunks={"time": 120}, engine: str = "netcdf4") -> xr.Dataset:
    paths = sorted(glob(paths_glob))
    if not paths:
        raise FileNotFoundError(f"No files match: {paths_glob}")
    return xr.open_mfdataset(paths, combine="by_coords", parallel=True, chunks=chunks, engine=engine)

def subset_box(ds: xr.Dataset, box: Dict[str, float]) -> xr.Dataset:
    latn = "lat" if "lat" in ds.coords else "latitude"
    lonn = "lon" if "lon" in ds.coords else "longitude"
    return ds.sel({latn: slice(box["lat_min"], box["lat_max"]),
                   lonn: slice(box["lon_min"], box["lon_max"])})

def area_weighted_boxmean(da: xr.DataArray) -> xr.DataArray:
    latn = "lat" if "lat" in da.coords else "latitude"
    w = np.cos(np.deg2rad(da[latn]))
    return da.weighted(w).mean(dim=[latn, "lon" if "lon" in da.coords else "longitude"])

# doad dataset
ds = open_mfdataset(FILES_GLOB, chunks={"time": 120}, engine="netcdf4")
assert VAR in ds, f"{VAR} not found in dataset"

# curve for each region
results = {}

for name, box in REGIONS.items():
    da = subset_box(ds[[VAR]], box)[VAR]
    da_box = area_weighted_boxmean(da).rename(f"sst_boxmean_{name}")
    da_box = da_box.compute()
    time_np = pd.to_datetime(da_box.time.values)          
    temp_np = np.asarray(da_box.values, dtype=float)     
    assert temp_np.shape[0] == time_np.shape[0], "time/temperature length mismatch"

    y0 = max(CLIM_YEARS[0], int(time_np.year.min()))
    y1 = min(CLIM_YEARS[1], int(time_np.year.max()))
    ords = np.array([d.toordinal() for d in time_np], dtype=int)

    # clim['seas'] and clim['thresh'] for the 90th percentile 
    res90, clim90 = mhw.detect(ords, temp_np, climatologyPeriod=[y0, y1],pctile=90, minDuration=5, joinAcrossGaps=True)
    seas_full     = np.asarray(clim90["seas"],   dtype=float)  # daily, same length as time_np
    thresh90_full = np.asarray(clim90["thresh"], dtype=float)
    assert seas_full.shape[0] == temp_np.shape[0] == time_np.shape[0]

    # -clim['seas'] and clim['thresh'] for the 80th percentile 
    res80, clim80 = mhw.detect(
        ords, temp_np, climatologyPeriod=[y0, y1],
        pctile=80, minDuration=5, joinAcrossGaps=True
    )
    thresh80_full = np.asarray(clim80["thresh"], dtype=float)
    assert thresh80_full.shape[0] == temp_np.shape[0]

    df = pd.DataFrame({"seas": seas_full, "p90": thresh90_full, "p80": thresh80_full}, index=time_np)
    g = df.groupby(df.index.dayofyear).mean(numeric_only=True)
    days = np.arange(1, 367)
    g = g.reindex(days)

    seas_366        = g["seas"].to_numpy()
    thresh_90th_366 = g["p90"].to_numpy()
    thresh_80th_366 = g["p80"].to_numpy()

    months = np.arange(1, 13)
    monthly_seas = np.array([np.nanmean(seas_366[_month_to_doy0[m]])        for m in months])
    monthly_90   = np.array([np.nanmean(thresh_90th_366[_month_to_doy0[m]]) for m in months])
    monthly_80   = np.array([np.nanmean(thresh_80th_366[_month_to_doy0[m]]) for m in months])

    results[name] = {
        "monthly_seas": monthly_seas,
        "monthly_90_thresh": monthly_90,
        "monthly_80_thresh": monthly_80,
    }

# plot the combined axis graphs
month_labels = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]
x = np.arange(1, 13)
region_order = list(REGIONS.keys())

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(11, 8), sharex=True, constrained_layout=True)

# Common y-limits for comparability
all_vals = []
for nm in region_order:
    all_vals.extend([
        results[nm]["monthly_seas"],
        results[nm]["monthly_90_thresh"],
        results[nm]["monthly_80_thresh"],
    ])
all_vals = np.concatenate(all_vals)
ymin = float(np.nanmin(all_vals)) - 0.2
ymax = float(np.nanmax(all_vals)) + 0.2

for ax, nm in zip(axes, region_order):
    ms   = results[nm]["monthly_seas"]
    mt90 = results[nm]["monthly_90_thresh"]
    mt80 = results[nm]["monthly_80_thresh"]

    ax.plot(x, ms,   marker="o", label="Climatology (Monthly)", color="C0",lw=1.6)
    ax.plot(x, mt90, marker="s", label="90th perc Threshold",   ls="--", color="#ff7f0e", lw=1.6)
    ax.plot(x, mt80, marker="^", label="80th perc Threshold",   ls=":",  color="C2",lw=1.6)

    ax.set_ylabel("Temp (°C)")
    ax.set_title(f"{nm}: Monthly Climatology, 90th & 80th Thresholds ({CLIM_YEARS[0]}–{CLIM_YEARS[1]})")
    ax.legend(loc="upper right")
    ax.set_ylim(ymin, ymax)

axes[-1].set_xticks(x)
axes[-1].set_xticklabels(month_labels)
axes[-1].set_xlabel("Month")

plt.show()



# The Following Script Calcultes and plot the Seasonal Climatology, 90th and 80th percentile Threshold of Arabian Sea, Bay Of Bengal, North Indian Ocean

In [None]:
# Seasonal Climatology, 90th and 80th percentile Threshold of Arabian Sea, Bay Of Bengal, North Indian Ocean

from glob import glob
from pathlib import Path
import numpy as np, pandas as pd, xarray as xr
import matplotlib.pyplot as plt

CLIM_YEARS: Tuple[int,int] = (1982, 2024) 
SEASONS = {
    "DJF": [12, 1, 2],
    "MAM": [3, 4, 5],
    "JJA": [6, 7, 8],
    "SON": [9, 10, 11],
}

season_order = ["DJF", "MAM", "JJA", "SON"]
OUTDIR = Path("outputs"); OUTDIR.mkdir(parents=True, exist_ok=True)

# Leap-aware DOY→month mapping (using leap year 2000)
ref = pd.date_range("2000-01-01", "2000-12-31", freq="D")
doy_month = ref.month.values              
doy_index = np.arange(1, 367)             
# Precompute mask for each season over DOY
season_to_doymask0 = {
    s: np.isin(doy_month, months)         
    for s, months in SEASONS.items()
}

def open_ds(globpat, chunks={"time": 120}, engine="netcdf4"):
    return xr.open_mfdataset(sorted(glob(globpat)), combine="by_coords",
                             chunks=chunks, engine=engine)

def subset_box(ds: xr.Dataset, box: Dict[str, float]) -> xr.Dataset:
    latn = "lat" if "lat" in ds.coords else "latitude"
    lonn = "lon" if "lon" in ds.coords else "longitude"
    return ds.sel({latn: slice(box["lat_min"], box["lat_max"]),
                   lonn: slice(box["lon_min"], box["lon_max"])})

def area_weighted_boxmean(da: xr.DataArray) -> xr.DataArray:
    latn = "lat" if "lat" in da.coords else "latitude"
    w = np.cos(np.deg2rad(da[latn]))
    return da.weighted(w).mean(dim=[latn, "lon" if "lon" in da.coords else "longitude"])

def boxmean_doy_curves_seas(da_box: xr.DataArray, baseline=(1982,2024)):

    da_box = da_box.compute()

    t = pd.to_datetime(da_box.time.values)        
    temp = np.asarray(da_box.values, dtype=float)
    assert len(t) == len(temp), "time/temperature length mismatch"

    # Clip baseline to available data
    y0 = max(baseline[0], int(t.year.min()))
    y1 = min(baseline[1], int(t.year.max()))

    ords = np.array([d.toordinal() for d in t], dtype=int)

    # clim['seas'] and clim['thresh'] for the 90th percentile 
    res90, clim90 = mhw.detect(
        ords, temp, climatologyPeriod=[y0, y1],
        pctile=90, minDuration=5, joinAcrossGaps=True
    )
    seas_full     = np.asarray(clim90["seas"],   dtype=float)  # same seas for any pctile
    thresh90_full = np.asarray(clim90["thresh"], dtype=float)
    assert seas_full.shape[0] == len(t) and thresh90_full.shape[0] == len(t)

    # clim['seas'] and clim['thresh'] for the 80th percentile 
    res80, clim80 = mhw.detect(
        ords, temp, climatologyPeriod=[y0, y1],
        pctile=80, minDuration=5, joinAcrossGaps=True
    )
    thresh80_full = np.asarray(clim80["thresh"], dtype=float)
    assert thresh80_full.shape[0] == len(t)

    # Build DataFrame aligned on SAME index, then DOY groupby
    df = pd.DataFrame(
        {"seas": seas_full, "p90": thresh90_full, "p80": thresh80_full},
        index=t
    )
    g = df.groupby(df.index.dayofyear).mean(numeric_only=True)   
    g = g.reindex(doy_index)                                    

    return g["seas"].to_numpy(), g["p90"].to_numpy(), g["p80"].to_numpy()

#load dataset
ds0 = open_ds(FILES_GLOB, engine="netcdf4")

results = {}
for name, box in REGIONS.items():
    ds_r   = subset_box(ds0[[VAR]], box)
    da_box = area_weighted_boxmean(ds_r[VAR]).rename("sst_boxmean")

    seas_doy, p90_doy, p80_doy = boxmean_doy_curves_seas(da_box, CLIM_YEARS)

    # Seasonal aggregation from DOY
    seas_vals   = []
    p90_vals    = []
    p80_vals    = []
    for s in season_order:
        mask0 = season_to_doymask0[s]     
        seas_vals.append(  np.nanmean(seas_doy[mask0]) )
        p90_vals.append(   np.nanmean(p90_doy[mask0]) )
        p80_vals.append(   np.nanmean(p80_doy[mask0]) )
    results[name] = {
        "seas": np.array(seas_vals),
        "p90":  np.array(p90_vals),
        "p80":  np.array(p80_vals),
    }

# plot the combined axis graphs
x = np.arange(4)  # 4 seasons
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(11, 8), sharex=True, constrained_layout=True)

# common y-limits across regions and curves
all_vals = np.concatenate([np.r_[v["seas"], v["p90"], v["p80"]] for v in results.values()])
ymin = float(np.nanmin(all_vals)) - 0.2
ymax = float(np.nanmax(all_vals)) + 0.2

for ax, nm in zip(axes, results.keys()):
    ms   = results[nm]["seas"]
    mt90 = results[nm]["p90"]
    mt80 = results[nm]["p80"]

    ax.plot(x, ms,   marker="o",  lw=1.6, color="C0",     label="Climatology (Seasonal)")
    ax.plot(x, mt90, marker="s",  lw=1.6, ls="--", color="#ff7f0e", label="90th percentile threshold")
    ax.plot(x, mt80, marker="^",  lw=1.6, ls=":",  color="C2",      label="80th percentile threshold")
    ax.set_ylabel("Temp (°C)")
    ax.set_title(f"{nm}: Seasonal Climatology, 90th & 80th thresholds ({CLIM_YEARS[0]}–{CLIM_YEARS[1]})")
    ax.legend(loc="upper right")
    ax.set_ylim(ymin, ymax)

axes[-1].set_xticks(x, season_order)
axes[-1].set_xlabel("Season")
plt.show()
