In [None]:
# primary requirements: # python3, numpy, pandas, xarray, dask, bottleneck, netCDF4, cftime, matplotlib, marineHeatWaves
# run this script to check if all dependencies are installed and working
# also fixes np.NaN to np.nan in marineHeatWaves module

import sys, platform
import numpy as np
import pandas as pd
import xarray as xr, dask, bottleneck, netCDF4, cftime
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Union
import matplotlib.pyplot as plt
import os
import matplotlib.dates as mdates
import marineHeatWaves as mhw, re, pathlib
from importlib.metadata import version, PackageNotFoundError

# convert np.NAN to np.nan to avoid dependency issues
if not hasattr(np, "NaN"):
    np.NaN = np.nan  

#Check the package is working
try:
    mhw_ver = version('marineHeatWaves')
except PackageNotFoundError:
    mhw_ver = getattr(mhw, '__version__', 'unknown')
print('OK:', pd.__version__, xr.__version__, dask.__version__, 'mhw:', mhw_ver)
print('mhw module file:', getattr(mhw, '__file__', 'unknown'))

#checkk the marineHeatWaves module file and replace np.NaN with np.nan
p = pathlib.Path(mhw.__file__)
p.write_text(re.sub(r'\bnp\.NaN\b', 'np.nan', p.read_text()))

print(sys.executable)
print(platform.python_version())


In [None]:
#check if marineHeatWaves package is working

t = pd.date_range("2000-01-01", periods=40, freq="D")
ords = np.array([d.toordinal() for d in t], dtype=int)
temp = 25 + np.sin(np.linspace(0, 6.28, 40)); temp[15:25] += 2

res, clim = mhw.detect(ords, temp, climatologyPeriod=[2000, 2000], pctile=90, minDuration=5, joinAcrossGaps=True) #hobdey's defination
print("events:", len(res.get("time_start", [])))

# Verifies that core libs are importable and compatible with marineHeatWaves.
from importlib import import_module
import numpy as np
import marineHeatWaves as mhw

want = ["numpy", "pandas", "xarray", "dask", "bottleneck","netCDF4", "cftime", "marineHeatWaves"]
vers = {}
problems = []

for m in want:
    try:
        import_module(m)
        try:
            vers[m] = version(m)
        except PackageNotFoundError:
            mod = import_module(m)
            vers[m] = getattr(mod, "__version__", "unknown")
    except Exception as e:
        problems.append((m, str(e)))
        vers[m] = "IMPORT FAILED"

print("Package versions detailes:")
for k in want:
    print(f"{k:16s} : {vers[k]}")
if problems:
    print("\n[ENV WARNING] Some packages failed to import:")
    for m, msg in problems:
        print(f" - {m}: {msg}")

# marineHeatWaves.detect accepts ordinal ints + 1D temp array
try:
    # make a synthetic 40-day series with a warm spell
    days = np.array([pd.Timestamp("2000-01-01") + pd.Timedelta(i, "D") for i in range(40)])
    ords = np.array([d.toordinal() for d in days], dtype=int)
    temp = 25 + np.sin(np.linspace(0, 6.28, 40))
    temp[15:25] += 2.0  # warm event
    res, clim = mhw.detect(ords, temp, climatologyPeriod=[1991, 2020], pctile=90, minDuration=5, joinAcrossGaps=True)
    print("\nmarineHeatWaves.detect is working correctly.")
    print(f"Detected events: {len(res.get('time_start', []))}")
except Exception as e:
    print("\n[ENV ERROR] marineHeatWaves.detect detection failed:", e)


In [None]:
# Checking the files and Summarize the data

from glob import glob
import numpy as np, pandas as pd

# file path: data loading and visualization
FILES_GLOB = "/home/Desktop/Noah_data_1982-2024_SST_daily_mean/sst.day.mean.*.nc"

# region of interest
ROI = dict(lat_min=0.0, lat_max=30.0, lon_min=40.0, lon_max=110.0)

# North indian ocean subregions
REGIONS = {
    "Arabian Sea":    {"lon_min": 40.0, "lon_max": 78.0,  "lat_min": 0.0, "lat_max": 30.0},
    "Bay Of Bengal":   {"lon_min": 78.0, "lon_max": 110.0, "lat_min": 0.0, "lat_max": 30.0},
    "North Indian Ocean": {"lon_min": 40.0, "lon_max": 110.0, "lat_min": 0.0, "lat_max": 30.0},
}

# IO / plotting
use_dask = True                 # set False to fully load into memory
chunks = {"time": 90} if use_dask else None
sample_map_date: Optional[str] = None  
boxmean_var_name = "sst"       
climatologyPeriod = [1982, 2024]  

# loading the data
def open_sst(files, chunks=None, engine: str = "netcdf4") -> xr.Dataset:
    if isinstance(files, str):
        p = Path(files)
        if any(ch in files for ch in "*?[]"):
            paths = sorted(glob(files))
            if not paths:
                raise FileNotFoundError(f"No files match glob: {files}")
            return xr.open_mfdataset(paths, combine="by_coords", parallel=True, chunks=chunks, engine=engine)
        else:
            if not p.exists():
                raise FileNotFoundError(f"File not found: {files}")
            return xr.open_dataset(files, chunks=chunks, engine=engine)
    else:
        paths = [str(Path(f)) for f in files]
        for f in paths:
            if not Path(f).exists():
                raise FileNotFoundError(f"File not found: {f}")
        return xr.open_mfdataset(paths, combine="by_coords", parallel=True, chunks=chunks, engine=engine)

