# Mitigate Climate Change

This notebook outlines the general workflow for the data within the [Mitigate Climate Change](https://oceancentral.org/track/mitigate-climate-change) page of the Ocean Central website.

## Utils functions and globals used for making all figures

In [None]:
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
import cartopy.crs as ccrs
from shapely.geometry import box
import rioxarray
import re

from rasterio.features import geometry_mask
from scipy.stats import linregress
from tqdm import tqdm

# Open the biodiversity priority areas based on Zhao et al. 2020 (https://www.sciencedirect.com/science/article/abs/pii/S0006320719312182?via%3Dihub)
masked_data = rioxarray.open_rasterio('../Data/masked_top_30_percent_over_water.tif')

# Set the CRS for masked_data if it's not already set
if 'crs' not in masked_data.attrs:
    masked_data.rio.write_crs('EPSG:4326', inplace=True)

# Load SST dataset and EEZ shapefile
seas_shapefile_path = '../Data/World_Seas_IHO_v3/World_Seas_IHO_v3.shp'
SEAS_DF = gpd.read_file(seas_shapefile_path)

# Calculate linear trend and p-value for each grid point
def calculate_trend_and_significance(x):
    if np.isnan(x).all():
        return np.nan, np.nan, np.nan
    else:
        slope, intercept, _, p_value, _ = stats.linregress(range(len(x)), x)
        return slope, intercept, p_value

# Calculate the trend and significance of the trend at each pixel in an xarray dataset
def calculate_trend_df(climate_df):
    df_mean = climate_df.groupby('time.year').mean()
    
    # Apply the trend and p-value calculation to the entire dataset
    results = xr.apply_ufunc(
        calculate_trend_and_significance,
        df_mean,
        input_core_dims=[['year']],
        vectorize=True,
        output_core_dims=[[], [], []],
        output_dtypes=[float, float, float]
    )
    
    # Extract the trend and p-value into separate DataArrays
    trends_da = xr.DataArray(results[0], coords=df_mean.isel(year=0).coords, name='trend')
    pvalues_da = xr.DataArray(results[2], coords=df_mean.isel(year=0).coords, name='p_value')
    
    # Create a significance mask where p-value < 0.05
    significant_da = xr.DataArray((pvalues_da < 0.05), coords=pvalues_da.coords, name='significant')
    
    # Combine trend, p-value, and significance mask into a single dataset
    trend_significance_ds = xr.Dataset({
        'trend': trends_da,
        'p_value': pvalues_da,
        'significant': significant_da
    })
    
    # Set the CRS for the trends dataset to match the EEZ CRS
    trend_significance_ds = trend_significance_ds.rio.write_crs("epsg:4326")
    return trend_significance_ds

# Calculate area-weighted trend, significance for each sea/ocean area
def area_trend(trend_significance_ds, SEAS_DF=SEAS_DF):
    # Iterate over each sea/ocean area and calculate the area-weighted trend and significant area percentage
    area_weighted_trends = []
    
    # Check if 'lat' and 'lon' are in the dataset, otherwise check for 'latitude' and 'longitude'
    if 'lat' in trend_significance_ds.dims and 'lon' in trend_significance_ds.dims:
        trend_significance_ds = trend_significance_ds.rename({'lat': 'y', 'lon': 'x'})
    elif 'latitude' in trend_significance_ds.dims and 'longitude' in trend_significance_ds.dims:
        trend_significance_ds = trend_significance_ds.rename({'latitude': 'y', 'longitude': 'x'})

    # Interpolate biodiversity priority areas to the same resolution as the climate data
    masked_data_interp = masked_data.interp(
        x=trend_significance_ds['x'],
        y=trend_significance_ds['y'],
        method='nearest'
    )

    # Calculate the area for each grid cell (assumes lat/lon grid)
    lat = trend_significance_ds['y'].values
    lon = trend_significance_ds['x'].values
    
    # Calculate grid cell area using Haversine formula or by approximation
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    
    # Earth radius in kilometers
    R = 6371
    dlat = np.gradient(lat_rad)
    dlon = np.gradient(lon_rad)
    
    # Approximate area calculation
    cell_areas = (R**2 * np.outer(np.sin(dlat), dlon)) * np.cos(lat_rad[:, None])
    
    for i, row in SEAS_DF.iterrows():
        try:
            region_name = row['NAME']
            area = row['area']
            geom = row['geometry']
    
            # Mask SST trends with the sea geometry
            masked_trends = trend_significance_ds['trend'].rio.clip([geom], drop=True)
            masked_significance = trend_significance_ds['significant'].rio.clip([geom], drop=True)
    
            # Clip cell_areas to the same extent as masked_trends
            cell_areas_clipped = xr.DataArray(
                cell_areas, 
                dims=['y', 'x'], 
                coords={'y': trend_significance_ds['y'], 'x': trend_significance_ds['x']}
            )
            
            # Set CRS for cell_areas_clipped to match the CRS of trend_significance_ds
            cell_areas_clipped = cell_areas_clipped.rio.write_crs('EPSG:4326')
    
            # Clip cell_areas to the same geometry
            cell_areas_clipped = cell_areas_clipped.rio.clip([geom], drop=True)
        
            # Compute the area-weighted trend
            weighted_trend = (masked_trends * cell_areas_clipped).sum(dim=('y', 'x')) / cell_areas_clipped.sum()
    
            # Compute the total area that is significant
            significant_masked_areas = (masked_significance * cell_areas_clipped).where(masked_significance, 0)
            total_significant_area = significant_masked_areas.sum(dim=('y', 'x')).item()
    
            # Calculate the percentage of the area that is significant
            total_area = cell_areas_clipped.sum()
            significant_area_percent = (total_significant_area / total_area) * 100
    
            # Calculate the area for biodiversity based on the mask
            area_biodiversity = ((masked_significance * cell_areas_clipped) * masked_data_interp).sum(dim=['x', 'y']).values
    
            # Store the result
            area_weighted_trends.append({
                'Region_Name': region_name,
                'geometry': geom,
                'Weighted_Trend': weighted_trend.item(),
                'Sea_Area': area,
                'Significant_Area': area*total_significant_area/total_area.item(),
                'Significant_Area_Percent': 100*total_significant_area/total_area.item(),
                'Biodiversity_Area': area*area_biodiversity[0]/total_area.item(),
                'Biodiversity_Area_Percent': 100*area_biodiversity[0]/total_area.item(),
            })
        except Exception as e:
            print(e)

    # Convert the results to a GeoDataFrame for easy viewing
    area_weighted_trends_gdf = gpd.GeoDataFrame(area_weighted_trends, crs=SEAS_DF.crs)
    return area_weighted_trends_gdf

def area_heatwave(temp_df, SEAS_DF=SEAS_DF):
    area_heatwave = []

    # Set CRS and rename dimensions and coordinates
    temp_df = temp_df.rio.write_crs("epsg:4326")
    temp_df = temp_df.rename({'latdim': 'y', 'londim': 'x'}).rename({'lat': 'y', 'lon': 'x'})  # Adjust based on your dimensions

    # Select heatwave categories >= 3 and aggregate over time
    temp_df = (temp_df['heatwave_category'] >= 3).any(dim='time')

    # Interpolate biodiversity priority areas to the same resolution as the climate data
    masked_data_interp = masked_data.interp(
        x=temp_df['x'],
        y=temp_df['y'],
        method='nearest'
    )

    # Calculate the area for each grid cell (assumes lat/lon grid)
    lat = temp_df['y'].values
    lon = temp_df['x'].values
    
    # Calculate grid cell area using Haversine formula or by approximation
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    
    # Earth radius in kilometers
    R = 6371
    dlat = np.gradient(lat_rad)
    dlon = np.gradient(lon_rad)
    
    # Approximate area calculation
    cell_areas = (R**2 * np.outer(np.sin(dlat), dlon)) * np.cos(lat_rad[:, None])
    
    # Use tqdm to track progress through SEAS_DF.iterrows()
    for i, row in tqdm(SEAS_DF.iterrows(), total=len(SEAS_DF), desc="Processing Sea Areas"):
        try:
            region_name = row['NAME']
            area = row['area']
            geom = row['geometry']
    
            # Mask SST trends with the sea geometry
            masked_df = temp_df.rio.clip([geom], drop=True)
    
            # Clip cell_areas to the same extent as masked_df
            cell_areas_clipped = xr.DataArray(
                cell_areas, 
                dims=['y', 'x'], 
                coords={'y': temp_df['y'], 'x': temp_df['x']}
            )
            
            # Set CRS for cell_areas_clipped to match the CRS of trend_significance_ds
            cell_areas_clipped = cell_areas_clipped.rio.write_crs('EPSG:4326')
    
            # Clip cell_areas to the same geometry
            cell_areas_clipped = cell_areas_clipped.rio.clip([geom], drop=True)
        
            # Compute the total area that is impacted by a severe heatwave
            heatwave_area = (masked_df * cell_areas_clipped).sum(dim=('y', 'x')).compute()  # Compute to convert from Dask array
    
            # Calculate the area for biodiversity based on the mask
            area_biodiversity = ((masked_df * cell_areas_clipped) * masked_data_interp).sum(dim=['x', 'y']).compute()

            total_area = cell_areas_clipped.sum(dim=('y', 'x')).compute()  # Ensure computation
    
            # Extract values after computing
            heatwave_value = heatwave_area.item() if heatwave_area.size == 1 else heatwave_area.values[0]
            total_area_value = total_area.item() if total_area.size == 1 else total_area.values[0]
            area_biodiversity = area_biodiversity.item() if area_biodiversity.size == 1 else area_biodiversity.values[0]
    
            # Store the result
            area_heatwave.append({
                'Region_Name': region_name,
                'geometry': geom,
                'Heatwave_Area': area*heatwave_value/total_area_value,
                'Heatwave_Area_Percent': 100*heatwave_value/total_area_value,
                'Sea_Area': area,
                'Biodiversity_Area': area*area_biodiversity/total_area_value,
                'Biodiversity_Area_Percent': 100*area_biodiversity/total_area_value,
            })
        except Exception as e:
            print(f"Error processing {region_name}: {e}")

    # Convert the results to a GeoDataFrame for easy viewing
    area_heatwave_gdf = gpd.GeoDataFrame(area_heatwave, crs=SEAS_DF.crs)
    return area_heatwave_gdf


def area_coral_acidification_by_sea(
    trend_significance_ds,
    coral_gdf,
    SEAS_DF,
    all_touched=False,
):
    """
    Coral acidification exposure by marine area (IHO seas),
    using raster-based fraction × true coral area.

    Returns GeoDataFrame with one row per sea.
    """

    results = []

    # ---------------------------
    # 1. Prepare raster
    # ---------------------------
    ds = trend_significance_ds.copy().rio.write_crs("EPSG:4326")

    if "lat" in ds.dims and "lon" in ds.dims:
        ds = ds.rename({"lat": "y", "lon": "x"})

    sig = ds["significant"].astype(bool)

    lat = ds["y"].values
    lon = ds["x"].values

    transform = Affine(
        lon[1] - lon[0], 0, lon.min(),
        0, lat[0] - lat[1], lat.max()
    )

    # ---------------------------
    # 2. Prepare coral polygons
    # ---------------------------
    corals = coral_gdf.to_crs("EPSG:4326").copy()
    corals["geometry"] = corals.geometry.buffer(0)

    # Equal-area version for TRUE coral area
    corals_eq = corals.to_crs("ESRI:54009")

    # ---------------------------
    # 3. Loop over seas
    # ---------------------------
    for _, row in tqdm(SEAS_DF.iterrows(), total=len(SEAS_DF), desc="Processing Seas"):

        sea_name = row["NAME"]
        sea_geom = row["geometry"]

        try:
            # -----------------------------------
            # A. Coral geometry within this sea
            # -----------------------------------
            coral_sea = gpd.overlay(
                corals,
                gpd.GeoDataFrame(geometry=[sea_geom], crs="EPSG:4326"),
                how="intersection",
            )

            if coral_sea.empty:
                continue

            # True coral area in this sea (km²)
            coral_sea_eq = coral_sea.to_crs("ESRI:54009")
            true_coral_area_km2 = coral_sea_eq.geometry.area.sum() / 1e6

            # -----------------------------------
            # B. Rasterize coral mask (sea-limited)
            # -----------------------------------
            coral_mask = rasterize(
                [(g, 1) for g in coral_sea.geometry if not g.is_empty],
                out_shape=(len(lat), len(lon)),
                transform=transform,
                fill=0,
                all_touched=all_touched,
                dtype="uint8",
            ).astype(bool)

            if not coral_mask.any():
                continue

            # -----------------------------------
            # C. Clip significance raster to sea
            # -----------------------------------
            sig_sea = sig.rio.clip([sea_geom], drop=False).values.astype(bool)

            # -----------------------------------
            # D. Fraction of coral pixels affected
            # -----------------------------------
            coral_pixels = coral_mask
            affected_pixels = coral_mask & sig_sea

            affected_fraction = (
                affected_pixels.sum() / coral_pixels.sum()
                if coral_pixels.sum() > 0 else np.nan
            )

            affected_coral_area_km2 = affected_fraction * true_coral_area_km2

            # -----------------------------------
            # E. Store
            # -----------------------------------
            results.append({
                "Region_Name": sea_name,
                "True_Coral_Area_km2": true_coral_area_km2,
                "Affected_Coral_Area_km2": affected_coral_area_km2,
                "Affected_Coral_Area_Percent": affected_fraction * 100,
                "geometry": sea_geom,
            })

        except Exception as e:
            print(f"Error processing {sea_name}: {e}")

    return gpd.GeoDataFrame(results, crs=SEAS_DF.crs)

# Temperature

## Figure 1

<p align="center">
  <img src="Figs/climate_temperature_1.png" style="width:50%;">
</p>

In [None]:
ocean_data = pd.read_csv("../Data/GISTEMP_SST.csv") # Data downloaded from GISS Surface Temperature Analysis (v4)
gmst_data = pd.read_csv("../Data/GMST_GISTEMP4.csv") # Data downloaded from GISS Surface Temperature Analysis (v4)

ocean_data['Ocean_Annual'] = ocean_data['Lowess(5)']
gmst_data['GMST_Annual'] = gmst_data['Lowess(5)']

temp_data = ocean_data.merge(gmst_data,on='Year')

# Calculate the mean of the 'No_Smoothing' column for the period 1880-1900
base_period = temp_data[(temp_data['Year'] >= 1880) & (temp_data['Year'] <= 1900)]
mean_sst_base_period = base_period['Ocean_Annual'].mean()
mean_gmst_base_period = base_period['GMST_Annual'].mean()

# Update the 'No_Smoothing' column to be anomalies relative to the period 1880-1900
temp_data['Ocean_Annual'] = temp_data['Ocean_Annual'] - mean_sst_base_period
temp_data['GMST_Annual'] = temp_data['GMST_Annual'] - mean_gmst_base_period

# Perform linear regression to find the slope and intercept
slope, intercept, _, _, _ = linregress(temp_data['Year'], temp_data['Ocean_Annual'])

# Calculate the trend line (y = mx + b) for each time point
temp_data['ocean_trend'] = intercept + slope * temp_data['Year']

# Perform linear regression to find the slope and intercept
slope, intercept, _, _, _ = linregress(temp_data['Year'], temp_data['GMST_Annual'])

# Calculate the trend line (y = mx + b) for each time point
temp_data['gmst_trend'] = intercept + slope * temp_data['Year']
temp_data['paris_goal'] = 1.5

# Save out as JSON
temp_data[['Year','Ocean_Annual','ocean_trend','GMST_Annual','gmst_trend','paris_goal']].to_json(
    "../Data/Figure_1_temperature.json", 
    orient="records",  # list of dicts (one per row)
    indent=2           # pretty print
)


# Display the updated DataFrame
temp_data[['Year','Ocean_Annual','ocean_trend','GMST_Annual','gmst_trend','paris_goal']]


## Figure 2

<p align="center">
  <img src="Figs/climate_temperature_2.png" style="width:50%;">
</p>

In [None]:
import numpy as np
import xarray as xr
import rioxarray

trends_df = xr.open_dataset("../Data/global_omi_tempsal_sst_trend_19932021_P20220331.nc")

# --------------------
# CONFIG
# --------------------
OUTPUT_TIF_8BIT   = "../Data/Figure_2_temperature.tif"
OUTPUT_TIF_FLOAT  = "../Data/sst_trend_float32.tif"
USE_PERCENTILES   = True
P_LOW, P_HIGH     = 2, 98
VMIN_FIXED, VMAX_FIXED = -0.5, 0.5  # if USE_PERCENTILES=False

# 1) Select the dataarray
da = trends_df["sst_trends"]
if "time" in da.dims:
    da = da.mean(dim="time")

# 2) Rename to x/y if needed
rename_map = {}
if "lat" in da.dims or "lat" in da.coords: rename_map["lat"] = "y"
if "latitude" in da.dims or "latitude" in da.coords: rename_map["latitude"] = "y"
if "lon" in da.dims or "lon" in da.coords: rename_map["lon"] = "x"
if "longitude" in da.dims or "longitude" in da.coords: rename_map["longitude"] = "x"
if rename_map:
    da = da.rename(rename_map)

# 3) Ensure y is north→south (descending)
if da["y"].values[0] < da["y"].values[-1]:
    da = da.sortby("y", ascending=False)

# 4) Register spatial dims & CRS
da = da.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)
if da.rio.crs is None:
    da = da.rio.write_crs("EPSG:4326", inplace=False)

# 5) Prepare data and scaling
data = da.data.astype(np.float32)
valid = np.isfinite(data)

if USE_PERCENTILES:
    vmin = float(np.nanpercentile(data, P_LOW))
    vmax = float(np.nanpercentile(data, P_HIGH))
