In [None]:
#Importing Libraries
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.ticker as ticker
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import ScalarFormatter
import matplotlib.gridspec as gridspec
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
import xarray as xr
import os; import time
import pickle
import h5py

In [None]:
#MAIN DIRECTORIES
mainDirectory='/mnt/lustre/koa/koastore/torri_group/air_directory/Projects/DCI-Project/'
scratchDirectory='/home/air673/koa_scratch/'
codeDirectory='/mnt/lustre/koa/koastore/torri_group/air_directory/Projects/DCI-Project/Project_Algorithms/Tracked_Profiles'

In [None]:
#LOADING DATA
def GetDataDirectories(simulationNumber):
    if simulationNumber == 1:
        Directory=os.path.join(mainDirectory,'Model/cm1r20.3/run')
        res='1km'; t_res='5min'; Np_str='1e6'; Nz_str='34'
    elif simulationNumber == 2:
        Directory=scratchDirectory
        res='1km'; t_res='1min'; Np_str='50e6'; Nz_str='95'
    elif simulationNumber == 3:
        Directory=scratchDirectory
        res='250m'; t_res='1min'; Np_str='50e6'; Nz_str='95'
        
    dataDirectory = os.path.join(Directory, f"cm1out_{res}_{t_res}_{Nz_str}nz.nc")
    parcelDirectory = os.path.join(Directory,f"cm1out_pdata_{res}_{t_res}_{Np_str}np.nc")
    return dataDirectory, parcelDirectory, res,t_res,Np_str,Nz_str
    
def GetData(dataDirectory, parcelDirectory):
    dataNC = xr.open_dataset(dataDirectory, decode_timedelta=True) 
    parcelNC = xr.open_dataset(parcelDirectory, decode_timedelta=True) 
    return dataNC,parcelNC

def SubsetDataVars(dataNC):
    varList = ["thflux", "qvflux", "tsk", "cape", 
               "cin", "lcl", "lfc", "th",
               "prs", "rho", "qv", "qc",
               "qr", "qi", "qs","qg", 
               "buoyancy", "uinterp", "vinterp", "winterp",]
    
    varList += ["ptb_hadv", "ptb_vadv", "ptb_hidiff", "ptb_vidiff",
                "ptb_hturb", "ptb_vturb", "ptb_mp", "ptb_rdamp", 
                "ptb_rad", "ptb_div", "ptb_diss",]
    
    varList += ["qvb_hadv", "qvb_vadv", "qvb_hidiff", "qvb_vidiff", 
                "qvb_hturb", "qvb_vturb", "qvb_mp",]
    
    varList += ["wb_hadv", "wb_vadv", "wb_hidiff", "wb_vidiff",
                "wb_hturb", "wb_vturb", "wb_pgrad", "wb_rdamp", "wb_buoy",]

    return dataNC[varList]

[dataDirectory,parcelDirectory, res,t_res,Np_str,Nz_str] = GetDataDirectories(simulationNumber=1)
[data1,parcel1] = GetData(dataDirectory, parcelDirectory)

In [None]:
dir='/mnt/lustre/koa/koastore/torri_group/air_directory/Projects/DCI-Project/'

In [None]:
#########################################

In [None]:
import sys
dir2='/mnt/lustre/koa/koastore/torri_group/air_directory/Projects/DCI-Project/'
path=dir2+'Functions/'
sys.path.append(path)

import NumericalFunctions
from NumericalFunctions import * # import NumericalFunctions 
import PlottingFunctions
from PlottingFunctions import * # import PlottingFunctions

# # Get all functions in NumericalFunctions
# import inspect
# functions = [f[0] for f in inspect.getmembers(NumericalFunctions, inspect.isfunction)]
# functions

#####

#Import StatisticalFunctions 
import sys
dir2='/mnt/lustre/koa/koastore/torri_group/air_directory/Projects/DCI-Project/'
path=dir2+'Functions/'
sys.path.append(path)

import StatisticalFunctions
from StatisticalFunctions import * # import NumericalFunctions 

In [None]:
def LoadAllCloudBase():
    dir2 = dir + f'Project_Algorithms/Tracking_Algorithms/OUTPUT/'
    in_file = dir2 + f"all_cloudbase_{res}_{t_res}_{Np_str}.pkl"
    with open(in_file, 'rb') as f:
        all_cloudbase = pickle.load(f)
    return(all_cloudbase)
min_all_cloudbase=np.nanmin(LoadAllCloudBase())
all_cloudbase=min_all_cloudbase
print(f"Minimum Cloudbase is: {all_cloudbase}\n")

