In [2]:
# Author: Anja Katzenberger

# This code creates interactive 3D hysteresis plot
# depending on upper (solar radiation) and lower (surface temperature) boundary conditions

In [3]:
%matplotlib inline

### LOAD MODULES
#-------------------

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import xarray as xr
import numpy as np
from pyts.decomposition import SingularSpectrumAnalysis
import plotly.graph_objects as go
import plotly.io as pio

# where to save figures
save_dir = 'C:/Users/anjaka/Nextcloud/PhD/03_MonsoonPlanet_Hysteresis/Figures'


In [4]:
#%%
###  LOAD DATA
#-------------------
data_dir = 'C:/Users/anjaka/Nextcloud/PhD/03_MonsoonPlanet_Hysteresis/data/slab_10years_ydaymean'

slab_list = [50,100,200,500]  # available slabs
slab_list_sel = [50,100,200,500] # selected slabs

# Initialize dictionaries to store the data for each slab
ds_dict = {}

pr = {}
sw = {}
tsurf = {}
wvp = {}

for slab in slab_list:
    # Load the dataset
    ds_dict[slab] = xr.open_dataset(f'{data_dir}/slab{slab}m.nc')
    
    # Extract the variables
    sw[slab] = ds_dict[slab]['swdn_toa']
    tsurf[slab] = ds_dict[slab]['t_surf']-273.15
    pr[slab] = ds_dict[slab]['precip']*86400
    wvp[slab] = ds_dict[slab]['WVP']



In [5]:

#%%
#  PROCESS DATA
#-------------------

l_m = 10
u_m = 20


# Initialize dictionaries to store the results
pr_m = {}
sw_m = {}
tsurf_m = {}
wvp_m = {}

pr_m_z = {}
sw_m_z = {}
tsurf_m_z = {}
wvp_m_z = {}

pr_m_z_mean = {}
sw_m_z_mean = {}
tsurf_m_z_mean = {}
wvp_m_z_mean = {}


for slab in slab_list:
    # Select the latitude slice and calculate the mean over longitude
    pr_m[slab] = pr[slab].sel(lat=slice(l_m,u_m)).mean('lon')
    sw_m[slab] = sw[slab].sel(lat=slice(l_m,u_m)).mean('lon')
    tsurf_m[slab] = tsurf[slab].sel(lat=slice(l_m,u_m)).mean('lon')
    wvp_m[slab] = wvp[slab].sel(lat=slice(l_m,u_m)).mean('lon')

    # Calculate the weighted mean over latitude
    weights = np.cos(np.deg2rad(pr_m[slab].lat))
    pr_m_z_mean[slab] = pr_m[slab].weighted(weights).mean(dim=['lat'])

    weights = np.cos(np.deg2rad(sw_m[slab].lat))
    sw_m_z_mean[slab] = sw_m[slab].weighted(weights).mean(dim=['lat'])

    weights = np.cos(np.deg2rad(tsurf_m[slab].lat))
    tsurf_m_z_mean[slab] = tsurf_m[slab].weighted(weights).mean(dim=['lat'])

    weights = np.cos(np.deg2rad(wvp_m[slab].lat))
    wvp_m_z_mean[slab] = wvp_m[slab].weighted(weights).mean(dim=['lat'])



In [6]:

#%%
# Singular Spectrum Analysis
#-------------------

# Singular Spectrum Analysis
L = 20 # window_size

pr_m_z_mean_ssa = {}
tsurf_m_z_mean_ssa = {}
sw_m_z_mean_ssa = {}
wvp_m_z_mean_ssa = {}

# Perform the Singular Spectrum Analysis for 'tsurf'
for i in range(len(slab_list)):
    F = tsurf_m_z_mean[slab_list[i]]
    F_arr = np.array(F)
    F_in = np.array([F_arr])
    ssa = SingularSpectrumAnalysis(window_size = L)
    X_ssa = ssa.transform(F_in)
    tsurf_m_z_mean_ssa[slab_list[i]] = X_ssa[0, 0, :]

# Perform the Singular Spectrum Analysis for 'sw'
for i in range(len(slab_list)):
    F = sw_m_z_mean[slab_list[i]]
    F_arr = np.array(F)
    F_in = np.array([F_arr])
    ssa = SingularSpectrumAnalysis(window_size = L)
    X_ssa = ssa.transform(F_in)
    sw_m_z_mean_ssa[slab_list[i]] = X_ssa[0, 0, :]

# Perform the Singular Spectrum Analysis for 'wvp'
for i in range(len(slab_list)):
    F = pr_m_z_mean[slab_list[i]]
    F_arr = np.array(F)
    F_in = np.array([F_arr])
    ssa = SingularSpectrumAnalysis(window_size = L)
    X_ssa = ssa.transform(F_in)
    pr_m_z_mean_ssa[slab_list[i]] = X_ssa[0, 0, :]

# Perform the Singular Spectrum Analysis for 'wvp'
for i in range(len(slab_list)):
    F = wvp_m_z_mean[slab_list[i]]
    F_arr = np.array(F)
    F_in = np.array([F_arr])
    ssa = SingularSpectrumAnalysis(window_size = L)
    X_ssa = ssa.transform(F_in)
    wvp_m_z_mean_ssa[slab_list[i]] = X_ssa[0, 0, :]


