In [None]:
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

In [None]:
#In[2]:
# define function
import src.SAT_function as data_process
import src.Data_Preprocess as preprosess

In [None]:
import src.slurm_cluster as scluster
client, scluster = scluster.init_dask_slurm_cluster(walltime="02:00:00")

In [None]:
def func_mk(x):
    """
    Mann-Kendall test for trend
    """
    results = data_process.apply_mannkendall(x)
    slope = results[0]
    p_val = results[1]
    return slope, p_val

In [None]:
# Input the MMEM of SAT-OBS internal variability
dir_residuals = '/work/mh0033/m301036/Land_surf_temp/Disentangling_OBS_SAT_trend/Figure2/MIROC6/'
ds_MIROC6_1850_2022 = xr.open_mfdataset(dir_residuals + 'GSAT_MIROC6_Internal_Variability_anomalies_1850_2022.nc',chunks={'run':1})

In [None]:
ds_MIROC6_1850_2022

In [None]:
# ds_MIROC6_1850_2022 = ds_MIROC6_1850_2022.rename({'__xarray_dataarray_variable__': 'tas'})

In [None]:
# Generate the running windows of the residuals of SAT-OBS
#       with a series of equal length with an interval of 5 years starting from 10 years to 100 years
#       and calculate the trend pattern of each segment
#       and calculate the ensemble standard deviation of the trend pattern of each interval of segments

# define the function to generate the running windows of the residuals of SAT-OBS
def generate_segments(data, segment_length):
    """
    data: 3D array with dimensions [year, lat, lon]
    segment_length: length of each segment in years
    """
    years = range(int(data['year'].min().item()), int(data['year'].max().item()) - segment_length + 2)
    print(years)
    # Initialize an empty list to store the segments
    segments = []
    
    # For each year in the range
    for year in years:
        # Extract the segment of data from that year to year + segment_length
        segment = data.sel(year=slice(str(year), str(year + segment_length - 1)))
        
        # Append this segment to the list of segments
        segments.append(segment)
    
    return segments

In [None]:
# Generate the running windows of the residuals of SAT-OBS
time_interval = [60]

ICV_segments = {}
for i in time_interval:
    ICV_segments[i] = generate_segments(ds_MIROC6_1850_2022['tas'], segment_length=i)

In [None]:
# Assuming ICV_segments is a dictionary with segment_length as keys and list of DataArray segments as values
max_num_segments = max(len(segments) for segments in ICV_segments.values())
segment_lengths = ICV_segments.keys()

# Create a new Dataset to hold the new arrays
new_ds = xr.Dataset()

for segment_length in segment_lengths:
    segments_list = ICV_segments[segment_length]
    # print(segments_list)
    
    # Pad the segments list to have the same number of segments
    padded_segments = segments_list.copy()
    while len(padded_segments) < max_num_segments:
        # Create a DataArray filled with NaNs to match the shape of the segments
        nan_segment = xr.full_like(padded_segments[0], np.nan)
        padded_segments.append(nan_segment)
    
    # Create a coordinate for the new segment dimension
    segment_coord = range(max_num_segments)
    
    # Concatenate the padded segments with the new segment coordinate
    concatenated = xr.concat(padded_segments, dim=segment_coord)
    
    # Assign a specific name to the new dimension
    concatenated = concatenated.rename({'concat_dim': 'segment'})
    
    # Add the new DataArray to the new dataset
    new_ds[f'ICV_segments_{segment_length}yr'] = concatenated

In [None]:
new_ds

In [None]:
ds_combined = xr.merge([ds_MIROC6_1850_2022, new_ds])

In [None]:
ds_combined

In [None]:
# check the minimum and maximum of the new variable
# ds_combined['ICV_segments_30yr'].min().values, ds_combined['ICV_segments_30yr'].max().values

In [None]:
# define function to calculate the standard deviation of the trend pattern of each interval of segments
def std_trend_pattern(data):
    """
    data: 4D array with dimensions [year, lat, lon, segment]
    segment_length: length of each segment in years
    """
    # calculate the standard deviation of the trend pattern of each interval of segments
    std_trend_pattern = np.nanstd(data, axis=0)
    
    return std_trend_pattern

In [None]:
# Calculate the trend pattern of each segment
#       and calculate the ensemble standard deviation of the trend pattern of each interval of segments
for segment_length in segment_lengths:
    # Calculate the trend pattern of each segment
    ds_combined[f'ICV_segments_{segment_length}yr_trend'], ds_combined[f'ICV_segments_{segment_length}yr_p_values'] = xr.apply_ufunc(
        func_mk,
        ds_combined[f'ICV_segments_{segment_length}yr'],
        input_core_dims=[['year']],
        output_core_dims=[[],[]],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[float, float]
    )
    # multiply the trend pattern of each segment with 10.0 to get the trend pattern in degC/decade
    ds_combined[f'ICV_segments_{segment_length}yr_trend'] = ds_combined[f'ICV_segments_{segment_length}yr_trend']*10.0

In [None]:
for segment_length in segment_lengths:
    # Calculate the standard deviation of the trend pattern of each interval of segments
    ds_combined[f'ICV_segments_{segment_length}yr_std_trend_pattern'] = xr.apply_ufunc(
        std_trend_pattern,
        ds_combined[f'ICV_segments_{segment_length}yr_trend'],
        input_core_dims=[['segment']],
        output_core_dims=[[ ]],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[float],
        dask_gufunc_kwargs={'allow_rechunk': True}
    )

In [None]:
ds_combined

In [None]:
# calculate the ensemble mean of the trend pattern of each interval of segments;
#     and save the ensemble mean of the trend pattern of each interval of segments to the dataset
# for segment_length in segment_lengths:
#     key_trend = f'ICV_segments_{segment_length}yr_trend'
#     key_mean = f'ICV_segments_{segment_length}yr_trend_mean'

#     if key_trend in ds_combined:
#         # Calculate mean
#         data = np.nanmean(ds_combined[key_trend], axis=0)
        
#         # Check if the mean key exists, if not, initialize it
#         if key_mean not in ds_combined:
#             ds_combined[key_mean] = []

#         # Append data
#         ds_combined[key_mean]= (['lat', 'lon'], data)

In [None]:
ds_output = '/work/mh0033/m301036/Land_surf_temp/Disentangling_OBS_SAT_trend/Figure2/MIROC6/'
# ds_combined['ICV_segments_10yr_std_trend_pattern'].to_netcdf(ds_output + 'ICV_segments_10yr_std_trend_pattern.nc')
# ds_combined['ICV_segments_30yr_std_trend_pattern'].to_netcdf(ds_output + 'ICV_segments_30yr_std_trend_pattern.nc')
ds_combined['ICV_segments_60yr_std_trend_pattern'].to_netcdf(ds_output + 'ICV_segments_60yr_std_trend_pattern.nc')