In [None]:
def LoadMeanLFC():
    dir2 = dir + f'Project_Algorithms/Tracking_Algorithms/OUTPUT/'
    in_file = dir2 + f"MeanLFC_{res}_{t_res}_{Np_str}.pkl"
    with open(in_file, 'rb') as f:
        MeanLFC = pickle.load(f)
    return MeanLFC
MeanLFC=LoadMeanLFC()
print(f"Mean LFC is: {MeanLFC}\n")

In [None]:
#LIMITING Y AXIS
limit_y=True
limit_y=False

In [None]:
def limit_axes_to_y(ax, y_min=0, y_max=7, buffer_frac=0.1):
    ax.set_ylim(y_min, y_max)

    x_limited = []

    # Handle lines (from ax.plot)
    for line in ax.get_lines():
        xdata, ydata = np.array(line.get_xdata()), np.array(line.get_ydata())
        y_mask = (ydata >= y_min) & (ydata <= y_max)
        x_visible = xdata[y_mask]
        x_visible = x_visible[np.isfinite(x_visible)]
        x_limited.extend(x_visible)

    # Handle fill_betweenx (PolyCollection)
    for collection in ax.collections:
        try:
            paths = collection.get_paths()
            for path in paths:
                verts = path.vertices  # Nx2 array of (x, y)
                xdata, ydata = verts[:, 0], verts[:, 1]
                y_mask = (ydata >= y_min) & (ydata <= y_max)
                x_visible = xdata[y_mask]
                x_visible = x_visible[np.isfinite(x_visible)]
                x_limited.extend(x_visible)
        except Exception as e:
            print("Warning: failed to process collection:", e)

    if len(x_limited) > 0:
        x_limited = np.array(x_limited)
        x_min, x_max = np.min(x_limited), np.max(x_limited)

        if not (np.isfinite(x_min) and np.isfinite(x_max)):
            print("Warning: Non-finite x-limits detected, skipping set_xlim")
            return

        x_range = x_max - x_min
        buffer = buffer_frac * x_range if x_range > 0 else 0.1
        ax.set_xlim(x_min - buffer, x_max + buffer)
    else:
        print("Warning: No visible x data within y limits to set xlim")


In [None]:
#LOADING DATA
################################################################

In [None]:
#CL vs nonCL
################################################################

In [None]:
data_type="Tracked_Properties"
type1='CL';type2='nonCL'
dir3=dir+'Project_Algorithms/Tracked_Profiles/OUTPUT_FILES/'
filePath=dir3+f"{data_type}_"+f"{type1}_{type2}_tracked_profiles_{res}_{t_res}_{Np_str}.h5"
key_list=[]
with h5py.File(filePath, 'r') as h5f:
    for key in h5f.keys():
        globals()[key] = h5f[key][:]
        if '_squares' not in key:
            key_list.append(key)

#CALCULATING STANDARD DEVIATION
for key in key_list:
    globals()[key+f"_SE"]=ProfileStandardError(globals()[key],globals()[key+f"_squares"]); factor=1.96
    # globals()[key+f"_SE"]=ProfileStandardDeviation(globals()[key],globals()[key+f"_squares"]); factor=1

#MULTIPLING QV BY 1000
for key in key_list:
    # print(key)
    if 'Q' in key:
        globals()[key][:,0]*=1000
        globals()[key+f"_SE"][:,0]*=1000
    # print(key)

In [None]:
#NEEDED TO PLOT THE CORRECT DATA
data_type="Tracked_WQVTH_Budgets"

type1='CL';type2='nonCL'
dir3=dir+'Project_Algorithms/Tracked_Profiles/OUTPUT_FILES/'
filePath=dir3+f"{data_type}_"+f"{type1}_{type2}_tracked_profiles_{res}_{t_res}_{Np_str}.h5"
key_list=[]
with h5py.File(filePath, 'r') as h5f:
    for key in h5f.keys():
        globals()[key] = h5f[key][:]
        if '_squares' not in key:
            key_list.append(key)

#CALCULATING STANDARD DEVIATION
for key in key_list:
    globals()[key+f"_SE"]=ProfileStandardError(globals()[key],globals()[key+f"_squares"]);factor=1.96
    # globals()[key+f"_SE"]=ProfileStandardDeviation(globals()[key],globals()[key+f"_squares"]);factor=1

#MULTIPLING QV BY 1000
for key in key_list:
    if 'QVB' in key:
        globals()[key][:,0]*=1000
        globals()[key+f"_SE"][:,0]*=1000
    # print(key)

