# Mitigate Climate Change

This notebook outlines the general workflow for the data within the [Mitigate Climate Change](https://oceancentral.org/track/mitigate-climate-change) page of the Ocean Central website.

## Utils functions and globals used for making all figures

In [None]:
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
import cartopy.crs as ccrs
from shapely.geometry import box
import rioxarray
import re

from rasterio.features import geometry_mask
from scipy.stats import linregress
from tqdm import tqdm

# Open the biodiversity priority areas based on Zhao et al. 2020 (https://www.sciencedirect.com/science/article/abs/pii/S0006320719312182?via%3Dihub)
masked_data = rioxarray.open_rasterio('../Data/masked_top_30_percent_over_water.tif')

# Set the CRS for masked_data if it's not already set
if 'crs' not in masked_data.attrs:
    masked_data.rio.write_crs('EPSG:4326', inplace=True)

# Load SST dataset and EEZ shapefile
seas_shapefile_path = '../Data/World_Seas_IHO_v3/World_Seas_IHO_v3.shp'
SEAS_DF = gpd.read_file(seas_shapefile_path)

# Calculate linear trend and p-value for each grid point
def calculate_trend_and_significance(x):
    if np.isnan(x).all():
        return np.nan, np.nan, np.nan
    else:
        slope, intercept, _, p_value, _ = stats.linregress(range(len(x)), x)
        return slope, intercept, p_value

# Calculate the trend and significance of the trend at each pixel in an xarray dataset
def calculate_trend_df(climate_df):
    df_mean = climate_df.groupby('time.year').mean()
    
    # Apply the trend and p-value calculation to the entire dataset
    results = xr.apply_ufunc(
        calculate_trend_and_significance,
        df_mean,
        input_core_dims=[['year']],
        vectorize=True,
        output_core_dims=[[], [], []],
        output_dtypes=[float, float, float]
    )
    
    # Extract the trend and p-value into separate DataArrays
    trends_da = xr.DataArray(results[0], coords=df_mean.isel(year=0).coords, name='trend')
    pvalues_da = xr.DataArray(results[2], coords=df_mean.isel(year=0).coords, name='p_value')
    
    # Create a significance mask where p-value < 0.05
    significant_da = xr.DataArray((pvalues_da < 0.05), coords=pvalues_da.coords, name='significant')
    
    # Combine trend, p-value, and significance mask into a single dataset
    trend_significance_ds = xr.Dataset({
        'trend': trends_da,
        'p_value': pvalues_da,
        'significant': significant_da
    })
    
    # Set the CRS for the trends dataset to match the EEZ CRS
    trend_significance_ds = trend_significance_ds.rio.write_crs("epsg:4326")
    return trend_significance_ds

# Calculate area-weighted trend, significance for each sea/ocean area
def area_trend(trend_significance_ds, SEAS_DF=SEAS_DF):
    # Iterate over each sea/ocean area and calculate the area-weighted trend and significant area percentage
    area_weighted_trends = []
    
    # Check if 'lat' and 'lon' are in the dataset, otherwise check for 'latitude' and 'longitude'
    if 'lat' in trend_significance_ds.dims and 'lon' in trend_significance_ds.dims:
        trend_significance_ds = trend_significance_ds.rename({'lat': 'y', 'lon': 'x'})
    elif 'latitude' in trend_significance_ds.dims and 'longitude' in trend_significance_ds.dims:
        trend_significance_ds = trend_significance_ds.rename({'latitude': 'y', 'longitude': 'x'})

    # Interpolate biodiversity priority areas to the same resolution as the climate data
    masked_data_interp = masked_data.interp(
        x=trend_significance_ds['x'],
        y=trend_significance_ds['y'],
        method='nearest'
    )

    # Calculate the area for each grid cell (assumes lat/lon grid)
    lat = trend_significance_ds['y'].values
    lon = trend_significance_ds['x'].values
    
    # Calculate grid cell area using Haversine formula or by approximation
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    
    # Earth radius in kilometers
    R = 6371
    dlat = np.gradient(lat_rad)
    dlon = np.gradient(lon_rad)
    
    # Approximate area calculation
    cell_areas = (R**2 * np.outer(np.sin(dlat), dlon)) * np.cos(lat_rad[:, None])
    
    for i, row in SEAS_DF.iterrows():
        try:
            region_name = row['NAME']
            area = row['area']
            geom = row['geometry']
    
            # Mask SST trends with the sea geometry
            masked_trends = trend_significance_ds['trend'].rio.clip([geom], drop=True)
            masked_significance = trend_significance_ds['significant'].rio.clip([geom], drop=True)
    
            # Clip cell_areas to the same extent as masked_trends
            cell_areas_clipped = xr.DataArray(
                cell_areas, 
                dims=['y', 'x'], 
                coords={'y': trend_significance_ds['y'], 'x': trend_significance_ds['x']}
            )
            
            # Set CRS for cell_areas_clipped to match the CRS of trend_significance_ds
            cell_areas_clipped = cell_areas_clipped.rio.write_crs('EPSG:4326')
    
            # Clip cell_areas to the same geometry
            cell_areas_clipped = cell_areas_clipped.rio.clip([geom], drop=True)
        
            # Compute the area-weighted trend
            weighted_trend = (masked_trends * cell_areas_clipped).sum(dim=('y', 'x')) / cell_areas_clipped.sum()
    
            # Compute the total area that is significant
            significant_masked_areas = (masked_significance * cell_areas_clipped).where(masked_significance, 0)
            total_significant_area = significant_masked_areas.sum(dim=('y', 'x')).item()
    
            # Calculate the percentage of the area that is significant
            total_area = cell_areas_clipped.sum()
            significant_area_percent = (total_significant_area / total_area) * 100
    
            # Calculate the area for biodiversity based on the mask
            area_biodiversity = ((masked_significance * cell_areas_clipped) * masked_data_interp).sum(dim=['x', 'y']).values
    
            # Store the result
            area_weighted_trends.append({
                'Region_Name': region_name,
                'geometry': geom,
                'Weighted_Trend': weighted_trend.item(),
                'Sea_Area': area,
                'Significant_Area': area*total_significant_area/total_area.item(),
                'Significant_Area_Percent': 100*total_significant_area/total_area.item(),
                'Biodiversity_Area': area*area_biodiversity[0]/total_area.item(),
                'Biodiversity_Area_Percent': 100*area_biodiversity[0]/total_area.item(),
            })
        except Exception as e:
            print(e)

    # Convert the results to a GeoDataFrame for easy viewing
    area_weighted_trends_gdf = gpd.GeoDataFrame(area_weighted_trends, crs=SEAS_DF.crs)
    return area_weighted_trends_gdf

def area_heatwave(temp_df, SEAS_DF=SEAS_DF):
    area_heatwave = []

    # Set CRS and rename dimensions and coordinates
    temp_df = temp_df.rio.write_crs("epsg:4326")
    temp_df = temp_df.rename({'latdim': 'y', 'londim': 'x'}).rename({'lat': 'y', 'lon': 'x'})  # Adjust based on your dimensions

    # Select heatwave categories >= 3 and aggregate over time
    temp_df = (temp_df['heatwave_category'] >= 3).any(dim='time')

    # Interpolate biodiversity priority areas to the same resolution as the climate data
    masked_data_interp = masked_data.interp(
        x=temp_df['x'],
        y=temp_df['y'],
        method='nearest'
    )

    # Calculate the area for each grid cell (assumes lat/lon grid)
    lat = temp_df['y'].values
    lon = temp_df['x'].values
    
    # Calculate grid cell area using Haversine formula or by approximation
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    
    # Earth radius in kilometers
    R = 6371
    dlat = np.gradient(lat_rad)
    dlon = np.gradient(lon_rad)
    
    # Approximate area calculation
    cell_areas = (R**2 * np.outer(np.sin(dlat), dlon)) * np.cos(lat_rad[:, None])
    
    # Use tqdm to track progress through SEAS_DF.iterrows()
    for i, row in tqdm(SEAS_DF.iterrows(), total=len(SEAS_DF), desc="Processing Sea Areas"):
        try:
            region_name = row['NAME']
            area = row['area']
            geom = row['geometry']
    
            # Mask SST trends with the sea geometry
            masked_df = temp_df.rio.clip([geom], drop=True)
    
            # Clip cell_areas to the same extent as masked_df
            cell_areas_clipped = xr.DataArray(
                cell_areas, 
                dims=['y', 'x'], 
                coords={'y': temp_df['y'], 'x': temp_df['x']}
            )
            
            # Set CRS for cell_areas_clipped to match the CRS of trend_significance_ds
            cell_areas_clipped = cell_areas_clipped.rio.write_crs('EPSG:4326')
    
            # Clip cell_areas to the same geometry
            cell_areas_clipped = cell_areas_clipped.rio.clip([geom], drop=True)
        
            # Compute the total area that is impacted by a severe heatwave
            heatwave_area = (masked_df * cell_areas_clipped).sum(dim=('y', 'x')).compute()  # Compute to convert from Dask array
    
            # Calculate the area for biodiversity based on the mask
            area_biodiversity = ((masked_df * cell_areas_clipped) * masked_data_interp).sum(dim=['x', 'y']).compute()

            total_area = cell_areas_clipped.sum(dim=('y', 'x')).compute()  # Ensure computation
    
            # Extract values after computing
            heatwave_value = heatwave_area.item() if heatwave_area.size == 1 else heatwave_area.values[0]
            total_area_value = total_area.item() if total_area.size == 1 else total_area.values[0]
            area_biodiversity = area_biodiversity.item() if area_biodiversity.size == 1 else area_biodiversity.values[0]
    
            # Store the result
            area_heatwave.append({
                'Region_Name': region_name,
                'geometry': geom,
                'Heatwave_Area': area*heatwave_value/total_area_value,
                'Heatwave_Area_Percent': 100*heatwave_value/total_area_value,
                'Sea_Area': area,
                'Biodiversity_Area': area*area_biodiversity/total_area_value,
                'Biodiversity_Area_Percent': 100*area_biodiversity/total_area_value,
            })
        except Exception as e:
            print(f"Error processing {region_name}: {e}")

    # Convert the results to a GeoDataFrame for easy viewing
    area_heatwave_gdf = gpd.GeoDataFrame(area_heatwave, crs=SEAS_DF.crs)
    return area_heatwave_gdf
    

# Temperature

## Figure 1

<p align="center">
  <img src="Figs/climate_temperature_1.png" style="width:50%;">
</p>

In [None]:
ocean_data = pd.read_csv("../Data/GISTEMP_SST.csv") # Data downloaded from GISS Surface Temperature Analysis (v4)
gmst_data = pd.read_csv("../Data/GMST_GISTEMP4.csv") # Data downloaded from GISS Surface Temperature Analysis (v4)

ocean_data['Ocean_Annual'] = ocean_data['Lowess(5)']
gmst_data['GMST_Annual'] = gmst_data['Lowess(5)']

temp_data = ocean_data.merge(gmst_data,on='Year')

# Calculate the mean of the 'No_Smoothing' column for the period 1880-1900
base_period = temp_data[(temp_data['Year'] >= 1880) & (temp_data['Year'] <= 1900)]
mean_sst_base_period = base_period['Ocean_Annual'].mean()
mean_gmst_base_period = base_period['GMST_Annual'].mean()

# Update the 'No_Smoothing' column to be anomalies relative to the period 1880-1900
temp_data['Ocean_Annual'] = temp_data['Ocean_Annual'] - mean_sst_base_period
temp_data['GMST_Annual'] = temp_data['GMST_Annual'] - mean_gmst_base_period

# Perform linear regression to find the slope and intercept
slope, intercept, _, _, _ = linregress(temp_data['Year'], temp_data['Ocean_Annual'])

# Calculate the trend line (y = mx + b) for each time point
temp_data['ocean_trend'] = intercept + slope * temp_data['Year']

# Perform linear regression to find the slope and intercept
slope, intercept, _, _, _ = linregress(temp_data['Year'], temp_data['GMST_Annual'])

# Calculate the trend line (y = mx + b) for each time point
temp_data['gmst_trend'] = intercept + slope * temp_data['Year']
temp_data['paris_goal'] = 1.5

# Save out as JSON
temp_data[['Year','Ocean_Annual','ocean_trend','GMST_Annual','gmst_trend','paris_goal']].to_json(
    "../Data/Figure_1_temperature.json", 
    orient="records",  # list of dicts (one per row)
    indent=2           # pretty print
)


# Display the updated DataFrame
temp_data[['Year','Ocean_Annual','ocean_trend','GMST_Annual','gmst_trend','paris_goal']]


## Figure 2

<p align="center">
  <img src="Figs/climate_temperature_2.png" style="width:50%;">
</p>

In [None]:
import numpy as np
import xarray as xr
import rioxarray

trends_df = xr.open_dataset("../Data/global_omi_tempsal_sst_trend_19932021_P20220331.nc")

# --------------------
# CONFIG
# --------------------
OUTPUT_TIF_8BIT   = "../Data/Figure_2_temperature.tif"
OUTPUT_TIF_FLOAT  = "../Data/sst_trend_float32.tif"
USE_PERCENTILES   = True
P_LOW, P_HIGH     = 2, 98
VMIN_FIXED, VMAX_FIXED = -0.5, 0.5  # if USE_PERCENTILES=False

# 1) Select the dataarray
da = trends_df["sst_trends"]
if "time" in da.dims:
    da = da.mean(dim="time")

# 2) Rename to x/y if needed
rename_map = {}
if "lat" in da.dims or "lat" in da.coords: rename_map["lat"] = "y"
if "latitude" in da.dims or "latitude" in da.coords: rename_map["latitude"] = "y"
if "lon" in da.dims or "lon" in da.coords: rename_map["lon"] = "x"
if "longitude" in da.dims or "longitude" in da.coords: rename_map["longitude"] = "x"
if rename_map:
    da = da.rename(rename_map)

# 3) Ensure y is north→south (descending)
if da["y"].values[0] < da["y"].values[-1]:
    da = da.sortby("y", ascending=False)

# 4) Register spatial dims & CRS
da = da.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)
if da.rio.crs is None:
    da = da.rio.write_crs("EPSG:4326", inplace=False)