# Subset to region of interest
def subset_roi(ds: xr.Dataset, roi: Dict[str, float], var: str = "sst") -> xr.Dataset:
    # Make sure coords are named commonly
    lat_name = "lat" if "lat" in ds.coords else "latitude"
    lon_name = "lon" if "lon" in ds.coords else "longitude"

    ds2 = ds.sel(**{lat_name: slice(roi["lat_min"], roi["lat_max"]),lon_name: slice(roi["lon_min"], roi["lon_max"]),})
    # Keep only the variable of interest and coords
    if var in ds2:
        return ds2[[var]]
    else:
        raise KeyError(f"Variable '{var}' not found. Available: {list(ds2.data_vars)}")

# Print metadata
def print_metadata(ds: xr.Dataset, var: str = "sst") -> None:
    # Dimention sizes
    for c in ds.coords:
        arr = ds.coords[c]
        try:
            vals = arr.values
            preview = f"{vals[:3]} ... {vals[-3:]}" if vals.size > 6 else vals
        except Exception:
            preview = arr
        print(f"{c}: {preview}")

    # convert time 
    if "time" in ds.coords:
        t = pd.to_datetime(ds["time"].values)
        print("\nTime Coverage:")
        print(f"Start: {pd.Timestamp(t[0]).date()}, End: {pd.Timestamp(t[-1]).date()}, Length: {t.size} steps")

    # Global attrs
    print("\nGlobal attributes:")
    for k, v in ds.attrs.items():
        print(f"{k}: {v}")

    # Variable attrs
    if var in ds:
        print(f"\nVariable '{var}' attributes:")
        for k, v in ds[var].attrs.items():
            print(f"{k}: {v}") 
            print("Dimensions:")
    for k, v in ds.dims.items():
        print(f"{k}: {v}")

    # Coordinates preview
    print("\nCoordinates (first few):")
   
def summarize_events_table(res: dict) -> pd.DataFrame:
    """
    Convert marineHeatWaves 'res' dict to a tidy event table.'time_*' in res are *ordinal* days (Python datetime.toordinal).
    """
    to_ts = lambda arr: pd.to_datetime([pd.Timestamp.fromordinal(int(d)) for d in np.asarray(arr)])

    return pd.DataFrame({
        "start_date":                   to_ts(res["time_start"]),
        "end_date":                     to_ts(res["time_end"]),
        "duration_days":                res["duration"],
        "intensity_max_degC":           res["intensity_max"],
        "intensity_mean_degC":          res["intensity_mean"],
        "cumulative_intensity_degC":    res["intensity_cumulative"],
    })

# Open files
ds = open_sst(FILES_GLOB, chunks={"time": 120}, engine="netcdf4")

# Subset ROI &  'sst'
ds_roi = subset_roi(ds, ROI, var=boxmean_var_name)
da = ds_roi[boxmean_var_name]

# Inspect metadata
print_metadata(ds_roi, var=boxmean_var_name)


# Common script to calculate the marine heat wave for Arabian Sea, Bay of Bengal and North India Ocean

In [None]:
# common script to calculate the marine heat wave 
# Imports the libraries and dependencies
from glob import glob
from pathlib import Path
import numpy as np, pandas as pd, xarray as xr
import matplotlib.pyplot as plt
from dask.diagnostics import ProgressBar

# Oliver’s package (Hobday method implementation)
import marineHeatWaves as mhw

# files path
FILES_GLOB = "/home/Noah_data_1982-2024_SST_daily_mean/sst.day.mean.*.nc" 
SST_VAR    = "sst"

# define the baseline year or analysis time-frame
BASELINE = (1982, 2024)

# Event definition based on Hobday et al. (2016)
MIN_DUR  = 5     # minimum event duration (days)
MAX_GAP  = 2     # join across gaps up to this many days

# Dask-friendly chunks
CHTIME, CHXY = 160, 40

# Set the regions
ROI_DICT = {
    "Arabian Sea": {"lon_min": 40.0, "lon_max": 80.0, "lat_min":  0.0, "lat_max": 30.0, "slug": "arabian_sea"},
    "Bay Of Bengal": {"lon_min": 80.0, "lon_max": 110.0,"lat_min":  0.0, "lat_max":  30.0, "slug": "bay_of_bengal"},
    "North Indian Ocean": {"lon_min": 40.0, "lon_max": 110.0,"lat_min":  0.0, "lat_max":  30.0,"slug": "north_indian_ocean"},
}

OUTROOT = Path("outputs_mhw"); OUTROOT.mkdir(parents=True, exist_ok=True)