In [20]:
### PLOT INTERACTIVE 3D PLOT
#-------------------

colors = ["#DC143C", "#FF8000", "#33A1C9", "#191970"]

# Create a 3D line plot
fig = go.Figure(data=[
    go.Scatter3d(
        x=list(sw_m_z_mean_ssa[50]) + [sw_m_z_mean_ssa[50][0]],
        y=list(tsurf_m_z_mean_ssa[50]) + [tsurf_m_z_mean_ssa[50][0]],
        z=list(pr_m_z_mean_ssa[50]) + [pr_m_z_mean_ssa[50][0]],
        mode='lines',
        name='50m',
        line=dict(color=colors[0]),
    ),
    go.Scatter3d(
        x=list(sw_m_z_mean_ssa[100]) + [sw_m_z_mean_ssa[100][0]],
        y=list(tsurf_m_z_mean_ssa[100]) + [tsurf_m_z_mean_ssa[100][0]],
        z=list(pr_m_z_mean_ssa[100]) + [pr_m_z_mean_ssa[100][0]],
        mode='lines',
        name='100m',
        line=dict(color=colors[1]),
    ),
    go.Scatter3d(
        x=list(sw_m_z_mean_ssa[200]) + [sw_m_z_mean_ssa[200][0]],
        y=list(tsurf_m_z_mean_ssa[200]) + [tsurf_m_z_mean_ssa[200][0]],
        z=list(pr_m_z_mean_ssa[200]) + [pr_m_z_mean_ssa[200][0]],
        mode='lines',
        name='200m',
        line=dict(color=colors[2]),
    ),
    go.Scatter3d(
        x=list(sw_m_z_mean_ssa[500]) + [sw_m_z_mean_ssa[500][0]],
        y=list(tsurf_m_z_mean_ssa[500]) + [tsurf_m_z_mean_ssa[500][0]],
        z=list(pr_m_z_mean_ssa[500]) + [pr_m_z_mean_ssa[500][0]],
        mode='lines',
        name='500m',
        line=dict(color=colors[3]),
    ),
    go.Scatter3d(
        x=[sw_m_z_mean_ssa[50][0]],
        y=[tsurf_m_z_mean_ssa[50][0]],
        z=[pr_m_z_mean_ssa[50][0]],
        mode='markers',
        name = '',
        marker=dict(size=3, color=colors[0]),
    ),
    go.Scatter3d(
        x=[sw_m_z_mean_ssa[100][0]],
        y=[tsurf_m_z_mean_ssa[100][0]],
        z=[pr_m_z_mean_ssa[100][0]],
        mode='markers',
        name = '',
        marker=dict(size=3, color=colors[1]),
    ),
    go.Scatter3d(
        x=[sw_m_z_mean_ssa[200][0]],
        y=[tsurf_m_z_mean_ssa[200][0]],
        z=[pr_m_z_mean_ssa[200][0]],
        mode='markers',
        name = '',
        marker=dict(size=3, color=colors[2]),
    ),
    go.Scatter3d(
        x=[sw_m_z_mean_ssa[500][0]],
        y=[tsurf_m_z_mean_ssa[500][0]],
        z=[pr_m_z_mean_ssa[500][0]],
        mode='markers',
        name = '',
        marker=dict(size=3, color=colors[3]),
    ),  
        go.Scatter3d(
        x=[sw_m_z_mean_ssa[50][183]],
        y=[tsurf_m_z_mean_ssa[50][183]],
        z=[pr_m_z_mean_ssa[50][183]],
        mode='markers',
        name = '',
        marker=dict(size=3, color=colors[0]),
    ),
    go.Scatter3d(
        x=[sw_m_z_mean_ssa[100][183]],
        y=[tsurf_m_z_mean_ssa[100][183]],
        z=[pr_m_z_mean_ssa[100][183]],
        mode='markers',
        name = '',
        marker=dict(size=3, color=colors[1]),
    ),
    go.Scatter3d(
        x=[sw_m_z_mean_ssa[200][183]],
        y=[tsurf_m_z_mean_ssa[200][183]],
        z=[pr_m_z_mean_ssa[200][183]],
        mode='markers',
        name = '',
        marker=dict(size=3, color=colors[2]),
    ),
    go.Scatter3d(
        x=[sw_m_z_mean_ssa[500][183]],
        y=[tsurf_m_z_mean_ssa[500][183]],
        z=[pr_m_z_mean_ssa[500][183]],
        mode='markers',
        name = '',
        marker=dict(size=3, color=colors[3]),
    ),  
])


fig.update_layout(
    autosize=False,
    width=1000,
    height=800,
    scene=dict(
        xaxis_title='Solar radiation (W/m<sup>2</sup>)',
        yaxis_title='Surface temperature (°C)',
        zaxis_title='Precipitation (mm/day)',
        camera=dict(
            eye=dict(x=0, y=0, z=2.0)
        )
    )
)