# 5) Prepare data and scaling
data = da.data.astype(np.float32)
valid = np.isfinite(data)

if USE_PERCENTILES:
    vmin = float(np.nanpercentile(data, P_LOW))
    vmax = float(np.nanpercentile(data, P_HIGH))
else:
    vmin, vmax = float(VMIN_FIXED), float(VMAX_FIXED)

if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
    raise ValueError(f"Bad scaling range: vmin={vmin}, vmax={vmax}")

# 6) Build the 8-bit layer (reserve 0 for NoData)
scaled = np.zeros_like(data, dtype=np.uint8)  # 0 = NoData
scaled_valid = (np.clip((data[valid] - vmin) / (vmax - vmin), 0.0, 1.0) * 254 + 1).astype(np.uint8)
scaled[valid] = scaled_valid

da8 = da.copy(data=scaled)

# --- CRITICAL: clear CF encodings that carry a massive _FillValue ---
da8.encoding.pop("_FillValue", None)
da8.encoding.pop("missing_value", None)
# (optional) also clear scale/offset if present
da8.encoding.pop("scale_factor", None)
da8.encoding.pop("add_offset", None)

# Set nodata appropriate for uint8 (0)
da8 = da8.rio.write_nodata(0, encoded=False, inplace=False)

# 7a) Write 8-bit GeoTIFF
da8.rio.to_raster(
    OUTPUT_TIF_8BIT,
    dtype="uint8",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

# 7b) Optional: write float32 with native values
daf = da.where(np.isfinite(da))
# Clear encodings here too
daf.encoding.pop("_FillValue", None)
daf.encoding.pop("missing_value", None)
daf.encoding.pop("scale_factor", None)
daf.encoding.pop("add_offset", None)

# Use NaN as nodata for float32 (supported by rasterio/GTiff)
daf = daf.rio.write_nodata(np.nan, encoded=False, inplace=False)

daf.rio.to_raster(
    OUTPUT_TIF_FLOAT,
    dtype="float32",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

print(
    f"Wrote {OUTPUT_TIF_8BIT} (uint8; 0=NoData, 1–255=data) and {OUTPUT_TIF_FLOAT} (float32). "
    f"Scale for 8-bit: vmin={vmin:.4f}, vmax={vmax:.4f} (units of sst_trends)."
)


## Figure 3

<p align="center">
  <img src="Figs/climate_temperature_3.png" style="width:50%;">
</p>

**This figure calculates the total area for each major sea region globally. It also calculates 1) the area for each sea region impacted by a severe marine heatwave in the year 2023 as well as 2) the area for each sea that both has priority biodiversity areas AND is impacted by a severe marine heatwave.**