# open SST data using xarray
def open_sst(files_glob: str, roi: dict) -> tuple[xr.Dataset, str, str]:
    """Open OISST, subset ROI, return dataset and coordinate names."""
    paths = sorted(glob(files_glob))
    if not paths:
        raise FileNotFoundError(f"No files match: {files_glob}")
    ds = xr.open_mfdataset(paths, combine="by_coords",chunks={"time": CHTIME}, engine="netcdf4")
    latn = "lat" if "lat" in ds.coords else "latitude"
    lonn = "lon" if "lon" in ds.coords else "longitude"
    ds = ds.sel({latn: slice(roi["lat_min"], roi["lat_max"]),lonn: slice(roi["lon_min"], roi["lon_max"])})
    return ds, latn, lonn
# calculate the area of the grids
def area_weights_1d(ds: xr.Dataset, latn: str) -> xr.DataArray:
    """1-D cos(lat) weights (finite) broadcastable across lon."""
    return np.cos(np.deg2rad(ds[latn]))

# use the hobdey's defination on region
def _detect_mask_1d(sst_1d: np.ndarray, thresh_1d: np.ndarray, min_dur: int = MIN_DUR, max_gap: int = MAX_GAP) -> np.ndarray:
    """
    Hobday event mask on a single time series: exceed >= threshold; join gaps <= max_gap; drop runs < min_dur.
    """
    ok = np.isfinite(sst_1d) & np.isfinite(thresh_1d)
    exc = ok & (sst_1d >= thresh_1d)
    if not exc.any():
        return exc
    x = exc.astype(np.int8)
    n = x.size

    # join short gaps
    i = 0
    while i < n:
        if x[i] == 1:
            j = i + 1
            while j < n and x[j] == 1:
                j += 1
            g0 = j
            while g0 < n and x[g0] == 0:
                g0 += 1
            gap_len = g0 - j
            if gap_len > 0 and gap_len <= max_gap:
                x[j:g0] = 1
                j = g0
            i = j
        else:
            i += 1

    # enforce min duration
    y = x.copy()
    i = 0
    while i < n:
        if y[i] == 1:
            j = i + 1
            while j < n and y[j] == 1:
                j += 1
            if (j - i) < min_dur:
                y[i:j] = 0
            i = j
        else:
            i += 1

    return y.astype(bool)

def detect_mask_time(sst: xr.DataArray, thresh: xr.DataArray) -> xr.DataArray:
    """ Vectorized Hobday event mask over (time, lat, lon). Requires single time-chunk."""
    return xr.apply_ufunc(_detect_mask_1d, sst, thresh, input_core_dims=[["time"], ["time"]], output_core_dims=[["time"]],
                          vectorize=True, dask="parallelized",output_dtypes=[bool],)

# per grid climatology & threshold 
def _clim_thresh_time_1d(ords_1d: np.ndarray, temp_1d: np.ndarray,y0: int, y1: int, pct: int,
                         min_dur: int, max_gap: int) -> tuple[np.ndarray, np.ndarray]:
    """
    For one grid time series: run mhw.detect to get time-aligned climatology (seas) and threshold.
    """
    # If mostly missing, return NaNs to avoid unstable fits
    if np.isfinite(temp_1d).sum() < 30:
        n = temp_1d.size
        return np.full(n, np.nan, float), np.full(n, np.nan, float)

    # oliver's detection
    _, clim = mhw.detect(ords_1d.astype(int), temp_1d.astype(float),climatologyPeriod=[int(y0), int(y1)],pctile=int(pct),minDuration=int(min_dur),
        joinAcrossGaps=True,maxGap=int(max_gap),)
    seas   = np.asarray(clim["seas"],   float)
    thresh = np.asarray(clim["thresh"], float)
    return seas, thresh
                             
# compute time-aligned climatology and threshold for every grid
def build_grid_baseline(ds: xr.Dataset, latn: str, lonn: str,pctile: int = 90) -> tuple[xr.DataArray, xr.DataArray]:
    
    t_index = pd.to_datetime(ds["time"].values)
    y0 = max(BASELINE[0], t_index.year.min())
    y1 = min(BASELINE[1], t_index.year.max())
    ords_da = xr.DataArray(np.array([d.toordinal() for d in t_index], dtype=int),coords={"time": ds["time"]},dims=["time"],).chunk({"time": -1})  # single time chunk for gufunc
    sst = ds[SST_VAR].chunk({"time": -1, latn: CHXY, lonn: CHXY})

    seas_t, thresh_t = xr.apply_ufunc(_clim_thresh_time_1d, ords_da, sst,input_core_dims=[["time"], ["time"]],output_core_dims=[["time"], ["time"]],
        vectorize=True, dask="parallelized",output_dtypes=[float, float],kwargs=dict(y0=int(y0), y1=int(y1), pct=int(pctile),min_dur=MIN_DUR, max_gap=MAX_GAP),)
    seas_t   = seas_t.rename("seas_t")
    thresh_t = thresh_t.rename("thresh_t")
    return seas_t, thresh_t



In [None]:
REG = "Arabian Sea"
ROI = ROI_DICT[REG]

