# Composite vertical anomalous temperature structure of MHW events

------------------------------------------------
To do: 
- Redo phases so that there are three phases
- Make pcolormesh plots instead of contourf
- Add variability/error indication

## 0. Set up

In [None]:
# Import statements
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import cmocean
import pandas as pd
import matplotlib.ticker as ticker
import pyarrow.parquet as pq

## 1. Load data

In [None]:
filepaths = [f'/glade/derecho/scratch/cassiacai/vertical_structure_events_ens_{i}.nc' for i in range(100)]
all_events = xr.open_mfdataset(filepaths, combine='nested', concat_dim='event').compute()

In [None]:
# previously stored MLD
dataset = pq.ParquetDataset("HMXL_stats_0_100.parquet")

merged_df = dataset.read().to_pandas()
merged_df = merged_df.reset_index(drop=True)

# 0 --> issue with 3 and 4
# 1 --> issue with  2 and 3
# 3 --> issue with  2 and 3

merged_df.at[0, 'mean'] = np.delete(merged_df.at[0, 'mean'], [3])
merged_df.at[1, 'mean'] = np.delete(merged_df.at[1, 'mean'], [2])
merged_df.at[3, 'mean'] = np.delete(merged_df.at[3, 'mean'], [2])

merged_df.at[0, 'max'] = np.delete(merged_df.at[0, 'max'], [3])
merged_df.at[1, 'max'] = np.delete(merged_df.at[1, 'max'], [2])
merged_df.at[3, 'max'] = np.delete(merged_df.at[3, 'max'], [2])

# Flatten
flattened_MLD = [item for sublist in merged_df['mean'] for item in sublist]
max_flattened_MLD = [item for sublist in merged_df['max'] for item in sublist]

## 2. Storing phases and indices

In [None]:
# Initialize lists to store ALL phases AND indices
all_pre = []
all_mid1 = []
all_mid2 = []
all_mid3 = []
all_post = []
all_ens_idx = []  # To store ensemble indices
all_event_idx = []  # To store event indices
mld_ls = []
max_mld_ls = []
# Loop through each event
for event_id in range(all_events.data.shape[0]):
    # Extract one event's data
    event_data = all_events.isel(event=event_id).dropna(dim='time')
    # Get total duration in months
    total_months = event_data.sizes['time']
    
    # First 3 months are always PRE
    pre = event_data.data.isel(time=slice(0, 3)).mean(dim='time')
    
    # Last 3 months are always POST
    post = event_data.data.isel(time=slice(-3, None)).mean(dim='time')
    
    # Middle period (remaining months)
    middle_data = event_data.data.isel(time=slice(3, -3))
    middle_data_surf = event_data.data.isel(time=slice(3, -3)).isel(z_t = 0)

    if len(middle_data_surf) > 0:
        # Split middle period equally into Mid1 and Mid2
        if middle_data_surf.argmax().item() > 0:
            mid1 = middle_data.isel(time=slice(0, middle_data_surf.argmax().item())).mean(dim='time')
            mid2 = middle_data.isel(time=middle_data_surf.argmax().item()).drop_vars('time')
            mid3 = middle_data.isel(time=slice(middle_data_surf.argmax().item()+1, None)).mean(dim='time')
            mld = flattened_MLD[event_id]
            max_mld = max_flattened_MLD[event_id]
        if middle_data_surf.argmax().item() == 0:
            mid1 = middle_data.isel(time=middle_data_surf.argmax().item()).drop_vars('time')
            mid2 = middle_data.isel(time=middle_data_surf.argmax().item()).drop_vars('time')
            mid3 = middle_data.isel(time=slice(middle_data_surf.argmax().item()+1, None)).mean(dim='time')
            mld = flattened_MLD[event_id]
            max_mld = max_flattened_MLD[event_id]
        else:
            mid1 = middle_data.isel(time=middle_data_surf.argmax().item()).drop_vars('time')
            mid2 = middle_data.isel(time=middle_data_surf.argmax().item()).drop_vars('time')
            mid3 = middle_data.isel(time=middle_data_surf.argmax().item()).drop_vars('time')
            mld = flattened_MLD[event_id]
            max_mld = max_flattened_MLD[event_id]
        # Append data
        all_pre.append(pre)
        all_mid1.append(mid1)
        all_mid2.append(mid2)
        all_mid3.append(mid3)
        all_post.append(post)
        mld_ls.append(mld)
        max_mld_ls.append(max_mld)

