In [None]:
### Walsh-Lawler Seasonality Index (intra-annual variability for precipitation and potential evapotranspiration)
# -*- coding: utf-8 -*-
import xarray as xr
import pandas as pd
import numpy as np
import rasterio
from rasterio.transform import from_origin
import os

def compute_and_save_seasonality_index(nc_path, output_tif_path):
    """
    Compute the Walsh-Lawler Seasonality Index (SI) from monthly data and save the result as a GeoTIFF.
    Following Walsh and Lawler, 1981
    
    Parameters:
    nc_path (str): Path to the input NetCDF file containing monthly potential evapotranspiration (etp) data.
    output_tif_path (str): Path to the output GeoTIFF file to save the computed SI.
    """
    # Open the NetCDF dataset
    ds = xr.open_dataset(nc_path)

    # Convert the time coordinate to datetime format
    ds['time'] = pd.to_datetime(ds['time'].values, format='%Y-%m')

    # Compute the mean for each year
    yearly_means = ds.groupby('time.year').mean(dim='time')

    # Compute monthly anomalies from the annual mean (absolute deviation)
    monthly_diff = ds.groupby('time.year') - yearly_means
    monthly_diff['etp'] = abs(monthly_diff['etp'])

    # Sum the monthly deviations for each year
    yearly_sum = monthly_diff.groupby('time.year').sum(dim='time', skipna=False)

    # Calculate the Seasonality Index for each year
    yearly_SI = yearly_sum / (yearly_means * 12)

    # Average the Seasonality Index across years
    SI = yearly_SI.mean(dim='year')

    # Extract the final result
    result = SI['etp']
    print(result)

    # Get spatial coordinates
    lon = ds['lon'].values
    lat = ds['lat'].values

    # Calculate pixel resolution
    xsize = lon[1] - lon[0]
    ysize = lat[0] - lat[1]

    # Define affine transformation
    transform = from_origin(
        west=lon.min() - xsize / 2,
        north=lat.max() + ysize / 2,
        xsize=xsize,
        ysize=ysize
    )

    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_tif_path), exist_ok=True)

    # Save the result as a GeoTIFF
    with rasterio.open(
        output_tif_path,
        'w',
        driver='GTiff',
        height=result.shape[0],
        width=result.shape[1],
        count=1,
        dtype='float32',
        crs='EPSG:4326',
        transform=transform,
        nodata=np.nan,
    ) as dst:
        dst.write(result.values, 1)

    print(f"Seasonality Index saved to: {output_tif_path}")

In [None]:
### Normalized inter-annual variability for precipitation and potential evapotranspiration
import xarray as xr
import numpy as np
import pandas as pd
import rasterio
from rasterio.transform import from_origin
import os

def compute_and_save_interannual_variability(nc_path, output_tif_path):
    """
    Compute normalized inter-annual variability of monthly data
    from a NetCDF file and save the result as a GeoTIFF.

    Parameters:
    nc_path (str): Path to the NetCDF file containing monthly potential evapotranspiration (etp) data with time, lat, and lon dimensions.
    output_tif_path (str): Path to save the output GeoTIFF file.
    """

    # Load the NetCDF dataset
    ds = xr.open_dataset(nc_path)

    # Ensure the time dimension is in datetime format
    ds['time'] = pd.to_datetime(ds['time'].values, format='%Y-%m')

    etp = ds['etp']

    # Identify grid cells that contain NaNs at any time point
    mask = etp.isnull().any(dim="time")

    # Calculate yearly total ETP, ignoring NaNs
    yearly_total = etp.groupby("time.year").sum(dim="time", skipna=True)

    # Reassign NaNs to previously masked locations
    yearly_total = yearly_total.where(~mask, np.nan)

    # Compute standard deviation and mean across years
    std_dev = yearly_total.std(dim='year')
    mean_val = yearly_total.mean(dim='year')

    # Calculate interannual variability (coefficient of variation)
    interannual_variability = std_dev / mean_val

    # Get spatial resolution and bounds
    lon = ds['lon'].values
    lat = ds['lat'].values
    xsize = lon[1] - lon[0]
    ysize = lat[0] - lat[1]

    # Construct affine transformation for GeoTIFF
    transform = from_origin(
        west=lon.min() - xsize / 2,
        north=lat.max() + ysize / 2,
        xsize=xsize,
        ysize=ysize
    )

    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_tif_path), exist_ok=True)

    # Save the result to a GeoTIFF file
    with rasterio.open(
        output_tif_path,
        'w',
        driver='GTiff',
        height=interannual_variability.shape[0],
        width=interannual_variability.shape[1],
        count=1,
        dtype=interannual_variability.dtype.name,
        crs="EPSG:4326",  # Assuming WGS84 projection
        transform=transform,
        nodata=np.nan,
    ) as dst:
        dst.write(interannual_variability.values, 1)

    print(f"Interannual variability saved to: {output_tif_path}")