# Use None or 0 to show every year.
YEAR_TICK_STEP  = 5
# (Optional) force a start/end for ticks; leave as None to auto-fit data range
YEAR_TICK_START = None
YEAR_TICK_END   = None
# (Optional) label rotation (degrees)
YEAR_TICK_ROT   = 0

def apply_year_ticks(ax, years, step=YEAR_TICK_STEP, start=YEAR_TICK_START, end=YEAR_TICK_END, rotate=YEAR_TICK_ROT):
    years = np.asarray(years, dtype=int)
    y_min, y_max = int(years.min()), int(years.max())
    a = y_min if start is None else int(start)
    b = y_max if end   is None else int(end)
    if (step is None) or (step == 0):
        ticks = np.unique(years)
    else:
        ticks = np.arange(a, b + 1, int(step), dtype=int)
    ax.set_xticks(ticks)
    ax.set_xticklabels([str(y) for y in ticks], rotation=rotate)

# Open SST (ROI)
ds, latn, lonn = open_sst(FILES_GLOB, ROI)

# Land mask & chunks
ocean = ds[SST_VAR].notnull().any("time")
sst   = ds[SST_VAR].where(ocean).chunk({"time": CHTIME, latn: CHXY, lonn: CHXY})

# Per-grid climatology & 90th-threshold (Oliver/Hobday) — time-aligned
seas_t, thresh_t = build_grid_baseline(ds, latn, lonn, pctile=90)

# Detect Hobday events per grid (ensure single time chunk along 'time')
# compute boolean mask for logic
evt_mask_bool = detect_mask_time(sst.chunk({"time": -1}),thresh_t.chunk({"time": -1})).rename("mhw_mask")  # bool
# NaN on land & keep float so NaNs persist to Zarr
evt_mask = evt_mask_bool.where(ocean).astype("float32")

# Daily metrics for future analysis
intensity = (sst - seas_t).rename("intensity") 
excess    = (sst - thresh_t).where(evt_mask == 1).rename("excess")  # only on event days

# Yearly per-grid summaries
# compute starts on the boolean mask
starts_bool = (evt_mask_bool & ~(evt_mask_bool.shift(time=1, fill_value=False)))

# NaN on land, keep float dtype 
starts = starts_bool.where(ocean).astype("float32")
events_per_year_grid = (starts.groupby("time.year").sum("time").rename("events_per_year").astype("float32"))  
days_per_year_grid = (evt_mask.groupby("time.year").sum("time").rename("days_per_year").astype("float32"))

# Area-weighted regional means (per year)
w_lat = area_weights_1d(ds, latn)
freq_region = (events_per_year_grid.where(ocean)).weighted(w_lat).mean(dim=[latn, lonn]).compute()
days_region = (days_per_year_grid.where(ocean)).weighted(w_lat).mean(dim=[latn, lonn]).compute()
total_events_region = float(freq_region.sum().values)
total_days_region   = float(days_region.sum().values)
print(f"{REG} totals — events: {total_events_region:.1f}, days: {total_days_region:.1f}")

# Quick plots (region-mean series)
years_freq = (freq_region.coords.get("year", None).values
              if "year" in freq_region.coords
              else freq_region.get_index("year").values).astype(int)
years_days = (days_region.coords.get("year", None).values
              if "year" in days_region.coords
              else days_region.get_index("year").values).astype(int)

fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(years_freq, freq_region.values, color="#5b8def", width=0.8)
ax.set_title(f" MHW events counts in {REG} {BASELINE[0]}–{BASELINE[1]}")
ax.set_xlabel("Year"); ax.set_ylabel("Events/year")
ax.grid(True, ls="--", alpha=0.4)
ax.text(0.5, 0.92, f"Total Events: {total_events_region:.1f}", transform=ax.transAxes, ha="center")
apply_year_ticks(ax, years_freq)  
plt.tight_layout(); plt.show()

fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(years_days, days_region.values, color="#f6a300", width=0.8)
ax.set_title(f"MHW days in {REG} {BASELINE[0]}–{BASELINE[1]}")
ax.set_xlabel("Year"); ax.set_ylabel("Days/year")
ax.grid(True, ls="--", alpha=0.4)
ax.text(0.5, 0.92, f"Total MHW days: {total_days_region:.1f}", transform=ax.transAxes, ha="center")
apply_year_ticks(ax, years_days)  
plt.tight_layout(); plt.show()

# 9) SAVE (Zarr + CSV)
out_dir = OUTROOT / ROI_DICT[REG]["slug"]; out_dir.mkdir(parents=True, exist_ok=True)

daily_ds = xr.Dataset({"mhw_mask": evt_mask, "intensity": intensity, "excess": excess, "seas_t": seas_t, "thresh_t": thresh_t},
    coords={"time": ds["time"], latn: ds[latn], lonn: ds[lonn]},).chunk({"time": CHTIME, latn: CHXY, lonn: CHXY})

# Keep only core coords to avoid Zarr alignment issues
keep = {"time", latn, lonn}
dropc = [c for c in daily_ds.coords if c not in keep]
if dropc:
    daily_ds = daily_ds.reset_coords(dropc, drop=True)