In [None]:
mld_da = xr.DataArray(
    mld_ls,
    dims=["event"],
    name="MLD" #mean
)

max_mld_da = xr.DataArray(
    max_mld_ls,
    dims=["event"],
    name="max MLD" # max
)

# Combine phases into one DataArray
combined = xr.concat(
    [
        xr.concat(all_pre[:], dim="event"),
        xr.concat(all_mid1[:], dim="event"),
        xr.concat(all_mid2[:], dim="event"),
        xr.concat(all_mid3[:], dim="event"),
        xr.concat(all_post[:], dim="event"),
    ],
    dim=xr.DataArray(["pre", "mid1", "mid2", "mid3", "post"], dims="phase")
)

combined = combined.assign_coords(mld=mld_da)
combined = combined.assign_coords(max_mld=max_mld_da)

combined_dropna = combined.dropna(dim='event')

In [None]:
# plotting the vertical anomalous temperature structure for each MHW
for i in range(combined_dropna.shape[1]):
    new_combined = combined_dropna[:,i,:15].transpose()
    fig, ax = plt.subplots(figsize=(3, 3))
    
    contour = new_combined.plot.contourf(
        ax=ax,
        yincrease=False,                # Depth increases downward
        cmap='CMRmap',                  # Red-Blue diverging colormap
        robust=True,                    # Ignore outliers in color scaling
        levels=21, 
        vmin=0, vmax=0.5,
        add_colorbar=False              # We'll add it manually later
    )
    
    
    cbar = plt.colorbar(contour, ax=ax, pad=0.02)
    cbar.set_label("Temperature Anomaly (°C)", fontsize=12)
    
    ax.grid(True, linestyle=':', alpha=0.5)
    plt.axhline(y=0, c='k', linestyle='dashed')
    plt.tight_layout()
    plt.show()

## 3. Grouping MHWs by depth of maximum anomalous warming

In [None]:
event_max_var = combined_dropna.max(dim='phase')
max_depths = event_max_var.argmax(dim='z_t')
z_t_meters = combined_dropna.z_t
max_z_t = z_t_meters[max_depths]

# # # Bin by depth
# depth_bins = [0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, np.inf]
# depth_labels = ["0", "1000", "2000", "3000", "4000", "5000", "6000", "7000", "8000"]

# Bin by depth
depth_bins = [0, 1000, 3000, 5000, 7000, np.inf]
depth_labels = ["0", "1000", "3000", "5000","7000"]

depth_category = xr.DataArray(
    pd.cut(max_z_t, bins=depth_bins, labels=depth_labels),
    dims=["event"],
    name="depth_category"
)

# Group by depth category
grouped = combined_dropna.groupby(depth_category)

## 4. Visualization

#### Composite plots of MHW vertical anomalous temperature structure grouped by depth of max warming anomaly splitting in 20m intervals

In [None]:
ds_0 = grouped["0"]
ds_1000 = grouped["1000"]
ds_3000 = grouped["3000"]
ds_5000 = grouped["5000"]
ds_7000 = grouped["7000"]

print(ds_0.data.shape[1])
print(ds_1000.data.shape[1])
print(ds_3000.data.shape[1])
print(ds_5000.data.shape[1])
print(ds_7000.data.shape[1])

In [None]:
# print mean MLD
print(ds_0.mld.mean().item())
print(ds_1000.mld.mean().item())
print(ds_3000.mld.mean().item())
print(ds_5000.mld.mean().item())
print(ds_7000.mld.mean().item())

# print max MLD
max_0 = ds_0.max_mld.mean().item()
max_1000 = ds_1000.max_mld.mean().item()
max_3000 = ds_3000.max_mld.mean().item()
max_5000 = ds_5000.max_mld.mean().item()
max_7000 = ds_7000.max_mld.mean().item()