In [None]:
#SBZ vs nonSBZ
################################################################

In [None]:
data_type="Tracked_Properties"
type1='SBZ';type2='nonSBZ'
dir3=dir+'Project_Algorithms/Tracked_Profiles/OUTPUT_FILES/'
filePath=dir3+f"{data_type}_"+f"{type1}_{type2}_tracked_profiles_{res}_{t_res}_{Np_str}.h5"
key_list=[]
with h5py.File(filePath, 'r') as h5f:
    for key in h5f.keys():
        globals()[key] = h5f[key][:]
        if '_squares' not in key:
            key_list.append(key)
        # print(key)

#CALCULATING STANDARD DEVIATION
for key in key_list:
    globals()[key+f"_SE"]=ProfileStandardError(globals()[key],globals()[key+f"_squares"])
    # globals()[key+f"_SE"]=ProfileStandardDeviation(globals()[key],globals()[key+f"_squares"])

#MULTIPLING QV BY 1000
for key in key_list:
    if 'Q' in key:
        globals()[key][:,0]*=1000
        globals()[key+f"_SE"][:,0]*=1000
    # print(key)

In [None]:
#NEEDED TO PLOT THE CORRECT DATA #*#*
data_type="Tracked_WQVTH_Budgets"

type1='SBZ';type2='nonSBZ'
dir3=dir+'Project_Algorithms/Tracked_Profiles/OUTPUT_FILES/'
filePath=dir3+f"{data_type}_"+f"{type1}_{type2}_tracked_profiles_{res}_{t_res}_{Np_str}.h5"
key_list=[]
with h5py.File(filePath, 'r') as h5f:
    for key in h5f.keys():
        globals()[key] = h5f[key][:]
        if '_squares' not in key:
            key_list.append(key)

#CALCULATING STANDARD DEVIATION
for key in key_list:
    globals()[key+f"_SE"]=ProfileStandardError(globals()[key],globals()[key+f"_squares"])
    # globals()[key+f"_SE"]=ProfileStandardDeviation(globals()[key],globals()[key+f"_squares"])

#MULTIPLING QV BY 1000
for key in key_list:
    if 'QVB' in key:
        globals()[key][:,0]*=1000
        globals()[key+f"_SE"][:,0]*=1000

In [None]:
#ColdPool
################################################################

# data_type="Tracked_Properties"
# type1='ColdPool'
# dir3=dir+'Project_Algorithms/Tracked_Profiles/OUTPUT_FILES/'
# filePath=dir3+f"{data_type}_"+f"{type1}_tracked_profiles_{res}_{t_res}_{Np_str}.h5"
# key_list=[]
# with h5py.File(filePath, 'r') as h5f:
#     for key in h5f.keys():
#         globals()[key] = h5f[key][:]
#         if '_squares' not in key:
#             key_list.append(key)

# #CALCULATING STANDARD DEVIATION
# for key in key_list:
#     # globals()[key+f"_SE"]=ProfileStandardError(globals()[key],globals()[key+f"_squares"])
#     globals()[key+f"_SE"]=ProfileStandardDeviation(globals()[key],globals()[key+f"_squares"])

# # #MULTIPLING QV BY 1000
# # for key in key_list:
# #     if 'Q' in key:
# #         globals()[key][:,0]*=1000
# #         globals()[key+f"_SE"][:,0]*=1000






# #NEEDED TO PLOT THE CORRECT DATA #*#*
# data_type="Tracked_WQVTH_Budgets"

# type1='ColdPool'
# dir3=dir+'Project_Algorithms/Tracked_Profiles/OUTPUT_FILES/'
# filePath=dir3+f"{data_type}_"+f"{type1}_tracked_profiles_{res}_{t_res}_{Np_str}.h5"
# key_list=[]
# with h5py.File(filePath, 'r') as h5f:
#     for key in h5f.keys():
#         globals()[key] = h5f[key][:]
#         if '_squares' not in key:
#             key_list.append(key)

# #CALCULATING STANDARD DEVIATION
# for key in key_list:
#     # globals()[key+f"_SE"]=ProfileStandardError(globals()[key],globals()[key+f"_squares"])
#     globals()[key+f"_SE"]=ProfileStandardDeviation(globals()[key],globals()[key+f"_squares"])

# # #MULTIPLING QV BY 1000
# # for key in key_list:
# #     if 'QVB' in key:
# #         globals()[key][:,0]*=1000
# #         globals()[key+f"_SE"][:,0]*=1000