print("Writing Zarr outputs… (can take time)")
with ProgressBar():
    daily_ds.to_zarr(out_dir / "mhw_daily.zarr", mode="w", safe_chunks=False, align_chunks=True)
    events_per_year_grid.to_zarr(out_dir / "mhw_events_per_year_grid.zarr", mode="w", safe_chunks=False, align_chunks=True)
    days_per_year_grid.to_zarr(out_dir / "mhw_days_per_year_grid.zarr", mode="w", safe_chunks=False, align_chunks=True)

freq_region.to_dataframe().to_csv(out_dir / "mhw_region_mean_events_per_year.csv")
days_region.to_dataframe().to_csv(out_dir / "mhw_region_mean_days_per_year.csv")
pd.DataFrame({"mhw_total_events_region_mean":[total_events_region],"mhw_total_days_region_mean":[total_days_region]}).to_csv(out_dir / "mhw_region_totals.csv", index=False)

print("Arabian Sea — done.")


In [None]:
REG = "Bay Of Bengal"
ROI = ROI_DICT[REG]

# Use None or 0 to show every year.
YEAR_TICK_STEP  = 5
# (Optional) force a start/end for ticks
YEAR_TICK_START = None
YEAR_TICK_END   = None
# (Optional) label rotation (degrees)
YEAR_TICK_ROT   = 0

def apply_year_ticks(ax, years, step=YEAR_TICK_STEP, start=YEAR_TICK_START, end=YEAR_TICK_END, rotate=YEAR_TICK_ROT):
    years = np.asarray(years, dtype=int)
    y_min, y_max = int(years.min()), int(years.max())
    a = y_min if start is None else int(start)
    b = y_max if end   is None else int(end)
    if (step is None) or (step == 0):
        ticks = np.unique(years)
    else:
        ticks = np.arange(a, b + 1, int(step), dtype=int)
    ax.set_xticks(ticks)
    ax.set_xticklabels([str(y) for y in ticks], rotation=rotate)

# Open SST (ROI)
ds, latn, lonn = open_sst(FILES_GLOB, ROI)

# Land mask & chunks
ocean = ds[SST_VAR].notnull().any("time")
sst   = ds[SST_VAR].where(ocean).chunk({"time": CHTIME, latn: CHXY, lonn: CHXY})

# per-grid climatology & 90th-threshold 
seas_t, thresh_t = build_grid_baseline(ds, latn, lonn, pctile=90)

# detect events per grid 
# CHANGED: compute boolean mask for logic
evt_mask_bool = detect_mask_time(sst.chunk({"time": -1}),thresh_t.chunk({"time": -1})).rename("mhw_mask")  # bool

# NaN on land & keep float so NaNs persist to Zarr
evt_mask = evt_mask_bool.where(ocean).astype("float32")

# daily metrics for future analysis
intensity = (sst - seas_t).rename("intensity")                 # all days
excess    = (sst - thresh_t).where(evt_mask == 1).rename("excess")  # only on event days

# yearly per-grid summaries
# compute starts on the boolean mask
starts_bool = (evt_mask_bool &~(evt_mask_bool.shift(time=1, fill_value=False)))

# NaN on land, keep float dtype
starts = starts_bool.where(ocean).astype("float32")
events_per_year_grid = (starts.groupby("time.year").sum("time").rename("events_per_year").astype("float32"))
days_per_year_grid = (evt_mask  .groupby("time.year").sum("time").rename("days_per_year").astype("float32"))

# area-weighted regional means (per year)
w_lat = area_weights_1d(ds, latn)
freq_region = (events_per_year_grid.where(ocean)).weighted(w_lat).mean(dim=[latn, lonn]).compute()
days_region = (days_per_year_grid.where(ocean)).weighted(w_lat).mean(dim=[latn, lonn]).compute()

total_events_region = float(freq_region.sum().values)
total_days_region   = float(days_region.sum().values)
print(f"{REG} totals — events: {total_events_region:.1f}, days: {total_days_region:.1f}")

# plot the graphs (region-mean series)
years_freq = (freq_region.coords.get("year", None).values
              if "year" in freq_region.coords
              else freq_region.get_index("year").values).astype(int)
years_days = (days_region.coords.get("year", None).values
              if "year" in days_region.coords
              else days_region.get_index("year").values).astype(int)

fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(years_freq, freq_region.values, color="#5b8def", width=0.8)
ax.set_title(f" MHW events counts in {REG} {BASELINE[0]}–{BASELINE[1]}")
ax.set_xlabel("Year"); ax.set_ylabel("Events/year")
ax.grid(True, ls="--", alpha=0.4)
ax.text(0.5, 0.92, f"Total Events: {total_events_region:.1f}", transform=ax.transAxes, ha="center")
apply_year_ticks(ax, years_freq)   
plt.tight_layout(); plt.show()

fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(years_days, days_region.values, color="#f6a300", width=0.8)
ax.set_title(f"MHW days in {REG} {BASELINE[0]}–{BASELINE[1]}")
ax.set_xlabel("Year"); ax.set_ylabel("Days/year")
ax.grid(True, ls="--", alpha=0.4)
ax.text(0.5, 0.92, f"Total MHW days: {total_days_region:.1f}", transform=ax.transAxes, ha="center")
apply_year_ticks(ax, years_days)   
plt.tight_layout(); plt.show()