In [None]:
# # # Dictionary of your depth-binned datasets
# depth_datasets = {
#     '0-1000m': ds_0,
#     '1000-2000m': ds_1000,
#     '2000-3000m': ds_2000,
#     '3000-4000m': ds_3000,
#     '4000-5000m': ds_4000,
#     '5000-6000m': ds_5000,
#     '6000-7000m': ds_6000,
#     # '7000-8000m': ds_7000,
#     '8000-9000m': ds_8000
# }
# Dictionary of your depth-binned datasets
depth_datasets = {
    '0-1000m': ds_0,
    '1000-2000m': ds_1000,
    '3000-4000m': ds_3000,
    '5000-6000m': ds_5000,
    '7000-8000m': ds_7000,
}
# Create a figure for each depth bin
for depth_range, ds in depth_datasets.items():
    # Calculate mean across events
    new_combined = ds.mean(dim='event')

    new_combined = new_combined.isel(z_t = slice(0, 15))
    # Convert units if needed (assuming original is in cm)
    if new_combined["z_t"].units == "cm":
        new_combined["z_t"] = new_combined["z_t"] / 100
        new_combined["z_t"].attrs["units"] = "m"
    
    # Create plot
    fig, ax = plt.subplots(figsize=(3, 3))

    ax.set_ylim(500, 14000)
    # Contour plot
    contour = new_combined.transpose().plot.contourf(
        ax=ax,
        yincrease=False,
        robust=True,
        levels=26,
        cmap='CMRmap',#cmocean.cm.balance,                  # Red-Blue diverging colormap
        vmin=0, vmax=0.5,
        add_colorbar=False
    )
    contour = new_combined.transpose().plot.contour(
        ax=ax,
        yincrease=False,
        levels=[0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4],
        cmap='k',
        linewidths=0.5
    )
    plt.axhline(y=ds.max_mld.mean(), c='lime', linestyle='-')
    plt.axhline(y=ds.mld.mean(), c='lime', linestyle='dotted')
    # plt.title('')
    # Add phase dividers
    for x in [0.5, 1.5, 2.5, 3.5]:
        ax.axvline(x, color='white', linestyle='--', linewidth=0.75, alpha=0.3)
    # # Colorbar
    # cbar = plt.colorbar(contour, ax=ax, pad=0.05)
    # cbar.formatter = ticker.StrMethodFormatter("{x:.2f}")  # 1 decimal
    # cbar.update_ticks()
    # plt.axhline(y=
    # Labels and title
    # ax.set_title(f"n={ds.data.shape[1]} events", fontsize=14)
    ax.set_xlabel("MHW Phase", fontsize=12)
    ax.set_ylabel("Depth (m)", fontsize=12)
    ax.set_xticks([0., 1., 2, 3, 4])
    ax.set_xticklabels(["Pre", "Lead up", "Max", "Decline", "Post"], fontsize=10)
    # ax.grid(True, linestyle=':', alpha=0.5)
    # ax.axhline(y=0, c='k', linestyle='dashed')
    # Adjust y-axis tick labels to be divided by 100
    yticks = ax.get_yticks()
    ax.set_yticks(yticks)
    ax.set_yticklabels([f"{int(tick/100)}" for tick in yticks])

    plt.tight_layout()
    # plt.savefig(f"mhw_structure_{depth_range.replace('-','_')}.png", dpi=300)
    plt.show()

#### Bar plot of depth of max warming anomaly MHW counts splitting in 10m intervals

In [None]:
####### Bar plot of depth of max warming anomaly MHW counts splitting in 10m intervals
# event_max_var = combined_dropna.max(dim='phase')
# max_depths = event_max_var.argmax(dim='z_t')
# z_t_meters = combined_dropna.z_t
# max_z_t = z_t_meters[max_depths]

# # # Bin by depth
# depth_bins = [0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, np.inf]
# depth_labels = ["0", "1000", "2000", "3000", "4000", "5000", "6000", "7000"]
# # # Bin by depth
# # depth_bins = [0, 1000, 3000, 5000, 7000, np.inf]
# # depth_labels = ["0", "1000", "3000", "5000","7000"]

# depth_category = xr.DataArray(
#     pd.cut(max_z_t, bins=depth_bins, labels=depth_labels),
#     dims=["event"],
#     name="depth_category"
# )

# # Group by depth category
# grouped = combined_dropna.groupby(depth_category)

# ds_0 = grouped["0"]
# ds_1000 = grouped["1000"]
# ds_2000 = grouped["2000"]
# ds_3000 = grouped["3000"]
# ds_4000 = grouped["4000"]
# ds_5000 = grouped["5000"]
# ds_6000 = grouped["6000"]
# ds_7000 = grouped["7000"]

