### Calculation of the ratio S/N

In [None]:
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
#In[2]:
# define function
import src.SAT_function as data_process
import src.Data_Preprocess as preprosess

In [None]:
import cmocean

In [None]:
# import src.slurm_cluster as scluster
# client, scluster = scluster.init_dask_slurm_cluster()

In [None]:
# print(client)

### Input both the forced and ICV_std trend data 

In [None]:
dir_forced= './Figure4_Regional_separation/reversed_trend_cal/data/'
variable_name = np.arange(2013,1949,-1).astype(str)
segment_lengths = np.arange(10, 74, 1).astype(str)

# input into the dataset
HadCRUT5_forced_ds = xr.Dataset()
for var,segment_length in zip(variable_name,segment_lengths):
    # print(var,segment_length)
    file = dir_forced + 'forced_HadCRUT5_annual_'+var+'-2022_trend.nc'
    ds = xr.open_mfdataset(file).rename({'__xarray_dataarray_variable__':'trend_forced'})
    variable_name = f'forced_{segment_length}_trend'
    HadCRUT5_forced_ds[variable_name] = ds['trend_forced']*10.0
    
print(HadCRUT5_forced_ds)

In [None]:
HadCRUT5_forced_ds

In [None]:
# Input the standard deviation of SAT-OBS residuals
dir_std = './Figure3/data/ICV_STD_whole/'

HadCRUT5_ICV_std_ds = xr.Dataset()
for segment_length in segment_lengths:
    file = dir_std + 'GSAT_HadCRUT5_Internal_Variability_trend_'+segment_length+'yr_segments_1850_2022_std.nc'
    ds = xr.open_mfdataset(file)
    # print(ds)
    variable_name = f'ICV_{segment_length}_std'
    HadCRUT5_ICV_std_ds[variable_name] = ds[f'ICV_segments_{segment_length}yr_trend_std']
    
print(HadCRUT5_ICV_std_ds)

In [None]:
HadCRUT5_ICV_std_ds

In [None]:
# define function to calculate the ratio of the trend pattern of each segment to the standard deviation of the trend pattern of each interval of segments
def SNR_trend_pattern(data, std_trend_pattern):
    """
    data: 2D array with dimensions [lat, lon]
    std_trend_pattern: 2D array with dimensions [lat, lon]
    """
    return data/std_trend_pattern

In [None]:
# Calculate the trend pattern of each segment
#       and calculate the ensemble standard deviation of the trend pattern of each interval of segments
SNR_trend_pattern_ds = xr.Dataset()
for segment_length in segment_lengths:
    variable_name = f'SNR_{segment_length}_trend'
    SNR_trend_pattern_ds[variable_name] = SNR_trend_pattern(HadCRUT5_forced_ds[f'forced_{segment_length}_trend'], HadCRUT5_ICV_std_ds[f'ICV_{segment_length}_std'])

print(SNR_trend_pattern_ds[variable_name])

In [None]:
SNR_trend_pattern_ds.values

In [None]:
# check the result max and min
for var in SNR_trend_pattern_ds.data_vars:
    print(var)
    print(SNR_trend_pattern_ds[var].max().values)
    print(SNR_trend_pattern_ds[var].min().values)

In [None]:
# # save the output 
# dir_out = './Figure3/data/Ratio/'

# for segment_length in segment_lengths:
#     file_out = dir_out + 'SNR_trend_pattern_HadCRUT5_'+segment_length+'yr_segments_ends_2022.nc'
#     SNR_trend_pattern_ds[f'SNR_{segment_length}_trend'].to_netcdf(file_out)

### Plot the SNR ratio according to the trend length

In [None]:
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 16
# plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['ytick.direction'] = 'out'
plt.rcParams['ytick.minor.visible'] = True
plt.rcParams['ytick.major.right'] = True
plt.rcParams['ytick.right'] = True
plt.rcParams['xtick.bottom'] = True
plt.rcParams['savefig.transparent'] = True
plt.rcParams['pdf.fonttype'] = 42
# plt.rcParams['legend.frameon']      = False
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.ticker as mticker
import cartopy.feature as cfeature
import cartopy.mpl.ticker as cticker
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import matplotlib.gridspec as gridspec
import matplotlib as mpl
import seaborn as sns
from matplotlib.colors import ListedColormap
from matplotlib.colors import BoundaryNorm, ListedColormap

### Steps:
1. Stack the SNR values for each trend length: This will give each grid point a series of SNR values corresponding to different trend lengths.
2. Find the first trend length where SNR >= 1.0: For each grid point, check the SNR values across all trend lengths and identify the first trend length where SNR exceeds or equals 1.0.
3. Store the corresponding trend length: Create a 2D array (lat × lon) where each grid point holds the first trend length that meets the SNR condition.
4. Plot the results: Use contour shading to visualize the trend length where SNR first exceeds 1.0.