# saveing the zarr and csv files
out_dir = OUTROOT / ROI_DICT[REG]["slug"]; out_dir.mkdir(parents=True, exist_ok=True)

daily_ds = xr.Dataset({"mhw_mask": evt_mask, "intensity": intensity, "excess": excess,"seas_t": seas_t, "thresh_t": thresh_t},
    coords={"time": ds["time"], latn: ds[latn], lonn: ds[lonn]},).chunk({"time": CHTIME, latn: CHXY, lonn: CHXY})

# Keep only core coords to avoid Zarr alignment issues
keep = {"time", latn, lonn}
dropc = [c for c in daily_ds.coords if c not in keep]
if dropc:daily_ds = daily_ds.reset_coords(dropc, drop=True)

print("Writing Zarr outputs… (can take time)")
with ProgressBar():
    daily_ds.to_zarr(out_dir / "mhw_daily.zarr", mode="w",safe_chunks=False, align_chunks=True)
    events_per_year_grid.to_zarr(out_dir / "mhw_events_per_year_grid.zarr", mode="w",safe_chunks=False, align_chunks=True)
    days_per_year_grid.to_zarr(out_dir / "mhw_days_per_year_grid.zarr", mode="w",safe_chunks=False, align_chunks=True)

freq_region.to_dataframe().to_csv(out_dir / "mhw_region_mean_events_per_year.csv")
days_region.to_dataframe().to_csv(out_dir / "mhw_region_mean_days_per_year.csv")
pd.DataFrame({"mhw_total_events_region_mean":[total_events_region],"mhw_total_days_region_mean":[total_days_region]}).to_csv(out_dir / "mhw_region_totals.csv", index=False)
print("Bay of Bengal — done.")



In [None]:
REG = "North Indian Ocean"
ROI = ROI_DICT[REG]
 
# Use None or 0 to show every year.
YEAR_TICK_STEP  = 5
# (Optional) force a start/end for ticks; leave as None to auto-fit data range
YEAR_TICK_START = None
YEAR_TICK_END   = None
# (Optional) label rotation (degrees)
YEAR_TICK_ROT   = 0

def apply_year_ticks(ax, years, step=YEAR_TICK_STEP, start=YEAR_TICK_START, end=YEAR_TICK_END, rotate=YEAR_TICK_ROT):
    years = np.asarray(years, dtype=int)
    y_min, y_max = int(years.min()), int(years.max())
    a = y_min if start is None else int(start)
    b = y_max if end   is None else int(end)
    if (step is None) or (step == 0):
        ticks = np.unique(years)
    else:
        ticks = np.arange(a, b + 1, int(step), dtype=int)
    ax.set_xticks(ticks)
    ax.set_xticklabels([str(y) for y in ticks], rotation=rotate)

# Open SST (ROI)
ds, latn, lonn = open_sst(FILES_GLOB, ROI)

# Land mask & chunks
ocean = ds[SST_VAR].notnull().any("time")
sst   = ds[SST_VAR].where(ocean).chunk({"time": CHTIME, latn: CHXY, lonn: CHXY})

# Per-grid climatology & 90th-threshold (Oliver/Hobday)
seas_t, thresh_t = build_grid_baseline(ds, latn, lonn, pctile=90)

# Detect Hobday events per grid 
# compute boolean mask for logic
evt_mask_bool = detect_mask_time(sst.chunk({"time": -1}),thresh_t.chunk({"time": -1})).rename("mhw_mask")  # bool

# CHANGED: NaN on land & keep float so NaNs persist to Zarr
evt_mask = evt_mask_bool.where(ocean).astype("float32")

# Daily metrics for future analysis
intensity = (sst - seas_t).rename("intensity")                
excess    = (sst - thresh_t).where(evt_mask == 1).rename("excess") 
# Yearly per-grid summaries
starts_bool = (evt_mask_bool & ~(evt_mask_bool.shift(time=1, fill_value=False)))

# NaN on land, keep float dtype
starts = starts_bool.where(ocean).astype("float32")
events_per_year_grid = (starts.groupby("time.year").sum("time").rename("events_per_year").astype("float32")) 
days_per_year_grid = (evt_mask.groupby("time.year").sum("time").rename("days_per_year").astype("float32"))

# Area-weighted regional means (per year)
w_lat = area_weights_1d(ds, latn)
freq_region = (events_per_year_grid.where(ocean)).weighted(w_lat).mean(dim=[latn, lonn]).compute()
days_region = (days_per_year_grid.where(ocean)).weighted(w_lat).mean(dim=[latn, lonn]).compute()

total_events_region = float(freq_region.sum().values)
total_days_region   = float(days_region.sum().values)
print(f"{REG} totals — events: {total_events_region:.1f}, days: {total_days_region:.1f}")

# Quick plots
years_freq = (freq_region.coords.get("year", None).values
              if "year" in freq_region.coords
              else freq_region.get_index("year").values).astype(int)
years_days = (days_region.coords.get("year", None).values
              if "year" in days_region.coords
              else days_region.get_index("year").values).astype(int)

fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(years_freq, freq_region.values, color="#5b8def", width=0.8)
ax.set_title(f" MHW events counts in {REG} {BASELINE[0]}–{BASELINE[1]}")
ax.set_xlabel("Year"); ax.set_ylabel("Events/year")
ax.grid(True, ls="--", alpha=0.4)
ax.text(0.5, 0.92, f"Total Events: {total_events_region:.1f}", transform=ax.transAxes, ha="center")
apply_year_ticks(ax, years_freq)  
plt.tight_layout(); plt.show()

fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(years_days, days_region.values, color="#f6a300", width=0.8)
ax.set_title(f"MHW days in {REG} {BASELINE[0]}–{BASELINE[1]}")
ax.set_xlabel("Year"); ax.set_ylabel("Days/year")
ax.grid(True, ls="--", alpha=0.4)
ax.text(0.5, 0.92, f"Total MHW days: {total_days_region:.1f}", transform=ax.transAxes, ha="center")
apply_year_ticks(ax, years_days)  
plt.tight_layout(); plt.show()

# SAVE (Zarr + CSV)
out_dir = OUTROOT / ROI_DICT[REG]["slug"]; out_dir.mkdir(parents=True, exist_ok=True)

daily_ds = xr.Dataset({"mhw_mask": evt_mask, "intensity": intensity, "excess": excess,"seas_t": seas_t, "thresh_t": thresh_t},
    coords={"time": ds["time"], latn: ds[latn], lonn: ds[lonn]},).chunk({"time": CHTIME, latn: CHXY, lonn: CHXY})

# Keep only core coords to avoid Zarr alignment issues
keep = {"time", latn, lonn}
dropc = [c for c in daily_ds.coords if c not in keep]
if dropc:
    daily_ds = daily_ds.reset_coords(dropc, drop=True)

print("Writing Zarr outputs… (can take time)")
with ProgressBar():
    daily_ds.to_zarr(out_dir / "mhw_daily.zarr", mode="w",safe_chunks=False, align_chunks=True)
    events_per_year_grid.to_zarr(out_dir / "mhw_events_per_year_grid.zarr", mode="w",safe_chunks=False, align_chunks=True)
    days_per_year_grid.to_zarr(out_dir / "mhw_days_per_year_grid.zarr", mode="w",safe_chunks=False, align_chunks=True)

freq_region.to_dataframe().to_csv(out_dir / "mhw_region_mean_events_per_year.csv")
days_region.to_dataframe().to_csv(out_dir / "mhw_region_mean_days_per_year.csv")
pd.DataFrame({"mhw_total_events_region_mean":[total_events_region],"mhw_total_days_region_mean":[total_days_region]}).to_csv(out_dir / "mhw_region_totals.csv", index=False)

print("North Indian Ocean MHW — done.")


# The following script compute and plot the  combined bar graph of AS, BoB and NIO regions

In [None]:
# compute and plot the  combined bar graph of AS, BoB and NIO regions
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

# CSV Files path for event analysis
as_event_file  = "/home/Desktop/arabian_sea/mhw_region_mean_events_per_year.csv"
bob_event_file = "/home/Desktop/bay_of_bengal/mhw_region_mean_events_per_year.csv"
nio_event_file = "/home/Desktop/north_indian_ocean/mhw_region_mean_events_per_year.csv"   #provide the proper path here

# Load the csv file
as_event_df  = pd.read_csv(as_event_file)
bob_event_df = pd.read_csv(bob_event_file)
nio_event_df = pd.read_csv(nio_event_file)

# Rename columns
as_event_df  = as_event_df.rename(columns={"events_per_year": "AS"})
bob_event_df = bob_event_df.rename(columns={"events_per_year": "BoB"})
nio_event_df = nio_event_df.rename(columns={"events_per_year": "NIO"})

# merge into single dataframe
event_df = as_event_df.merge(bob_event_df, on="year").merge(nio_event_df, on="year")
event_years = event_df["year"].values

# Combine the regions bar into one graph
x = np.arange(len(event_years))
w = 0.25   # adjust the bar width

# Plot the figure
fig, ax = plt.subplots(figsize=(14,6))
ax.bar(x - w, event_df["AS"],  width=w, color="#0072B2",  label="Arabian Sea")   # blue 
ax.bar(x,     event_df["BoB"], width=w, color= "#D55E00", label="Bay of Bengal")  # vermillion
ax.bar(x + w, event_df["NIO"], width=w, color= "#009E73", label="North Indian Ocean")  # bluish green

ax.set_title("MHW Events Comparison of AS, BoB and NIO (1982–2024)")
ax.set_ylabel("Events / year")
ax.set_xlabel("Year")

# 5-year ticks
tick_idx = np.arange(0, len(event_years), 3)
ax.set_xticks(tick_idx)
ax.set_xticklabels(event_years[tick_idx])
ax.legend()
ax.grid(True, ls="--", alpha=0.4)
plt.tight_layout()
plt.show()




# The following Script compute and plot the  combined bar graph total MHW days of AS, BoB and NIO regions

