In [9]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
from pathlib import Path
import xarray as xr
import matplotlib as mpl
import cartopy as cpy
import datetime as dt

In [10]:
DATA_DIR = "/gws/nopw/j04/kscale/USERS/dship/LoSSETT_out/"
simid = 'CTC5GAL'
n_scales = 32
fpath_mean = os.path.join(
        DATA_DIR,
        f"DR_test_{simid}_Nl_{n_scales}_time-mean.nc"
    )

if not os.path.exists(fpath_mean):
    print('mean file doesnt exist - compute')

ds = xr.open_dataset(fpath_mean)
ds['length_scale']

##### Cell to choose single scale and change domain

In [120]:
ell=1760000
da = ds['DR_indicator'].sel(length_scale=ell, method='nearest')
da_flat = da.stack(sample=("latitude", "longitude"))  # now (pressure, sample)
da_flat = da_flat.drop_vars(['latitude', 'longitude'])
df = da_flat.to_dataframe().reset_index()

##### Main plotting script

In [None]:
# Loop to run for all scales
for ell in ds['length_scale'].values:
    da = ds['DR_indicator'].sel(length_scale=ell, method='nearest')
    da_flat = da.stack(sample=("latitude", "longitude"))
    da_flat = da_flat.drop_vars(['latitude', 'longitude'])
    df = da_flat.to_dataframe().reset_index()
    
    pal = sns.color_palette("plasma", len(df['pressure'].unique()))
    fig, axes = plt.subplots(len(df['pressure'].unique()), 1, figsize=(5, 5), sharex=True)
    
    pressure_levels = sorted(df['pressure'].unique())

    mean_values = []
    subplot_positions = []

    for i, pressure in enumerate(pressure_levels):
        ax = axes[i]
    
        subset = df[df['pressure'] == pressure] # Get data for this pressure level
      
        color = pal[i] # Get color for this level
        
        # Plot KDE for this pressure level
        sns.kdeplot(data=subset, x="DR_indicator", ax=ax,
                    fill=True, color=color, alpha=0.7, linewidth=1.5,
                    bw_adjust=1.0, clip_on=True)
        
        # Add the pressure label
        ax.text(-0.00011, 0.5, f"{int(pressure)} hPa", color=color, ha="right", va="center", transform=ax.transAxes)

        mean_value = subset["DR_indicator"].mean()
        ax.axvline(x=mean_value, color='red', linestyle='--', linewidth=1.5, label='Mean')

        # Axis formatting        
        ax.set_yticks([])
        ax.ticklabel_format(axis='x',style='sci',scilimits=(0,0))
        ax.set_ylabel('')
        
        # Remove all spines except bottom on the last subplot
        if i < len(pressure_levels) - 1:
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.set_xlabel('')
            ax.tick_params(axis='x',which='both',bottom=False,top=False,labelbottom=False)
        else:
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['left'].set_visible(False)
        
        ax.set_xlim(-5e-4, 5e-4)
        ax.set_xticks([-5e-4,-4e-4,-3e-4,-2e-4,-1e-4,0,1e-4,2e-4,3e-4,4e-4,5e-4])

        ax.axvline(x=0,color='k',linewidth=0.5)

    # Set x-label on the bottom axis only
    axes[-1].set_xlabel(r"$\mathcal{D}_{\ell}$ (m$^2$ s$^{-3}$)", fontsize=12)

    # Adjust spacing between subplots
    plt.subplots_adjust(hspace=0.1)  # Less extreme spacing
    plt.suptitle(f'{simid} DYAMOND Summer | ' + r'$\ell$' + f'={int(ell / 1000)} km', y=0.92)
    ell_km = f"{int(ell / 1000):04d}" # format to 4sf for sorting outputs
    # Save the figure for the current length scale
    output_path = f'/home/users/emg97/emgScripts/LoSSETT/plotting/Plots/Dlu_hist_ridges/LO_ridge_{simid}DS_l{ell_km}km.png'
    fig.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close(fig)  # Close the figure to avoid memory issues