### Calculation of the ratio S/N

In [None]:
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
#In[2]:
# define function
import src.SAT_function as data_process
import src.Data_Preprocess as preprosess

### Input both the forced and ICV_std trend data 

In [None]:
dir_in= '/work/mh0033/m301036/Land_surf_temp/Disentangling_OBS_SAT_trend/LE_evaluation/Fig3_IPSL/output/'

IPSL_forced_ds = xr.open_mfdataset(dir_in + 'IPSL_forced_segmented_trend.nc').tas*10.0
print(IPSL_forced_ds)

In [None]:
# # Assuming `ds` is your dataset
# segment_lengths = range(10, 74, 1)  # Generate segment lengths from 10 to 73
# new_period_names = [f"forced_{length}yr_trend" for length in segment_lengths]  # Create new names

# # Replace the period coordinate
# if len(new_period_names) == len(IPSL_forced_ds['tas'].period):
#     IPSL_forced_ds['tas'] = IPSL_forced_ds['tas'].assign_coords(period=new_period_names)
#     print("Updated period dimension successfully.")
# else:
#     print(f"Error: Mismatch in length. New names: {len(new_period_names)}, Current period: {len(IPSL_forced_ds['tas'].period)}")

# # Inspect the updated dataset
# print(IPSL_forced_ds)


In [None]:
IPSL_forced_ds

In [None]:
# Input multiple runs ICV into one dataset with new variable dimension 'run'
variable_indices = np.arange(1, 51, 1).astype(str)
segment_lengths = np.arange(10, 74, 1).astype(str)

In [None]:
# Input the standard deviation of SAT-OBS residuals
dir_std = '/work/mh0033/m301036/Land_surf_temp/Disentangling_OBS_SAT_trend/LE_evaluation/Fig3_IPSL/output/'

IPSL_ICV_std_ds = xr.open_mfdataset(dir_std + 'IPSL_ICV_noise_std_trend_pattern_1850_2022.nc')

In [None]:
IPSL_ICV_std_ds

In [None]:
# # check the zero value in the std
# std_trend_pattern = IPSL_ICV_std_ds['std_trend_10'].squeeze()
# # Check where std_trend_pattern is zero
# zero_mask = std_trend_pattern == 0

# # Count the number of zeros
# num_zeros = zero_mask.sum().item()
# print(f"Number of zeros in std_trend_pattern: {num_zeros}")

# # Optionally print locations (coordinates) of zero values
# if num_zeros > 0:
#     zero_coords = std_trend_pattern.where(zero_mask, drop=True)
#     print(f"Coordinates of zero values: {zero_coords}")


In [None]:
# define function to calculate the ratio of the trend pattern of each segment to the standard deviation of the trend pattern of each interval of segments
def SNR_trend_pattern(data, std_trend_pattern):
    """
    data: 2D array with dimensions [lat, lon]
    std_trend_pattern: 2D array with dimensions [lat, lon]
    """
    safe_std = np.where(std_trend_pattern != 0,std_trend_pattern, 1e-10)
    return data/safe_std

In [None]:
print(IPSL_forced_ds.sel(period="2013-2022"))

In [None]:
# Calculate the trend pattern of each segment and ensemble standard deviation
num = np.arange(2013, 1949, -1)
SNR_trend_pattern_ds = xr.Dataset()  # Use an xarray.Dataset to store results by key

for segment_length, num in zip(segment_lengths, num):
    # Construct the corresponding period key for the segment
    forced_key = f"{num}-2022"
    variable_name = f'SNR_{segment_length}_trend'
    
    # Debug: Print the selected forced trend period
    print(f"Processing segment length {segment_length}, forced key: {forced_key}")
    print(IPSL_forced_ds.sel(period=forced_key))
    
    # Apply the SNR trend pattern function
    result = xr.apply_ufunc(
        SNR_trend_pattern,
        IPSL_forced_ds.sel(period=forced_key).chunk({"run": -1}),  # Adjust chunking
        IPSL_ICV_std_ds[f'std_trend_{segment_length}'].chunk({"run": -1}),  # Adjust chunking
        input_core_dims=[['lat', 'lon'], ['lat', 'lon']],  # Core dimensions
        output_core_dims=[['lat', 'lon']],                # Output retains ['lat', 'lon']
        vectorize=True,                                   # Vectorize across non-core dims
        dask='parallelized',                              # Enable parallelization
        output_dtypes=[float]                             # Output data type
    )
    
    # Assign the result to the Dataset using the variable name
    SNR_trend_pattern_ds[variable_name] = result