In [None]:
# compute and plot the  combined bar graph total MHW days of AS, BoB and NIO regions
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

# CSV Files path for days analysis
as_day_file  = "/home/Desktop/Jupyter files/outputs_mhw/arabian_sea/mhw_region_mean_days_per_year.csv"
bob_day_file = "/home/Desktop/Jupyter files/outputs_mhw/bay_of_bengal/mhw_region_mean_days_per_year.csv"
nio_day_file = "/home/Desktop/Jupyter files/outputs_mhw/north_indian_ocean/mhw_region_mean_days_per_year.csv"   #provide the proper path here

# Load the days csv file
as_day_df  = pd.read_csv(as_day_file)
bob_day_df = pd.read_csv(bob_day_file)
nio_day_df = pd.read_csv(nio_day_file)

# Rename columns
as_day_df  = as_day_df.rename(columns={"days_per_year": "AS"})
bob_day_df = bob_day_df.rename(columns={"days_per_year": "BoB"})
nio_day_df = nio_day_df.rename(columns={"days_per_year": "NIO"})

# merge into single dataframe
day_df = as_day_df.merge(bob_day_df, on="year").merge(nio_day_df, on="year")
day_years = day_df["year"].values

# Combine the regions bar into one graph
x = np.arange(len(day_years))
w = 0.25   # adjust the bar width

# Plot the figures
fig, ax = plt.subplots(figsize=(14,6))
ax.bar(x - w, day_df["AS"],  width=w, color= "#440154",  label="Arabian Sea")    #Dark Purple
ax.bar(x,     day_df["BoB"], width=w, color= "#21918c", label="Bay of Bengal")  # Teal
ax.bar(x + w, day_df["NIO"], width=w, color= "#fde725", label="North Indian Ocean")  # Yellow

ax.set_title("MHW days Comparison of AS, BoB and NIO (1982–2024)")
ax.set_ylabel(" MHW Days")
ax.set_xlabel("Year")

# 5-year ticks
tick_idx = np.arange(0, len(event_years), 3)
ax.set_xticks(tick_idx)
ax.set_xticklabels(event_years[tick_idx])
ax.legend()
ax.grid(True, ls="--", alpha=0.4)
plt.tight_layout()
plt.show()




# The following scirpt calculte and plot the Mean, Max and Cumulative intensity per year in Arabian sea

In [None]:
# calculte and plot the Mean, Max and Cumulative intensity per year in Arabian sea

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

# Zarr file path
zarr_path = "/home/deepak/Desktop/CAS_deepak/Jupyter files/outputs_mhw/arabian_sea/mhw_daily.zarr"
BASELINE = (1982, 2024)
ds = xr.open_zarr(zarr_path)

# Variables expected from your pipeline
mask      = ds["mhw_mask"]      # 1 on event days
intensity = ds["intensity"]     # SST - climatology

latn = "lat" if "lat" in ds.coords else "latitude"
lonn = "lon" if "lon" in ds.coords else "longitude"

# Area weights
w = np.cos(np.deg2rad(ds[latn]))

# Keep only MHW days
int_mhw = intensity.where(mask == 1)

# Mean intensity per year
mean_intensity_year = (int_mhw.groupby("time.year").mean("time").weighted(w).mean(dim=[latn, lonn]).compute())
# Max intensity per year
max_intensity_year = (int_mhw.groupby("time.year").max("time").weighted(w).mean(dim=[latn, lonn]).compute())


# Cumulative intensity per year
cum_intensity_year = (int_mhw.groupby("time.year").sum("time").weighted(w).mean(dim=[latn, lonn]).compute())

# Extract years
years = mean_intensity_year["year"].values.astype(int)


# Plotting the bar graphs

# Mean intensity
fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(years, mean_intensity_year, color="#0072B2", width=0.8)
ax.set_title(f"Mean MHW Intensity in Arabian sea ({BASELINE[0]}–{BASELINE[1]})")
ax.set_xlabel("Year")
ax.set_ylabel("°C")
tick_idx = np.arange(0, len(years), 5)
ax.set_xticks(years[tick_idx])
ax.grid(True, ls="--", alpha=0.4)
plt.tight_layout()
plt.show()

# Max intensity
fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(years, max_intensity_year, color="#D55E00", width= 0.8)
ax.set_title(f"Maximum MHW Intensity in Arabian sea ({BASELINE[0]}–{BASELINE[1]})")
ax.set_xlabel("Year")
ax.set_ylabel("°C")
tick_idx = np.arange(0, len(years), 5)
ax.set_xticks(years[tick_idx])
ax.grid(True, ls="--", alpha=0.4)
plt.tight_layout()
plt.show()

# Cumulative intensity
fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(years, cum_intensity_year, color="#009E73", width= 0.8)
ax.set_title(f"Cumulative MHW Intensity in Arabian sea ({BASELINE[0]}–{BASELINE[1]})")
ax.set_xlabel("Year")
ax.set_ylabel("°C · days")
tick_idx = np.arange(0, len(years), 5)
ax.set_xticks(years[tick_idx])
ax.grid(True, ls="--", alpha=0.4)
plt.tight_layout()
plt.show()