In [None]:
# Define the range of trend lengths you're analyzing
trend_lengths = np.arange(10, 74)

# Stack the data into a new dimension for trend lengths
stacked_snr = xr.concat([SNR_trend_pattern_ds[f'SNR_{t}_trend'] for t in trend_lengths], dim='trend_length')
stacked_snr = stacked_snr.assign_coords(trend_length=trend_lengths)

In [None]:
type(stacked_snr.trend_length)

In [None]:
def first_valid_trend_length(snr_values):
    """
    Find the first trend length where SNR > 1.0 and all subsequent SNR values are persistently > 1.0.
    
    Args:
    - snr_values (np.ndarray): Array of SNR values for a specific grid point across all trend lengths.
    
    Returns:
    - first_valid_idx (int or float): The index of the first trend length where SNR > 1.0 persistently,
                                      or NaN if no such trend length is found.
    """
    # Check where SNR > 1.0
    condition = snr_values > 1.0

    # Find the index of the first occurrence of SNR > 1.0
    for idx in range(len(condition)):
        # Check if SNR > 1.0 from this index onward is all True
        if condition[idx] and np.all(condition[idx:]):
            return idx  # Return the first valid index
    
    # If no valid index is found, return NaN
    return np.nan

In [None]:
snr_values = np.array([0.5, 0.8, 1.2, 1.5, 1.6])  # Example data
result = first_valid_trend_length(snr_values)
print("First valid trend length index:", result)

In [None]:
trend_lengths

In [None]:
# Apply the function to find the first trend length where SNR >= 1.0
first_trend_idx = xr.apply_ufunc(
    first_valid_trend_length, 
    stacked_snr.chunk(dict(trend_length=-1)),  # Ensure the data is chunked along the trend_length dimension
    input_core_dims=[['trend_length']],  # Apply function along each grid point
    vectorize=True,                      # Apply in a vectorized way
    dask='parallelized',                 # Enable parallel computation with Dask
    output_dtypes=[float],               # Output will be float (since it may contain NaN)
)

In [None]:
# Step 1: Convert first_trend_idx to integers, but keep NaNs intact
first_trend_idx_int = first_trend_idx.fillna(-1).astype(int)  # Replace NaNs with -1 temporarily

# Step 2: Map indices to actual trend lengths
first_trend_length_array = xr.DataArray(
    np.where(first_trend_idx_int >= 0, trend_lengths[first_trend_idx_int], np.nan),  # Use trend lengths for valid indices, NaN for invalid
    dims=['lat', 'lon'],  # Keep lat/lon dimensions
    coords={'lat': first_trend_idx.lat, 'lon': first_trend_idx.lon}
)

In [None]:
print(np.unique(first_trend_length_array.values))

In [None]:
type(first_trend_length_array)

In [None]:
first_trend_length_array

In [None]:
print(first_trend_length_array.shape)

In [None]:
# Apply the function to each grid point across all trend lengths
def is_monotonic_after_first(snr_values):
    """
    Check if |S/N| values exceed 1.0 at some point and stay above 1.0 for all subsequent trend lengths.

    Args:
        snr_values (np.ndarray): Array of S/N values across trend lengths.

    Returns:
        bool: True if from the first instance of |S/N| > 1.0 onwards, all values remain > 1.0.
    """
    # Find the first index where |S/N| > 1.0
    condition = np.abs(snr_values) > 1.0
    first_valid_idx = np.argmax(condition) if np.any(condition) else None

    if first_valid_idx is None:
        # No S/N > 1.0, return False
        return False

    # Check if all subsequent values are also > 1.0
    return np.all(np.abs(snr_values[first_valid_idx:]) > 1.0)

In [None]:
monotonic_map = xr.apply_ufunc(
    is_monotonic_after_first,
    stacked_snr.chunk(dict(trend_length=-1)),  # Ensure the data is chunked along the trend_length dimension
    input_core_dims=[['trend_length']],  # Apply function along each grid point
    vectorize=True,                      # Apply in a vectorized way
    dask='parallelized',                 # Enable parallel computation with Dask
    output_dtypes=[bool]  # Output is boolean (True/False)
)

In [None]:
monotonic_map

In [None]:
# save the monotonic test
dir_out = './Revised_main_figures/Figure4_Emergence_timescale/data/'
monotonic_map.to_dataset(name="monotonicity").to_netcdf(dir_out+'obs_emergence_monotonicity.nc')