# Final output
print(SNR_trend_pattern_ds)



In [None]:
SNR_trend_pattern_ds

In [None]:
#save the output
dir_out = '/work/mh0033/m301036/Land_surf_temp/Disentangling_OBS_SAT_trend/LE_evaluation/Fig3_IPSL/output/'
SNR_trend_pattern_ds.to_netcdf(dir_out + 'IPSL_SNR_segments_pattern_1850_2022.nc')

In [None]:
# check the result max and min
for var in SNR_trend_pattern_ds.data_vars:
    print(var)
    print(SNR_trend_pattern_ds[var].max().values)
    print(SNR_trend_pattern_ds[var].min().values)

In [None]:
# # save the output 
# dir_out = '/work/mh0033/m301036/Land_surf_temp/Disentangling_OBS_SAT_trend/Figure3/data/Ratio/'

# for segment_length in segment_lengths:
#     file_out = dir_out + 'SNR_trend_pattern_IPSL_'+segment_length+'yr_segments_ends_2022.nc'
#     SNR_trend_pattern_ds[f'SNR_{segment_length}_trend'].to_netcdf(file_out)

### Plot the SNR ratio according to the trend length

In [None]:
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 16
# plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['ytick.direction'] = 'out'
plt.rcParams['ytick.minor.visible'] = True
plt.rcParams['ytick.major.right'] = True
plt.rcParams['ytick.right'] = True
plt.rcParams['xtick.bottom'] = True
plt.rcParams['savefig.transparent'] = True
plt.rcParams['pdf.fonttype'] = 42
# plt.rcParams['legend.frameon']      = False
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.ticker as mticker
import cartopy.feature as cfeature
import cartopy.mpl.ticker as cticker
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import matplotlib.gridspec as gridspec
import matplotlib as mpl
import seaborn as sns
from matplotlib.colors import ListedColormap
from matplotlib.colors import BoundaryNorm, ListedColormap

In [None]:
def plot_trend(trend_data, lats, lons, levels=None, extend=None, cmap=None, norm=None,
                                 title="", ax=None, show_xticks=False, show_yticks=False):
    """
    Plot the trend spatial pattern using Robinson projection with significance overlaid.

    Parameters:
    - trend_data: 2D numpy array with the trend values.
    - lats, lons: 1D arrays of latitudes and longitudes.
    - p_values: 2D array with p-values for each grid point.
    - GMST_p_values: 2D array with GMST p-values for each grid point.
    - title: Title for the plot.
    - ax: Existing axis to plot on. If None, a new axis will be created.
    - show_xticks, show_yticks: Boolean flags to show x and y axis ticks.
    
    Returns:
    - contour_obj: The contour object from the plot.
    """
# Create a new figure/axis if none is provided
    if ax is None:
        fig, ax = plt.subplots(figsize=(20, 15), subplot_kw={'projection': ccrs.Robinson()})
        ax.set_global()
        
    contour_obj = ax.contourf(lons, lats, trend_data, levels=levels, extend=extend, cmap=cmap, norm=norm, transform=ccrs.PlateCarree(central_longitude=0))
    # Plot significance masks with different hatches
    # ax.contourf(lons, lats, significance_mask, levels=[0.05, 1.0],hatches=['///'], colors='none', transform=ccrs.PlateCarree())

    ax.coastlines(resolution='110m')
    gl = ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False, linewidth=1, linestyle='--',
                      color='gray', alpha=0.35)

    # Disable labels on the top and right of the plot
    gl.top_labels = False
    gl.right_labels = False

    # Enable labels on the bottom and left of the plot
    gl.bottom_labels = show_xticks
    gl.left_labels = show_yticks
    gl.xformatter = cticker.LongitudeFormatter()
    gl.yformatter = cticker.LatitudeFormatter()
    gl.xlabel_style = {'size': 16}
    gl.ylabel_style = {'size': 16}
    
    if show_xticks:
        gl.bottom_labels = True
    if show_yticks:
        gl.left_labels = True
    
    ax.set_title(title, loc='center', fontsize=18, pad=5.0)

    return contour_obj