In [None]:
# Year 2024 Heatwave data downloaded from NOAA's Coral Reef Watch https://coralreefwatch.noaa.gov/product/marine_heatwave/
import xarray as xr
import glob

files = sorted(glob.glob("../Data/2023/*.nc"))

temp_df = xr.open_mfdataset("../Data/2023/*.nc")

area_df = area_heatwave(temp_df)

# Save the GeoDataFrame to a GeoJSON file
area_df.to_file("../Data/Figure_3_temperature.geojson",driver="GeoJSON")

## Figure 4

<p align="center">
  <img src="Figs/climate_temperature_4.png" style="width:50%;">
</p>

$CO_2$ data retrieved from [Lan, X., P. Tans, & K.W. Thoning (2025). Trends in globally-averaged CO₂ determined from NOAA Global Monitoring Laboratory measurements (Version 2025-11) NOAA Global Monitoring Laboratory. https://doi.org/10.15138/9N0H-ZH07](https://gml.noaa.gov/webdata/ccgg/trends/co2/co2_annmean_mlo.txt). Temperature data retrieved from GISS Surface Temperature Analysis (v4).

## Figure 5

<p align="center">
  <img src="Figs/climate_temperature_5.png" style="width:50%;">
</p>

Data were retrieved from [Hughes, T. P., et al. (2018). Spatial and temporal patterns of mass bleaching of corals in the Anthropocene. Science. – processed by Our World in Data](https://ourworldindata.org/grapher/coral-bleaching-events).

## Figure 6

<p align="center">
  <img src="Figs/climate_temperature_6.png" style="width:50%;">
</p>

In [None]:
#!/usr/bin/env python3
import xarray as xr
import rioxarray  # pip install rioxarray rasterio

# --- Load and compute heatwave days ---
ds = xr.open_mfdataset("../Data/2023/*.nc", combine="by_coords")

# If you already have temp_df, you can instead do:
# hw_count = temp_df

hw_count = (ds["heatwave_category"] >= 1).sum(dim="time")
hw_count = hw_count.astype("float32")

# --- Rechunk along spatial dims for quantile ---
# Identify spatial dimensions (everything except 'time')
spatial_dims = [d for d in hw_count.dims if d != "time"]

# Make each spatial dimension a single chunk
hw_q = hw_count.chunk({d: -1 for d in spatial_dims})

# --- Compute 2nd and 98th percentiles over space ---
q = hw_q.quantile([0.02, 0.98], dim=spatial_dims, skipna=True).compute()

p2 = float(q.sel(quantile=0.02).values)
p98 = float(q.sel(quantile=0.98).values)

print(f"2nd percentile: {p2}, 98th percentile: {p98}")

if p98 <= p2:
    raise ValueError("98th percentile is not greater than 2nd percentile; cannot scale safely.")

# --- Scale to 0–255 using the 2nd and 98th percentiles ---
scaled = (hw_count - p2) / (p98 - p2) * 255.0
scaled = scaled.clip(0, 255)
scaled = scaled.fillna(0)

# Round and cast to uint8
scaled_uint8 = scaled.round().astype("uint8")

# --- Set spatial dims and CRS for Mapbox ---
# Try to guess your x/y dimension names
cands_x = ["lon", "longitude", "x", "londim"]
cands_y = ["lat", "latitude", "y", "latdim"]

x_dim = next(d for d in cands_x if d in scaled_uint8.dims)
y_dim = next(d for d in cands_y if d in scaled_uint8.dims)

scaled_uint8 = scaled_uint8.rio.set_spatial_dims(x_dim=x_dim, y_dim=y_dim, inplace=False)

# Ensure WGS84
if not scaled_uint8.rio.crs:
    scaled_uint8 = scaled_uint8.rio.write_crs("EPSG:4326")

# Optional: use 0 as nodata for transparency in Mapbox
scaled_uint8 = scaled_uint8.rio.write_nodata(0)

# --- Save as GeoTIFF ---
out_path = "Figure_6_temperature.tif"
scaled_uint8.rio.to_raster(out_path, dtype="uint8")

print(f"Saved GeoTIFF: {out_path}")


## Figure 7

<p align="center">
  <img src="Figs/climate_temperature_7.png" style="width:50%;">
</p>

Data were retrieved from [Jones et al. (2024) – with major processing by Our World in Data](https://ourworldindata.org/co2-and-greenhouse-gas-emissions#explore-data-on-co2-and-greenhouse-gas-emissions).

# Salinity

In [None]:
df = gpd.read_file("../Data/Figure_3_temperature.geojson")

(df['Biodiversity_Area'].sum() / df['Sea_Area'].sum())

## Figure 1

<p align="center">
  <img src="Figs/climate_salinity_1.png" style="width:50%;">
</p>

In [None]:
import xarray as xr
import pandas as pd

salt_df = xr.open_dataset("~/Downloads/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")
salt_df = salt_df['salinity'].mean(dim=['lat','lon']).resample(time='Y').mean()

final_subset = salt_df.sel(time=slice('1994-01-01', None))

# Create a pandas DataFrame with these columns
df = pd.DataFrame({
    'time': final_subset['time'].values,
    'salinity': final_subset.values,
})

# Convert 'time' to datetime
df['time'] = pd.to_datetime(df['time'])

# Convert datetime to a numerical value for linear regression (using ordinal format)
df['time_ordinal'] = df['time'].map(pd.Timestamp.toordinal)

# Perform linear regression to find the slope and intercept
slope, intercept, _, _, _ = linregress(df['time_ordinal'], df['salinity'])

# Calculate the trend line (y = mx + b) for each time point
df['linear_trend'] = intercept + slope * df['time_ordinal']

df[['time','salinity','linear_trend']].to_csv("../Data/Figure_1_salinity.csv")

# Display the updated DataFrame
df[['time','salinity','linear_trend']].head()

## Figure 2

<p align="center">
  <img src="Figs/climate_salinity_2.png" style="width:50%;">
</p>

In [None]:
import numpy as np
import xarray as xr
import rioxarray

salt_df = xr.open_dataset("~/Downloads/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")

trend_significance_ds = calculate_trend_df(salt_df['salinity'])

# --------------------
# CONFIG
# --------------------
OUTPUT_TIF_8BIT   = "./Figure_2_salinity.tif"
OUTPUT_TIF_FLOAT  = "./salinity_trend_float32.tif"
USE_PERCENTILES   = True
P_LOW, P_HIGH     = 2, 98
VMIN_FIXED, VMAX_FIXED = -0.5, 0.5  # if USE_PERCENTILES=False

# 1) Select the dataarray: salinity trend
da = trend_significance_ds['trend']

# 2) Rename to x/y if needed
rename_map = {}
if "lat" in da.dims or "lat" in da.coords: rename_map["lat"] = "y"
if "latitude" in da.dims or "latitude" in da.coords: rename_map["latitude"] = "y"
if "lon" in da.dims or "lon" in da.coords: rename_map["lon"] = "x"
if "longitude" in da.dims or "longitude" in da.coords: rename_map["longitude"] = "x"
if rename_map:
    da = da.rename(rename_map)

# 3) Ensure y is north→south (descending)
if da["y"].values[0] < da["y"].values[-1]:
    da = da.sortby("y", ascending=False)
# ensure (y, x) order
if tuple(da.dims) != ("y", "x"):
    da = da.transpose("y", "x")

# 4) Register spatial dims & force EPSG:4326
da = da.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=False)

# If CRS is missing, assume geographic lon/lat; then force-reproject to EPSG:4326
if da.rio.crs is None:
    da = da.rio.write_crs("EPSG:4326", inplace=False)

# If CRS is not 4326, reproject it so the saved rasters are truly EPSG:4326
if str(da.rio.crs) != "EPSG:4326":
    # Use nearest for categorical-like; bilinear is typical for continuous SLA
    da = da.rio.reproject("EPSG:4326", resampling=rioxarray.rio.reproject.Resampling.bilinear)

# Reassert dims order after any reprojection (just in case)
if tuple(da.dims) != ("y", "x"):
    da = da.transpose("y", "x")


# 5) Compute scaling range (Dask-safe)
import dask.array as dsa

def compute_percentiles_safe(da, p_low, p_high):
    """
    Try dask.array.nanpercentile; if that fails (older dask/xarray),
    fall back to a coarse sample to keep memory in check.
    """
    if getattr(da, "chunks", None):
        try:
            vmin = float(dsa.nanpercentile(da.data, p_low).compute())
            vmax = float(dsa.nanpercentile(da.data, p_high).compute())
            return vmin, vmax
        except Exception:
            pass  # fall back to sampled approach

    # Fallback: sample every Nth pixel (keeps memory tiny)
    step_y = max(int(len(da.y) // 512), 1) if "y" in da.dims else 4
    step_x = max(int(len(da.x) // 512), 1) if "x" in da.dims else 4
    das = da.isel(
        y=slice(0, None, step_y) if "y" in da.dims else slice(None),
        x=slice(0, None, step_x) if "x" in da.dims else slice(None),
    ).load()  # small enough to load
    vmin = float(np.nanpercentile(das.values, p_low))
    vmax = float(np.nanpercentile(das.values, p_high))
    return vmin, vmax

if USE_PERCENTILES:
    vmin, vmax = compute_percentiles_safe(da, P_LOW, P_HIGH)
else:
    vmin, vmax = float(VMIN_FIXED), float(VMAX_FIXED)

if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
    raise ValueError(f"Bad scaling range: vmin={vmin}, vmax={vmax}")


# 6) Build the 8-bit layer with xr.where (no boolean indexing)
# Reserve 0 for NoData; valid cells map to [1,255]
norm = ((da - vmin) / (vmax - vmin)).clip(0, 1)
scaled_da = xr.where(np.isfinite(da), norm * 254.0 + 1.0, 0.0).astype("uint8")

# Clear troublesome encodings on the uint8 view
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    scaled_da.encoding.pop(k, None)

# Set nodata appropriate for uint8 (0)
scaled_da = scaled_da.rio.write_nodata(0, encoded=False, inplace=False)

# 7a) Write 8-bit GeoTIFF (works with Dask; will stream-chunk if array is chunked)
scaled_da.rio.to_raster(
    OUTPUT_TIF_8BIT,
    dtype="uint8",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

# 7b) Write float32 GeoTIFF with native values
daf = da.where(np.isfinite(da)).astype("float32")
for k in ("_FillValue", "missing_value", "scale_factor", "add_offset"):
    daf.encoding.pop(k, None)
# Use NaN as nodata for float32
daf = daf.rio.write_nodata(np.nan, encoded=False, inplace=False)

daf.rio.to_raster(
    OUTPUT_TIF_FLOAT,
    dtype="float32",
    compress="LZW",
    tiled=True,
    blockxsize=256,
    blockysize=256,
)

print(
    f"Wrote {OUTPUT_TIF_8BIT} (uint8; 0=NoData, 1–255=data) and {OUTPUT_TIF_FLOAT} (float32). "
    f"8-bit scale: vmin={vmin:.4f}, vmax={vmax:.4f} (native SLA units)."
)


## Figure 3

<p align="center">
  <img src="Figs/climate_salinity_3.png" style="width:50%;">
</p>

In [None]:
salt_df = xr.open_dataset("~/Downloads/OceanSODA_ETHZ-v2023.OCADS.01_1982-2022.nc")

trend_significance_ds = calculate_trend_df(salt_df['salinity'])

area_df = area_trend(trend_significance_ds)

# Save the GeoDataFrame to a GeoJSON file
area_df.to_file("../Data/Figure_3_salinity.geojson",driver="GeoJSON")

del salt_df, area_df

## Figure 4

<p align="center">
  <img src="Figs/climate_salinity_4.png" style="width:50%;">
</p>

**This figure reproduces Figure 3A from [Delworth et al. (2022)](https://www.pnas.org/doi/10.1073/pnas.2116655119). It plots an index of the Atlantic Meridional Overturning Circulation (AMOC) and how it changes over time. There are 6 scenarios --- 2 of which have historical data ("HISTORICAL" (representing a historical evolution of anthropogenic emissions) and "NATURAL" (no changes to anthropogenic emissions)), and 5 have projections ("NATURAL" again, as well as different Shared Socioeconomic Pathway (SSP) scenarios for how emissions will change in the future). Each scenario is run several times, which provides the measures of uncertainty presented in the figure.**

In [None]:
# Let's load the file with more flexibility in handling its format to try and correctly parse it.
with open('../Data/TAR_FIGURE_3_AMOC_45N', 'r') as file:
    lines = file.readlines()

# Adjust the column naming logic to include the full scenario name as requested
scenario_data = {}
current_scenario = None
time_range = range(1921, 2100)

for line in lines:
    # Check if line indicates a new scenario/ensemble using the SPEAR_c192_o1 pattern
    scenario_match = re.search(r'SPEAR_c192_o1_(.+)_ENS_(\d+)', line)
    if scenario_match:
        # Preserve the full scenario name (like HIST_SSP585_ALLForc) and ensemble number
        scenario_name = scenario_match.group(1) + "_" + scenario_match.group(2)
        current_scenario = scenario_name
        scenario_data[current_scenario] = []
    else:
        # If the line contains numerical data, extract and add it to the current scenario
        match = re.findall(r"[-+]?\d*\.\d+|\d+", line.strip())
        if match and current_scenario:
            scenario_data[current_scenario].extend([float(value) for value in match])

# Creating a DataFrame where the first column is the time and each subsequent column is a scenario/ensemble
df = pd.DataFrame({'Year': list(time_range)})

for scenario, values in scenario_data.items():
    df[scenario] = values[:len(time_range)]  # Ensuring the values match the time range

# Define scenario prefixes to filter the columns
scenarios = ['SSP119', 'SSP245', 'SSP534', 'SSP585', 'NATURAL']

# Initialize a result DataFrame
result = pd.DataFrame()
result['Year'] = df['Year']

# Loop through each scenario to calculate the mean and confidence intervals
for scenario in scenarios:
    # Identify relevant columns for the current scenario
    scenario_columns = [col for col in df.columns if scenario in col]
    
    if scenario_columns:  # Ensure there are columns for this scenario
        # Calculate mean and confidence intervals
        result[f'{scenario}_Mean'] = df[scenario_columns].mean(axis=1)
        result[f'{scenario}_CI_Lower'] = df[scenario_columns].quantile(0.05, axis=1)
        result[f'{scenario}_CI_Upper'] = df[scenario_columns].quantile(0.95, axis=1)

# Create historical scenario columns by aggregating the SSP scenarios up to the year 2014
historical_columns = ['SSP119', 'SSP245', 'SSP534', 'SSP585']

# Create a mask for years up to 2014
mask_historical = result['Year'] <= 2014

# Calculate mean and confidence intervals for historical scenarios
result['Historical_Mean'] = result.loc[mask_historical, [f'{scenario}_Mean' for scenario in historical_columns]].mean(axis=1)
result['Historical_CI_Lower'] = result.loc[mask_historical, [f'{scenario}_CI_Lower' for scenario in historical_columns]].mean(axis=1)
result['Historical_CI_Upper'] = result.loc[mask_historical, [f'{scenario}_CI_Upper' for scenario in historical_columns]].mean(axis=1)

# Set SSP columns to NaN for years up to 2014
for scenario in historical_columns:
    result.loc[mask_historical, f'{scenario}_Mean'] = np.nan
    result.loc[mask_historical, f'{scenario}_CI_Lower'] = np.nan
    result.loc[mask_historical, f'{scenario}_CI_Upper'] = np.nan

# Set Historical columns to NaN for years after 2014
result.loc[~mask_historical, 'Historical_Mean'] = np.nan
result.loc[~mask_historical, 'Historical_CI_Lower'] = np.nan
result.loc[~mask_historical, 'Historical_CI_Upper'] = np.nan

result.to_csv("../Data/Figure_4_salinity.csv")

result.head()


## Figure 5

<p align="center">
  <img src="Figs/climate_salinity_5.png" style="width:50%;">
</p>

Monthly data were retrieved from CMEMS [Mercator Ocean International / Copernicus Marine Service (2023). Global Ocean Physics Reanalysis (GLOBAL_MULTIYEAR_PHY_001_030) [Data set]. Copernicus Marine Service. https://doi.org/10.48670/moi-00021](https://data.marine.copernicus.eu/product/GLOBAL_MULTIYEAR_PHY_001_030/description).

In [None]:
#!/usr/bin/env python3
import os, json, datetime as dt
import numpy as np
import xarray as xr
import rioxarray
from affine import Affine

# ---------------------------
# CONFIG
# ---------------------------
IN_PATH = "../Data/cmems_mod_glo_phy_my_0.083deg_P1M-m_1760469819927.nc"
OUT_DIR = "../Data/Figure_5_salinity"
VAR_NAME_HINTS = ["so", "salinity"]
WRITE_FLOAT32 = False   # set True to write Float32 GeoTIFFs instead of 8-bit
LOW_Q, HIGH_Q = 0.02, 0.98   # robust percentiles for 8-bit scaling
COMPRESS = "DEFLATE"    # GeoTIFF compression
NODATA_UINT8 = 0        # 0 reserved as nodata in 8-bit output
OPEN_WITH_CHUNKS = True # try to enable dask-friendly quantiles

os.makedirs(OUT_DIR, exist_ok=True)

def pick_var(ds: xr.Dataset) -> xr.DataArray:
    for name in VAR_NAME_HINTS:
        if name in ds.data_vars:
            return ds[name]
    return next(iter(ds.data_vars.values()))

def ensure_xy(ds_like: xr.DataArray) -> xr.DataArray:
    da = ds_like
    rename_map = {}
    if "lon" in da.dims: rename_map["lon"] = "x"
    if "longitude" in da.dims: rename_map["longitude"] = "x"
    if "lat" in da.dims: rename_map["lat"] = "y"
    if "latitude" in da.dims: rename_map["latitude"] = "y"
    if rename_map:
        da = da.rename(rename_map)
    if "x" not in da.dims or "y" not in da.dims:
        raise ValueError(f"Expected 'x' and 'y' dims, found {list(da.dims)}")
    if np.all(np.diff(da["y"].values) > 0):
        da = da.sortby("y", ascending=False)
    da = da.rio.write_crs("EPSG:4326", inplace=False)
    xv, yv = da["x"].values, da["y"].values
    dx = float(np.mean(np.diff(xv)))
    dy = float(np.mean(np.diff(yv)))
    x0 = xv.min() - dx / 2.0
    y0 = yv.max() + dy / 2.0
    transform = Affine(dx, 0.0, x0, 0.0, -dy, y0)
    da = da.rio.write_transform(transform, inplace=False)
    return da

def compute_global_bounds(var: xr.DataArray, low_q=LOW_Q, high_q=HIGH_Q):
    """
    Compute global robust quantiles across all time/y/x and return (vmin, vmax).
    Tries xarray.quantile (dask-friendly), falls back to numpy if needed.
    """
    try:
        q = var.quantile([low_q, high_q], dim=[d for d in var.dims if d in ("time","y","x","lat","lon","latitude","longitude")], skipna=True)
        vmin = float(q.sel(quantile=low_q).values)
        vmax = float(q.sel(quantile=high_q).values)
    except Exception:
        arr = var.values.astype(np.float32)  # may be large!
        vmin, vmax = np.nanquantile(arr, [low_q, high_q])
        vmin, vmax = float(vmin), float(vmax)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
        # very flat / missing data; fallback to global min/max
        try:
            vmin = float(var.min(skipna=True).values)
            vmax = float(var.max(skipna=True).values)
        except Exception:
            arr = var.values.astype(np.float32)
            vmin, vmax = float(np.nanmin(arr)), float(np.nanmax(arr))
    return vmin, vmax

def scale_to_uint8_fixed(da: xr.DataArray, vmin: float, vmax: float, nodata_val=NODATA_UINT8):
    """
    Scale with fixed vmin/vmax to 8-bit, reserving 0 as nodata.
    """
    arr = da.values.astype(np.float32)
    mask = ~np.isfinite(arr)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
        # avoid divide-by-zero
        vmin = np.nanmin(arr)
        vmax = np.nanmax(arr)
        if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
            scaled = np.full(arr.shape, nodata_val, dtype=np.uint8)
            return scaled
    scaled = (arr - vmin) / (vmax - vmin + 1e-12)
    scaled = np.clip(scaled, 0.0, 1.0)
    scaled = (scaled * 254.0 + 1.0).astype(np.uint8)
    scaled[mask] = nodata_val
    return scaled

def main():
    open_kwargs = {"chunks": "auto"} if OPEN_WITH_CHUNKS else {}
    ds = xr.open_dataset(IN_PATH, **open_kwargs)
    var = pick_var(ds)
    if "time" not in var.dims:
        raise ValueError("Expected a 'time' dimension in salinity DataArray.")

    # --- Compute one pair of robust bounds across ALL timesteps ---
    global_vmin, global_vmax = compute_global_bounds(var, LOW_Q, HIGH_Q)
    print(f"Global scaling bounds ({LOW_Q:.2f}-{HIGH_Q:.2f} quantiles): vmin={global_vmin:.4f}, vmax={global_vmax:.4f}")

    scaling_meta = {
        "variable": str(var.name),
        "quantiles": [LOW_Q, HIGH_Q],
        "global_vmin": float(global_vmin),
        "global_vmax": float(global_vmax),
        "nodata_uint8": NODATA_UINT8,
        "files": []
    }

    for i in range(var.sizes["time"]):
        slice_t = var.isel(time=i).squeeze(drop=True)
        da = ensure_xy(slice_t)

        if "time" in slice_t.coords:
            tlabel = np.datetime_as_string(slice_t["time"].values, unit="D").replace("-", "")[:6]  # YYYYMM
        else:
            tlabel = f"t{i:04d}"

        if WRITE_FLOAT32:
            out_path = os.path.join(OUT_DIR, f"salinity_{tlabel}_float32.tif")
            da = da.rio.write_nodata(np.nan, inplace=False)
            da.rio.to_raster(out_path, dtype="float32", tiled=True, compress=COMPRESS)
            print(f"Wrote {out_path}")
            scaling_meta["files"].append({"file": os.path.basename(out_path), "tindex": i, "tlabel": tlabel})
        else:
            scaled = scale_to_uint8_fixed(da, global_vmin, global_vmax, NODATA_UINT8)
            out_path = os.path.join(OUT_DIR, f"salinity_{tlabel}.tif")
            da8 = xr.DataArray(
                scaled, dims=("y","x"), coords={"y": da["y"], "x": da["x"]},
                name="salinity_uint8",
                attrs={"long_name": "Sea Water Practical Salinity (scaled 8-bit, global bounds)",
                       "units": "unitless (0=nodata)",
                       "scale_vmin": global_vmin, "scale_vmax": global_vmax}
            ).rio.write_crs(da.rio.crs).rio.write_transform(da.rio.transform())
            da8 = da8.rio.write_nodata(NODATA_UINT8, inplace=False)
            da8.rio.to_raster(out_path, dtype="uint8", tiled=True, compress=COMPRESS)
            scaling_meta["files"].append({"file": os.path.basename(out_path), "tindex": i, "tlabel": tlabel})
            print(f"Wrote {out_path} using global scaling [{global_vmin:.4f}, {global_vmax:.4f}]")

    if not WRITE_FLOAT32:
        meta_path = os.path.join(OUT_DIR, "salinity_scaling_metadata.json")
        with open(meta_path, "w") as f:
            json.dump(scaling_meta, f, indent=2)
        print(f"Saved {meta_path}")

if __name__ == "__main__":
    main()


## Figure 6

<p align="center">
  <img src="Figs/climate_salinity_6.png" style="width:50%;">
</p>

Monthly data were retrieved from CMEMS [Mercator Ocean International / Copernicus Marine Service (2023). Global Ocean Physics Reanalysis (GLOBAL_MULTIYEAR_PHY_001_030) [Data set]. Copernicus Marine Service. https://doi.org/10.48670/moi-00021](https://data.marine.copernicus.eu/product/GLOBAL_MULTIYEAR_PHY_001_030/description).

In [None]:
#!/usr/bin/env python3
import os, json
import numpy as np
import xarray as xr
import rioxarray
from affine import Affine

# ---------------------------
# CONFIG
# ---------------------------
U_PATH = "../Data/cmems_mod_glo_phy_my_0.083deg_P1M-m_1760473504514.nc"  # uo
V_PATH = "../Data/cmems_mod_glo_phy_my_0.083deg_P1M-m_1760473531668.nc"  # vo
OUT_DIR = "../Data/Figure_6_salinity"
WRITE_FLOAT32 = False        # True -> Float32 GeoTIFFs; False -> 8-bit scaled GeoTIFFs
LOW_Q, HIGH_Q = 0.02, 0.98   # robust percentile bounds used globally across ALL timesteps
COMPRESS = "DEFLATE"         # GeoTIFF compression
NODATA_UINT8 = 0             # 0 reserved as nodata in 8-bit output
OPEN_WITH_CHUNKS = True      # dask-friendly

os.makedirs(OUT_DIR, exist_ok=True)

# ---------------------------
# helpers
# ---------------------------
def _ensure_xy(da: xr.DataArray) -> xr.DataArray:
    """Rename lon/lat to x/y if needed, set CRS, ensure y desc, attach affine transform."""
    rename_map = {}
    if "lon" in da.dims: rename_map["lon"] = "x"
    if "longitude" in da.dims: rename_map["longitude"] = "x"
    if "lat" in da.dims: rename_map["lat"] = "y"
    if "latitude" in da.dims: rename_map["latitude"] = "y"
    if rename_map:
        da = da.rename(rename_map)
    if "x" not in da.dims or "y" not in da.dims:
        raise ValueError(f"Expected 'x' and 'y' dims, found {list(da.dims)}")
    # y descending (north -> south)
    if np.all(np.diff(da["y"].values) > 0):
        da = da.sortby("y", ascending=False)
    # georef
    da = da.rio.write_crs("EPSG:4326", inplace=False)
    xv, yv = da["x"].values, da["y"].values
    dx = float(np.mean(np.diff(xv)))
    dy = float(np.mean(np.diff(yv)))
    x0 = xv.min() - dx / 2.0
    y0 = yv.max() + dy / 2.0
    transform = Affine(dx, 0.0, x0, 0.0, -dy, y0)
    da = da.rio.write_transform(transform, inplace=False)
    return da

def _pick_surface(da: xr.DataArray) -> xr.DataArray:
    """
    If a vertical dimension exists, take surface (depth ~ 0) by coord if available,
    otherwise first level. Handles common names.
    """
    z_names = [d for d in da.dims if d.lower() in ("depth","lev","z","depthu","depthv","deptho")]
    if not z_names:
        return da
    z = z_names[0]
    try:
        if z in da.coords and np.issubdtype(da.coords[z].dtype, np.number):
            # choose nearest to 0 if 0 present or nearest available
            if 0 in set(np.asarray(da.coords[z]).round(6)):
                da = da.sel({z: 0}, method="nearest")
            else:
                da = da.sel({z: float(da.coords[z].values.min())}, method="nearest")
        else:
            da = da.isel({z: 0})
    except Exception:
        da = da.isel({z: 0})
    return da

def _compute_global_bounds(var: xr.DataArray, low_q=LOW_Q, high_q=HIGH_Q):
    """Compute global robust quantiles across time/y/x; fallback to min/max if degenerate."""
    try:
        q = var.quantile([low_q, high_q], dim=[d for d in var.dims if d in ("time","y","x","lat","lon","latitude","longitude")], skipna=True)
        vmin = float(q.sel(quantile=low_q).values)
        vmax = float(q.sel(quantile=high_q).values)
    except Exception:
        arr = var.values.astype(np.float32)
        vmin, vmax = np.nanquantile(arr, [low_q, high_q])
        vmin, vmax = float(vmin), float(vmax)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
        # fallback to global min/max
        try:
            vmin = float(var.min(skipna=True).values)
            vmax = float(var.max(skipna=True).values)
        except Exception:
            arr = var.values.astype(np.float32)
            vmin, vmax = float(np.nanmin(arr)), float(np.nanmax(arr))
    return vmin, vmax

def _scale_to_uint8_fixed(da: xr.DataArray, vmin: float, vmax: float, nodata_val=NODATA_UINT8):
    """Scale with fixed vmin/vmax to 8-bit, reserving 0 as nodata."""
    arr = da.values.astype(np.float32)
    mask = ~np.isfinite(arr)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
        vmin = np.nanmin(arr)
        vmax = np.nanmax(arr)
        if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin >= vmax:
            return np.full(arr.shape, nodata_val, dtype=np.uint8)
    scaled = (arr - vmin) / (vmax - vmin + 1e-12)
    scaled = np.clip(scaled, 0.0, 1.0)
    scaled = (scaled * 254.0 + 1.0).astype(np.uint8)  # 1..255; 0 reserved as nodata
    scaled[mask] = nodata_val
    return scaled

# ---------------------------
# main
# ---------------------------
def main():
    open_kwargs = {"chunks": "auto"} if OPEN_WITH_CHUNKS else {}
    # Merge u and v by coords
    ds = xr.open_mfdataset([U_PATH, V_PATH], combine="by_coords", **open_kwargs)

    if "uo" not in ds.data_vars or "vo" not in ds.data_vars:
        # some CMEMS releases use 'uo'/'vo' names consistently; if not present, try to find them
        raise ValueError(f"Could not find variables 'uo' and 'vo' in the provided datasets. Found: {list(ds.data_vars)}")

    uo = ds["uo"]
    vo = ds["vo"]

    # Surface slice if needed
    uo_sfc = _pick_surface(uo)
    vo_sfc = _pick_surface(vo)

    # Ensure time alignment
    if "time" not in uo_sfc.dims or "time" not in vo_sfc.dims:
        raise ValueError("Expected a 'time' dimension in both uo and vo.")
    # align by time intersection to be safe
    uo_sfc, vo_sfc = xr.align(uo_sfc, vo_sfc, join="inner")

    # Compute speed magnitude
    speed = xr.apply_ufunc(
        lambda a, b: np.sqrt(a*a + b*b),
        uo_sfc, vo_sfc,
        dask="parallelized", output_dtypes=[np.float32]
    )
    speed.name = "current_speed"
    speed.attrs.update({"long_name": "Ocean current speed", "units": "m s-1"})

    # Compute one pair of robust bounds across ALL timesteps
    global_vmin, global_vmax = _compute_global_bounds(speed, LOW_Q, HIGH_Q)
    print(f"Global speed scaling bounds ({LOW_Q:.2f}-{HIGH_Q:.2f} quantiles): "
          f"vmin={global_vmin:.5f}, vmax={global_vmax:.5f}")

    # Metadata holder
    scaling_meta = {
        "variable": "current_speed",
        "source_vars": ["uo","vo"],
        "units": "m s-1",
        "quantiles": [LOW_Q, HIGH_Q],
        "global_vmin": float(global_vmin),
        "global_vmax": float(global_vmax),
        "nodata_uint8": NODATA_UINT8,
        "files": []
    }

    # Iterate each time slice
    T = speed.sizes["time"]
    for i in range(T):
        slice_t = speed.isel(time=i).squeeze(drop=True)

        # Ensure x/y + georeferencing
        da = _ensure_xy(slice_t)

        # Timestamp label YYYYMM
        if "time" in slice_t.coords:
            tlabel = np.datetime_as_string(slice_t["time"].values, unit="D").replace("-", "")[:6]
        else:
            tlabel = f"t{i:04d}"

        if WRITE_FLOAT32:
            out_path = os.path.join(OUT_DIR, f"current_speed_{tlabel}_float32.tif")
            da = da.rio.write_nodata(np.nan, inplace=False)
            da.rio.to_raster(out_path, dtype="float32", tiled=True, compress=COMPRESS)
            print(f"Wrote {out_path}")
            scaling_meta["files"].append({"file": os.path.basename(out_path), "tindex": i, "tlabel": tlabel})
        else:
            scaled = _scale_to_uint8_fixed(da, global_vmin, global_vmax, NODATA_UINT8)
            out_path = os.path.join(OUT_DIR, f"current_speed_{tlabel}.tif")
            da8 = xr.DataArray(
                scaled, dims=("y","x"), coords={"y": da["y"], "x": da["x"]},
                name="current_speed_uint8",
                attrs={
                    "long_name": "Ocean current speed (scaled 8-bit, global bounds)",
                    "units": "unitless (0=nodata)",
                    "scale_vmin": global_vmin,
                    "scale_vmax": global_vmax
                }
            ).rio.write_crs(da.rio.crs).rio.write_transform(da.rio.transform())
            da8 = da8.rio.write_nodata(NODATA_UINT8, inplace=False)
            da8.rio.to_raster(out_path, dtype="uint8", tiled=True, compress=COMPRESS)
            scaling_meta["files"].append({"file": os.path.basename(out_path), "tindex": i, "tlabel": tlabel})
            print(f"Wrote {out_path} using global scaling [{global_vmin:.5f}, {global_vmax:.5f}]")

    # Save sidecar metadata (for legends, rescaling)
    if not WRITE_FLOAT32:
        meta_path = os.path.join(OUT_DIR, "current_speed_scaling_metadata.json")
        with open(meta_path, "w") as f:
            json.dump(scaling_meta, f, indent=2)
        print(f"Saved {meta_path}")

if __name__ == "__main__":
    main()