else:
    vmin, vmax = float(VMIN_FIXED), float(VMAX_FIXED)

if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
    raise ValueError(f"Bad scaling range: vmin={vmin}, vmax={vmax}")

# 6) Build the 8-bit layer (reserve 0 for NoData)
scaled = np.zeros_like(data, dtype=np.uint8)  # 0 = NoData
scaled_valid = (np.clip((data[valid] - vmin) / (vmax - vmin), 0.0, 1.0) * 254 + 1).astype(np.uint8)
scaled[valid] = scaled_valid

da8 = da.copy(data=scaled)

# --- CRITICAL: clear CF encodings that carry a massive _FillValue ---
da8.encoding.pop("_FillValue", None)
da8.encoding.pop("missing_value", None)
# (optional) also clear scale/offset if present
da8.encoding.pop("scale_factor", None)
da8.encoding.pop("add_offset", None)

# Set nodata appropriate for uint8 (0)
da8 = da8.rio.write_nodata(0, encoded=False, inplace=False)

# 7a) Write 8-bit GeoTIFF
da8.rio.to_raster(
    OUTPUT_TIF_8BIT,
    dtype="uint8",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

# 7b) Optional: write float32 with native values
daf = da.where(np.isfinite(da))
# Clear encodings here too
daf.encoding.pop("_FillValue", None)
daf.encoding.pop("missing_value", None)
daf.encoding.pop("scale_factor", None)
daf.encoding.pop("add_offset", None)

# Use NaN as nodata for float32 (supported by rasterio/GTiff)
daf = daf.rio.write_nodata(np.nan, encoded=False, inplace=False)

daf.rio.to_raster(
    OUTPUT_TIF_FLOAT,
    dtype="float32",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

print(
    f"Wrote {OUTPUT_TIF_8BIT} (uint8; 0=NoData, 1–255=data) and {OUTPUT_TIF_FLOAT} (float32). "
    f"Scale for 8-bit: vmin={vmin:.4f}, vmax={vmax:.4f} (units of sst_trends)."
)


## Figure 3

<p align="center">
  <img src="Figs/climate_temperature_3.png" style="width:50%;">
</p>

**This figure calculates the total area for each major sea region globally. It also calculates 1) the area for each sea region impacted by a severe marine heatwave in the year 2023 as well as 2) the area for each sea that both has priority biodiversity areas AND is impacted by a severe marine heatwave.**

In [None]:
# Year 2024 Heatwave data downloaded from NOAA's Coral Reef Watch https://coralreefwatch.noaa.gov/product/marine_heatwave/
import xarray as xr
import glob

files = sorted(glob.glob("../Data/2023/*.nc"))

temp_df = xr.open_mfdataset("../Data/2023/*.nc")

area_df = area_heatwave(temp_df)

# Save the GeoDataFrame to a GeoJSON file
area_df.to_file("../Data/Figure_3_temperature.geojson",driver="GeoJSON")

## Figure 4

<p align="center">
  <img src="Figs/climate_temperature_4.png" style="width:50%;">
</p>