### Steps:
1. Stack the SNR values for each trend length: This will give each grid point a series of SNR values corresponding to different trend lengths.
2. Find the first trend length where SNR >= 1.0: For each grid point, check the SNR values across all trend lengths and identify the first trend length where SNR exceeds or equals 1.0.
3. Store the corresponding trend length: Create a 2D array (lat × lon) where each grid point holds the first trend length that meets the SNR condition.
4. Plot the results: Use contour shading to visualize the trend length where SNR first exceeds 1.0.

In [None]:
# Define the range of trend lengths you're analyzing
trend_lengths = np.arange(10, 74)

# Stack the data into a new dimension for trend lengths
stacked_snr = xr.concat([SNR_trend_pattern_ds[f'SNR_{t}_trend'] for t in trend_lengths], dim='trend_length')
stacked_snr = stacked_snr.assign_coords(trend_length=trend_lengths)

In [None]:
stacked_snr

In [None]:
type(stacked_snr.trend_length)

In [None]:
# # Define a function to find the first occurrence of SNR >= 1.0
# def first_valid_trend_length(snr_data):
#     # Use np.argmax to find the first index where SNR >= 1.0
#     condition = snr_data >= 1.0
#     first_valid_idx = np.argmax(condition, axis=0)
    
#     # If no valid index is found, set to NaN (ignore if all are less than 1.0)
#     no_valid = np.all(~condition, axis=0)
#     first_valid_idx[no_valid] = np.nan
    
#     return first_valid_idx

In [None]:
def first_valid_trend_length(snr_values):
    """
    Find the first trend length where SNR >= 1.0.
    
    Args:
    - snr_values (np.ndarray): Array of SNR values for a specific grid point across all trend lengths.
    
    Returns:
    - first_valid_idx (np.ndarray or float): The index of the first trend length where SNR >= 1.0, 
                                             or NaN if no such trend length is found.
    """
    # Mask negative SNR values by treating them as NaN
    # snr_values = np.where(snr_values < 0, np.nan, snr_values)

    # Check where SNR is >= 1.0
    condition = abs(snr_values) > 1.0
    
    # Get the index of the first valid trend length where SNR >= 1.0
    if np.any(condition):
        first_valid_idx = np.argmax(condition, axis=0)
    else:
        first_valid_idx = np.nan  # Return NaN if no trend length satisfies the condition
    
    return first_valid_idx

In [None]:
trend_lengths

In [None]:
# Apply the function to find the first trend length where SNR >= 1.0
first_trend_idx = xr.apply_ufunc(
    first_valid_trend_length, 
    stacked_snr.chunk(dict(trend_length=-1),{"run": -1}),  # Ensure the data is chunked along the trend_length dimension
    input_core_dims=[['trend_length']],  # Apply function along each grid point
    vectorize=True,                      # Apply in a vectorized way
    dask='parallelized',                 # Enable parallel computation with Dask
    output_dtypes=[float],               # Output will be float (since it may contain NaN)
)

In [None]:
first_trend_idx

In [None]:
# Step 1: Convert first_trend_idx to integers, but keep NaNs intact
first_trend_idx_int = first_trend_idx.fillna(-1).astype(int)  # Replace NaNs with -1 temporarily

# Step 2: Map indices to actual trend lengths
first_trend_length_array = xr.DataArray(
    np.where(first_trend_idx_int >= 0, trend_lengths[first_trend_idx_int], np.nan),  # Use trend lengths for valid indices, NaN for invalid
    dims=['run','lat', 'lon'],  # Keep lat/lon dimensions
    coords={'run': first_trend_idx.run,'lat': first_trend_idx.lat, 'lon': first_trend_idx.lon}
)

In [None]:
print(np.unique(first_trend_length_array.values))