In [None]:
### Normalized intra- and inter-annual variability for AI and soil moisture
import xarray as xr
import numpy as np
import pandas as pd
import rasterio
from rasterio.transform import from_origin
from sklearn.linear_model import LinearRegression
import os

def compute_intra_annual_variability(data_array):
    """
    Compute intra-annual variability: the standard deviation divided by the mean 
    across months within each year, then averaged over years.

    Parameters:
        data_array (xarray.DataArray): A 3D array with dimensions (time, lat, lon)

    Returns:
        xarray.DataArray: A 2D array of intra-annual variability (lat, lon)
    """
    annual_std = data_array.groupby("time.year").std(dim="time", skipna=True)
    annual_mean = data_array.groupby("time.year").mean(dim="time", skipna=True)
    variability = annual_std / annual_mean
    return variability.mean(dim="year", skipna=True)


def compute_inter_annual_variability(data_array):
    """
    Compute inter-annual variability: detrend the annual means and compute 
    the standard deviation of residuals divided by the multi-year mean.

    Parameters:
        data_array (xarray.DataArray): A 3D array with dimensions (time, lat, lon)

    Returns:
        xarray.DataArray: A 2D array of inter-annual variability (lat, lon)
    """
    annual_mean = data_array.groupby("time.year").mean(dim="time", skipna=True)
    years = np.arange(annual_mean.shape[0]).reshape(-1, 1)
    arr = annual_mean.values
    detrended = np.zeros_like(arr)

    model = LinearRegression()
    for i in range(arr.shape[1]):
        for j in range(arr.shape[2]):
            valid = ~np.isnan(arr[:, i, j])
            if valid.sum() > 1:
                model.fit(years[valid], arr[valid, i, j])
                trend = model.predict(years)
                detrended[:, i, j] = arr[:, i, j] - trend
            else:
                detrended[:, i, j] = np.nan

    residual_std = np.nanstd(detrended, axis=0)
    mean_overall = np.nanmean(arr, axis=0)
    result = residual_std / mean_overall
    return xr.DataArray(result, coords=[data_array.lat, data_array.lon], dims=["lat", "lon"])

In [None]:
### Intra- and inter-annual variability for temperature
def compute_intra_annual_variability(data_array, mask):
    """
    Compute intra-annual variability (within-year standard deviation).
    
    Parameters:
        data_array (xarray.DataArray): Temperature data with time dimension.
        mask (xarray.DataArray): Boolean mask indicating pixels with missing values across the full time period.
    
    Returns:
        xarray.DataArray: Mean intra-annual standard deviation for each pixel.
    """
    # Group by year and calculate standard deviation of each year
    annual_std = data_array.groupby("time.year").std(dim="time", skipna=True)
    
    # Take the mean of standard deviations over all years
    mean_annual_std = annual_std.mean(dim="year", skipna=True)
    
    # Set fully-missing pixels to NaN
    mean_annual_std = mean_annual_std.where(~mask, np.nan)
    
    return mean_annual_std


def compute_interannual_variability(data_array, mask):
    """
    Compute interannual variability (between-year variability after detrending).
    
    Parameters:
        data_array (xarray.DataArray): Temperature data with time dimension.
        mask (xarray.DataArray): Boolean mask indicating pixels with missing values across the full time period.
    
    Returns:
        xarray.DataArray: Standard deviation of detrended annual mean temperature for each pixel.
    """
    # Calculate annual mean temperature for each year
    annual_mean = data_array.groupby("time.year").mean(dim="time", skipna=True)
    
    # Detrend the annual mean time series for each pixel
    def detrend(annual_mean_data):
        years = np.arange(annual_mean_data.shape[0]).reshape(-1, 1)
        model = LinearRegression()
        detrended_data = np.zeros_like(annual_mean_data)
        
        for i in range(annual_mean_data.shape[1]):
            for j in range(annual_mean_data.shape[2]):
                valid = ~np.isnan(annual_mean_data[:, i, j])
                if valid.sum() > 1:
                    model.fit(years[valid], annual_mean_data[valid, i, j])
                    trend = model.predict(years)
                    residuals = annual_mean_data[:, i, j] - trend
                    detrended_data[:, i, j] = residuals
                else:
                    detrended_data[:, i, j] = np.nan
        return detrended_data

    # Apply detrending
    detrended = detrend(annual_mean.values)
    
    # Calculate standard deviation of residuals
    interannual_std = np.nanstd(detrended, axis=0)
    
    # Convert to xarray.DataArray
    interannual_std_da = xr.DataArray(interannual_std, coords=[data_array['lat'], data_array['lon']], dims=["lat", "lon"])
    
    # Apply mask to remove fully-missing pixels
    interannual_std_da = interannual_std_da.where(~mask, np.nan)
    
    return interannual_std_da