In [None]:
def plot_trend(lons, lats, data, levels=None, extend=None, cmap=None, norm=None,
                                 title="", ax=None, show_xticks=False, show_yticks=False):
    """
    Plot the trend spatial pattern using Robinson projection with significance overlaid.

    Parameters:
    - data: 2D numpy array with the trend values.
    - lats, lons: 1D arrays of latitudes and longitudes.
    - p_values: 2D array with p-values for each grid point.
    - GMST_p_values: 2D array with GMST p-values for each grid point.
    - title: Title for the plot.
    - ax: Existing axis to plot on. If None, a new axis will be created.
    - show_xticks, show_yticks: Boolean flags to show x and y axis ticks.
    
    Returns:
    - contour_obj: The contour object from the plot.
    """
# Create a new figure/axis if none is provided
    if ax is None:
        fig, ax = plt.subplots(figsize=(20, 15), subplot_kw={'projection': ccrs.Robinson()})
        ax.set_global()
        
    # contour_obj = ax.contourf(lons, lats, data, levels=levels, extend=extend, 
    #                         cmap=cmap, norm=norm, transform=ccrs.PlateCarree())
    # Assuming lons, lats, data, levels, cmap, norm, and extend are already defined
    pcolormesh_obj = ax.pcolormesh(lons, lats, data, norm=norm, cmap=cmap, shading='auto',
                                   transform=ccrs.PlateCarree())
    # pcolormesh_obj = data.plot.pcolormesh(ax=ax, transform=ccrs.PlateCarree(), cmap=cmap, norm=norm, 
    # add_colorbar=False)

    # Plot significance masks with different hatches
    # ax.contourf(lons, lats, significance_mask, levels=[0.05, 1.0],hatches=['///'], colors='none', transform=ccrs.PlateCarree())

    ax.coastlines(resolution='110m')
    gl = ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False, linewidth=1, linestyle='--',
                      color='gray', alpha=0.35)

    # Disable labels on the top and right of the plot
    gl.top_labels = False
    gl.right_labels = False

    # Enable labels on the bottom and left of the plot
    gl.bottom_labels = show_xticks
    gl.left_labels = show_yticks
    gl.xformatter = cticker.LongitudeFormatter()
    gl.yformatter = cticker.LatitudeFormatter()
    gl.xlabel_style = {'size': 16}
    gl.ylabel_style = {'size': 16}
    
    if show_xticks:
        gl.bottom_labels = True
    if show_yticks:
        gl.left_labels = True
    
    ax.set_title(title, loc='center', fontsize=24, pad=5.0)

    return pcolormesh_obj

In [None]:
# Check for NaN values
nan_indices = np.isnan(first_trend_length_array)

# Now you can print or analyze these points
print("NaN indices:", np.where(nan_indices))


In [None]:
# check the boolean value True or False
print("Monotonic map shape:", monotonic_map.shape)
# Check for True or False values
true_indices = np.where(monotonic_map.values == True)
# Now you can print or analyze these points
print("True indices:", true_indices)

In [None]:
# # WH box (lat 42 to 60, lon 310 to 350)
# wh_box = first_trend_length_array.sel(lat=slice(42, 60), lon=slice(310, 350))
# print("WH Box First Trend Length Values:\n", wh_box.values)

In [None]:
# so_box = first_trend_length_array.sel(lat=slice(-90, -42), lon=slice(180, 200))
# print("SO Box First Trend Length Values:\n", so_box.values)

In [None]:
import copy
import matplotlib as mpl
import matplotlib.colors as mcolors
import palettable
import cartopy.util as cutil
import numpy.ma as ma

# cmdict = cmocean.cm.matter
# norm = mcolors.Normalize(vmin=10, vmax=74)

In [None]:
print(trend_lengths)

In [None]:
# Plot the first trend length where SNR >= 1.0
fig, ax = plt.subplots(figsize=(15, 10), subplot_kw={'projection': ccrs.Robinson(180)})
trend_lengths = np.append(np.arange(10, 75, 5), 75)  # [10, 15, ..., 75, 80]
cmap_base = plt.get_cmap('OrRd_r')
colors = cmap_base(np.linspace(0, 1, len(trend_lengths) - 1))  # One less than number of edges
colors = np.vstack([colors, [1, 1, 1, 1]])  # Add white at the end for >75 years
custom_cmap = ListedColormap(colors)
norm = BoundaryNorm(trend_lengths, ncolors=len(trend_lengths) - 1)

# trend_lengths = np.arange(10, 80, 5)  # Define the range of trend lengths
# # Make a copy of the colormap before modifying it
# cmap = copy.copy(plt.get_cmap('OrRd_r'))  # Use 'OrRd_r' reversed colormap