# print(ds_0.data.shape[1])
# print(ds_1000.data.shape[1])
# print(ds_2000.data.shape[1])
# print(ds_3000.data.shape[1])
# print(ds_4000.data.shape[1])
# print(ds_5000.data.shape[1])
# print(ds_6000.data.shape[1])
# print(ds_7000.data.shape[1])

In [None]:
# # Depth range to number of events mapping
# depth_counts = {
#     '0-10m': ds_0.data.shape[1],
#     '10-20m': ds_1000.data.shape[1],
#     '20-30m': ds_2000.data.shape[1],
#     '30-40m': ds_3000.data.shape[1],
#     '40-50m': ds_4000.data.shape[1],
#     '50-60m': ds_5000.data.shape[1],
#     '60-70m': ds_6000.data.shape[1],
#     '70-80m': ds_7000.data.shape[1],
#     '80-90m': ds_8000.data.shape[1],
#     # Add more if needed
# }

# # Extract labels and counts
# bin_labels = list(depth_counts.keys())
# counts = list(depth_counts.values())

# # Create horizontal bar plot
# plt.figure(figsize=(5, 4))
# bars = plt.barh(bin_labels, counts, color='teal')
# # Add count labels next to bars
# for bar in bars:
#     width = bar.get_width()
#     plt.text(width + 1, bar.get_y() + bar.get_height() / 2,
#              f'{width}', va='center', fontsize=12)

# # Plot formatting
# plt.xlabel('Number of Events')
# plt.ylabel('Depth Maximum (m)')
# plt.grid(axis='x', linestyle='--', alpha=0.6)
# plt.tight_layout()
# plt.xticks(fontsize=15)
# plt.yticks(fontsize=15)
# plt.show()

In [None]:
# Depth range to number of events mapping
depth_counts = {
    '0-10m': ds_0.data.shape[1],
    '10-20m': ds_1000.data.shape[1],
    '20-30m': ds_2000.data.shape[1],
    '30-40m': ds_3000.data.shape[1],
    '40-50m': ds_4000.data.shape[1],
    '50-60m': ds_5000.data.shape[1],
    '60-70m': ds_6000.data.shape[1],
    '70-80m': ds_7000.data.shape[1],
    '80-90m': ds_8000.data.shape[1],
}

# Extract labels and counts
bin_labels = list(depth_counts.keys())
counts = list(depth_counts.values())

# Convert to percentages
total = sum(counts)
percentages = [count / total * 100 for count in counts]

# Create horizontal bar plot
plt.figure(figsize=(3, 3))
bars = plt.barh(bin_labels, percentages, color='teal')

# Add percentage labels next to bars
for bar, pct in zip(bars, percentages):
    plt.text(pct + 0.5, bar.get_y() + bar.get_height() / 2,
             f'{pct:.1f}%', va='center', fontsize=12, fontweight='bold')

# Plot formatting
plt.xlabel('Percentage of Events', fontsize=12)
plt.grid(axis='x', linestyle='--', alpha=0.6)
plt.tight_layout()
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()


#### Bar plot of depth of max warming anomaly MHW counts splitting in 20m intervals

In [None]:
# Depth range to number of events mapping
depth_counts = {
    '0-10m': ds_0.data.shape[1],
    '10-30m': ds_1000.data.shape[1],
    # '20-30m': ds_2000.data.shape[1],
    '30-50m': ds_3000.data.shape[1],
    # '40-50m': ds_4000.data.shape[1],
    '50-70m': ds_5000.data.shape[1],
    # '60-70m': ds_6000.data.shape[1],
    '70+m': ds_7000.data.shape[1],
    # '80-90m': ds_8000.data.shape[1],
}

# Extract labels and counts
bin_labels = list(depth_counts.keys())
counts = list(depth_counts.values())

# Convert to percentages
total = sum(counts)
percentages = [count / total * 100 for count in counts]

# Create horizontal bar plot
plt.figure(figsize=(3, 3))
bars = plt.barh(bin_labels, percentages, color='teal')

# Add percentage labels next to bars
for bar, pct in zip(bars, percentages):
    plt.text(pct + 0.5, bar.get_y() + bar.get_height() / 2,
             f'{pct:.1f}%', va='center', fontsize=12, fontweight='bold')

# Plot formatting
plt.xlabel('Percentage of Events', fontsize=12)
plt.grid(axis='x', linestyle='--', alpha=0.6)
plt.tight_layout()
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()