In [None]:
#############################################
#PLOTTING

In [None]:
#Produced averaged profiles for plotting
def averaged_profiles(profile):
    out_var=profile[ (profile[:, 1] > 1)]; #gets rid of rows that have no data
    out_var=np.array([out_var[:, 0] / out_var[:, 1], out_var[:, 2]]).T #divides the data column by the counter column
    return out_var

In [None]:
def PlotSE(ax, profile, SE_profile, color, factor, switch=1, alpha=0.1, min_value=None):
    lower = profile[:, 0] - factor * SE_profile[:, 0] * switch
    upper = profile[:, 0] + factor * SE_profile[:, 0] * switch
    if min_value is not None:
        lower = np.maximum(lower, min_value)
    ax.fill_betweenx(profile[:, -1], lower, upper, color=color, alpha=alpha)
    #Example #PlotSE(ax,profile,SE_profile,color='black',factor=factor, min_value=min_value)

In [None]:
def PlotBudgetProfiles(types, linestyles, vars, var_unit, budget_unit, title_tag, colors, axs, min_value=None):
    # axs should be a list/array of 3 Axes objects for that row
    
    first_legend = None  # To store the first legend object

    for type3, linestyle in zip(types, linestyles):
        for cloud_type in ["all", "shallow", "deep"]:
            # print(f'Currently on cloud type: {cloud_type}')

            def plotting(out_var, axis, label, color, linestyle,linewidth=1.25):
                axis.plot(out_var[:, 0], out_var[:, 1], label=label, color=color, linestyle=linestyle,linewidth=linewidth)
                axis.grid(True)

            for var in vars:
                profile_name = f"{type3}_{cloud_type.upper()}_profile_array_{var.upper()}"
                profile_se_name = f"{profile_name}_SE"
                globals()[f"profile_{var}"] = globals()[profile_name]
                globals()[f"profile_{var}_SE"] = globals()[profile_se_name]

            for var in vars:
                globals()[f"out_{var}"] = averaged_profiles(globals()[f"profile_{var}"])

            # Choose axis depending on cloud_type
            if cloud_type == 'all':
                axis_rest = axs[0]
                if 'w' in vars[0]:
                    axs[0].set_title('ALL')
            elif cloud_type == 'shallow':
                axis_rest = axs[1]
                if 'w' in vars[0]:
                    axs[1].set_title('SHALLOW')
            elif cloud_type == 'deep':
                axis_rest = axs[2]
                if 'w' in vars[0]:
                    axs[2].set_title('DEEP')

            for var, label, color in zip(vars[:], vars[:], colors[:]):
                out_var = globals()[f"out_{var}"]
                SE_profile = globals()[f"profile_{var}_SE"]
                # if 'qv' in vars: #ALREADY DONE WHEN LOADING IN DATA
                #     out_var[:, 0] *= 1000
                #     SE_profile[:, 0] *= 1000
                
                if linestyle != 'solid':
                    label = ""
                plotting(out_var, axis_rest, label=label, color=color, linestyle=linestyle)
                profile = out_var.copy()
                PlotSE(axis_rest, profile, SE_profile, color=color, factor=factor)
                axis_rest.set_ylabel('z (km)')
                axis_rest.set_xlabel(f'({budget_unit})')
                if cloud_type == 'shallow':
                    first_legend = axis_rest.legend(loc='upper left', fontsize=8)

    for ax in axs:
        ax.set_ylim(bottom=0)
        if limit_y == True:
            limit_axes_to_y(ax, y_min=0, y_max=7)
        if limit_y == False:
            limit_axes_to_y(ax, y_min=0, y_max=20)

    # ===== FIX TICKS =====
    SnapLimitsToTicks(axs, dim='x')
    apply_scientific_notation(axs,decimals=2,scientific=True)

    for axis in axs:
        axis.axhline(all_cloudbase, color='purple', linestyle='dashed', zorder=-100)
        axis.axhline(MeanLFC / 1000, color='green', linestyle='dashed', zorder=-100)

    ##### SECOND LEGEND FOR LINESTYLES #####
    custom_lines = [
        Line2D([0], [0], color='black', linestyle='solid', linewidth=1, label='CL'),
        Line2D([0], [0], color='black', linestyle='dashed', linewidth=1, label='nonCL'),
        Line2D([0], [0], color='black', linestyle='dashdot', linewidth=1, label='SBF'),
        # Line2D([0], [0], color='black', linestyle='dotted', linewidth=1, label='nonSBF'),
    ]
    # Add second legend only once per big figure (we will add after all rows)
    
    return first_legend, custom_lines  # Return to add legends outside if needed