# # Define BoundaryNorm for discrete colormap intervals
# norm = BoundaryNorm(trend_lengths, cmap.N)  # cmap.N defines the number of colors in the colormap

# extend = 'neither'  # No extension beyond the colormap range
# cmdict_r = cmdict.reversed()

# Mask invalid data (NaN or Inf values)
masked_array = ma.masked_invalid(first_trend_length_array)

# Add cyclic point to data and longitude
Ratio_with_cyclic, lon_with_cyclic = cutil.add_cyclic_point(masked_array, coord=first_trend_length_array.lon)

# Check the shapes to ensure consistency
print("Shape of Ratio_with_cyclic:", Ratio_with_cyclic.shape)
print("Shape of lon_with_cyclic:", lon_with_cyclic.shape)
print("Shape of SNR_trend_pattern_ds.lat:", SNR_trend_pattern_ds.lat.shape)

# Plot the data (ensure 'contour' is the correct mappable object)
pcolormesh_plot = plot_trend(lon_with_cyclic, SNR_trend_pattern_ds.lat, Ratio_with_cyclic,
                                    levels=trend_lengths, extend='max', cmap=custom_cmap, norm=norm,
                                        title='Emergence time scale (years)', ax=ax, show_xticks=True, show_yticks=True)
# Assume monotonic_map is a boolean DataArray (True = monotonic, False = not monotonic)
# and first_trend_length_array is your main data

# Add cyclic point to monotonic_map for plotting
monotonic_with_cyclic, lon_with_cyclic = cutil.add_cyclic_point(monotonic_map.values, coord=monotonic_map.lon)

# Overlay the monotonicity mask as hatching (False = not monotonic)
ax.contourf(
    lon_with_cyclic, SNR_trend_pattern_ds.lat, ~monotonic_with_cyclic,
    levels=[0.5, 1.5], hatches=['///'], colors='none', transform=ccrs.PlateCarree(), alpha=0
)

# Add the regional outlines and calculate midpoints for labels:
# Arctic box
arctic_lon_mid = (0 + 360) / 2
arctic_lat_mid = (66.5 + 90) / 2
ax.plot([0, 360, 360, 0, 0], [66.5, 66.5, 90, 90, 66.5],
        color='lightgrey', linewidth=2.0, transform=ccrs.PlateCarree())
ax.text(arctic_lon_mid, arctic_lat_mid, 'ARC', color='black', fontsize=18, transform=ccrs.PlateCarree(),
        ha='center', va='center')  # Label for Arctic

# WH box
wh_lon_mid = (310 + 350) / 2
wh_lat_mid = (42 + 60) / 2
box_lons = np.array([310, 350, 350, 310, 310])
box_lats = np.array([42, 42, 60, 60, 42])
ax.plot(box_lons, box_lats, color='lightgrey', linewidth=2.0, transform=ccrs.PlateCarree())
ax.text(wh_lon_mid, wh_lat_mid, 'NAWH', color='black', fontsize=18, transform=ccrs.PlateCarree(),
        ha='center', va='center')  # Label for WH

# Southeast Pacific box
sep_lon_mid = (200 + 320) / 2
sep_lat_mid = (0 + -25) / 2
ax.plot([200%360, 280%360, 280%360, 250%360, 200%360], [0, 0, -25, -25, 0],
        color='lightgrey', linewidth=2.0, transform=ccrs.PlateCarree())
ax.text(sep_lon_mid, sep_lat_mid, 'SEP', color='black', fontsize=18, transform=ccrs.PlateCarree(),
        ha='center', va='center')  # Label for SEP

# Extratropical South Pacific box
sop_lon_mid = (220 + 280) / 2
sop_lat_mid = (-40 + -60) / 2
ax.plot([220%360, 280%360, 280%360, 220%360, 220%360], [-40, -40, -60, -60, -40],
        color='lightgrey', linewidth=2.0, transform=ccrs.PlateCarree())
ax.text(sop_lon_mid, sop_lat_mid, 'SOP', color='black', fontsize=18, transform=ccrs.PlateCarree(),
        ha='center', va='center')  # Label for SOP
# Add colorbar for the plot

cbar_ax = fig.add_axes([0.275, 0.12, 0.5, 0.04])
cbar = plt.colorbar(pcolormesh_plot, cax=cbar_ax, orientation='horizontal', extend='max')
# # Customize the colorbar
cbar.ax.tick_params(labelsize=14)
cbar.set_label('Emergence time scale (years)', fontsize=16)

fig.savefig('Emergence_Trend_Length.png', dpi=300, bbox_inches='tight')
fig.savefig('Emergence_Trend_Length.pdf', dpi=300, bbox_inches='tight')
plt.show()

### 