$CO_2$ data retrieved from [Lan, X., P. Tans, & K.W. Thoning (2025). Trends in globally-averaged CO₂ determined from NOAA Global Monitoring Laboratory measurements (Version 2025-11) NOAA Global Monitoring Laboratory. https://doi.org/10.15138/9N0H-ZH07](https://gml.noaa.gov/webdata/ccgg/trends/co2/co2_annmean_mlo.txt). Temperature data retrieved from GISS Surface Temperature Analysis (v4).

## Figure 5

<p align="center">
  <img src="Figs/climate_temperature_5.png" style="width:50%;">
</p>

Data were retrieved from [Hughes, T. P., et al. (2018). Spatial and temporal patterns of mass bleaching of corals in the Anthropocene. Science. – processed by Our World in Data](https://ourworldindata.org/grapher/coral-bleaching-events).

## Figure 6

<p align="center">
  <img src="Figs/climate_temperature_6.png" style="width:50%;">
</p>

In [None]:
#!/usr/bin/env python3
import xarray as xr
import rioxarray  # pip install rioxarray rasterio

# --- Load and compute heatwave days ---
ds = xr.open_mfdataset("../Data/2023/*.nc", combine="by_coords")

# If you already have temp_df, you can instead do:
# hw_count = temp_df

hw_count = (ds["heatwave_category"] >= 1).sum(dim="time")
hw_count = hw_count.astype("float32")

# --- Rechunk along spatial dims for quantile ---
# Identify spatial dimensions (everything except 'time')
spatial_dims = [d for d in hw_count.dims if d != "time"]

# Make each spatial dimension a single chunk
hw_q = hw_count.chunk({d: -1 for d in spatial_dims})

# --- Compute 2nd and 98th percentiles over space ---
q = hw_q.quantile([0.02, 0.98], dim=spatial_dims, skipna=True).compute()

p2 = float(q.sel(quantile=0.02).values)
p98 = float(q.sel(quantile=0.98).values)

print(f"2nd percentile: {p2}, 98th percentile: {p98}")

if p98 <= p2:
    raise ValueError("98th percentile is not greater than 2nd percentile; cannot scale safely.")

# --- Scale to 0–255 using the 2nd and 98th percentiles ---
scaled = (hw_count - p2) / (p98 - p2) * 255.0
scaled = scaled.clip(0, 255)
scaled = scaled.fillna(0)

# Round and cast to uint8
scaled_uint8 = scaled.round().astype("uint8")

# --- Set spatial dims and CRS for Mapbox ---
# Try to guess your x/y dimension names
cands_x = ["lon", "longitude", "x", "londim"]
cands_y = ["lat", "latitude", "y", "latdim"]

x_dim = next(d for d in cands_x if d in scaled_uint8.dims)
y_dim = next(d for d in cands_y if d in scaled_uint8.dims)

scaled_uint8 = scaled_uint8.rio.set_spatial_dims(x_dim=x_dim, y_dim=y_dim, inplace=False)

# Ensure WGS84
if not scaled_uint8.rio.crs:
    scaled_uint8 = scaled_uint8.rio.write_crs("EPSG:4326")

# Optional: use 0 as nodata for transparency in Mapbox
scaled_uint8 = scaled_uint8.rio.write_nodata(0)

# --- Save as GeoTIFF ---
out_path = "Figure_6_temperature.tif"
scaled_uint8.rio.to_raster(out_path, dtype="uint8")

print(f"Saved GeoTIFF: {out_path}")


## Figure 7

<p align="center">
  <img src="Figs/climate_temperature_7.png" style="width:50%;">
</p>

Data were retrieved from [Jones et al. (2024) – with major processing by Our World in Data](https://ourworldindata.org/co2-and-greenhouse-gas-emissions#explore-data-on-co2-and-greenhouse-gas-emissions).

## Figure 8

<p align="center">
  <img src="Figs/climate_temperature_8.png" style="width:50%;">
</p>

Data were retrieved from [van Woesik and Kratochwill.](https://springernature.figshare.com/articles/dataset/Global_Coral_Bleaching_Database/17076290?backTo=%2Fcollections%2FA_Global_Coral-Bleaching_Database_GCBD_1998_2020%2F5314466&file=31573496)

In [None]:
import sqlite3
import pandas as pd
import geopandas as gpd

# --- paths ---
db_path = "../Data/Global_Coral_Bleaching_Database_SQLite_11_24_21.db"
out_geojson = "../Data/Figure_8_temperature.geojson"

# --- connect to the SQLite database ---
conn = sqlite3.connect(db_path)

query = """
SELECT DISTINCT
    s.Site_ID,
    s.Site_Name,
    s.Latitude_Degrees AS lat,
    s.Longitude_Degrees AS lon
FROM Site_Info_tbl s
JOIN Sample_Event_tbl se ON s.Site_ID = se.Site_ID
JOIN Bleaching_tbl b ON se.Sample_ID = b.Sample_ID
WHERE b.Percent_Bleached IS NOT NULL
  AND b.Percent_Bleached > 0
  AND s.Latitude_Degrees IS NOT NULL
  AND s.Longitude_Degrees IS NOT NULL
"""

bleach_sites = pd.read_sql_query(query, conn)
conn.close()

print(f"Found {len(bleach_sites)} sites with bleaching records.")

# --- fix bytes columns (decode to str) ---
for col in bleach_sites.columns:
    if bleach_sites[col].dtype == object:
        bleach_sites[col] = bleach_sites[col].apply(
            lambda v: v.decode("utf-8", errors="ignore") if isinstance(v, (bytes, bytearray)) else v
        )

# (Optional) keep just a few clean columns
# bleach_sites = bleach_sites[["Site_ID", "Site_Name", "lat", "lon"]].copy()

# --- convert to GeoDataFrame ---
gdf = gpd.GeoDataFrame(
    bleach_sites,
    geometry=gpd.points_from_xy(bleach_sites["lon"], bleach_sites["lat"]),
    crs="EPSG:4326"
)

# --- save as GeoJSON ---
gdf.to_file(out_geojson, driver="GeoJSON")
print(f"Saved to {out_geojson}")


# Salinity

## Figure 1

<p align="center">
  <img src="Figs/climate_salinity_1.png" style="width:50%;">
</p>

In [None]:
import xarray as xr
import pandas as pd

salt_df = xr.open_dataset("../Data/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")
salt_df = salt_df['salinity'].mean(dim=['lat','lon']).resample(time='Y').mean()

final_subset = salt_df.sel(time=slice('1994-01-01', None))

# Create a pandas DataFrame with these columns
df = pd.DataFrame({
    'time': final_subset['time'].values,
    'salinity': final_subset.values,
})

# Convert 'time' to datetime
df['time'] = pd.to_datetime(df['time'])

# Convert datetime to a numerical value for linear regression (using ordinal format)
df['time_ordinal'] = df['time'].map(pd.Timestamp.toordinal)

# Perform linear regression to find the slope and intercept
slope, intercept, _, _, _ = linregress(df['time_ordinal'], df['salinity'])

# Calculate the trend line (y = mx + b) for each time point
df['linear_trend'] = intercept + slope * df['time_ordinal']

df[['time','salinity','linear_trend']].to_csv("../Data/Figure_1_salinity.csv")

# Display the updated DataFrame
df[['time','salinity','linear_trend']].head()

## Figure 2

<p align="center">
  <img src="Figs/climate_salinity_2.png" style="width:50%;">
</p>

In [None]:
import numpy as np
import xarray as xr
import rioxarray

salt_df = xr.open_dataset("../Data/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")

trend_significance_ds = calculate_trend_df(salt_df['salinity'])

# --------------------
# CONFIG
# --------------------
OUTPUT_TIF_8BIT   = "./Figure_2_salinity.tif"
OUTPUT_TIF_FLOAT  = "./salinity_trend_float32.tif"
USE_PERCENTILES   = True
P_LOW, P_HIGH     = 2, 98
VMIN_FIXED, VMAX_FIXED = -0.5, 0.5  # if USE_PERCENTILES=False

# 1) Select the dataarray: salinity trend
da = trend_significance_ds['trend']

# 2) Rename to x/y if needed
rename_map = {}
if "lat" in da.dims or "lat" in da.coords: rename_map["lat"] = "y"
if "latitude" in da.dims or "latitude" in da.coords: rename_map["latitude"] = "y"
if "lon" in da.dims or "lon" in da.coords: rename_map["lon"] = "x"
if "longitude" in da.dims or "longitude" in da.coords: rename_map["longitude"] = "x"
if rename_map:
    da = da.rename(rename_map)

# 3) Ensure y is north→south (descending)
if da["y"].values[0] < da["y"].values[-1]:
    da = da.sortby("y", ascending=False)
# ensure (y, x) order
if tuple(da.dims) != ("y", "x"):
    da = da.transpose("y", "x")

# 4) Register spatial dims & force EPSG:4326
da = da.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)

# If CRS is missing, assume geographic lon/lat; then force-reproject to EPSG:4326
if da.rio.crs is None:
    da = da.rio.write_crs("EPSG:4326", inplace=False)

# If CRS is not 4326, reproject it so the saved rasters are truly EPSG:4326
if str(da.rio.crs) != "EPSG:4326":
    # Use nearest for categorical-like; bilinear is typical for continuous SLA
    da = da.rio.reproject("EPSG:4326", resampling=rioxarray.rio.reproject.Resampling.bilinear)

# Reassert dims order after any reprojection (just in case)
if tuple(da.dims) != ("y", "x"):
    da = da.transpose("y", "x")


# 5) Compute scaling range (Dask-safe)
import dask.array as dsa

def compute_percentiles_safe(da, p_low, p_high):
    """
    Try dask.array.nanpercentile; if that fails (older dask/xarray),
    fall back to a coarse sample to keep memory in check.
    """
    if getattr(da, "chunks", None):
        try:
            vmin = float(dsa.nanpercentile(da.data, p_low).compute())
            vmax = float(dsa.nanpercentile(da.data, p_high).compute())
            return vmin, vmax
        except Exception:
            pass  # fall back to sampled approach

    # Fallback: sample every Nth pixel (keeps memory tiny)
    step_y = max(int(len(da.y) // 512), 1) if "y" in da.dims else 4
    step_x = max(int(len(da.x) // 512), 1) if "x" in da.dims else 4
    das = da.isel(
        y=slice(0, None, step_y) if "y" in da.dims else slice(None),
        x=slice(0, None, step_x) if "x" in da.dims else slice(None),
    ).load()  # small enough to load
    vmin = float(np.nanpercentile(das.values, p_low))
    vmax = float(np.nanpercentile(das.values, p_high))
    return vmin, vmax

if USE_PERCENTILES:
    vmin, vmax = compute_percentiles_safe(da, P_LOW, P_HIGH)
else:
    vmin, vmax = float(VMIN_FIXED), float(VMAX_FIXED)

if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
    raise ValueError(f"Bad scaling range: vmin={vmin}, vmax={vmax}")


# 6) Build the 8-bit layer with xr.where (no boolean indexing)
# Reserve 0 for NoData; valid cells map to [1,255]
norm = ((da - vmin) / (vmax - vmin)).clip(0, 1)
scaled_da = xr.where(np.isfinite(da), norm * 254.0 + 1.0, 0.0).astype("uint8")

# Clear troublesome encodings on the uint8 view
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    scaled_da.encoding.pop(k, None)

# Set nodata appropriate for uint8 (0)
scaled_da = scaled_da.rio.write_nodata(0, encoded=False, inplace=False)

# 7a) Write 8-bit GeoTIFF (works with Dask; will stream-chunk if array is chunked)
scaled_da.rio.to_raster(
    OUTPUT_TIF_8BIT,
    dtype="uint8",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

# 7b) Write float32 GeoTIFF with native values
daf = da.where(np.isfinite(da)).astype("float32")
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    daf.encoding.pop(k, None)
# Use NaN as nodata for float32
daf = daf.rio.write_nodata(np.nan, encoded=False, inplace=False)

daf.rio.to_raster(
    OUTPUT_TIF_FLOAT,
    dtype="float32",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

print(
    f"Wrote {OUTPUT_TIF_8BIT} (uint8; 0=NoData, 1–255=data) and {OUTPUT_TIF_FLOAT} (float32). "
    f"8-bit scale: vmin={vmin:.4f}, vmax={vmax:.4f} (native SLA units)."
)


## Figure 3

<p align="center">
  <img src="Figs/climate_salinity_3.png" style="width:50%;">
</p>

In [None]:
salt_df = xr.open_dataset("../Data/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")

trend_significance_ds = calculate_trend_df(salt_df['salinity'])

area_df = area_trend(trend_significance_ds)

# Save the GeoDataFrame to a GeoJSON file
area_df.to_file("../Data/Figure_3_salinity.geojson",driver="GeoJSON")

del salt_df, area_df

## Figure 4

<p align="center">
  <img src="Figs/climate_salinity_4.png" style="width:50%;">
</p>

**This figure reproduces Figure 3A from [Delworth et al. (2022)](https://www.pnas.org/doi/10.1073/pnas.2116655119). It plots an index of the Atlantic Meridional Overturning Circulation (AMOC) and how it changes over time. There are 6 scenarios --- 2 of which have historical data ("HISTORICAL" (representing a historical evolution of anthropogenic emissions) and "NATURAL" (no changes to anthropogenic emissions)), and 5 have projections ("NATURAL" again, as well as different Shared Socioeconomic Pathway (SSP) scenarios for how emissions will change in the future). Each scenario is run several times, which provides the measures of uncertainty presented in the figure.**

In [None]:
# Let's load the file with more flexibility in handling its format to try and correctly parse it.
with open('../Data/TAR_FIGURE_3_AMOC_45N', 'r') as file:
    lines = file.readlines()

# Adjust the column naming logic to include the full scenario name as requested
scenario_data = {}
current_scenario = None
time_range = range(1921, 2100)

for line in lines:
    # Check if line indicates a new scenario/ensemble using the SPEAR_c192_o1 pattern
    scenario_match = re.search(r'SPEAR_c192_o1_(.+)_ENS_(\d+)', line)
    if scenario_match:
        # Preserve the full scenario name (like HIST_SSP585_ALLForc) and ensemble number
        scenario_name = scenario_match.group(1) + "_" + scenario_match.group(2)
        current_scenario = scenario_name
        scenario_data[current_scenario] = []
    else:
        # If the line contains numerical data, extract and add it to the current scenario
        match = re.findall(r"[-+]?\d*\.\d+|\d+", line.strip())
        if match and current_scenario:
            scenario_data[current_scenario].extend([float(value) for value in match])

# Creating a DataFrame where the first column is the time and each subsequent column is a scenario/ensemble
df = pd.DataFrame({'Year': list(time_range)})

for scenario, values in scenario_data.items():
    df[scenario] = values[:len(time_range)]  # Ensuring the values match the time range

# Define scenario prefixes to filter the columns
scenarios = ['SSP119', 'SSP245', 'SSP534', 'SSP585', 'NATURAL']

# Initialize a result DataFrame
result = pd.DataFrame()
result['Year'] = df['Year']

# Loop through each scenario to calculate the mean and confidence intervals
for scenario in scenarios:
    # Identify relevant columns for the current scenario
    scenario_columns = [col for col in df.columns if scenario in col]
    
    if scenario_columns:  # Ensure there are columns for this scenario
        # Calculate mean and confidence intervals
        result[f'{scenario}_Mean'] = df[scenario_columns].mean(axis=1)
        result[f'{scenario}_CI_Lower'] = df[scenario_columns].quantile(0.05, axis=1)
        result[f'{scenario}_CI_Upper'] = df[scenario_columns].quantile(0.95, axis=1)

# Create historical scenario columns by aggregating the SSP scenarios up to the year 2014
historical_columns = ['SSP119', 'SSP245', 'SSP534', 'SSP585']

# Create a mask for years up to 2014
mask_historical = result['Year'] <= 2014

# Calculate mean and confidence intervals for historical scenarios
result['Historical_Mean'] = result.loc[mask_historical, [f'{scenario}_Mean' for scenario in historical_columns]].mean(axis=1)
result['Historical_CI_Lower'] = result.loc[mask_historical, [f'{scenario}_CI_Lower' for scenario in historical_columns]].mean(axis=1)
result['Historical_CI_Upper'] = result.loc[mask_historical, [f'{scenario}_CI_Upper' for scenario in historical_columns]].mean(axis=1)

# Set SSP columns to NaN for years up to 2014
for scenario in historical_columns:
    result.loc[mask_historical, f'{scenario}_Mean'] = np.nan
    result.loc[mask_historical, f'{scenario}_CI_Lower'] = np.nan
    result.loc[mask_historical, f'{scenario}_CI_Upper'] = np.nan

# Set Historical columns to NaN for years after 2014
result.loc[~mask_historical, 'Historical_Mean'] = np.nan
result.loc[~mask_historical, 'Historical_CI_Lower'] = np.nan
result.loc[~mask_historical, 'Historical_CI_Upper'] = np.nan

result.to_csv("../Data/Figure_4_salinity.csv")

result.head()


## Figure 5

<p align="center">
  <img src="Figs/climate_salinity_5.png" style="width:50%;">
</p>

Monthly data were retrieved from CMEMS [Mercator Ocean International / Copernicus Marine Service (2023). Global Ocean Physics Reanalysis (GLOBAL_MULTIYEAR_PHY_001_030) [Data set]. Copernicus Marine Service. https://doi.org/10.48670/moi-00021](https://data.marine.copernicus.eu/product/GLOBAL_MULTIYEAR_PHY_001_030/description).

In [None]:
#!/usr/bin/env python3
import os, json, datetime as dt
import numpy as np
import xarray as xr
import rioxarray
from affine import Affine

# ---------------------------
# CONFIG
# ---------------------------
IN_PATH = "../Data/cmems_mod_glo_phy_my_0.083deg_P1M-m_1760469819927.nc"
OUT_DIR = "../Data/Figure_5_salinity"
VAR_NAME_HINTS = ["so", "salinity"]
WRITE_FLOAT32 = False   # set True to write Float32 GeoTIFFs instead of 8-bit
LOW_Q, HIGH_Q = 0.02, 0.98   # robust percentiles for 8-bit scaling
COMPRESS = "DEFLATE"    # GeoTIFF compression
NODATA_UINT8 = 0        # 0 reserved as nodata in 8-bit output
OPEN_WITH_CHUNKS = True # try to enable dask-friendly quantiles

os.makedirs(OUT_DIR, exist_ok=True)

def pick_var(ds: xr.Dataset) -> xr.DataArray:
    for name in VAR_NAME_HINTS:
        if name in ds.data_vars:
            return ds[name]
    return next(iter(ds.data_vars.values()))

def ensure_xy(ds_like: xr.DataArray) -> xr.DataArray:
    da = ds_like
    rename_map = {}
    if "lon" in da.dims: rename_map["lon"] = "x"
    if "longitude" in da.dims: rename_map["longitude"] = "x"
    if "lat" in da.dims: rename_map["lat"] = "y"
    if "latitude" in da.dims: rename_map["latitude"] = "y"
    if rename_map:
        da = da.rename(rename_map)
    if "x" not in da.dims or "y" not in da.dims:
        raise ValueError(f"Expected 'x' and 'y' dims, found {list(da.dims)}")
    if np.all(np.diff(da["y"].values) > 0):
        da = da.sortby("y", ascending=False)
    da = da.rio.write_crs("EPSG:4326", inplace=False)
    xv, yv = da["x"].values, da["y"].values
    dx = float(np.mean(np.diff(xv)))
    dy = float(np.mean(np.diff(yv)))
    x0 = xv.min() - dx / 2.0
    y0 = yv.max() + dy / 2.0
    transform = Affine(dx, 0.0, x0, 0.0, -dy, y0)
    da = da.rio.write_transform(transform, inplace=False)
    return da

def compute_global_bounds(var: xr.DataArray, low_q=LOW_Q, high_q=HIGH_Q):
    """
    Compute global robust quantiles across all time/y/x and return (vmin, vmax).
    Tries xarray.quantile (dask-friendly), falls back to numpy if needed.
    """
    try:
        q = var.quantile([low_q, high_q], dim=[d for d in var.dims if d in ("time","y","x","lat","lon","latitude","longitude")], skipna=True)
        vmin = float(q.sel(quantile=low_q).values)
        vmax = float(q.sel(quantile=high_q).values)
    except Exception:
        arr = var.values.astype(np.float32)  # may be large!
        vmin, vmax = np.nanquantile(arr, [low_q, high_q])
        vmin, vmax = float(vmin), float(vmax)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
        # very flat / missing data; fallback to global min/max
        try:
            vmin = float(var.min(skipna=True).values)
            vmax = float(var.max(skipna=True).values)
        except Exception:
            arr = var.values.astype(np.float32)
            vmin, vmax = float(np.nanmin(arr)), float(np.nanmax(arr))
    return vmin, vmax

def scale_to_uint8_fixed(da: xr.DataArray, vmin: float, vmax: float, nodata_val=NODATA_UINT8):
    """
    Scale with fixed vmin/vmax to 8-bit, reserving 0 as nodata.
    """
    arr = da.values.astype(np.float32)
    mask = ~np.isfinite(arr)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
        # avoid divide-by-zero
        vmin = np.nanmin(arr)
        vmax = np.nanmax(arr)
        if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
            scaled = np.full(arr.shape, nodata_val, dtype=np.uint8)
            return scaled
    scaled = (arr - vmin) / (vmax - vmin + 1e-12)
    scaled = np.clip(scaled, 0.0, 1.0)
    scaled = (scaled * 254.0 + 1.0).astype(np.uint8)
    scaled[mask] = nodata_val
    return scaled

def main():
    open_kwargs = {"chunks": "auto"} if OPEN_WITH_CHUNKS else {}
    ds = xr.open_dataset(IN_PATH, **open_kwargs)
    var = pick_var(ds)
    if "time" not in var.dims:
        raise ValueError("Expected a 'time' dimension in salinity DataArray.")

    # --- Compute one pair of robust bounds across ALL timesteps ---
    global_vmin, global_vmax = compute_global_bounds(var, LOW_Q, HIGH_Q)
    print(f"Global scaling bounds ({LOW_Q:.2f}-{HIGH_Q:.2f} quantiles): vmin={global_vmin:.4f}, vmax={global_vmax:.4f}")

    scaling_meta = {
        "variable": str(var.name),
        "quantiles": [LOW_Q, HIGH_Q],
        "global_vmin": float(global_vmin),
        "global_vmax": float(global_vmax),
        "nodata_uint8": NODATA_UINT8,
        "files": []
    }

    for i in range(var.sizes["time"]):
        slice_t = var.isel(time=i).squeeze(drop=True)
        da = ensure_xy(slice_t)

        if "time" in slice_t.coords:
            tlabel = np.datetime_as_string(slice_t["time"].values, unit="D").replace("-", "")[:6]  # YYYYMM
        else:
            tlabel = f"t{i:04d}"

        if WRITE_FLOAT32:
            out_path = os.path.join(OUT_DIR, f"salinity_{tlabel}_float32.tif")
            da = da.rio.write_nodata(np.nan, inplace=False)
            da.rio.to_raster(out_path, dtype="float32", tiled=True, compress=COMPRESS)
            print(f"Wrote {out_path}")
            scaling_meta["files"].append({"file": os.path.basename(out_path), "tindex": i, "tlabel": tlabel})
        else:
            scaled = scale_to_uint8_fixed(da, global_vmin, global_vmax, NODATA_UINT8)
            out_path = os.path.join(OUT_DIR, f"salinity_{tlabel}.tif")
            da8 = xr.DataArray(
                scaled, dims=("y","x"), coords={"y": da["y"], "x": da["x"]},
                name="salinity_uint8",
                attrs={"long_name": "Sea Water Practical Salinity (scaled 8-bit, global bounds)",
                       "units": "unitless (0=nodata)",
                       "scale_vmin": global_vmin, "scale_vmax": global_vmax}
            ).rio.write_crs(da.rio.crs).rio.write_transform(da.rio.transform())
            da8 = da8.rio.write_nodata(NODATA_UINT8, inplace=False)
            da8.rio.to_raster(out_path, dtype="uint8", tiled=True, compress=COMPRESS)
            scaling_meta["files"].append({"file": os.path.basename(out_path), "tindex": i, "tlabel": tlabel})
            print(f"Wrote {out_path} using global scaling [{global_vmin:.4f}, {global_vmax:.4f}]")

    if not WRITE_FLOAT32:
        meta_path = os.path.join(OUT_DIR, "salinity_scaling_metadata.json")
        with open(meta_path, "w") as f:
            json.dump(scaling_meta, f, indent=2)
        print(f"Saved {meta_path}")

if __name__ == "__main__":
    main()


## Figure 6

<p align="center">
  <img src="Figs/climate_salinity_6.png" style="width:50%;">
</p>

Monthly data were retrieved from CMEMS [Mercator Ocean International / Copernicus Marine Service (2023). Global Ocean Physics Reanalysis (GLOBAL_MULTIYEAR_PHY_001_030) [Data set]. Copernicus Marine Service. https://doi.org/10.48670/moi-00021](https://data.marine.copernicus.eu/product/GLOBAL_MULTIYEAR_PHY_001_030/description).

In [None]:
#!/usr/bin/env python3
import os, json
import numpy as np
import xarray as xr
import rioxarray
from affine import Affine

# ---------------------------
# CONFIG
# ---------------------------
U_PATH = "../Data/cmems_mod_glo_phy_my_0.083deg_P1M-m_1760473504514.nc"  # uo
V_PATH = "../Data/cmems_mod_glo_phy_my_0.083deg_P1M-m_1760473531668.nc"  # vo
OUT_DIR = "../Data/Figure_6_salinity"
WRITE_FLOAT32 = False        # True -> Float32 GeoTIFFs; False -> 8-bit scaled GeoTIFFs
LOW_Q, HIGH_Q = 0.02, 0.98   # robust percentile bounds used globally across ALL timesteps
COMPRESS = "DEFLATE"         # GeoTIFF compression
NODATA_UINT8 = 0             # 0 reserved as nodata in 8-bit output
OPEN_WITH_CHUNKS = True      # dask-friendly

os.makedirs(OUT_DIR, exist_ok=True)

# ---------------------------
# helpers
# ---------------------------
def _ensure_xy(da: xr.DataArray) -> xr.DataArray:
    """Rename lon/lat to x/y if needed, set CRS, ensure y desc, attach affine transform."""
    rename_map = {}
    if "lon" in da.dims: rename_map["lon"] = "x"
    if "longitude" in da.dims: rename_map["longitude"] = "x"
    if "lat" in da.dims: rename_map["lat"] = "y"
    if "latitude" in da.dims: rename_map["latitude"] = "y"
    if rename_map:
        da = da.rename(rename_map)
    if "x" not in da.dims or "y" not in da.dims:
        raise ValueError(f"Expected 'x' and 'y' dims, found {list(da.dims)}")
    # y descending (north -> south)
    if np.all(np.diff(da["y"].values) > 0):
        da = da.sortby("y", ascending=False)
    # georef
    da = da.rio.write_crs("EPSG:4326", inplace=False)
    xv, yv = da["x"].values, da["y"].values
    dx = float(np.mean(np.diff(xv)))
    dy = float(np.mean(np.diff(yv)))
    x0 = xv.min() - dx / 2.0
    y0 = yv.max() + dy / 2.0
    transform = Affine(dx, 0.0, x0, 0.0, -dy, y0)
    da = da.rio.write_transform(transform, inplace=False)
    return da

def _pick_surface(da: xr.DataArray) -> xr.DataArray:
    """
    If a vertical dimension exists, take surface (depth ~ 0) by coord if available,
    otherwise first level. Handles common names.
    """
    z_names = [d for d in da.dims if d.lower() in ("depth","lev","z","depthu","depthv","deptho")]
    if not z_names:
        return da
    z = z_names[0]
    try:
        if z in da.coords and np.issubdtype(da.coords[z].dtype, np.number):
            # choose nearest to 0 if 0 present or nearest available
            if 0 in set(np.asarray(da.coords[z]).round(6)):
                da = da.sel({z: 0}, method="nearest")
            else:
                da = da.sel({z: float(da.coords[z].values.min())}, method="nearest")
        else:
            da = da.isel({z: 0})
    except Exception:
        da = da.isel({z: 0})
    return da

def _compute_global_bounds(var: xr.DataArray, low_q=LOW_Q, high_q=HIGH_Q):
    """Compute global robust quantiles across time/y/x; fallback to min/max if degenerate."""
    try:
        q = var.quantile([low_q, high_q], dim=[d for d in var.dims if d in ("time","y","x","lat","lon","latitude","longitude")], skipna=True)
        vmin = float(q.sel(quantile=low_q).values)
        vmax = float(q.sel(quantile=high_q).values)
    except Exception:
        arr = var.values.astype(np.float32)
        vmin, vmax = np.nanquantile(arr, [low_q, high_q])
        vmin, vmax = float(vmin), float(vmax)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
        # fallback to global min/max
        try:
            vmin = float(var.min(skipna=True).values)
            vmax = float(var.max(skipna=True).values)
        except Exception:
            arr = var.values.astype(np.float32)
            vmin, vmax = float(np.nanmin(arr)), float(np.nanmax(arr))
    return vmin, vmax

def _scale_to_uint8_fixed(da: xr.DataArray, vmin: float, vmax: float, nodata_val=NODATA_UINT8):
    """Scale with fixed vmin/vmax to 8-bit, reserving 0 as nodata."""
    arr = da.values.astype(np.float32)
    mask = ~np.isfinite(arr)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
        vmin = np.nanmin(arr)
        vmax = np.nanmax(arr)
        if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
            return np.full(arr.shape, nodata_val, dtype=np.uint8)
    scaled = (arr - vmin) / (vmax - vmin + 1e-12)
    scaled = np.clip(scaled, 0.0, 1.0)
    scaled = (scaled * 254.0 + 1.0).astype(np.uint8)  # 1..255; 0 reserved as nodata
    scaled[mask] = nodata_val
    return scaled

# ---------------------------
# main
# ---------------------------
def main():
    open_kwargs = {"chunks": "auto"} if OPEN_WITH_CHUNKS else {}
    # Merge u and v by coords
    ds = xr.open_mfdataset([U_PATH, V_PATH], combine="by_coords", **open_kwargs)

    if "uo" not in ds.data_vars or "vo" not in ds.data_vars:
        # some CMEMS releases use 'uo'/'vo' names consistently; if not present, try to find them
        raise ValueError(f"Could not find variables 'uo' and 'vo' in the provided datasets. Found: {list(ds.data_vars)}")

    uo = ds["uo"]
    vo = ds["vo"]

    # Surface slice if needed
    uo_sfc = _pick_surface(uo)
    vo_sfc = _pick_surface(vo)

    # Ensure time alignment
    if "time" not in uo_sfc.dims or "time" not in vo_sfc.dims:
        raise ValueError("Expected a 'time' dimension in both uo and vo.")
    # align by time intersection to be safe
    uo_sfc, vo_sfc = xr.align(uo_sfc, vo_sfc, join="inner")

    # Compute speed magnitude
    speed = xr.apply_ufunc(
        lambda a, b: np.sqrt(a*a + b*b),
        uo_sfc, vo_sfc,
        dask="parallelized", output_dtypes=[np.float32]
    )
    speed.name = "current_speed"
    speed.attrs.update({"long_name": "Ocean current speed", "units": "m s-1"})

    # Compute one pair of robust bounds across ALL timesteps
    global_vmin, global_vmax = _compute_global_bounds(speed, LOW_Q, HIGH_Q)
    print(f"Global speed scaling bounds ({LOW_Q:.2f}-{HIGH_Q:.2f} quantiles): "
          f"vmin={global_vmin:.5f}, vmax={global_vmax:.5f}")

    # Metadata holder
    scaling_meta = {
        "variable": "current_speed",
        "source_vars": ["uo","vo"],
        "units": "m s-1",
        "quantiles": [LOW_Q, HIGH_Q],
        "global_vmin": float(global_vmin),
        "global_vmax": float(global_vmax),
        "nodata_uint8": NODATA_UINT8,
        "files": []
    }

    # Iterate each time slice
    T = speed.sizes["time"]
    for i in range(T):
        slice_t = speed.isel(time=i).squeeze(drop=True)

        # Ensure x/y + georeferencing
        da = _ensure_xy(slice_t)

        # Timestamp label YYYYMM
        if "time" in slice_t.coords:
            tlabel = np.datetime_as_string(slice_t["time"].values, unit="D").replace("-", "")[:6]
        else:
            tlabel = f"t{i:04d}"

        if WRITE_FLOAT32:
            out_path = os.path.join(OUT_DIR, f"current_speed_{tlabel}_float32.tif")
            da = da.rio.write_nodata(np.nan, inplace=False)
            da.rio.to_raster(out_path, dtype="float32", tiled=True, compress=COMPRESS)
            print(f"Wrote {out_path}")
            scaling_meta["files"].append({"file": os.path.basename(out_path), "tindex": i, "tlabel": tlabel})
        else:
            scaled = _scale_to_uint8_fixed(da, global_vmin, global_vmax, NODATA_UINT8)
            out_path = os.path.join(OUT_DIR, f"current_speed_{tlabel}.tif")
            da8 = xr.DataArray(
                scaled, dims=("y","x"), coords={"y": da["y"], "x": da["x"]},
                name="current_speed_uint8",
                attrs={
                    "long_name": "Ocean current speed (scaled 8-bit, global bounds)",
                    "units": "unitless (0=nodata)",
                    "scale_vmin": global_vmin,
                    "scale_vmax": global_vmax
                }
            ).rio.write_crs(da.rio.crs).rio.write_transform(da.rio.transform())
            da8 = da8.rio.write_nodata(NODATA_UINT8, inplace=False)
            da8.rio.to_raster(out_path, dtype="uint8", tiled=True, compress=COMPRESS)
            scaling_meta["files"].append({"file": os.path.basename(out_path), "tindex": i, "tlabel": tlabel})
            print(f"Wrote {out_path} using global scaling [{global_vmin:.5f}, {global_vmax:.5f}]")

    # Save sidecar metadata (for legends, rescaling)
    if not WRITE_FLOAT32:
        meta_path = os.path.join(OUT_DIR, "current_speed_scaling_metadata.json")
        with open(meta_path, "w") as f:
            json.dump(scaling_meta, f, indent=2)
        print(f"Saved {meta_path}")

if __name__ == "__main__":
    main()


# Acidity

## Figure 1

<p align="center">
  <img src="Figs/climate_acidity_1.png" style="width:50%;">
</p>

**This figure gets pH data from OceanSODA and also plots a linear trend over the period 1982 to present.**

In [None]:
import xarray as xr
import pandas as pd

acid_df = xr.open_dataset("../Data/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")
acid_df = acid_df['ph_total'].mean(dim=['lat','lon']).resample(time='Y').mean()

# Create a pandas DataFrame with these columns
df = pd.DataFrame({
    'time': acid_df['time'].values,
    'ph_total': acid_df.values,
})

# Convert 'time' to datetime
df['time'] = pd.to_datetime(df['time'])

# Convert datetime to a numerical value for linear regression (using ordinal format)
df['time_ordinal'] = df['time'].map(pd.Timestamp.toordinal)

# Perform linear regression to find the slope and intercept
slope, intercept, _, _, _ = linregress(df['time_ordinal'], df['ph_total'])

# Calculate the trend line (y = mx + b) for each time point
df['linear_trend'] = intercept + slope * df['time_ordinal']

df[['time','ph_total','linear_trend']].to_csv("../Data/Figure_1_acidity.csv")

# Display the updated DataFrame
df[['time','ph_total','linear_trend']]

## Figure 2

<p align="center">
  <img src="Figs/climate_acidity_2.png" style="width:50%;">
</p>

In [None]:
import numpy as np
import xarray as xr
import rioxarray

acid_df = xr.open_dataset("../Data/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")

trend_significance_ds = calculate_trend_df(acid_df['ph_total'])

# --------------------
# CONFIG
# --------------------
OUTPUT_TIF_8BIT   = "./Figure_2_acidity.tif"
OUTPUT_TIF_FLOAT  = "./acidity_trend_float32.tif"
USE_PERCENTILES   = True
P_LOW, P_HIGH     = 2, 98
VMIN_FIXED, VMAX_FIXED = -0.5, 0.5  # if USE_PERCENTILES=False

# 1) Select the dataarray: salinity trend
da = trend_significance_ds['trend']

# 2) Rename to x/y if needed
rename_map = {}
if "lat" in da.dims or "lat" in da.coords: rename_map["lat"] = "y"
if "latitude" in da.dims or "latitude" in da.coords: rename_map["latitude"] = "y"
if "lon" in da.dims or "lon" in da.coords: rename_map["lon"] = "x"
if "longitude" in da.dims or "longitude" in da.coords: rename_map["longitude"] = "x"
if rename_map:
    da = da.rename(rename_map)

# 3) Ensure y is north→south (descending)
if da["y"].values[0] < da["y"].values[-1]:
    da = da.sortby("y", ascending=False)
# ensure (y, x) order
if tuple(da.dims) != ("y", "x"):
    da = da.transpose("y", "x")

# 4) Register spatial dims & force EPSG:4326
da = da.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)

# If CRS is missing, assume geographic lon/lat; then force-reproject to EPSG:4326
if da.rio.crs is None:
    da = da.rio.write_crs("EPSG:4326", inplace=False)

# If CRS is not 4326, reproject it so the saved rasters are truly EPSG:4326
if str(da.rio.crs) != "EPSG:4326":
    # Use nearest for categorical-like; bilinear is typical for continuous SLA
    da = da.rio.reproject("EPSG:4326", resampling=rioxarray.rio.reproject.Resampling.bilinear)

# Reassert dims order after any reprojection (just in case)
if tuple(da.dims) != ("y", "x"):
    da = da.transpose("y", "x")


# 5) Compute scaling range (Dask-safe)
import dask.array as dsa

def compute_percentiles_safe(da, p_low, p_high):
    """
    Try dask.array.nanpercentile; if that fails (older dask/xarray),
    fall back to a coarse sample to keep memory in check.
    """
    if getattr(da, "chunks", None):
        try:
            vmin = float(dsa.nanpercentile(da.data, p_low).compute())
            vmax = float(dsa.nanpercentile(da.data, p_high).compute())
            return vmin, vmax
        except Exception:
            pass  # fall back to sampled approach

    # Fallback: sample every Nth pixel (keeps memory tiny)
    step_y = max(int(len(da.y) // 512), 1) if "y" in da.dims else 4
    step_x = max(int(len(da.x) // 512), 1) if "x" in da.dims else 4
    das = da.isel(
        y=slice(0, None, step_y) if "y" in da.dims else slice(None),
        x=slice(0, None, step_x) if "x" in da.dims else slice(None),
    ).load()  # small enough to load
    vmin = float(np.nanpercentile(das.values, p_low))
    vmax = float(np.nanpercentile(das.values, p_high))
    return vmin, vmax

if USE_PERCENTILES:
    vmin, vmax = compute_percentiles_safe(da, P_LOW, P_HIGH)
else:
    vmin, vmax = float(VMIN_FIXED), float(VMAX_FIXED)

if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
    raise ValueError(f"Bad scaling range: vmin={vmin}, vmax={vmax}")


# 6) Build the 8-bit layer with xr.where (no boolean indexing)
# Reserve 0 for NoData; valid cells map to [1,255]
norm = ((da - vmin) / (vmax - vmin)).clip(0, 1)
scaled_da = xr.where(np.isfinite(da), norm * 254.0 + 1.0, 0.0).astype("uint8")

# Clear troublesome encodings on the uint8 view
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    scaled_da.encoding.pop(k, None)

# Set nodata appropriate for uint8 (0)
scaled_da = scaled_da.rio.write_nodata(0, encoded=False, inplace=False)

# 7a) Write 8-bit GeoTIFF (works with Dask; will stream-chunk if array is chunked)
scaled_da.rio.to_raster(
    OUTPUT_TIF_8BIT,
    dtype="uint8",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

# 7b) Write float32 GeoTIFF with native values
daf = da.where(np.isfinite(da)).astype("float32")
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    daf.encoding.pop(k, None)
# Use NaN as nodata for float32
daf = daf.rio.write_nodata(np.nan, encoded=False, inplace=False)

daf.rio.to_raster(
    OUTPUT_TIF_FLOAT,
    dtype="float32",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

print(
    f"Wrote {OUTPUT_TIF_8BIT} (uint8; 0=NoData, 1–255=data) and {OUTPUT_TIF_FLOAT} (float32). "
    f"8-bit scale: vmin={vmin:.4f}, vmax={vmax:.4f} (native SLA units)."
)


## Figure 3

<p align="center">
  <img src="Figs/climate_acidity_3.png" style="width:50%;">
</p>

In [None]:
acid_df = xr.open_dataset("../Data/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")

trend_significance_ds = calculate_trend_df(acid_df['ph_total'])

area_df = area_trend(trend_significance_ds)

# Save the GeoDataFrame to a GeoJSON file
area_df.to_file("../Data/Figure_3_acidity.geojson",driver="GeoJSON")

del acid_df, area_df

## Figure 4

<p align="center">
  <img src="Figs/climate_acidity_4.png" style="width:50%;">
</p>

In [None]:
corals = gpd.read_file(
    "../Data/REEF_FORMING_CORALS/reef_forming_corals_dissolved.shp"
)

coral_by_sea = area_coral_acidification_by_sea(
    trend_significance_ds=trend_ds,
    coral_gdf=corals,
    SEAS_DF=SEAS_DF,
)

out_path = "Figure_4_acidity.geojson"

coral_by_sea.to_file(
    out_path,
    driver="GeoJSON"
)

print(f"Saved GeoJSON to: {out_path}")

## Figure 5

<p align="center">
  <img src="Figs/climate_acidity_5.png" style="width:50%;">
</p>

In [None]:
import xarray as xr
import numpy as np
import json

co2_df = xr.open_dataset("../Data/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")

# Latitude weights
weights = np.cos(np.deg2rad(co2_df['lat']))

# Weighted mean over lat/lon for selected variables
weighted = co2_df[['ph_total','spco2']].weighted(weights)
mean_timeseries = weighted.mean(dim=['lat','lon']).to_pandas()

# Reset index so time is a column
df_out = mean_timeseries.reset_index()

# Save to JSON (records orientation = list of dicts)
out_json = df_out[['time','spco2','ph_total']].to_json(
    orient="records", date_format="iso"
)

# Write to file
with open("../Data/Figure_5_acidity.json", "w") as f:
    f.write(out_json)


## Figure 6

<p align="center">
  <img src="Figs/climate_acidity_6.png" style="width:50%;">
</p>

In [None]:
import sqlite3
import pandas as pd
import geopandas as gpd

# --- paths ---
db_path = "../Data/Global_Coral_Bleaching_Database_SQLite_11_24_21.db"
out_geojson = "../Data/Figure_6_acidity.geojson"

# --- connect to the SQLite database ---
conn = sqlite3.connect(db_path)

query = """
SELECT DISTINCT
    s.Site_ID,
    s.Site_Name,
    s.Latitude_Degrees AS lat,
    s.Longitude_Degrees AS lon
FROM Site_Info_tbl s
JOIN Sample_Event_tbl se ON s.Site_ID = se.Site_ID
JOIN Bleaching_tbl b ON se.Sample_ID = b.Sample_ID
WHERE b.Percent_Bleached IS NOT NULL
  AND b.Percent_Bleached > 0
  AND s.Latitude_Degrees IS NOT NULL
  AND s.Longitude_Degrees IS NOT NULL
"""

bleach_sites = pd.read_sql_query(query, conn)
conn.close()

print(f"Found {len(bleach_sites)} sites with bleaching records.")

# --- fix bytes columns (decode to str) ---
for col in bleach_sites.columns:
    if bleach_sites[col].dtype == object:
        bleach_sites[col] = bleach_sites[col].apply(
            lambda v: v.decode("utf-8", errors="ignore") if isinstance(v, (bytes, bytearray)) else v
        )

# (Optional) keep just a few clean columns
# bleach_sites = bleach_sites[["Site_ID", "Site_Name", "lat", "lon"]].copy()

# --- convert to GeoDataFrame ---
gdf = gpd.GeoDataFrame(
    bleach_sites,
    geometry=gpd.points_from_xy(bleach_sites["lon"], bleach_sites["lat"]),
    crs="EPSG:4326"
)

# --- save as GeoJSON ---
gdf.to_file(out_geojson, driver="GeoJSON")
print(f"Saved to {out_geojson}")


## Figure 7

<p align="center">
  <img src="Figs/climate_acidity_7.png" style="width:50%;">
</p>

In [None]:
import numpy as np
import xarray as xr
import rioxarray

co2_df = xr.open_dataset("../Data/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")

# --------------------
# CONFIG
# --------------------
OUTPUT_TIF_8BIT   = "../Data/Figure_7_acidity.tif"
OUTPUT_TIF_FLOAT  = "../Data/co2_trend_float32.tif"
USE_PERCENTILES   = True
P_LOW, P_HIGH     = 2, 98
VMIN_FIXED, VMAX_FIXED = -0.5, 0.5  # if USE_PERCENTILES=False

# 1) Select the dataarray: salinity trend
da = co2_df['spco2'].mean(dim='time')

# 2) Rename to x/y if needed
rename_map = {}
if "lat" in da.dims or "lat" in da.coords: rename_map["lat"] = "y"
if "latitude" in da.dims or "latitude" in da.coords: rename_map["latitude"] = "y"
if "lon" in da.dims or "lon" in da.coords: rename_map["lon"] = "x"
if "longitude" in da.dims or "longitude" in da.coords: rename_map["longitude"] = "x"
if rename_map:
    da = da.rename(rename_map)

# 3) Ensure y is north→south (descending)
if da["y"].values[0] < da["y"].values[-1]:
    da = da.sortby("y", ascending=False)
# ensure (y, x) order
if tuple(da.dims) != ("y", "x"):
    da = da.transpose("y", "x")

# 4) Register spatial dims & force EPSG:4326
da = da.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)

# If CRS is missing, assume geographic lon/lat; then force-reproject to EPSG:4326
if da.rio.crs is None:
    da = da.rio.write_crs("EPSG:4326", inplace=False)

# If CRS is not 4326, reproject it so the saved rasters are truly EPSG:4326
if str(da.rio.crs) != "EPSG:4326":
    # Use nearest for categorical-like; bilinear is typical for continuous SLA
    da = da.rio.reproject("EPSG:4326", resampling=rioxarray.rio.reproject.Resampling.bilinear)

# Reassert dims order after any reprojection (just in case)
if tuple(da.dims) != ("y", "x"):
    da = da.transpose("y", "x")


# 5) Compute scaling range (Dask-safe)
import dask.array as dsa

def compute_percentiles_safe(da, p_low, p_high):
    """
    Try dask.array.nanpercentile; if that fails (older dask/xarray),
    fall back to a coarse sample to keep memory in check.
    """
    if getattr(da, "chunks", None):
        try:
            vmin = float(dsa.nanpercentile(da.data, p_low).compute())
            vmax = float(dsa.nanpercentile(da.data, p_high).compute())
            return vmin, vmax
        except Exception:
            pass  # fall back to sampled approach

    # Fallback: sample every Nth pixel (keeps memory tiny)
    step_y = max(int(len(da.y) // 512), 1) if "y" in da.dims else 4
    step_x = max(int(len(da.x) // 512), 1) if "x" in da.dims else 4
    das = da.isel(
        y=slice(0, None, step_y) if "y" in da.dims else slice(None),
        x=slice(0, None, step_x) if "x" in da.dims else slice(None),
    ).load()  # small enough to load
    vmin = float(np.nanpercentile(das.values, p_low))
    vmax = float(np.nanpercentile(das.values, p_high))
    return vmin, vmax

if USE_PERCENTILES:
    vmin, vmax = compute_percentiles_safe(da, P_LOW, P_HIGH)
else:
    vmin, vmax = float(VMIN_FIXED), float(VMAX_FIXED)

if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
    raise ValueError(f"Bad scaling range: vmin={vmin}, vmax={vmax}")


# 6) Build the 8-bit layer with xr.where (no boolean indexing)
# Reserve 0 for NoData; valid cells map to [1,255]
norm = ((da - vmin) / (vmax - vmin)).clip(0, 1)
scaled_da = xr.where(np.isfinite(da), norm * 254.0 + 1.0, 0.0).astype("uint8")

# Clear troublesome encodings on the uint8 view
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    scaled_da.encoding.pop(k, None)

# Set nodata appropriate for uint8 (0)
scaled_da = scaled_da.rio.write_nodata(0, encoded=False, inplace=False)

# 7a) Write 8-bit GeoTIFF (works with Dask; will stream-chunk if array is chunked)
scaled_da.rio.to_raster(
    OUTPUT_TIF_8BIT,
    dtype="uint8",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

# 7b) Write float32 GeoTIFF with native values
daf = da.where(np.isfinite(da)).astype("float32")
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    daf.encoding.pop(k, None)
# Use NaN as nodata for float32
daf = daf.rio.write_nodata(np.nan, encoded=False, inplace=False)

daf.rio.to_raster(
    OUTPUT_TIF_FLOAT,
    dtype="float32",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

print(
    f"Wrote {OUTPUT_TIF_8BIT} (uint8; 0=NoData, 1–255=data) and {OUTPUT_TIF_FLOAT} (float32). "
    f"8-bit scale: vmin={vmin:.4f}, vmax={vmax:.4f} (native SLA units)."
)


# Sea Level Rise

## Figure 1

<p align="center">
  <img src="Figs/climate_slr_1.png" style="width:50%;">
</p>

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import io
import requests
import numpy as np
from scipy import stats
from datetime import datetime, timedelta

# Load the data from the file, obtain unique hyperlink from https://sealevel.nasa.gov/
url = 'https://deotb6e7tfubr.cloudfront.net/s3-edaf5da92e0ce48fb61175c28b67e95d/podaac-ops-cumulus-protected.s3.us-west-2.amazonaws.com/NASA_SSH_GMSL_INDICATOR/NASA_SSH_GMSL_INDICATOR.txt?A-userid=ps4813&Expires=1758135329&Signature=h-v-~pNMAoeIxGWvEFIMUGYPX-1~zNUUVVD1ugzD7z9Jujn0kT3ZQBIgmcRPfaByPVVl2JhQiudnWl0sGQIPnJi1IrVCMwVd0Ye6KF~5-Wi6UuKTFJVSXae75hw0SJu4H64TiqpZWVCpRgEi-Q2LZ1kUJdVA2VNc4TMbbW2QVLIkazBXJHlZme~omHoLvGrAIF5GtDSEj2HTjZRCP9r6OBybnbNK7huNmSvJNPCt9y-GhqqB8TbUDkxpg78KBXkF-oOnu~pWLVUDtWatlv1UOhljqULhg4-QawOSohnBEEXoSZAKN~daGhAyccVjpAeaPR8VQpmS~ZUgugyMxoFxyQ__&Key-Pair-Id=K3RGFTW2DFGMID'
# Fetch the content
response = requests.get(url)
content = response.text

# Split the content into lines
lines = content.split('\n')

# Find the index of the line containing "Header_End"
header_end_index = next(i for i, line in enumerate(lines) if "Header_End" in line)

# Read the data, skipping the header rows
raw_data = pd.read_csv(io.StringIO('\n'.join(lines[header_end_index + 1:])), 
                       sep='\s+', 
                       header=None)

# Create a new DataFrame with 'date' and 'SLR' columns
df = pd.DataFrame({
    'date': raw_data[0],
    'SLR': raw_data[2] - raw_data[2].iloc[0]  # Shifting SLR so that the first value is 0
})

# Function to convert fractional year to datetime (year, month, day only)
def fractional_year_to_datetime(year):
    year_int = int(year)  # Extract the integer part
    remainder = year - year_int  # Get the fractional part
    beginning_of_year = datetime(year_int, 1, 1)
    days_in_year = (datetime(year_int + 1, 1, 1) - beginning_of_year).days
    return (beginning_of_year + timedelta(days=remainder * days_in_year)).date()

# Convert the fractional years in 'date' column to datetime (year-month-day)
df['date'] = df['date'].apply(fractional_year_to_datetime)

# Extract the year from the 'date' column and create a new 'year' column
df['year'] = df['date'].apply(lambda x: x.year)

# Group by the 'year' column and calculate the mean for the 'SLR' column
df_grouped = df.groupby('year').mean(numeric_only=True).reset_index()

# Fit a linear trend
slope, intercept, r_value, p_value, std_err = stats.linregress(df_grouped['year'], df_grouped['SLR'])

# Add the linear trend to the DataFrame
df_grouped['linear_trend'] = slope * df_grouped['year'] + intercept

# Save the grouped data to a JSON file
df_grouped.to_json("../Data/Figure_1_SLR.json", orient="records", date_format="iso")

# Display the first few rows of the grouped DataFrame
df_grouped.head()


## Figure 2

<p align="center">
  <img src="Figs/climate_slr_2.png" style="width:50%;">
</p>

In [None]:
import numpy as np
import xarray as xr
import rioxarray

# Copernicus Climate Change Service, Climate Data Store, (2018): Sea level daily gridded data from satellite observations for the global ocean from 1993 to present. Copernicus Climate Change Service (C3S) Climate Data Store (CDS)
SLR_df = xr.open_mfdataset("../Data/dataset-satellite-sea-level-global-dc7f92ea-2d3e-4fc6-b767-836a5b8c0bff/*.nc")

# --------------------
# CONFIG
# --------------------
OUTPUT_TIF_8BIT   = "./Figure_2_SLR.tif"
OUTPUT_TIF_FLOAT  = "./sla_mean_float32.tif"
USE_PERCENTILES   = True
P_LOW, P_HIGH     = 2, 98
VMIN_FIXED, VMAX_FIXED = -0.5, 0.5  # if USE_PERCENTILES=False

# 1) Select the dataarray: mean SLA over time
da = SLR_df["sla"]
if "time" in da.dims:
    da = da.mean(dim="time")

# 2) Rename to x/y if needed
rename_map = {}
if "lat" in da.dims or "lat" in da.coords: rename_map["lat"] = "y"
if "latitude" in da.dims or "latitude" in da.coords: rename_map["latitude"] = "y"
if "lon" in da.dims or "lon" in da.coords: rename_map["lon"] = "x"
if "longitude" in da.dims or "longitude" in da.coords: rename_map["longitude"] = "x"
if rename_map:
    da = da.rename(rename_map)

# 3) Ensure y is north→south (descending)
if da["y"].values[0] < da["y"].values[-1]:
    da = da.sortby("y", ascending=False)
# ensure (y, x) order
if tuple(da.dims) != ("y", "x"):
    da = da.transpose("y", "x")

# 4) Register spatial dims & force EPSG:4326
da = da.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)

# If CRS is missing, assume geographic lon/lat; then force-reproject to EPSG:4326
if da.rio.crs is None:
    da = da.rio.write_crs("EPSG:4326", inplace=False)

# If CRS is not 4326, reproject it so the saved rasters are truly EPSG:4326
if str(da.rio.crs) != "EPSG:4326":
    # Use nearest for categorical-like; bilinear is typical for continuous SLA
    da = da.rio.reproject("EPSG:4326", resampling=rioxarray.rio.reproject.Resampling.bilinear)

# Reassert dims order after any reprojection (just in case)
if tuple(da.dims) != ("y", "x"):
    da = da.transpose("y", "x")


# 5) Compute scaling range (Dask-safe)
import dask.array as dsa

def compute_percentiles_safe(da, p_low, p_high):
    """
    Try dask.array.nanpercentile; if that fails (older dask/xarray),
    fall back to a coarse sample to keep memory in check.
    """
    if getattr(da, "chunks", None):
        try:
            vmin = float(dsa.nanpercentile(da.data, p_low).compute())
            vmax = float(dsa.nanpercentile(da.data, p_high).compute())
            return vmin, vmax
        except Exception:
            pass  # fall back to sampled approach

    # Fallback: sample every Nth pixel (keeps memory tiny)
    step_y = max(int(len(da.y) // 512), 1) if "y" in da.dims else 4
    step_x = max(int(len(da.x) // 512), 1) if "x" in da.dims else 4
    das = da.isel(
        y=slice(0, None, step_y) if "y" in da.dims else slice(None),
        x=slice(0, None, step_x) if "x" in da.dims else slice(None),
    ).load()  # small enough to load
    vmin = float(np.nanpercentile(das.values, p_low))
    vmax = float(np.nanpercentile(das.values, p_high))
    return vmin, vmax

if USE_PERCENTILES:
    vmin, vmax = compute_percentiles_safe(da, P_LOW, P_HIGH)
else:
    vmin, vmax = float(VMIN_FIXED), float(VMAX_FIXED)

if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
    raise ValueError(f"Bad scaling range: vmin={vmin}, vmax={vmax}")


# 6) Build the 8-bit layer with xr.where (no boolean indexing)
# Reserve 0 for NoData; valid cells map to [1,255]
norm = ((da - vmin) / (vmax - vmin)).clip(0, 1)
scaled_da = xr.where(np.isfinite(da), norm * 254.0 + 1.0, 0.0).astype("uint8")

# Clear troublesome encodings on the uint8 view
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    scaled_da.encoding.pop(k, None)

# Set nodata appropriate for uint8 (0)
scaled_da = scaled_da.rio.write_nodata(0, encoded=False, inplace=False)

# 7a) Write 8-bit GeoTIFF (works with Dask; will stream-chunk if array is chunked)
scaled_da.rio.to_raster(
    OUTPUT_TIF_8BIT,
    dtype="uint8",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

# 7b) Write float32 GeoTIFF with native values
daf = da.where(np.isfinite(da)).astype("float32")
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    daf.encoding.pop(k, None)
# Use NaN as nodata for float32
daf = daf.rio.write_nodata(np.nan, encoded=False, inplace=False)

daf.rio.to_raster(
    OUTPUT_TIF_FLOAT,
    dtype="float32",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

print(
    f"Wrote {OUTPUT_TIF_8BIT} (uint8; 0=NoData, 1–255=data) and {OUTPUT_TIF_FLOAT} (float32). "
    f"8-bit scale: vmin={vmin:.4f}, vmax={vmax:.4f} (native SLA units)."
)


## Figure 3

<p align="center">
  <img src="Figs/climate_slr_3.png" style="width:50%;">
</p>

The data were derived from [Horwath et al.](https://essd.copernicus.org/articles/14/411/2022/). The data at the country-level were pulled from [Sea Level Explorer](https://earth.gov/sealevel/sea-level-explorer/).

## Figures 4 and 9

<p align="center">
  <img src="Figs/climate_slr_4.png" style="width:50%;">
</p>

<p align="center">
  <img src="Figs/climate_slr_9.png" style="width:50%;">
</p>

Global trends were retrieved from [NOAA](https://tidesandcurrents.noaa.gov/sltrends/).

## Figures 5 and 10

<p align="center">
  <img src="Figs/climate_slr_5.png" style="width:50%;">
</p>

In [None]:
import requests
import pandas as pd
import geopandas as gpd
from pathlib import Path
import time

# === CONFIGURATION — adapt paths/names to your data ===
EEZ_SHP = "../Data/Intersect_EEZ_IHO_v5_20241010/Intersect_EEZ_IHO_v5_20241010.shp"  # adjust path
OUTPUT_DIR = Path("sea_level_flood_data")
OUTPUT_DIR.mkdir(exist_ok=True)
MASTER_CSV = Path("Figure_5_SLR.csv")

BASE_URL = "https://d3qt3aobtsas2h.cloudfront.net/edge/ws/search/sealevelgovglobal"

# Load EEZ shapefile — must contain both EEZ_MRGID and ISO_SOV1
eez = gpd.read_file(EEZ_SHP)
required = ["EEZ_MRGID", "ISO_SOV1"]
for field in required:
    if field not in eez.columns:
        raise ValueError(f"Shapefile missing required field: {field}. Found columns: {eez.columns}")

eez_records = eez[["EEZ_MRGID", "ISO_SOV1"]].drop_duplicates()
print("Will attempt to fetch data for", len(eez_records), "EEZ records")

master = []

for _, rec_meta in eez_records.iterrows():
    mrg = rec_meta["EEZ_MRGID"]
    iso = rec_meta["ISO_SOV1"]

    params = {
        "mrg_id": mrg,
        "format": "csv"  # try xlsx first; fallback to csv if needed
    }

    try:
        resp = requests.get(BASE_URL, params=params, timeout=60)
        resp.raise_for_status()
    except Exception as e:
        print(f"[{iso} / MRGID {mrg}] Download failed: {e}")
        continue

    content = resp.content
    content_type = resp.headers.get("Content-Type","")
    # guess file extension
    if "spreadsheetml" in content_type or content.startswith(b"PK"):
        ext = "xlsx"
    else:
        ext = "csv"

    fn = OUTPUT_DIR / f"{iso}_{mrg}.{ext}"
    with open(fn, "wb") as f:
        f.write(content)

    # read flood data sheet
    try:
        if ext == "xlsx":
            df = pd.read_excel(fn, sheet_name="Recent-Flooding")
        else:
            df = pd.read_csv(fn)
    except Exception as e:
        print(f"[{iso} / MRGID {mrg}] Could not read Recent-Flooding sheet: {e}")
        continue

    df.columns = [c.strip() for c in df.columns]

    # Expect columns like: "High Tide Flood Days Minor", "High Tide Flood Days Moderate", "High Tide Flood Days Major"
    fldays_cols = [
        "High Tide Flood Days Minor",
        "High Tide Flood Days Moderate",
        "High Tide Flood Days Major"
    ]

    missing = [c for c in fldays_cols if c not in df.columns]
    if missing:
        print(f"[{iso} / MRGID {mrg}] Missing Flood-Days columns: {missing}. Available columns: {df.columns.tolist()}")
        continue

    # Sum over all rows (all years)
    total_flood_days = df[fldays_cols].sum(axis=1).sum()  # sum minor+moderate+major, then sum all years

    master.append({
        "ISO_SOV1": iso,
        "MRGID": mrg,
        "Number_of_High_Tide_Flood_Days": total_flood_days
    })

    time.sleep(0.5)

# Save master CSV
if master:
    pd.DataFrame(master).to_csv(MASTER_CSV, index=False)
    print("Saved flood-day summary to:", MASTER_CSV)
else:
    print("No flood-day data collected. Check for errors or missing sheets.")


## Figures 6 and 11

<p align="center">
  <img src="Figs/climate_slr_6.png" style="width:50%;">
</p>

Data were retrieved from Supplementary Data 1 of Kulp, S.A., Strauss, B.H. New elevation data triple estimates of global vulnerability to sea-level rise and coastal flooding. Nat Commun 10, 4844 (2019). https://doi.org/10.1038/s41467-019-12808-z.

## Figures 7 and 12

<p align="center">
  <img src="Figs/climate_slr_7.png" style="width:50%;">
</p>

<p align="center">
  <img src="Figs/climate_slr_10.png" style="width:50%;">
</p>

The mangrove extent raster was created from the Global Mangrove Watch Dataset from 2020. See the Restore Ecosystems notebook for more details.

In [None]:
import numpy as np
import xarray as xr
import rioxarray
import geopandas as gpd
from tqdm import tqdm

def latlon_cell_areas_km2(x, y):
    """
    Compute grid-cell areas (km²) for a lat/lon grid.
    Assumes x = lon (degrees), y = lat (degrees).
    """
    R = 6371.0  # Earth radius in km
    lon_rad = np.deg2rad(x)
    lat_rad = np.deg2rad(y)

    dlon = np.abs(np.gradient(lon_rad))   # (x,)
    dlat = np.abs(np.gradient(lat_rad))   # (y,)

    lat_upper = lat_rad + 0.5 * dlat
    lat_lower = lat_rad - 0.5 * dlat

    row_term = np.abs(np.sin(lat_upper) - np.sin(lat_lower))
    area = (R**2) * row_term[:, None] * dlon[None, :]
    return area


def prep_aqueduct_inun_mask_from_tif(
    tif_path: str,
    depth_m_threshold: float = 0.0,
    assume_band: int = 1,
    to_crs: str = "EPSG:4326",
):
    """
    Read an Aqueduct Floods hazard GeoTIFF (inundation depth in meters),
    reproject to EPSG:4326, and return a uint8 mask where depth >= threshold.

    Notes:
    - Aqueduct hazard maps represent inundation depth (meters). :contentReference[oaicite:1]{index=1}
    """
    da = rioxarray.open_rasterio(tif_path, masked=True).squeeze()

    # If there's still a band dimension, select one
    if "band" in da.dims:
        da = da.sel(band=assume_band)

    # Ensure CRS is known
    if da.rio.crs is None:
        raise ValueError(
            "Raster has no CRS. Check the .tif metadata and set it via da.rio.write_crs(...)."
        )

    # Reproject to EPSG:4326 for consistent vector clipping & area weighting
    if str(da.rio.crs) != to_crs:
        da = da.rio.reproject(to_crs)

    # Standardize spatial dims for your pipeline
    da = da.rename({"y": "y", "x": "x"}).rio.set_spatial_dims(x_dim="x", y_dim="y")

    # Build mask (depth >= threshold)
    mask_u8 = (da >= depth_m_threshold).astype("uint8")

    return mask_u8


def saltmarsh_slr_overlap_by_country_fast(
    slr_mask_u8,
    countries,
    saltmarsh_gdf,
    country_name_col="NAME",
    equal_area_crs="ESRI:54009",
):
    """
    Unchanged logic; 'slr_mask_u8' can be ANY uint8 raster mask (0/1) in EPSG:4326.
    """
    slr_mask_u8 = slr_mask_u8.rio.write_crs("EPSG:4326")

    # Cell areas (km²) for EPSG:4326 grid
    cell_area = latlon_cell_areas_km2(
        slr_mask_u8["x"].values,
        slr_mask_u8["y"].values,
    )

    cell_area_da = xr.DataArray(
        cell_area,
        dims=["y", "x"],
        coords={"y": slr_mask_u8["y"], "x": slr_mask_u8["x"]},
    ).rio.write_crs("EPSG:4326")

    # Prep vectors
    countries = countries.to_crs("EPSG:4326").copy()
    countries["geometry"] = countries.geometry.buffer(0)

    saltmarsh = saltmarsh_gdf.to_crs("EPSG:4326").copy()
    saltmarsh["geometry"] = saltmarsh.geometry.buffer(0)

    # Spatial index (critical)
    sindex = saltmarsh.sindex

    records = []

    for _, country in tqdm(countries.iterrows(), total=len(countries), desc="Processing countries"):
        country_name = country.get(country_name_col, "unknown")
        country_geom = country.geometry

        try:
            # 1) Spatially filter mangroves/saltmarshes to this sea
            bbox_matches = list(sindex.intersection(country_geom.bounds))
            if not bbox_matches:
                continue

            saltmarsh_sub = saltmarsh.iloc[bbox_matches]
            saltmarsh_sub = saltmarsh_sub[saltmarsh_sub.intersects(country_geom)]
            if saltmarsh_sub.empty:
                continue

            # 2) Clip + dissolve
            saltmarsh_in_country = gpd.clip(saltmarsh_sub, country_geom)
            saltmarsh_union = saltmarsh_in_country.unary_union

            saltmarsh_country_gdf = gpd.GeoDataFrame(
                [{"geometry": saltmarsh_union}], crs="EPSG:4326"
            )

            # 3) True mangrove/saltmarsh area (km²)
            saltmarsh_eq = saltmarsh_country_gdf.to_crs(equal_area_crs)
            true_saltmarsh_km2 = saltmarsh_eq.geometry.area.sum() / 1e6
            if true_saltmarsh_km2 <= 0:
                continue

            # 4) Raster clip (mask + cell areas)
            mask_clip = slr_mask_u8.rio.clip(saltmarsh_country_gdf.geometry, drop=True)
            area_clip = cell_area_da.rio.clip(saltmarsh_country_gdf.geometry, drop=True)

            affected_pix_km2 = (mask_clip * area_clip).sum(dim=("y", "x")).item()
            total_pix_km2 = area_clip.sum(dim=("y", "x")).item()

            frac = affected_pix_km2 / total_pix_km2 if total_pix_km2 > 0 else 0.0
            print(country_name, ": ", frac)

            records.append({
                "Country_Name": country_name,
                "Total_Saltmarsh_Area_km2": true_saltmarsh_km2,
                "Flood_Affected_Saltmarsh_Area_km2": frac * true_saltmarsh_km2,
                "Flood_Affected_Saltmarsh_Area_Percent": 100 * frac,
                "geometry": country_geom,
            })

        except Exception as e:
            print(f"Error processing {country_name}: {e}")

    return gpd.GeoDataFrame(records, crs="EPSG:4326")


# -------------------------
# Paths
# -------------------------
FLOOD_TIF = "../Data/inuncoast_rcp8p5_wtsub_2080_rp0050_0_perc_50.tif"
SALTMARSH_SHP = "../Data/WCMC027_Saltmarsh_v6_1/01_Data/WCMC027_Saltmarshes_Py_v6_1.shp"
COUNTRY_SHP = "../Data/ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp"

# -------------------------
# Load data
# -------------------------
# Pick a threshold:
# - 0.0 means “any modeled inundation”
# - 0.1 means “>= 10 cm inundation depth”
flood_mask = prep_aqueduct_inun_mask_from_tif(
    FLOOD_TIF,
    depth_m_threshold=0.0,
)

saltmarsh = gpd.read_file(SALTMARSH_SHP)
countries = gpd.read_file(COUNTRY_SHP)

# -------------------------
# Compute overlap
# -------------------------
saltmarsh_flood_by_country = saltmarsh_slr_overlap_by_country_fast(
    slr_mask_u8=flood_mask,
    countries=countries,
    saltmarsh_gdf=saltmarsh,
)

# -------------------------
# Save
# -------------------------
saltmarsh_flood_by_country.to_file(
    "Figure_7_SLR.geojson",
    driver="GeoJSON",
)

saltmarsh_flood_by_country


## Figures 8 and 13

<p align="center">
  <img src="Figs/climate_slr_8.png" style="width:50%;">
</p>

<p align="center">
  <img src="Figs/climate_slr_10.png" style="width:50%;">
</p>

The saltmarsh data were retrieved from Mcowen C, Weatherdon LV, Bochove J, Sullivan E, Blyth S, Zockler C, Stanwell-Smith D, Kingston N, Martin CS, Spalding M, Fletcher S (2017). A global map of saltmarshes. Biodiversity Data Journal 5: e11764. Paper doi: https://doi.org/10.3897/BDJ.5.e11764; Data URL: http://data.unep-wcmc.org/datasets/43. See the Restore Ecosystems notebook for further details.

In [None]:
import numpy as np
import xarray as xr
import rioxarray
import geopandas as gpd
from tqdm import tqdm

def latlon_cell_areas_km2(x, y):
    """
    Compute grid-cell areas (km²) for a lat/lon grid.
    Assumes x = lon (degrees), y = lat (degrees).
    """
    R = 6371.0  # Earth radius in km
    lon_rad = np.deg2rad(x)
    lat_rad = np.deg2rad(y)

    dlon = np.abs(np.gradient(lon_rad))   # (x,)
    dlat = np.abs(np.gradient(lat_rad))   # (y,)

    lat_upper = lat_rad + 0.5 * dlat
    lat_lower = lat_rad - 0.5 * dlat

    row_term = np.abs(np.sin(lat_upper) - np.sin(lat_lower))
    area = (R**2) * row_term[:, None] * dlon[None, :]
    return area


def prep_aqueduct_inun_mask_from_tif(
    tif_path: str,
    depth_m_threshold: float = 0.0,
    assume_band: int = 1,
    to_crs: str = "EPSG:4326",
):
    """
    Read an Aqueduct Floods hazard GeoTIFF (inundation depth in meters),
    reproject to EPSG:4326, and return a uint8 mask where depth >= threshold.

    Notes:
    - Aqueduct hazard maps represent inundation depth (meters). :contentReference[oaicite:1]{index=1}
    """
    da = rioxarray.open_rasterio(tif_path, masked=True).squeeze()

    # If there's still a band dimension, select one
    if "band" in da.dims:
        da = da.sel(band=assume_band)

    # Ensure CRS is known
    if da.rio.crs is None:
        raise ValueError(
            "Raster has no CRS. Check the .tif metadata and set it via da.rio.write_crs(...)."
        )

    # Reproject to EPSG:4326 for consistent vector clipping & area weighting
    if str(da.rio.crs) != to_crs:
        da = da.rio.reproject(to_crs)

    # Standardize spatial dims for your pipeline
    da = da.rename({"y": "y", "x": "x"}).rio.set_spatial_dims(x_dim="x", y_dim="y")

    # Build mask (depth >= threshold)
    mask_u8 = (da >= depth_m_threshold).astype("uint8")

    return mask_u8


def mangrove_slr_overlap_by_country_fast(
    slr_mask_u8,
    country_gdf,
    mangrove_gdf,
    country_name_col="NAME",
    equal_area_crs="ESRI:54009",
):
    """
    Unchanged logic; 'slr_mask_u8' can be ANY uint8 raster mask (0/1) in EPSG:4326.
    """
    slr_mask_u8 = slr_mask_u8.rio.write_crs("EPSG:4326")

    # Cell areas (km²) for EPSG:4326 grid
    cell_area = latlon_cell_areas_km2(
        slr_mask_u8["x"].values,
        slr_mask_u8["y"].values,
    )

    cell_area_da = xr.DataArray(
        cell_area,
        dims=["y", "x"],
        coords={"y": slr_mask_u8["y"], "x": slr_mask_u8["x"]},
    ).rio.write_crs("EPSG:4326")

    # Prep vectors
    countries = country_gdf.to_crs("EPSG:4326").copy()
    countries["geometry"] = countries.geometry.buffer(0)

    mangroves = mangrove_gdf.to_crs("EPSG:4326").copy()
    mangroves["geometry"] = mangroves.geometry.buffer(0)

    # Spatial index (critical)
    sindex = mangroves.sindex

    records = []

    for _, country in tqdm(countries.iterrows(), total=len(countries), desc="Processing countries"):
        country_name = country.get(country_name_col, "unknown")
        country_geom = country.geometry

        try:
            # 1) Spatially filter mangroves/saltmarshes to this sea
            bbox_matches = list(sindex.intersection(country_geom.bounds))
            if not bbox_matches:
                continue

            mangroves_sub = mangroves.iloc[bbox_matches]
            mangroves_sub = mangroves_sub[mangroves_sub.intersects(country_geom)]
            if mangroves_sub.empty:
                continue

            # 2) Clip + dissolve
            mangrove_in_country = gpd.clip(mangroves_sub, country_geom)
            mangrove_union = mangrove_in_country.unary_union

            mangrove_country_gdf = gpd.GeoDataFrame(
                [{"geometry": mangrove_union}], crs="EPSG:4326"
            )

            # 3) True mangrove/saltmarsh area (km²)
            mangrove_eq = mangrove_country_gdf.to_crs(equal_area_crs)
            true_mangrove_km2 = mangrove_eq.geometry.area.sum() / 1e6
            if true_mangrove_km2 <= 0:
                continue

            # 4) Raster clip (mask + cell areas)
            mask_clip = slr_mask_u8.rio.clip(mangrove_country_gdf.geometry, drop=True)
            area_clip = cell_area_da.rio.clip(mangrove_country_gdf.geometry, drop=True)

            affected_pix_km2 = (mask_clip * area_clip).sum(dim=("y", "x")).item()
            total_pix_km2 = area_clip.sum(dim=("y", "x")).item()

            frac = affected_pix_km2 / total_pix_km2 if total_pix_km2 > 0 else 0.0
            print(country_name, ": ", frac)

            records.append({
                "Country_Name": country_name,
                "Total_Mangrove_Area_km2": true_mangrove_km2,
                "Flood_Affected_Mangrove_Area_km2": frac * true_mangrove_km2,
                "Flood_Affected_Mangrove_Area_Percent": 100 * frac,
                "geometry": country_geom,
            })

        except Exception as e:
            print(f"Error processing {country_name}: {e}")

    return gpd.GeoDataFrame(records, crs="EPSG:4326")


# -------------------------
# Paths
# -------------------------
FLOOD_TIF = "../Data/inuncoast_rcp8p5_wtsub_2080_rp0050_0_perc_50.tif"
MANGROVE_SHP = "../Data/gmw_v3_2020_vec/gmw_v3_2020_vec.shp"
COUNTRY_SHP = "../Data/ne_50m_admin_0_countries/ne_50m_admin_0_countries.shp"

# -------------------------
# Load data
# -------------------------
# Pick a threshold:
# - 0.0 means “any modeled inundation”
# - 0.1 means “>= 10 cm inundation depth”
flood_mask = prep_aqueduct_inun_mask_from_tif(
    FLOOD_TIF,
    depth_m_threshold=0.0,
)

mangrove = gpd.read_file(MANGROVE_SHP)
country = gpd.read_file(COUNTRY_SHP)

# -------------------------
# Compute overlap
# -------------------------
mangrove_flood_by_country = mangrove_slr_overlap_by_country_fast(
    slr_mask_u8=flood_mask,
    country_gdf=country,
    mangrove_gdf=mangrove,
)

# -------------------------
# Save
# -------------------------
mangrove_flood_by_country.to_file(
    "Figure_8_SLR.geojson",
    driver="GeoJSON",
)

mangrove_flood_by_country


# Sea Ice

## Figure 1

<p align="center">
  <img src="Figs/climate_ice_1.png" style="width:50%;">
</p>

**This figure gets sea ice extent timseries data from [NSIDC](https://noaadata.apps.nsidc.org/NOAA/G02135/) and also plots a linear trend over the period 1978 to present.**

In [None]:
import pandas as pd
import requests
from io import StringIO

# Base URLs for the NSIDC Sea Ice Index monthly data (North and South)
base_urls = {
    "north": "https://noaadata.apps.nsidc.org/NOAA/G02135/north/monthly/data/",
    "south": "https://noaadata.apps.nsidc.org/NOAA/G02135/south/monthly/data/"
}

# List of file names for North and South
file_names = {
    "north": [f"N_{month:02d}_extent_v3.0.csv" for month in range(1, 13)],
    "south": [f"S_{month:02d}_extent_v3.0.csv" for month in range(1, 13)]
}

# Function to download and load a single file
def download_and_load(base_url, file_name):
    url = base_url + file_name
    response = requests.get(url)
    if response.status_code == 200:
        data = StringIO(response.text)
        df = pd.read_csv(data)
        df['mo'] = int(file_name.split('_')[1])  # Extract month from filename
        return df
    else:
        print(f"Failed to download {file_name}")
        return None

# Download and load all files for North and South
dataframes = {}
for region in base_urls:
    dataframes[region] = [download_and_load(base_urls[region], file) for file in file_names[region]]

# Remove any None values (failed downloads) and concatenate dataframes
for region in dataframes:
    dataframes[region] = [df for df in dataframes[region] if df is not None]
    dataframes[region] = pd.concat(dataframes[region], ignore_index=True)
    dataframes[region] = dataframes[region].sort_values(['year', 'mo']).reset_index(drop=True)

# Add north and south data together for corresponding year-month pairs
combined_data = pd.merge(
    dataframes['north'], 
    dataframes['south'], 
    on=['year', 'mo'], 
    suffixes=('_north', '_south')
)

# Calculate total extent (this assumes 'extent' column exists in both north and south data)
combined_data['total_extent'] = combined_data[' extent_north'] + combined_data[' extent_south']

combined_data = combined_data.query("` extent_north` != -9999")

# Calculate the annual average for extent_north and extent_south
annual_avg = combined_data.groupby('year').mean(numeric_only=True)[[' extent_north', ' extent_south']]

# Calculate the linear trend for extent_north
slope_north, intercept_north, r_value_north, p_value_north, std_err_north = stats.linregress(
    annual_avg.index, annual_avg[' extent_north']
)

# Calculate the linear trend for extent_south
slope_south, intercept_south, r_value_south, p_value_south, std_err_south = stats.linregress(
    annual_avg.index, annual_avg[' extent_south']
)

# Add the linear trend values as new columns to the DataFrame
annual_avg['linear_trend_north'] = slope_north * annual_avg.index + intercept_north
annual_avg['linear_trend_south'] = slope_south * annual_avg.index + intercept_south

annual_avg.to_csv("../Data/Figure_1_sea_ice.csv")

# Display the first few rows of the annual averages with trends
annual_avg.head()

## Figure 2

<p align="center">
  <img src="Figs/climate_ice_2.png" style="width:50%;">
</p>

**The videos of Northern and Southern sea ice concentrations are a raw pass-through of sea ice concentration data from [NSIDC](https://noaadata.apps.nsidc.org/NOAA/G02135/).**

## Figure 3

<p align="center">
  <img src="Figs/climate_ice_3.png" style="width:50%;">
</p>

$CO_2$ data retrieved from [Lan, X., P. Tans, & K.W. Thoning (2025). Trends in globally-averaged CO₂ determined from NOAA Global Monitoring Laboratory measurements (Version 2025-11) NOAA Global Monitoring Laboratory. https://doi.org/10.15138/9N0H-ZH07](https://gml.noaa.gov/webdata/ccgg/trends/co2/co2_annmean_mlo.txt). Temperature data retrieved from GISS Surface Temperature Analysis (v4).

## Figure 4

<p align="center">
  <img src="Figs/climate_ice_4.png" style="width:50%;">
</p>


In [None]:
"""
Sea-level components (MONTHLY):
  1) NASA_SSH_GMSL_INDICATOR.txt  (GMSL, cm)
  2) ocean_mass_200204_202506.txt (GRACE/GRACE-FO ocean mass, mm; deseasoned)
  3) mean_thermosteric_sea_level_anomaly_0-2000_seasonal.nc (steric 0–2000 m; time units: 'months since 1955-01-01 00:00:00')

Outputs:
  - ../Data/Figure_4_sea_ice.png
  - ../Data/Figure_4_sea_ice.csv
"""

from pathlib import Path
from io import StringIO
import re
from datetime import datetime
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

# --------------------
# Config
# --------------------
GMSL_TXT   = "../Data/NASA_SSH_GMSL_INDICATOR.txt"
MASS_TXT   = "../Data/ocean_mass_200204_202506.txt"
STERIC_NC  = "../Data/mean_thermosteric_sea_level_anomaly_0-2000_seasonal.nc"

OUT_PNG = "../Data/Figure_4_sea_ice.png"
OUT_CSV = "../Data/Figure_4_sea_ice.csv"
SMOOTH_WINDOW_MONTHS = 6

# --------------------
# Helpers
# --------------------
def decimal_year_to_timestamp(y: float) -> pd.Timestamp:
    year = int(np.floor(y)); rem = float(y) - year
    start = datetime(year, 1, 1); end = datetime(year + 1, 1, 1)
    return pd.Timestamp(start + (end - start) * rem)

def read_decimal_table(path: str, col_names, usecols=None) -> pd.DataFrame:
    raw = Path(path).read_text()
    data_lines = "\n".join(ln for ln in raw.splitlines() if re.match(r"^\s*[0-9]{4}\.", ln))
    df = pd.read_csv(StringIO(data_lines), sep=r"\s+", header=None, engine="python")
    df.columns = col_names[: df.shape[1]]
    if usecols is not None: df = df[usecols]
    t = pd.to_datetime(df.iloc[:,0].apply(decimal_year_to_timestamp))
    return df.set_index(t).drop(columns=df.columns[0]).sort_index()

def to_mm(s: pd.Series, units_hint: str | None) -> pd.Series:
    if not units_hint: return s
    u = units_hint.lower()
    if u in {"cm","centimeter","centimeters"}: return s * 10.0
    if u in {"m","meter","meters"}:           return s * 1000.0
    return s

def months_since_to_datetime(months: np.ndarray, ref_str: str) -> pd.DatetimeIndex:
    ref = pd.Timestamp(ref_str)
    out = []
    for m in months.astype(float):
        mi = int(np.floor(m)); frac = float(m) - mi
        base = ref + pd.DateOffset(months=mi)
        month_start = pd.Timestamp(base.year, base.month, 1)
        month_end = month_start + pd.DateOffset(months=1)
        days = (month_end - month_start).days
        out.append(month_start + pd.Timedelta(days=frac * days))
    return pd.DatetimeIndex(out)

def best_month_shift(steric_m: pd.Series, anchor_idx: pd.DatetimeIndex, candidates=(-1,0,1)) -> int:
    best_k, best_n = 0, -1
    for k in candidates:
        n = len(anchor_idx.intersection(steric_m.shift(k, freq="MS").index))
        if n > best_n or (n == best_n and abs(k) < abs(best_k)):
            best_k, best_n = k, n
    return best_k

# --------------------
# 1) GMSL (NASA SSH) — cm → mm → MONTHLY
# --------------------
gmsl = read_decimal_table(GMSL_TXT, ["time_dec","gmsl_cm","gmsl60_cm"], usecols=["time_dec","gmsl_cm"])
gmsl_m = to_mm(gmsl["gmsl_cm"], "cm").resample("MS").mean().interpolate("time")

# --------------------
# 2) Ocean mass (GRACE/GRACE-FO) — deseasoned mm → MONTHLY
# --------------------
mass = read_decimal_table(MASS_TXT,
                          ["time_dec","ocean_mass_mm","sigma_mm","ocean_mass_deseasoned_mm"],
                          usecols=["time_dec","ocean_mass_deseasoned_mm"])
mass_m = mass["ocean_mass_deseasoned_mm"].resample("MS").mean().interpolate("time")

# --------------------
# 3) Steric (0–2000 m) — months since 1955-01-01 → MONTHLY (robust)
# --------------------
ds = xr.open_dataset(STERIC_NC, decode_times=False)
if "seas_a_mm_WO" not in ds: raise KeyError("Missing 'seas_a_mm_WO' in steric NetCDF.")

time_units = ds["time"].attrs.get("units","")
m = re.match(r"months since\s+([0-9]{4}-[0-9]{2}-[0-9]{2}(?:\s+[0-9:]+)?)", time_units)
if not m: raise ValueError(f"Unexpected steric time units: {time_units!r}")
times = months_since_to_datetime(ds["time"].values, m.group(1))

steric = pd.Series(ds["seas_a_mm_WO"].values.astype("float64"), index=times)
steric = steric.replace([-2147483648, -2.14748365e9, 2.14748365e9], np.nan).dropna()
steric = to_mm(steric, ds["seas_a_mm_WO"].attrs.get("units","mm")).sort_index()

# Map to month-start stamps FIRST (so anchors survive), collapse duplicates
steric_ms = steric.copy()
steric_ms.index = steric_ms.index.to_period("M").to_timestamp(how="start")
steric_ms = steric_ms.groupby(level=0).mean()

# Now build full monthly index and interpolate between anchors
start = steric_ms.index.min()
end   = steric_ms.index.max()
monthly_index = pd.date_range(start, end, freq="MS")
steric_m = steric_ms.reindex(monthly_index).interpolate(method="time")

# Auto-fix 1-month anchor offset if present
anchor_idx = gmsl_m.dropna().index.intersection(mass_m.dropna().index)
k = best_month_shift(steric_m, anchor_idx, candidates=(-1,0,1))
if k != 0:
    steric_m = steric_m.shift(k, freq="MS")
    print(f"[info] steric month shift applied: {k:+d}")

# --------------------
# Align ocean mass vertically at first common month
# --------------------
df_raw = pd.concat({"GMSL_mm": gmsl_m, "OceanMass_mm": mass_m, "Steric_mm": steric_m}, axis=1)
common = df_raw.dropna(subset=["GMSL_mm","OceanMass_mm","Steric_mm"])
if common.empty: raise ValueError("No common month across GMSL, OceanMass, Steric after alignment.")

t0 = common.index.min()
delta = (common.loc[t0,"GMSL_mm"] - common.loc[t0,"Steric_mm"]) - common.loc[t0,"OceanMass_mm"]
mass_m_aligned = mass_m + delta
print(f"[info] ocean mass vertical delta: {delta:.3f} mm at {t0.date()}")

# --------------------
# Final MONTHLY dataframe + smoothing
# --------------------
df_m = pd.concat({"GMSL_mm": gmsl_m, "OceanMass_mm": mass_m_aligned, "Steric_mm": steric_m}, axis=1)
df_m["MassPlusSteric_mm"] = df_m["OceanMass_mm"] + df_m["Steric_mm"]
df_sm = df_m.rolling(window=SMOOTH_WINDOW_MONTHS, min_periods=1).mean()

# --------------------
# Plot + save
# --------------------
fig, ax = plt.subplots(figsize=(10,6))
df_sm["GMSL_mm"].plot(ax=ax, linewidth=2, label="Global mean sea level (NASA)")
df_sm["OceanMass_mm"].plot(ax=ax, linewidth=2, label="Global ocean mass (JPL, aligned)")
df_sm["Steric_mm"].plot(ax=ax, linewidth=2, label="Global steric (SIO, 0–2000 m)")

ax.set_ylabel("Sea level (mm)")
ax.set_title("Global sea level rise & components (monthly; mass aligned)")
ax.grid(True, alpha=0.3); ax.legend(loc="upper left")
plt.tight_layout(); plt.savefig(OUT_PNG, dpi=200)

df_sm.to_csv(OUT_CSV, index_label="date")
print(f"Saved figure → {OUT_PNG}")
print(f"Saved CSV    → {OUT_CSV}")

# --------------------
# Save JSON for Figure 4
# --------------------
import json

JSON_PATH = "../Data/Figure_4_sea_ice.json"

# Use the *unsmoothed* monthly series; require all three present
json_df = (
    df_m[["Steric_mm", "OceanMass_mm", "GMSL_mm"]]
    .dropna(subset=["Steric_mm", "OceanMass_mm", "GMSL_mm"])
    .reset_index()
    .rename(columns={"index": "date",
                     "Steric_mm": "steric",
                     "OceanMass_mm": "mass",
                     "GMSL_mm": "total"})
)

# Format date as YYYY-MM-DD and round values a bit (optional)
json_records = [
    {
        "date": d.strftime("%Y-%m-%d"),
        "steric": float(round(s, 3)),
        "mass": float(round(m, 3)),
        "total": float(round(t, 3)),
    }
    for d, s, m, t in zip(
        json_df["date"], json_df["steric"], json_df["mass"], json_df["total"]
    )
]

with open(JSON_PATH, "w") as f:
    json.dump(json_records, f, indent=2)

print(f"Saved JSON    → {JSON_PATH}")



## Figure 5

<p align="center">
  <img src="Figs/climate_ice_5.png" style="width:50%;">
</p>

**Solar radiation data were retrieved from [NASA CERES EBAF Edition 4.2](https://asdc.larc.nasa.gov/project/CERES/CERES_EBAF_Edition4.2)**

In [None]:
import json
import numpy as np
import pandas as pd
import xarray as xr
from datetime import datetime

a = xr.open_mfdataset("../Data/CERES_EBAF_Edition4.2_200003-202407.nc")

# 1) Difference
diff = a["solar_mon"] - a["toa_sw_all_mon"]

# 2) Area weights (cosine latitude)
lat_name = "lat" if "lat" in a.coords else "latitude"
lon_name = "lon" if "lon" in a.coords else "longitude"

weights = np.cos(np.deg2rad(a[lat_name]))
# Broadcast to full grid
weights_2d = weights.broadcast_like(diff)

# 3) Weighted mean over lat/lon at each *time step* (monthly or whatever your data are)
num = (diff * weights_2d).sum(dim=[lat_name, lon_name], skipna=True)
den = weights_2d.where(np.isfinite(diff)).sum(dim=[lat_name, lon_name])
awm_ts = num / den  # (time,) DataArray

# 4) Annual averages (works with regular or cftime calendars)
annual = awm_ts.groupby("time.year").mean("time")  # (year,) DataArray

# 5) Convert to JSON; write each year at Jan 1 for a clean ISO date (or keep as int)
records = []
for y, val in zip(annual["year"].values.tolist(), annual.values.tolist()):
    # Use ISO date "YYYY-01-01T00:00:00" for nicer time axis downstream
    records.append({
        "date": datetime(int(y), 1, 1).isoformat(),
        "value": float(val) if np.isfinite(val) else None
    })

out_path = "../Data/Figure_5_sea_ice.json"
with open(out_path, "w") as f:
    json.dump(records, f, indent=2)

print(f"Saved {out_path} with {len(records)} annual rows.")


## Figure 6

<p align="center">
  <img src="Figs/climate_ice_6.png" style="width:50%;">
</p>

In [None]:
# ============================
# IUCN: Species in Sea-Ice-Related Habitats (latest status per species)
# ============================
import os
import time
import json
import requests
import pandas as pd
import requests
import matplotlib.pyplot as plt
import numpy as np
import pycountry

from dotenv import load_dotenv 
from tqdm import tqdm
from sklearn.linear_model import LinearRegression

# Load environment variables
load_dotenv()
TOKEN = os.getenv("IUCN_API_KEY")

# ---------------------------
# CONFIG
# ---------------------------
API_BASE = "https://api.iucnredlist.org/api/v4"

HEADERS = {
    "accept": "application/json",
    "Authorization": TOKEN
}

# Output paths
OUT_DIR = "../Data"
os.makedirs(OUT_DIR, exist_ok=True)
OUT_HABITATS = os.path.join(OUT_DIR, "iucn_habitats_all.csv")
OUT_ASSESS_RAW = os.path.join(OUT_DIR, "iucn_sea_ice_assessments_raw.csv")
OUT_SPECIES_CSV = os.path.join(OUT_DIR, "sea_ice_species_latest.csv")
OUT_SPECIES_JSON = os.path.join(OUT_DIR, "sea_ice_species_latest.json")

# ---- Sea-ice habitat whitelist ----
# High-precision core:
CORE_SEA_ICE_CODES = {"10_1", "9_1"}  # Epipelagic + Neritic Pelagic
# Optional haul-out / colony associates near ice edges:
COASTAL_ADDONS = {"13_1", "12_1", "12_2"}  # Coastal cliffs/islands, rocky & sandy shorelines

# Choose which to use:
SEA_ICE_HABITAT_CODES = CORE_SEA_ICE_CODES  # or CORE_SEA_ICE_CODES | COASTAL_ADDONS

# Optional: keep only records with locations in polar regions
USE_POLAR_LOCATION_FILTER = True

# Arctic ISO2 (+ territories frequently present in IUCN locations)
ARCTIC_ISO2 = {
    "CA",  # Canada
    "GL",  # Greenland
    "IS",  # Iceland
    "NO",  # Norway
    "RU",  # Russia
    "US",  # United States (Alaska)
    "SJ",  # Svalbard & Jan Mayen
    "FO",  # Faroe Islands (edge case; include to be safe)
    "DK",  # Denmark (parent for GL; some records use DK)
}
# Antarctic / sub-Antarctic ISO2 (+ external territories used by IUCN)
ANTARCTIC_ISO2 = {
    "AQ",  # Antarctica
    "GS",  # South Georgia & South Sandwich Islands (UK)
    "TF",  # French Southern Territories
    "HM",  # Heard Island & McDonald Islands (AU)
    "BV",  # Bouvet Island (NO)
    "FK",  # Falkland Islands
    # Some subantarctic islands may appear under parent countries:
    "NZ",  # Campbell, Auckland, Antipodes, etc. (NZ)
    "AU",  # Macquarie (AU)
    "ZA",  # Prince Edward & Marion (ZA)
    "AR",  # South Orkney/South Shetland in some datasets under AR/CL
    "CL",  # Chilean Antarctic Territory
    "GB",  # UK parent for GS sometimes
}
POLAR_ISO2 = ARCTIC_ISO2 | ANTARCTIC_ISO2

# ---------------------------
# HELPERS
# ---------------------------
def _retry_get(url, headers, retries=6, backoff=0.6, timeout=60):
    for i in range(retries):
        r = requests.get(url, headers=headers, timeout=timeout)
        if r.status_code == 200:
            return r
        if r.status_code == 429:
            wait = backoff * (2 ** i)
            print(f"[429] Rate-limited. Sleeping {wait:.1f}s…")
            time.sleep(wait)
            continue
        if 500 <= r.status_code < 600:
            wait = backoff * (2 ** i)
            print(f"[{r.status_code}] Server error. Retry in {wait:.1f}s…")
            time.sleep(wait)
            continue
        raise RuntimeError(f"GET {url} failed: {r.status_code} - {r.text[:300]}")
    raise RuntimeError(f"GET {url} exhausted retries.")

def list_all_habitats():
    """Fetch all habitat definitions (paged) and cache to CSV."""
    page, per_page = 1, 200
    rows = []
    while True:
        url = f"{API_BASE}/habitats/?page={page}&per_page={per_page}"
        r = _retry_get(url, HEADERS)
        data = r.json()
        items = data.get("habitats", data.get("data", [])) or []
        if not items:
            break
        rows.extend(items)
        total_pages = int(r.headers.get("total-pages", page))
        if page >= total_pages:
            break
        page += 1
    df = pd.DataFrame(rows)
    if not df.empty:
        df.to_csv(OUT_HABITATS, index=False)
    return df

def pick_whitelisted_habitats(hdf: pd.DataFrame) -> pd.DataFrame:
    """Return rows for the whitelisted habitat codes with their descriptions (if available)."""
    if hdf is None or hdf.empty:
        # Fallback: synthesize minimal table
        return pd.DataFrame({"code": sorted(SEA_ICE_HABITAT_CODES), "description_en": [None]*len(SEA_ICE_HABITAT_CODES)})
    h = hdf.copy()
    h["code"] = h["code"].astype(str)
    # Normalize description field if nested
    if "description" in h.columns and isinstance(h["description"].iloc[0], dict):
        h["description_en"] = h["description"].apply(lambda d: (d or {}).get("en"))
    elif "name" in h.columns:
        h["description_en"] = h["name"]
    else:
        h["description_en"] = None
    return h[h["code"].isin(SEA_ICE_HABITAT_CODES)][["code", "description_en"]].drop_duplicates().sort_values("code")

def get_assessments_for_habitat(habitat_code: str):
    """
    /api/v4/habitats/{id_or_code}
    IUCN accepts codes like '10_1' and '9_1' here.
    """
    page, per_page = 1, 100
    out = []
    while True:
        url = f"{API_BASE}/habitats/{habitat_code}?page={page}&per_page={per_page}"
        r = _retry_get(url, HEADERS)
        data = r.json()
        out.extend(data.get("assessments", []))
        total_pages = int(r.headers.get("total-pages", 1))
        if page >= total_pages:
            break
        page += 1
    return out

def get_assessment_detail(assessment_id: int):
    url = f"{API_BASE}/assessment/{assessment_id}"
    r = _retry_get(url, HEADERS)
    return r.json()

def extract_locations(det_json):
    """Return sorted set of ISO2 location codes from detail JSON."""
    locs = det_json.get("locations", []) or []
    codes = sorted({loc.get("code") for loc in locs if loc.get("code")})
    return codes

def extract_threat_codes(det_json):
    threats = det_json.get("threats", []) or []
    return sorted({t.get("code") for t in threats if t.get("code")})

# ---------------------------
# MAIN
# ---------------------------
def main():
    if TOKEN == "REPLACE_ME_WITH_YOUR_TOKEN" or not TOKEN:
        raise SystemExit("Please set IUCN_API_TOKEN env var or edit TOKEN.")

    # 1) Habitats listing (for descriptions); safe to proceed if this fails
    try:
        habitats_df = list_all_habitats()
    except Exception as e:
        print(f"Warning: could not list habitats ({e}). Proceeding without descriptions.")
        habitats_df = pd.DataFrame()

    whitelist_df = pick_whitelisted_habitats(habitats_df)
    if whitelist_df.empty:
        print("No habitat descriptions found for whitelist; proceeding with codes only.")
        whitelist_df = pd.DataFrame({"code": sorted(SEA_ICE_HABITAT_CODES), "description_en": [None]*len(SEA_ICE_HABITAT_CODES)})

    print("\nUsing habitat codes:")
    print(whitelist_df.to_string(index=False))

    # 2) Pull assessments for selected habitats
    all_assess_raw = []
    for _, row in tqdm(whitelist_df.iterrows(), total=len(whitelist_df), desc="Habitats"):
        code = row["code"]
        try:
            assessments = get_assessments_for_habitat(code)
            for a in assessments:
                all_assess_raw.append({
                    "habitat_code": code,
                    "habitat_desc_en": row.get("description_en"),
                    "assessment_id": a.get("assessment_id"),
                    "sis_taxon_id": a.get("sis_taxon_id"),
                    "year_published": a.get("year_published"),
                    "latest": a.get("latest"),
                })
        except Exception as e:
            print(f"Failed habitat {code}: {e}")

    raw_df = pd.DataFrame(all_assess_raw).dropna(subset=["assessment_id"])
    if raw_df.empty:
        raise RuntimeError("No assessments returned for selected habitats.")
    raw_df["assessment_id"] = raw_df["assessment_id"].astype(int)
    raw_df = raw_df.drop_duplicates()
    raw_df.to_csv(OUT_ASSESS_RAW, index=False)

    # 3) Keep latest assessment per species
    latest_idx = (raw_df
        .sort_values(["sis_taxon_id", "year_published", "latest", "assessment_id"],
                     ascending=[True, False, False, False])
        .groupby("sis_taxon_id", as_index=False)
        .head(1)
        .index
    )
    latest_df = raw_df.loc[latest_idx].reset_index(drop=True)

    # 4) Enrich with detail (status, names, trend, locations, threats)
    enriched = []
    for _, r in tqdm(latest_df.iterrows(), total=len(latest_df), desc="Details"):
        aid = int(r["assessment_id"])
        try:
            det = get_assessment_detail(aid)
        except Exception as e:
            print(f"Detail failed {aid}: {e}")
            continue

        red = det.get("red_list_category", {}) or {}
        tax = det.get("taxon", {}) or {}
        common_names = tax.get("common_names", []) or []
        eng_common = next((n["name"] for n in common_names if n.get("language") == "eng"), None)
        locations = extract_locations(det)
        threat_codes = extract_threat_codes(det)

        pop_trend_val = det.get("population_trend")
        if isinstance(pop_trend_val, dict):
            pop_trend_val = pop_trend_val.get("code")

        enriched.append({
            "assessment_id": aid,
            "sis_taxon_id": det.get("sis_taxon_id"),
            "scientific_name": tax.get("scientific_name"),
            "english_common_name": eng_common,
            "class_name": tax.get("class_name"),
            "order_name": tax.get("order_name"),
            "family_name": tax.get("family_name"),
            "year_published": det.get("year_published"),
            "status_code": red.get("code"),
            "status_text": red.get("text"),
            "population_trend": pop_trend_val,
            "locations_iso2": ",".join(locations) if locations else None,
            "threat_codes": ",".join(threat_codes) if threat_codes else None,
            "habitat_code": r.get("habitat_code"),
            "habitat_desc_en": r.get("habitat_desc_en"),
            "detail_url": det.get("url"),
            "is_latest": bool(r.get("latest")),
        })

    species_df = pd.DataFrame(enriched).drop_duplicates(subset=["sis_taxon_id"]).reset_index(drop=True)

    if species_df.empty:
        raise RuntimeError("No species details assembled. Check token/permissions or reduce filters.")

    # 5) Optional polar location filter
    if USE_POLAR_LOCATION_FILTER:
        def _has_polar_loc(s):
            if not isinstance(s, str) or not s:
                return False
            codes = {c.strip() for c in s.split(",") if c.strip()}
            return bool(codes & POLAR_ISO2)
        species_df = species_df[species_df["locations_iso2"].apply(_has_polar_loc)].reset_index(drop=True)

    # 6) Save outputs
    species_df = species_df.sort_values(
        ["status_code", "class_name", "scientific_name"], na_position="last"
    )
    species_df.to_csv(OUT_SPECIES_CSV, index=False)

    minimal_records = species_df[[
        "sis_taxon_id", "scientific_name", "english_common_name",
        "class_name", "order_name", "family_name",
        "status_code", "status_text", "population_trend",
        "year_published", "habitat_code", "habitat_desc_en",
        "locations_iso2", "threat_codes", "detail_url"
    ]].to_dict(orient="records")
    with open(OUT_SPECIES_JSON, "w") as f:
        json.dump(minimal_records, f, ensure_ascii=False, indent=2)

    print(f"\nSaved: {OUT_SPECIES_CSV}\nSaved: {OUT_SPECIES_JSON}")

if __name__ == "__main__":
    main()


In [None]:
import pandas as pd

# Load CSV
df = pd.read_csv("../Data/sea_ice_species_latest.csv")

# Keep only latest records
latest = df.query("is_latest").copy()

# Mapping for IUCN threat codes
threats_mapping = {
    'DD': 'Data Deficient',
    'LC': 'Least Concern',
    'NT': 'Near Threatened',
    'VU': 'Vulnerable',
    'EN': 'Endangered',
    'CR': 'Critically Endangered',
    'EW': 'Extinct in the Wild',
    'EX': 'Extinct',
    'NE': 'Not Evaluated',
}

# Map threat codes to labels
latest['Threat'] = latest['status_code'].replace(threats_mapping)

# Map numeric trend codes to text labels
latest.loc[latest["population_trend"] == 0, 'Trend'] = "Increasing"
latest.loc[latest["population_trend"] == 1, 'Trend'] = "Decreasing"
latest.loc[latest["population_trend"] == 2, 'Trend'] = "Stable"
latest.loc[latest["population_trend"] == 3, 'Trend'] = "Unknown"

# Deduplicate by sis_taxon_id to avoid repeats
latest = latest.drop_duplicates(subset="sis_taxon_id", keep="first")

# Keep only relevant columns
out = latest[["english_common_name", "Threat", "Trend"]].copy()

# Save to JSON
out.to_json("../Data/Figure_6_sea_ice.json", orient="records", force_ascii=False, indent=2)

print(f"✅ Saved {len(out)} unique species to Figure_6_sea_ice.json")


In [None]:
latest.query("status_code == 'EN'")

## Figure 7

<p align="center">
  <img src="Figs/climate_ice_7.png" style="width:50%;">
</p>

In [None]:
import numpy as np
import xarray as xr
import rioxarray

trends_df = xr.open_dataset("../Data/global_omi_tempsal_sst_trend_19932021_P20220331.nc")

# --------------------
# CONFIG
# --------------------
OUTPUT_TIF_8BIT   = "../Data/Figure_7_sea_ice.tif"
OUTPUT_TIF_FLOAT  = "../Data/sst_trend_float32.tif"
USE_PERCENTILES   = True
P_LOW, P_HIGH     = 2, 98
VMIN_FIXED, VMAX_FIXED = -0.5, 0.5  # if USE_PERCENTILES=False

# 1) Select the dataarray
da = trends_df["sst_trends"]
if "time" in da.dims:
    da = da.mean(dim="time")

# 2) Rename to x/y if needed
rename_map = {}
if "lat" in da.dims or "lat" in da.coords: rename_map["lat"] = "y"
if "latitude" in da.dims or "latitude" in da.coords: rename_map["latitude"] = "y"
if "lon" in da.dims or "lon" in da.coords: rename_map["lon"] = "x"
if "longitude" in da.dims or "longitude" in da.coords: rename_map["longitude"] = "x"
if rename_map:
    da = da.rename(rename_map)

# 3) Ensure y is north→south (descending)
if da["y"].values[0] < da["y"].values[-1]:
    da = da.sortby("y", ascending=False)

# 4) Register spatial dims & CRS
da = da.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)
if da.rio.crs is None:
    da = da.rio.write_crs("EPSG:4326", inplace=False)

# 5) Prepare data and scaling
data = da.data.astype(np.float32)
valid = np.isfinite(data)

if USE_PERCENTILES:
    vmin = float(np.nanpercentile(data, P_LOW))
    vmax = float(np.nanpercentile(data, P_HIGH))
else:
    vmin, vmax = float(VMIN_FIXED), float(VMAX_FIXED)

if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
    raise ValueError(f"Bad scaling range: vmin={vmin}, vmax={vmax}")

# 6) Build the 8-bit layer (reserve 0 for NoData)
scaled = np.zeros_like(data, dtype=np.uint8)  # 0 = NoData
scaled_valid = (np.clip((data[valid] - vmin) / (vmax - vmin), 0.0, 1.0) * 254 + 1).astype(np.uint8)
scaled[valid] = scaled_valid

da8 = da.copy(data=scaled)

# --- CRITICAL: clear CF encodings that carry a massive _FillValue ---
da8.encoding.pop("_FillValue", None)
da8.encoding.pop("missing_value", None)
# (optional) also clear scale/offset if present
da8.encoding.pop("scale_factor", None)
da8.encoding.pop("add_offset", None)

# Set nodata appropriate for uint8 (0)
da8 = da8.rio.write_nodata(0, encoded=False, inplace=False)

# 7a) Write 8-bit GeoTIFF
da8.rio.to_raster(
    OUTPUT_TIF_8BIT,
    dtype="uint8",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

# 7b) Optional: write float32 with native values
daf = da.where(np.isfinite(da))
# Clear encodings here too
daf.encoding.pop("_FillValue", None)
daf.encoding.pop("missing_value", None)
daf.encoding.pop("scale_factor", None)
daf.encoding.pop("add_offset", None)

# Use NaN as nodata for float32 (supported by rasterio/GTiff)
daf = daf.rio.write_nodata(np.nan, encoded=False, inplace=False)

daf.rio.to_raster(
    OUTPUT_TIF_FLOAT,
    dtype="float32",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

print(
    f"Wrote {OUTPUT_TIF_8BIT} (uint8; 0=NoData, 1–255=data) and {OUTPUT_TIF_FLOAT} (float32). "
    f"Scale for 8-bit: vmin={vmin:.4f}, vmax={vmax:.4f} (units of sst_trends)."
)


## Figure 8

<p align="center">
  <img src="Figs/climate_ice_8.png" style="width:50%;">
</p>

In [None]:
# Copernicus Climate Change Service, Climate Data Store, (2018): Sea level daily gridded data from satellite observations for the global ocean from 1993 to present. Copernicus Climate Change Service (C3S) Climate Data Store (CDS)
SLR_df = xr.open_mfdataset("../Data/dataset-satellite-sea-level-global-dc7f92ea-2d3e-4fc6-b767-836a5b8c0bff/*.nc")

# Adjust longitudes from 0-360 to -180 to 180
SLR_df = SLR_df.assign_coords(longitude=(((SLR_df.longitude + 180) % 360) - 180)).sortby('longitude')

trend_significance_ds = calculate_trend_df(SLR_df['sla'].load())

# --------------------
# CONFIG
# --------------------
OUTPUT_TIF_8BIT   = "./Figure_8_sea_ice.tif"
OUTPUT_TIF_FLOAT  = "./SLR_trend_float32.tif"
USE_PERCENTILES   = True
P_LOW, P_HIGH     = 2, 98
VMIN_FIXED, VMAX_FIXED = -0.5, 0.5  # if USE_PERCENTILES=False

# 2) Rename to x/y if needed
rename_map = {}
if "lat" in da.dims or "lat" in da.coords: rename_map["lat"] = "y"
if "latitude" in da.dims or "latitude" in da.coords: rename_map["latitude"] = "y"
if "lon" in da.dims or "lon" in da.coords: rename_map["lon"] = "x"
if "longitude" in da.dims or "longitude" in da.coords: rename_map["longitude"] = "x"
if rename_map:
    da = da.rename(rename_map)

# 3) Ensure y is north→south (descending)
if da["y"].values[0] < da["y"].values[-1]:
    da = da.sortby("y", ascending=False)

# 4) Register spatial dims & CRS
da = da.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)
if da.rio.crs is None:
    da = da.rio.write_crs("EPSG:4326", inplace=False)

# 5) Prepare data and scaling
data = da.data.astype(np.float32)
valid = np.isfinite(data)

if USE_PERCENTILES:
    vmin = float(np.nanpercentile(data, P_LOW))
    vmax = float(np.nanpercentile(data, P_HIGH))
else:
    vmin, vmax = float(VMIN_FIXED), float(VMAX_FIXED)

if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
    raise ValueError(f"Bad scaling range: vmin={vmin}, vmax={vmax}")

# 6) Build the 8-bit layer (reserve 0 for NoData)
scaled = np.zeros_like(data, dtype=np.uint8)  # 0 = NoData
scaled_valid = (np.clip((data[valid] - vmin) / (vmax - vmin), 0.0, 1.0) * 254 + 1).astype(np.uint8)
scaled[valid] = scaled_valid

da8 = da.copy(data=scaled)

# --- CRITICAL: clear CF encodings that carry a massive _FillValue ---
da8.encoding.pop("_FillValue", None)
da8.encoding.pop("missing_value", None)
# (optional) also clear scale/offset if present
da8.encoding.pop("scale_factor", None)
da8.encoding.pop("add_offset", None)

# Set nodata appropriate for uint8 (0)
da8 = da8.rio.write_nodata(0, encoded=False, inplace=False)

# 7a) Write 8-bit GeoTIFF
da8.rio.to_raster(
    OUTPUT_TIF_8BIT,
    dtype="uint8",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

# 7b) Optional: write float32 with native values
daf = da.where(np.isfinite(da))
# Clear encodings here too
daf.encoding.pop("_FillValue", None)
daf.encoding.pop("missing_value", None)
daf.encoding.pop("scale_factor", None)
daf.encoding.pop("add_offset", None)

# Use NaN as nodata for float32 (supported by rasterio/GTiff)
daf = daf.rio.write_nodata(np.nan, encoded=False, inplace=False)

daf.rio.to_raster(
    OUTPUT_TIF_FLOAT,
    dtype="float32",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

print(
    f"Wrote {OUTPUT_TIF_8BIT} (uint8; 0=NoData, 1–255=data) and {OUTPUT_TIF_FLOAT} (float32). "
    f"Scale for 8-bit: vmin={vmin:.4f}, vmax={vmax:.4f} (units of sst_trends)."
)

## Figure 9

<p align="center">
  <img src="Figs/climate_ice_9.png" style="width:50%;">
</p>

In [None]:
import numpy as np
import xarray as xr
from scipy.stats import linregress

a = xr.open_mfdataset("../Data/CERES_EBAF_Edition4.2_200003-202407.nc")

diff = a["solar_mon"] - a["toa_sw_all_mon"]
diff_ann = diff.resample(time="YS").mean("time")

# Make time a single chunk (safe & cheap)
diff_ann = diff_ann.chunk({"time": -1})

years = (diff_ann.time.dt.year if hasattr(diff_ann.time, "dt")
         else xr.DataArray([t.year for t in diff_ann.time.values], coords={"time": diff_ann.time}, dims="time")).astype(float)
x = xr.DataArray(years, coords={"time": diff_ann.time}, dims="time").chunk({"time": -1})

def linregress_1d(y, x):
    m = np.isfinite(y) & np.isfinite(x)
    if m.sum() < 2:
        return np.nan, np.nan, np.nan, np.nan, np.nan
    r = linregress(x[m], y[m])
    return r.slope, r.intercept, r.rvalue, r.pvalue, r.stderr

slope, intercept, r, p, stderr = xr.apply_ufunc(
    linregress_1d, diff_ann, x,
    input_core_dims=[["time"], ["time"]],
    output_core_dims=[[], [], [], [], []],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[float, float, float, float, float],
    # If you can’t rechunk time (above), enable this instead:
    # dask_gufunc_kwargs={"allow_rechunk": True}
)

trend = xr.Dataset(
    {
        "slope_per_year": slope,
        "slope_per_decade": slope * 10.0,
        "intercept": intercept,
        "r": r, "p": p, "stderr": stderr,
        "n_obs": diff_ann.count("time"),
        "y_mean": diff_ann.mean("time"),
    }
)


import numpy as np
import xarray as xr
import rioxarray

# --------------------
# CONFIG
# --------------------
OUTPUT_TIF_8BIT   = "../Data/Figure_9_sea_ice.tif"
OUTPUT_TIF_FLOAT  = "../Data/slope_per_decade_float32.tif"

USE_PERCENTILES   = True
P_LOW, P_HIGH     = 2, 98
VMIN_FIXED, VMAX_FIXED = -0.2, 0.2  # adjust to your units/range if not using percentiles

# If your data are dask-backed & huge, consider using xarray's quantile (DASK_SAFE_PERCENTILES=True)
DASK_SAFE_PERCENTILES = True

# 1) Select the dataarray
da = trend["slope_per_decade"]  # (lat, lon) or (latitude, longitude)
if "time" in da.dims:
    da = da.mean(dim="time")

# 2) Rename to x/y if needed
rename_map = {}
if "lat" in da.dims or "lat" in da.coords: rename_map["lat"] = "y"
if "latitude" in da.dims or "latitude" in da.coords: rename_map["latitude"] = "y"
if "lon" in da.dims or "lon" in da.coords: rename_map["lon"] = "x"
if "longitude" in da.dims or "longitude" in da.coords: rename_map["longitude"] = "x"
if rename_map:
    da = da.rename(rename_map)

# 2b) Wrap longitudes from [0, 360) to [-180, 180) and sort west→east
if "x" in da.coords:
    x_wrapped = ((da["x"] + 180) % 360) - 180  # 0..360 -> -180..180
    da = da.assign_coords(x=x_wrapped).sortby("x")


# 3) Ensure y is north→south (descending)
if da["y"].values[0] < da["y"].values[-1]:
    da = da.sortby("y", ascending=False)

# 4) Register spatial dims & CRS
da = da.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)
if da.rio.crs is None:
    da = da.rio.write_crs("EPSG:4326", inplace=False)

# 5) Determine scaling range
if USE_PERCENTILES:
    if DASK_SAFE_PERCENTILES:
        vmin = float(da.quantile(P_LOW / 100.0, skipna=True).compute())
        vmax = float(da.quantile(P_HIGH / 100.0, skipna=True).compute())
    else:
        # Pull into memory (fast/simple, but not great for very large rasters)
        data_tmp = da.data
        if hasattr(data_tmp, "compute"):  # dask array
            data_tmp = data_tmp.compute()
        data_tmp = data_tmp.astype(np.float32)
        vmin = float(np.nanpercentile(data_tmp, P_LOW))
        vmax = float(np.nanpercentile(data_tmp, P_HIGH))
else:
    vmin, vmax = float(VMIN_FIXED), float(VMAX_FIXED)

if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
    raise ValueError(f"Bad scaling range: vmin={vmin}, vmax={vmax}")

# 6) Build the 8-bit layer (0 = NoData, 1..255 = scaled data)
data = da.data
if hasattr(data, "compute"):  # dask -> numpy
    data = data.compute()
data = data.astype(np.float32)

valid = np.isfinite(data)
scaled = np.zeros_like(data, dtype=np.uint8)  # 0 = NoData
scaled_valid = (np.clip((data[valid] - vmin) / (vmax - vmin), 0.0, 1.0) * 254 + 1).astype(np.uint8)
scaled[valid] = scaled_valid

da8 = da.copy(data=scaled)

# Clear encodings that can cause huge _FillValue / scaling metadata
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    da8.encoding.pop(k, None)

# Set nodata for uint8
da8 = da8.rio.write_nodata(0, encoded=False, inplace=False)

# 7a) Write 8-bit GeoTIFF (Mapbox-friendly)
da8.rio.to_raster(
    OUTPUT_TIF_8BIT,
    dtype="uint8",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

# 7b) Optional: write float32 with native values
daf = da.where(np.isfinite(da))
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    daf.encoding.pop(k, None)

# Use NaN as nodata for float32
daf = daf.rio.write_nodata(np.nan, encoded=False, inplace=False)

daf.rio.to_raster(
    OUTPUT_TIF_FLOAT,
    dtype="float32",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

print(
    f"Wrote {OUTPUT_TIF_8BIT} (uint8; 0=NoData, 1–255=data) and {OUTPUT_TIF_FLOAT} (float32). "
    f"8-bit scale: vmin={vmin:.6g}, vmax={vmax:.6g} (units of slope_per_decade)."
)


## Figure 10

<p align="center">
  <img src="Figs/climate_ice_10.png" style="width:50%;">
</p>

Data were retrieved from [IUCN SSC PBSG](https://www.iucn-pbsg.org/population-status/)

In [None]:
df = gpd.read_file("../Data/Figure_3_acidity.geojson")

(df['Biodiversity_Area'].sum() / df['Sea_Area'].sum())