In [None]:
def SavePlot(fig, type1, res, t_res, Np_str, dir2, limit_y=False):
    # Define output directory
    subdir_name = f'{res}_{t_res}_{Np_str}'
    output_dir = os.path.join(dir2, 'Project_Algorithms', 'Tracked_Profiles', 'PLOTS', 'Tracked_Budget', subdir_name)

    if limit_y==True:
        output_dir = os.path.join(output_dir, "SLICE")
    
    # Create the directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    #SAVE PATH
    filename = f'{type1}_Budget_Tracked_Profiles_{res}_{t_res}_{Np_str}.jpg'
    save_path = os.path.join(output_dir, filename)
    
    # Save the figure
    fig.savefig(save_path, bbox_inches='tight', dpi=300)
    print(f"Saved figure to: {save_path}")


In [None]:
fig = plt.figure(figsize=(18, 15))  # width x height, adjust as needed
gs = GridSpec(nrows=3, ncols=3, figure=fig, hspace=0.15, wspace=0.15)


# Create axes array for each row
axs_row1 = [fig.add_subplot(gs[0, i]) for i in range(3)]
axs_row2 = [fig.add_subplot(gs[1, i]) for i in range(3)]
axs_row3 = [fig.add_subplot(gs[2, i]) for i in range(3)]

#LISTS
types=['CL','nonCL','SBZ']
linestyles=['solid','dashed','dotted']

# First run
first_legend1, custom_lines = PlotBudgetProfiles(
    types=types,
    linestyles=linestyles,
    vars=[
        'wb_hadv',
        'wb_vadv',
        'wb_hidiff',
        'wb_vidiff',
        'wb_hturb',
        'wb_vturb',
        'wb_pgrad',
        'wb_buoy'],
    var_unit='m/s',
    budget_unit=r'$m/s^2$',
    title_tag=type1+' vs '+type2 + ' W',
    # colors=[
    #     'black', 'blue', 'orange', 'green', 'red',
    #     'purple', 'brown', 'pink', 'grey', 'cyan'],
    colors = [
        'blue', 'orange', 'green', 'red',
        'purple', 'brown', 'pink', 'grey'],
    axs=axs_row1
)

# Second run
first_legend2, _ = PlotBudgetProfiles(
    types=types,
    linestyles=linestyles,
    vars=[
        'qvb_hadv',
        'qvb_vadv',
        'qvb_hidiff',
        'qvb_vidiff',
        'qvb_hturb',
        'qvb_vturb',
        'qvb_mp'],
    var_unit='g/kg',
    budget_unit='g/kg/s',
    title_tag=type1+' vs '+type2 + ' QV',
    # colors=[
    #     'black', 'blue', 'orange', 'green', 'red',
    #     'purple', 'brown', 'pink', 'grey', 'cyan'],
    colors = [
        'blue', 'orange', 'green', 'red',
        'purple', 'brown', 'pink', 'grey'],
    axs=axs_row2,
    min_value=0
)

# Third run
first_legend3, _ = PlotBudgetProfiles(
    types=types,
    linestyles=linestyles,
    vars=[
        'ptb_hadv',
        'ptb_vadv',
        'ptb_hidiff',
        'ptb_vidiff',
        'ptb_hturb',
        'ptb_vturb',
        'ptb_mp'],
    var_unit='K',
    budget_unit='K/s',
    title_tag=type1+' vs '+type2 + ' TH',
    # colors=[
    #     'black', 'blue', 'orange', 'green', 'red',
    #     'purple', 'brown', 'pink', 'grey', 'cyan'],
    colors = [
        'blue', 'orange', 'green', 'red',
        'purple', 'brown', 'pink', 'grey'],
    axs=axs_row3
)

# Add the custom legend for linestyles once on the big figure
fig.legend(
    handles=custom_lines,
    loc='upper center',
    ncol=4,
    fontsize=10,
    title='Types',
    title_fontsize=12,
    bbox_to_anchor=(0.51, 0.95),  # Lower the legend below the top edge
    borderaxespad=0
)


# ===== FIXING TICKS =====
axes=fig.get_axes()
MatchAxisLimits([axes[0],axes[1],axes[2]])
MatchAxisLimits([axes[3],axes[4],axes[5]])
MatchAxisLimits([axes[6],axes[7],axes[8]])

#SAVING FIGURE
SavePlot(fig,'CLnonCL_SBF', res, t_res, Np_str, dir2, limit_y)

In [None]:
limit_y=True