In [None]:
type(first_trend_length_array)

In [None]:
print(first_trend_length_array.shape)

In [None]:
first_trend_length_array

In [None]:
#save the emergence data 
dir_out = '/work/mh0033/m301036/Land_surf_temp/Disentangling_OBS_SAT_trend/LE_evaluation/Fig3_IPSL/output/'
first_trend_length_array.to_netcdf(dir_out + 'IPSL_emergence_timescale.nc')

In [None]:
# output the ensemble mean of the first valid trend length in run dimension

first_trend_length_array_mean = first_trend_length_array.mean(dim='run')

In [None]:
def plot_trend(lons, lats, data, levels=None, extend=None, cmap=None, norm=None,
                                 title="", ax=None, show_xticks=False, show_yticks=False):
    """
    Plot the trend spatial pattern using Robinson projection with significance overlaid.

    Parameters:
    - data: 2D numpy array with the trend values.
    - lats, lons: 1D arrays of latitudes and longitudes.
    - p_values: 2D array with p-values for each grid point.
    - GMST_p_values: 2D array with GMST p-values for each grid point.
    - title: Title for the plot.
    - ax: Existing axis to plot on. If None, a new axis will be created.
    - show_xticks, show_yticks: Boolean flags to show x and y axis ticks.
    
    Returns:
    - contour_obj: The contour object from the plot.
    """
# Create a new figure/axis if none is provided
    if ax is None:
        fig, ax = plt.subplots(figsize=(20, 15), subplot_kw={'projection': ccrs.Robinson()})
        ax.set_global()
        
    # contour_obj = ax.contourf(lons, lats, data, levels=levels, extend=extend, 
    #                         cmap=cmap, norm=norm, transform=ccrs.PlateCarree())
    # Assuming lons, lats, data, levels, cmap, norm, and extend are already defined
    pcolormesh_obj = ax.pcolormesh(lons, lats, data, norm=norm, cmap=cmap, shading='auto',
                                   transform=ccrs.PlateCarree())
    # pcolormesh_obj = data.plot.pcolormesh(ax=ax, transform=ccrs.PlateCarree(), cmap=cmap, norm=norm, 
    # add_colorbar=False)

    # Plot significance masks with different hatches
    # ax.contourf(lons, lats, significance_mask, levels=[0.05, 1.0],hatches=['///'], colors='none', transform=ccrs.PlateCarree())

    ax.coastlines(resolution='110m')
    gl = ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False, linewidth=1, linestyle='--',
                      color='gray', alpha=0.35)

    # Disable labels on the top and right of the plot
    gl.top_labels = False
    gl.right_labels = False

    # Enable labels on the bottom and left of the plot
    gl.bottom_labels = show_xticks
    gl.left_labels = show_yticks
    gl.xformatter = cticker.LongitudeFormatter()
    gl.yformatter = cticker.LatitudeFormatter()
    gl.xlabel_style = {'size': 16}
    gl.ylabel_style = {'size': 16}
    
    if show_xticks:
        gl.bottom_labels = True
    if show_yticks:
        gl.left_labels = True
    
    ax.set_title(title, loc='center', fontsize=24, pad=5.0)

    return pcolormesh_obj

In [None]:
# save the emergence time scale of ensemble mean
dir_out = '/work/mh0033/m301036/Land_surf_temp/Disentangling_OBS_SAT_trend/LE_evaluation/Fig3_IPSL/output/'
first_trend_length_array_mean.to_netcdf(dir_out + 'IPSL_emergence_timescale_mean.nc')

In [None]:
import copy
import matplotlib as mpl
import matplotlib.colors as mcolors
import palettable
import cartopy.util as cutil
import numpy.ma as ma

In [None]:
# Plot the first trend length where SNR >= 1.0
fig, ax = plt.subplots(figsize=(15, 10), subplot_kw={'projection': ccrs.Robinson(180)})

trend_lengths = np.arange(10, 80, 5)  # Define the range of trend lengths
# Make a copy of the colormap before modifying it
cmap = plt.get_cmap("Spectral") #copy.copy(plt.get_cmap('OrRd_r'))  # Use 'OrRd_r' reversed colormap
# Define BoundaryNorm for discrete colormap intervals
norm = BoundaryNorm(trend_lengths, cmap.N)  # cmap.N defines the number of colors in the colormap

extend = 'neither'  # No extension beyond the colormap range
# cmdict_r = cmdict.reversed()

data_plot = first_trend_length_array_mean.values
# Mask invalid data (NaN or Inf values)
masked_array = ma.masked_invalid(data_plot)

# Add cyclic point to data and longitude
Ratio_with_cyclic, lon_with_cyclic = cutil.add_cyclic_point(masked_array, coord=first_trend_length_array.lon)

# Check the shapes to ensure consistency
print("Shape of Ratio_with_cyclic:", Ratio_with_cyclic.shape)
print("Shape of lon_with_cyclic:", lon_with_cyclic.shape)
print("Shape of SNR_trend_pattern_ds.lat:", SNR_trend_pattern_ds.lat.shape)

# Plot the data (ensure 'contour' is the correct mappable object)
pcolormesh_plot = plot_trend(lon_with_cyclic, SNR_trend_pattern_ds.lat, Ratio_with_cyclic,
                                    levels=trend_lengths, extend=extend, cmap=cmap, norm=norm,
                                        title='Emergence time scale (years)', ax=ax, show_xticks=True, show_yticks=True)


# Add the regional outlines and calculate midpoints for labels:
# Arctic box
arctic_lon_mid = (0 + 360) / 2
arctic_lat_mid = (66.5 + 90) / 2
ax.plot([0, 360, 360, 0, 0], [66.5, 66.5, 90, 90, 66.5],
        color='lightgrey', linewidth=2.0, transform=ccrs.PlateCarree())
ax.text(arctic_lon_mid, arctic_lat_mid, 'ARC', color='black', fontsize=18, transform=ccrs.PlateCarree(),
        ha='center', va='center')  # Label for Arctic

# WH box
wh_lon_mid = (310 + 350) / 2
wh_lat_mid = (42 + 60) / 2
box_lons = np.array([310, 350, 350, 310, 310])
box_lats = np.array([42, 42, 60, 60, 42])
ax.plot(box_lons, box_lats, color='lightgrey', linewidth=2.0, transform=ccrs.PlateCarree())
ax.text(wh_lon_mid, wh_lat_mid, 'NAWH', color='black', fontsize=18, transform=ccrs.PlateCarree(),
        ha='center', va='center')  # Label for WH

# Southeast Pacific box
sep_lon_mid = (200 + 320) / 2
sep_lat_mid = (0 + -25) / 2
ax.plot([200%360, 280%360, 280%360, 250%360, 200%360], [0, 0, -25, -25, 0],
        color='lightgrey', linewidth=2.0, transform=ccrs.PlateCarree())
ax.text(sep_lon_mid, sep_lat_mid, 'SEP', color='black', fontsize=18, transform=ccrs.PlateCarree(),
        ha='center', va='center')  # Label for SEP

# Extratropical South Pacific box
sop_lon_mid = (220 + 280) / 2
sop_lat_mid = (-40 + -60) / 2
ax.plot([220%360, 280%360, 280%360, 220%360, 220%360], [-40, -40, -60, -60, -40],
        color='lightgrey', linewidth=2.0, transform=ccrs.PlateCarree())
ax.text(sop_lon_mid, sop_lat_mid, 'SOP', color='black', fontsize=18, transform=ccrs.PlateCarree(),
        ha='center', va='center')  # Label for SOP
# Add colorbar for the plot

cbar_ax = fig.add_axes([0.275, 0.12, 0.5, 0.04])
cbar = plt.colorbar(pcolormesh_plot, cax=cbar_ax, orientation='horizontal')
# # Customize the colorbar
cbar.ax.tick_params(labelsize=14)
cbar.set_label('Emergence time scale (years)', fontsize=16)

fig.savefig('Emergence_Trend_Length.png', dpi=300, bbox_inches='tight')
fig.savefig('Emergence_Trend_Length.pdf', dpi=300, bbox_inches='tight')
plt.show()

### 