In [None]:
####################################
#ENVIRONMENT SETUP

In [None]:
#Importing Libraries
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.ticker as ticker
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import ScalarFormatter
import matplotlib.gridspec as gridspec
import xarray as xr

import sys; import os; import time; from datetime import timedelta
import pickle
import h5py

from tqdm import tqdm

from glob import glob

In [None]:
#MAIN DIRECTORIES
def GetDirectories():
    mainDirectory='/mnt/lustre/koa/koastore/torri_group/air_directory/Projects/DCI-Project/'
    mainCodeDirectory=os.path.join(mainDirectory,"Code/CodeFiles/")
    scratchDirectory='/mnt/lustre/koa/scratch/air673/'
    codeDirectory=os.getcwd()
    return mainDirectory,mainCodeDirectory,scratchDirectory,codeDirectory

[mainDirectory,mainCodeDirectory,scratchDirectory,codeDirectory] = GetDirectories()

In [None]:
def GetPlottingDirectory(plotFileName, plotType):
    plottingDirectory = mainCodeDirectory=os.path.join(mainDirectory,"Code","PLOTTING")
    
    specificPlottingDirectory = os.path.join(plottingDirectory, plotType, 
                                             f"{ModelData.res}_{ModelData.t_res}_{ModelData.Nz_str}nz")
    os.makedirs(specificPlottingDirectory, exist_ok=True)

    plottingFileName=os.path.join(specificPlottingDirectory, plotFileName)

    return plottingFileName

def SaveFigure(fig,plotType, fileName):
    plotFileName = f"{fileName}_{ModelData.res}_{ModelData.t_res}_{ModelData.Np_str}.jpg"
    plottingFileName = GetPlottingDirectory(plotFileName, plotType)
    print(f"Saving figure to {plottingFileName}")
    fig.savefig(plottingFileName, dpi=300, bbox_inches='tight')

In [None]:
#IMPORT CLASSES
sys.path.append(os.path.join(mainCodeDirectory,"2_Variable_Calculation"))
from CLASSES_Variable_Calculation import ModelData_Class, SlurmJobArray_Class, DataManager_Class

In [None]:
#data loading class
ModelData = ModelData_Class(mainDirectory, scratchDirectory, simulationNumber=1)
#data manager class
DataManager = DataManager_Class(mainDirectory, scratchDirectory, ModelData.res, ModelData.t_res, ModelData.Nz_str,
                                ModelData.Np_str, dataType="Tracking_Algorithms", dataName="Lagrangian_UpdraftTracking",
                                dtype='float32',codeSection = "Project_Algorithms")

In [None]:
#IMPORT FUNCTIONS
sys.path.append(os.path.join(mainCodeDirectory,"2_Variable_Calculation"))
import FUNCTIONS_Variable_Calculation
from FUNCTIONS_Variable_Calculation import *

In [None]:
#IMPORT CLASSES
sys.path.append(os.path.join(mainCodeDirectory,"3_Project_Algorithms","2_Tracking_Algorithms"))
from CLASSES_TrackingAlgorithms import TrackingAlgorithms_DataLoading_Class, SlurmJobArray_Class, Results_InputOutput_Class, TrackedParcel_Loading_Class

In [None]:
import sys
dir2='/mnt/lustre/koa/koastore/torri_group/air_directory/Projects/DCI-Project/'
path=os.path.join(dir2,'Code/CodeFiles/Functions')
sys.path.append(path)

import PlottingFunctions
from PlottingFunctions import * # import PlottingFunctions

# # Get all functions in NumericalFunctions
# import inspect
# functions = [f[0] for f in inspect.getmembers(NumericalFunctions, inspect.isfunction)]
# functions

In [None]:
#############################################
#LOADING DATA

In [None]:
#READING BACK IN SUBSETTED TRACKED PARCEL DATA
trackedArrays,LevelsDictionary = TrackedParcel_Loading_Class.LoadingSubsetParcelData(ModelData,DataManager,
                                                         Results_InputOutput_Class)

In [None]:
def Get_AvgConvergence(t):

    timeString = ModelData.timeStrings[t]
    outputDataDirectory=os.path.normpath(os.path.join(DataManager.outputDataDirectory,"..","Eulerian_CLTracking"))
    Dictionary = TrackingAlgorithms_DataLoading_Class.LoadData(ModelData, DataManager, timeString,
                     dataName="Eulerian_CLTracking",outputDataDirectory=outputDataDirectory,printstatement=False)
    avgConvergence = Dictionary["avgConvergence"]
    return avgConvergence
    
def find_SBF_xmaxs():
    xmaxs=[]
    for t in range(ModelData.Ntime)[1:]:
        if t == 0:
            avgConvergence_max=np.nan
        else:
            avgConvergence = Get_AvgConvergence(t)
            avgConvergence_max=np.max(avgConvergence)
            xmax = np.where(avgConvergence==avgConvergence_max)[0][0]
            xmaxs.append(xmax)
    return xmaxs
xmaxs=find_SBF_xmaxs()

In [None]:
def Get_LagrangianArrays(t, dataType="VARS", dataName="VARS", varNames=["W"]):
    res = ModelData.res
    t_res = ModelData.t_res
    Nz_str = ModelData.Nz_str
    inputDirectory = os.path.join(DataManager.inputDirectory,
                                  "..","LagrangianArrays",
                                  f"{res}_{t_res}_{Nz_str}nz", dataType)
    timeString = ModelData.timeStrings[t]

    FileName = os.path.join(inputDirectory, f"{dataName}_{res}_{t_res}_{Nz_str}nz_{timeString}.h5")

    dataDictionary = {}
    with h5py.File(FileName, 'r') as f:
        # print("Keys in file:", list(f.keys()))
        for key in varNames:
            dataDictionary[key] = f[key][:]
            # print(f"{key}: shape = {dataDictionary[key].shape}, dtype = {dataDictionary[key].dtype}")
    return dataDictionary

In [None]:
#############################################
#RUNNING FUNCTIONS

In [None]:
#numerical info
zh = ModelData.zh
yh = ModelData.yh-ModelData.yh[0]
xh = ModelData.xh-ModelData.xh[0]


# kms=np.argmax(ModelData.xh-ModelData.xh[0] >= 1)

In [None]:
def CollectData(trackedArray):
    #getting parcel index and time
    ps = trackedArray[:,0]
    ts = trackedArray[:,1]

    # sort by time
    sort_idx = np.argsort(ts)
    ts_sorted = ts[sort_idx]
    ps_sorted = ps[sort_idx]

    #initializing lists
    T_List = []; Z_List = []; Y_List = []; X_List = []
    Xdiff_List = []
    QV_List = []
    THETA_v_List = []

    #time cache (to avoid redundant looping
    previous_t = None
    
    #running through each parcel
    for t, p in tqdm(
        zip(ts_sorted, ps_sorted),
        total=len(ts_sorted),
        desc="Processing timesteps"):

        #X and VARS loading
        if t != previous_t:
            previous_t = t
            timeString = ModelData.timeStrings[t]

            Z_t = CallLagrangianArray(ModelData, DataManager, timeString, 'Z')
            Y_t = CallLagrangianArray(ModelData, DataManager, timeString, 'Y')
            X_t = CallLagrangianArray(ModelData, DataManager, timeString, 'X')
            xmaxs_t = xh[xmaxs[t]]
            
            VARS=Get_LagrangianArrays(t,varNames=["QV","THETA_v"])
            QV_t = VARS["QV"]
            THETA_v_t = VARS["THETA_v"]

        #DISTANCE METRICS
        Z_tp = Z_t[p] #getting z-grid number
        Y_tp = Y_t[p] #getting y-grid number
        X_tp = X_t[p] #getting x-grid number
        #converting to km
        Z_tp = zh[Z_tp]
        Y_tp = yh[Y_tp]
        X_tp = xh[X_tp]
    
        #getting index distance from sea-breeze
        Xdiff = X_tp - xmaxs_t
    
        #appending results to list
        T_List.append(t)
        Z_List.append(Z_tp)
        Y_List.append(Y_tp)
        X_List.append(X_tp)
        Xdiff_List.append(Xdiff)

        #VARIABLES
        QVParcel_t = QV_t[p]
        THETA_vParcel_t = THETA_v_t[p]
        
        QV_List.append(QVParcel_t)
        THETA_v_List.append(THETA_vParcel_t)
        
    return T_List,Z_List,Y_List,X_List,Xdiff_List, QV_List,THETA_v_List

In [None]:
def RunAllParcelTypes():
    results = {}
    
    for outer_key, inner_dict in trackedArrays.items():          # e.g. "CL"
        results[outer_key] = {}
        for inner_key, trackedArray in inner_dict.items():       # e.g. "DEEP"
            print(f"\nRunning CollectData for {outer_key} - {inner_key}")
    
            if trackedArray is None or len(trackedArray) == 0:
                print(f"  Skipping {outer_key}-{inner_key}: empty array")
                continue
    
            [T_List,Z_List,Y_List,X_List,Xdiff_List, QV_List,THETA_v_List] = CollectData(trackedArray)
    
            # store results in nested dict
            results[outer_key][inner_key] = {
                "T_List": T_List,
                "Z_List": Z_List,
                "Y_List": Y_List,
                "X_List": X_List,
                "Xdiff_List": Xdiff_List,
                "QV_List": QV_List,
                "THETA_v_List": THETA_v_List
            }
    return results

In [None]:
def LoadorRun():
    """
    Loads the tracked parcel results from a pickle file if it exists;
    otherwise runs RunAllParcelTypes() and saves the output.
    """
    fileName = f"Tracked_Histogram_Output_{ModelData.res}_{ModelData.t_res}_{ModelData.Nzh}nz.pkl"
    filePath = os.path.join(codeDirectory, fileName)

    if os.path.exists(filePath):
        # Load existing results
        with open(filePath, "rb") as f:
            results = pickle.load(f)
        print(f"Loaded results from {filePath}")
    else:
        # Run function and save new results
        print(f"No pickle file found, running RunAllParcelTypes()...")
        results = RunAllParcelTypes()

        with open(filePath, "wb") as f:
            pickle.dump(results, f)
        print(f"Saved results to {filePath}")

    return results

In [None]:
#############################################
#RUNNING

In [None]:
results = LoadorRun()

In [None]:
#############################################
#PLOTTING FUNCTIONS

In [None]:
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = ["DejaVu Sans", "Helvetica", "Arial"]
plt.rcParams["axes.titlesize"] = 8
plt.rcParams["axes.titleweight"] = "normal"   # no bold
plt.rcParams["axes.labelsize"] = 7            # smaller axis labels
plt.rcParams["xtick.labelsize"] = 10           # smaller tick labels
plt.rcParams["ytick.labelsize"] = 10

In [None]:
def PlotHistogram(axis, dataList, xlabel, bins=50, orientation="vertical",
                  color='steelblue', title=None, 
                  plotKDE=True):

    data = np.array(dataList)

    # histogram
    counts, bin_edges, _ = axis.hist(
        data,
        bins=bins,
        color=color,
        edgecolor='black',
        alpha=0.7,
        orientation=orientation        # <-- THIS FIXES YOUR PROBLEM
    )

    if plotKDE == True:
        # KDE
        kde = gaussian_kde(data)
        x_vals = np.linspace(bin_edges[0], bin_edges[-1], 400)
        bin_width = bin_edges[1] - bin_edges[0]
        kde_scaled = kde(x_vals) * len(data) * bin_width

        # plot KDE depending on orientation
        if orientation == "horizontal":
            axis.plot(
                kde_scaled, x_vals,
                color='blue', linewidth=1.8, zorder=10,
                label="_ignore_snap_"
            )
            axis.set_xlabel("Count")
            axis.set_ylabel(xlabel)
        else:
            axis.plot(
                x_vals, kde_scaled,
                color='blue', linewidth=1.8, zorder=10,
                label="_ignore_snap_"
            )
            axis.set_xlabel(xlabel)
            axis.set_ylabel("Count")
    
        if title:
            axis.set_title(title, fontsize=10, pad=10)
    
        axis.grid(True, linestyle='--', alpha=0.4)


In [None]:
#############################################
#SINGLE PLOTTING FUNCTIONS

In [None]:
# def PlotDistancesFunction(parcel_type):
#     # choose which outer key to plot
#     ptype = parcel_type
#     depth_types = ["ALL", "SHALLOW", "DEEP"]
    
#     # set up figure (2 rows × 3 columns)
#     fig = plt.figure(figsize=(12, 8))
#     gs  = gridspec.GridSpec(2, len(depth_types), figure=fig,
#                             wspace=0.3, hspace=0.35)
    
#     # loop through depth types
#     for j, depth in enumerate(depth_types):
#         # first row: X_List
#         ax_top = fig.add_subplot(gs[0, j])
#         if ptype in results and depth in results[ptype]:
#             data_x = results[ptype][depth]["X_List"]
#             data_x_mean = np.mean(data_x)
#             PlotHistogram(ax_top, data_x,
#                           xlabel="X distance from left side (km)",
#                           title=f"{ptype} – {depth}\n" 
#                           + r"$\mu$ = %.2f km" % data_x_mean)
    
#         # second row: Xdiff_List
#         ax_bottom = fig.add_subplot(gs[1, j])
#         if ptype in results and depth in results[ptype]:
#             data_xdiff = results[ptype][depth]["Xdiff_List"]
#             data_xdiff_mean = np.mean(data_xdiff)
#             PlotHistogram(ax_bottom, data_xdiff,
#                           xlabel="X distance from SBF (km)",
#                           title=f"{ptype} – {depth}\n"
#                                 + r"$\mu$ = %.2f km" % data_xdiff_mean)
#         else:
#             continue
 
#     fig.subplots_adjust(left=0.07, right=0.97,   
#                         bottom=0.08, top=0.90,
#                         wspace=0.35, hspace=0.35)
#     return fig

In [None]:
# def PlotVariablesFunction(parcel_type):
#     # choose which outer key to plot
#     ptype = parcel_type
#     depth_types = ["ALL", "SHALLOW", "DEEP"]
    
#     # set up figure (2 rows × 3 columns)
#     fig = plt.figure(figsize=(12, 8))
#     gs  = gridspec.GridSpec(2, len(depth_types), figure=fig,
#                             wspace=0.3, hspace=0.35)
    
#     # loop through depth types
#     for j, depth in enumerate(depth_types):
#         # first row: QV
#         ax_top = fig.add_subplot(gs[0, j])
#         if ptype in results and depth in results[ptype]:
#             data_x = results[ptype][depth]["QV_List"]
#             data_x = np.array(data_x)*1e3
#             data_x_mean = np.mean(data_x)
#             PlotHistogram(ax_top, data_x,
#                           xlabel="qv (g/kg)",
#                           title=f"{ptype} – {depth}\n"
#                                 + r"$\mu$ = %.2f g/kg" % data_x_mean)
    
#         # second row: TH
#         ax_bottom = fig.add_subplot(gs[1, j])
#         if ptype in results and depth in results[ptype]:
#             data_x = results[ptype][depth]["THETA_v_List"]
#             data_x_mean = np.mean(data_x)
#             PlotHistogram(ax_bottom, data_x,
#                           xlabel="th_v (K)",
#                           title=f"{ptype} – {depth}\n"
#                                 + r"$\mu$ = %.2f K" % data_x_mean)
#         else:
#             continue
 
#     fig.subplots_adjust(left=0.07, right=0.97,   
#                         bottom=0.08, top=0.90,
#                         wspace=0.35, hspace=0.35)
#     return fig

In [None]:
#############################################
#PLOTTING

In [None]:
# parcel_types = ["CL", "nonCL", "SBF", "nonSBF"]
# for parcel_type in parcel_types:
#     fig = PlotDistancesFunction(parcel_type)

#     #saving
#     fileName=f"Tracked_Histograms_Distances_{parcel_type}" 
#     SaveFigure(fig,plotType="Project_Algorithms/Tracking_Algorithms/Tracked_Histograms",fileName=fileName)

In [None]:
# parcel_types = ["CL", "nonCL", "SBF"]
# for parcel_type in parcel_types:
#     fig = PlotVariablesFunction(parcel_type)

#     #saving
#     fileName=f"Tracked_Histograms_Variables_{parcel_type}" 
#     SaveFigure(fig,plotType="Project_Algorithms/Tracking_Algorithms/Tracked_Histograms",fileName=fileName)

In [None]:
#############################################
#COMBINED PLOTTING FUNCTIONS

In [None]:
def AddCrossLines(fig, outer_gs, pad=0.02):
    """
    Draw a shorter cross on the figure.
    pad controls how far the lines stay away from edges (0–0.5).
    """
    # Vertical line: slightly shortened at top & bottom
    fig.lines.append(
        plt.Line2D(
            [0.5, 0.5],            # keep centered
            [pad, 1 - pad],        # shorter vertically
            transform=fig.transFigure,
            color="black", linewidth=1.2
        )
    )

    # Horizontal line: slightly shortened at left & right
    fig.lines.append(
        plt.Line2D(
            [pad, 1 - pad],        # shorter horizontally
            [0.5, 0.5],            # keep centered
            transform=fig.transFigure,
            color="black", linewidth=1.2
        )
    )

In [None]:
def PlotAllHistograms_Heights(parcel_types, results):
    """
    Creates a single figure with 4 parcel-type blocks:
      [CL, nonCL]
      [SBF, nonSBF]
    Each block contains a 2×3 grid of subplots (ALL, SHALLOW, DEEP).
    """

    depth_types = ["ALL", "SHALLOW", "DEEP"]

    # 2×2 outer grid for parcel-type groups
    fig = plt.figure(figsize=(14, 8))
    outer_gs = gridspec.GridSpec(2, 2, figure=fig, wspace=0.2, hspace=0.45)

    # map parcel types into positions
    parcel_map = {
        (0, 0): "CL",
        (0, 1): "nonCL",
        (1, 0): "SBF",
        (1, 1): "nonSBF"
    }

    # loop through outer 2×2 positions
    for (r, c), ptype in parcel_map.items():
        inner_gs = gridspec.GridSpecFromSubplotSpec(
            1, len(depth_types), subplot_spec=outer_gs[r, c],
            wspace=0.4, hspace=0.6  # <-- more breathing room between rows
        )

        # loop through 2×3 subplots inside each parcel-type block
        for j, depth in enumerate(depth_types):
            # --- TOP ROW: Z_List ---
            ax_top = fig.add_subplot(inner_gs[0, j])
            if ptype in results and depth in results[ptype]:
                data_z = np.array(results[ptype][depth]["Z_List"])*1000
                data_z_mean = np.mean(data_z)
                PlotHistogram(
                    ax_top, data_z,
                    orientation='horizontal',
                    xlabel="Z distance (m)",
                    title=f"{ptype} – {depth}\n"
                          + r"$\mu$ = %.2f m" % data_z_mean,
                     plotKDE=True
                )

                ax_top.set_ylim(bottom=0, top=1000)

    # Adjust overall layout to prevent overlap
    fig.subplots_adjust(left=0.06, right=0.97, bottom=0.06, top=0.94)

    AddCrossLines(fig, outer_gs)
    return fig

In [None]:
def PlotAllHistograms_Distances(parcel_types, results):
    """
    Creates a single figure with 4 parcel-type blocks:
      [CL, nonCL]
      [SBF, nonSBF]
    Each block contains a 2×3 grid of subplots (ALL, SHALLOW, DEEP).
    """

    depth_types = ["ALL", "SHALLOW", "DEEP"]

    # 2×2 outer grid for parcel-type groups
    fig = plt.figure(figsize=(14, 10))
    outer_gs = gridspec.GridSpec(2, 2, figure=fig, wspace=0.2, hspace=0.45)

    # map parcel types into positions
    parcel_map = {
        (0, 0): "CL",
        (0, 1): "nonCL",
        (1, 0): "SBF",
        (1, 1): "nonSBF"
    }

    # loop through outer 2×2 positions
    for (r, c), ptype in parcel_map.items():
        inner_gs = gridspec.GridSpecFromSubplotSpec(
            2, len(depth_types), subplot_spec=outer_gs[r, c],
            wspace=0.4, hspace=0.6  # <-- more breathing room between rows
        )

        # loop through 2×3 subplots inside each parcel-type block
        for j, depth in enumerate(depth_types):
            # --- TOP ROW: X_List ---
            ax_top = fig.add_subplot(inner_gs[0, j])
            if ptype in results and depth in results[ptype]:
                data_x = results[ptype][depth]["X_List"]
                data_x_mean = np.mean(data_x)
                PlotHistogram(
                    ax_top, data_x,
                    xlabel="X distance (km)",
                    title=f"{ptype} – {depth}\n"
                          + r"$\mu$ = %.2f km" % data_x_mean
                )
                ax_top.axvline((ModelData.xf-ModelData.xf[0])[-1]*1/4,color='green')
                ax_top.set_xlim(left=0, right=(ModelData.xf-ModelData.xf[0])[-1])
                

            # --- BOTTOM ROW: Xdiff_List ---
            ax_bottom = fig.add_subplot(inner_gs[1, j])
            if ptype in results and depth in results[ptype]:
                data_xdiff = results[ptype][depth]["Xdiff_List"]
                data_xdiff_mean = np.mean(data_xdiff)
                PlotHistogram(
                    ax_bottom, data_xdiff,
                    xlabel="X distance from SBF (km)",
                    title=r"$\mu$ = %.2f km" % data_xdiff_mean
                )
                ax_bottom.axvline(0,color='red')
                if ptype != "SBF":
                    halflength = (ModelData.xf-ModelData.xf[0])[-1]/2
                    ax_bottom.set_xlim(left=-halflength, right=halflength)
                elif ptype == "SBF":
                    ax_bottom.set_xlim(left=-10.0, right=10.0)

    # Adjust overall layout to prevent overlap
    fig.subplots_adjust(left=0.06, right=0.97, bottom=0.06, top=0.94)

    AddCrossLines(fig, outer_gs)
    return fig

In [None]:
def PlotAllHistograms_Variables(parcel_types, results):
    """
    Creates a single figure with 4 parcel-type blocks:
      [CL, nonCL]
      [SBF, nonSBF]
    Each block contains a 2×3 grid of subplots (ALL, SHALLOW, DEEP)
    for QV (top) and THv (bottom).
    """

    depth_types = ["ALL", "SHALLOW", "DEEP"]

    # 2×2 outer grid for parcel-type groups
    fig = plt.figure(figsize=(14, 10))
    outer_gs = gridspec.GridSpec(2, 2, figure=fig, wspace=0.2, hspace=0.45)

    # map parcel types into positions
    parcel_map = {
        (0, 0): "CL",
        (0, 1): "nonCL",
        (1, 0): "SBF",
        (1, 1): "nonSBF"
    }

    # loop through outer 2×2 positions
    for (r, c), ptype in parcel_map.items():
        inner_gs = gridspec.GridSpecFromSubplotSpec(
            2, len(depth_types), subplot_spec=outer_gs[r, c],
            wspace=0.4, hspace=0.55
        )

        # loop through depth types inside each 2×3 block
        for j, depth in enumerate(depth_types):

            # --- TOP ROW: QV (g/kg) ---
            ax_top = fig.add_subplot(inner_gs[0, j])
            if ptype in results and depth in results[ptype]:
                data_qv = np.array(results[ptype][depth]["QV_List"]) * 1e3
                data_qv_mean = np.mean(data_qv)
                PlotHistogram(ax_top, data_qv,
                              xlabel=r"$q_v$ (g/kg)",
                              title=f"{ptype} – {depth}\n"
                              + r"$\mu$ = %.2f g kg$^{-1}$" % data_qv_mean)

            # --- BOTTOM ROW: THv (K) ---
            ax_bottom = fig.add_subplot(inner_gs[1, j])
            if ptype in results and depth in results[ptype]:
                data_th = np.array(results[ptype][depth]["THETA_v_List"])
                data_th_mean = np.mean(data_th)
                PlotHistogram(ax_bottom, data_th,
                              xlabel=r"$\theta_v$ (K)",
                              title=r"$\mu$ = %.2f K" % data_th_mean)
    # Global layout
    fig.subplots_adjust(left=0.06, right=0.97, bottom=0.06, top=0.94)

    AddCrossLines(fig, outer_gs)
    return fig

In [None]:
#############################################
#COMBINED PLOTTING

In [None]:
parcel_types = ["CL", "nonCL", "SBF", "nonSBF"]
fig = PlotAllHistograms_Heights(parcel_types, results)

axes = fig.get_axes()
SnapLimitsToTicks(axes, dim='x')

SaveFigure(
    fig,
    plotType="Project_Algorithms/Tracking_Algorithms/Tracked_Histograms",
    fileName="Tracked_Histograms_Heights"
)

In [None]:
parcel_types = ["CL", "nonCL", "SBF", "nonSBF"]
fig = PlotAllHistograms_Distances(parcel_types, results)

axes = fig.get_axes()
SetEvenTicks(axes, dim='x', n_ticks=4, decimals=0)
SnapLimitsToTicks(axes, dim='y')

SaveFigure(
    fig,
    plotType="Project_Algorithms/Tracking_Algorithms/Tracked_Histograms",
    fileName="Tracked_Histograms_Distances"
)

In [None]:
parcel_types = ["CL", "nonCL", "SBF", "nonSBF"]
fig = PlotAllHistograms_Variables(parcel_types, results)

axes = fig.get_axes()
MatchAxisLimits(fig.axes[::2], dim='x')
MatchAxisLimits(fig.axes[1::2], dim='x')
SnapLimitsToTicks(axes, dim='y')

SaveFigure(
    fig,
    plotType="Project_Algorithms/Tracking_Algorithms/Tracked_Histograms",
    fileName="Tracked_Histograms_Variables"
)

In [None]:
#############################################
#2D FUNCTIONS

In [None]:
#(1) make a 2d (y,x) binned average-field of q_v/th_v at parcel initial time

In [None]:
#############################################
#CALCULATING FUNCTIONS

In [None]:
def SumArray(Dictionary):
    array = np.zeros((ModelData.Nyh,ModelData.Nxh))
    qv_array = array.copy()
    th_v_array = array.copy()
    count = array.copy()
    
    Y_List = Dictionary['Y_List']
    X_List = Dictionary['X_List']
    
    QV_List = Dictionary['QV_List']
    THETA_v_List = Dictionary['THETA_v_List']
    
    array = np.zeros((ModelData.Nyh,ModelData.Nxh))
    count = array.copy()
    
    for (y_kms,x_kms, qv,th_v) in zip(Y_List,X_List, QV_List,THETA_v_List):
        y=np.where(yh==y_kms)
        x=np.where(xh==x_kms)
        
        qv_array[y,x] += qv
        th_v_array[y,x] += th_v
        count[y,x] += 1
    return qv_array, th_v_array, count

In [None]:
def SumArray_SBFLocation(Dictionary):
    x_edges=np.linspace(-256, 256, ModelData.Nxh + 1)

    Ny = ModelData.Nyh
    Nx = len(x_edges) - 1  # number of bins
    
    qv_array = np.zeros((Ny, Nx))
    th_v_array = np.zeros((Ny, Nx))
    count = np.zeros((Ny, Nx))

    Y_List = Dictionary['Y_List']
    Xdiff_List = Dictionary['Xdiff_List']  # <-- SBF-relative distance (km)
    QV_List = Dictionary['QV_List']
    THETA_v_List = Dictionary['THETA_v_List']

    # convert SBF-relative km into bin indices
    x_bin = np.digitize(Xdiff_List, x_edges) - 1
    x_bin = np.clip(x_bin, 0, Nx-1)

    for (y_kms,xidx, qv,th_v) in zip(Y_List, x_bin, QV_List, THETA_v_List):
        y = np.where(yh == y_kms)[0]
        if len(y)==0:
            continue

        qv_array[y, xidx] += qv
        th_v_array[y, xidx] += th_v
        count[y, xidx] += 1

    return qv_array, th_v_array, count


In [None]:
def Run2DAverageFields(results):
    outputDictionary = {}

    for outer_key, inner_dict in tqdm(results.items(), desc="Parcel types"):
        outputDictionary[outer_key] = {}
        for inner_key, _ in inner_dict.items():
            Dictionary =  results[outer_key][inner_key]
            
            qv_array, th_v_array, count = SumArray(Dictionary)
            qv_SBFLocation, th_v_SBFLocation, count_SBFLocation = SumArray_SBFLocation(Dictionary)

            outputDictionary[outer_key][inner_key] = {
                "qv_array": qv_array,
                "th_v_array": th_v_array,
                "count": count,

                "qv_SBFLocation": qv_SBFLocation,
                "thv_SBFLocation": th_v_SBFLocation,
                "count_SBFLocation": count_SBFLocation
            }
    return outputDictionary

In [None]:
def TakeAverage(array,count):
    field = np.divide(
        array,
        count,
        out=np.full_like(array, np.nan, dtype=float),
        where=(count != 0)
    )
    return field

In [None]:
#############################################
#CALCULATING 

In [None]:
outputDictionary = Run2DAverageFields(results)

In [None]:
##################
#PLOTTING FUNCTIONS

In [None]:
plt.rcParams["axes.titlesize"] = 10
plt.rcParams["axes.labelsize"] = 10
plt.rcParams["xtick.labelsize"] = 10
plt.rcParams["ytick.labelsize"] = 10
plt.rcParams["legend.fontsize"] = 10
plt.rcParams["figure.titlesize"] = 10

In [None]:
def MakeAxis_Single():
    fig, axes = plt.subplots(2, 2, figsize=(10, 8))
    return fig, axes

In [None]:
def AddVLines(fig):
    ocean_percent = 1/4
    for ax in fig.get_axes():
        ax.axvline(512*ocean_percent, color='black', linestyle='--', linewidth=1.0)

In [None]:
def PlotSingle_XLocation(parcel_type='CL',parcel_depth='ALL'):
    [fig, axes] = MakeAxis_Single()
    
    qv_array = outputDictionary[parcel_type][parcel_depth]['qv_array']
    th_v_array = outputDictionary[parcel_type][parcel_depth]['th_v_array']
    count = outputDictionary[parcel_type][parcel_depth]['count']
    
    qv_field = TakeAverage(qv_array, count)
    th_v_field = TakeAverage(th_v_array, count)
    
    #QV
    multiplier=1e3
    axis = axes[0,0]
    
    cf = axis.pcolormesh(xh,yh,qv_field*multiplier)
    fig.colorbar(cf, ax=axis)
    axis.set_ylabel('y (km)')
    axis.set_xlabel('x (km)')
    axis.set_xlim(0,ModelData.Nxh);
    axis.set_ylim(0,ModelData.Nyh);
    
    axis = axes[1,0]
    axis.plot(xh,np.nanmean(qv_field,axis=0)*multiplier)
    axis.set_ylabel('qv (g/kg)')
    axis.set_xlabel('x (km)')
    axis.set_xlim(0,ModelData.Nxh);
    
    #TH_v
    multiplier=1
    axis = axes[0,1]
    cf = axis.pcolormesh(xh,yh,th_v_field*multiplier)
    fig.colorbar(cf, ax=axis)
    axis.set_ylabel('y (km)')
    axis.set_xlabel('x (km)')
    axis.set_xlim(0,ModelData.Nxh);
    axis.set_ylim(0,ModelData.Nyh);
    
    axis = axes[1,1]
    axis.plot(xh,np.nanmean(th_v_field,axis=0)*multiplier)
    axis.set_ylabel('th_v (K)')
    axis.set_xlabel('x (km)')
    axis.set_xlim(0,ModelData.Nxh);
    
    AddVLines(fig)

    return fig

In [None]:
def PlotSingle_SBFLocation(parcel_type='CL', parcel_depth='ALL'):

    [fig, axes] = MakeAxis_Single()

    # load SBF fields
    qv_array = outputDictionary[parcel_type][parcel_depth]['qv_SBFLocation']
    th_v_array = outputDictionary[parcel_type][parcel_depth]['thv_SBFLocation']
    count = outputDictionary[parcel_type][parcel_depth]['count_SBFLocation']

    # take averages safely (handling nan)
    qv_field = TakeAverage(qv_array, count)
    th_v_field = TakeAverage(th_v_array, count)

    # ========= QV =========
    multiplier = 1e3
    ax = axes[0, 0]

    _,x_centers_SBF = GetBins()
    cf = ax.pcolormesh(x_centers_SBF, yh, qv_field * multiplier, shading='auto')
    fig.colorbar(cf, ax=ax)

    ax.set_ylabel("y (km)")
    ax.set_xlabel("SBF-relative x (km)")
    # ax.set_xlim(xmin, xmax)
    ax.set_ylim(0, ModelData.Nyh)

    # line plot
    ax = axes[1, 0]
    ax.plot(x_centers_SBF, np.nanmean(qv_field, axis=0) * multiplier)

    ax.set_ylabel("qv (g/kg)")
    ax.set_xlabel("SBF-relative x (km)")
    # ax.set_xlim(xmin, xmax)

    # ========= THETA_V =========
    ax = axes[0, 1]
    cf = ax.pcolormesh(x_centers_SBF, yh, th_v_field, shading='auto')
    fig.colorbar(cf, ax=ax)

    ax.set_ylabel("y (km)")
    ax.set_xlabel("SBF-relative x (km)")
    # ax.set_xlim(xmin, xmax)
    ax.set_ylim(0, ModelData.Nyh)

    # line plot
    ax = axes[1, 1]
    ax.plot(x_centers_SBF, np.nanmean(th_v_field, axis=0))

    ax.set_ylabel("th_v (K)")
    ax.set_xlabel("SBF-relative x (km)")
    # ax.set_xlim(xmin, xmax)

    # optional: SBF centerline at x=0
    for a in fig.get_axes():
        a.axvline(0, color='k', linestyle='--', linewidth=1)

    fig.tight_layout()
    return fig


In [None]:
def MakeAxis_All():
    fig, axes = plt.subplots(4, 2, figsize=(10, 8))
    return fig, axes

In [None]:
def PlotAll_XLocation(outputDictionary):
    
    fig, axes = MakeAxis_All()
    
    parcel_order = ["CL", "nonCL", "SBF", "nonSBF"]
    depth_order  = ["ALL", "SHALLOW", "DEEP"]
    
    colors = {
        "ALL":     "black",
        "SHALLOW": "green",
        "DEEP":    "blue"
    }
    
    for row, ptype in enumerate(parcel_order):
        
        # Column 0 → qv
        ax_qv = axes[row, 0]
        
        # Column 1 → th_v
        ax_thv = axes[row, 1]
        
        for depth in depth_order:
            
            Dictionary = outputDictionary[ptype][depth]
            
            qv_array  = Dictionary["qv_array"]
            th_v_array = Dictionary["th_v_array"]
            count = Dictionary["count"]
            
            # TakeAverage returns a 2D field; then take nanmean over y
            qv_field  = TakeAverage(qv_array,count)
            th_v_field = TakeAverage(th_v_array,count)
            
            qv_profile  = np.nanmean(qv_field, axis=0)
            thv_profile = np.nanmean(th_v_field, axis=0)
            
            # plot qv mean profile
            ax_qv.plot(xh, qv_profile * 1e3, 
                       label=depth, color=colors[depth])
            
            # plot th_v mean profile
            ax_thv.plot(xh, thv_profile, 
                        label=depth, color=colors[depth])
            
        # Labels
        ax_qv.set_ylabel(f"{ptype}\nqv (g/kg)")
        ax_thv.set_ylabel(f"{ptype}\nth_v (K)")
        
        ax_qv.set_xlim(0, ModelData.Nxh)
        ax_thv.set_xlim(0, ModelData.Nxh)
        
        ax_qv.grid(alpha=0.3)
        ax_thv.grid(alpha=0.3)
        
        # Legend only on first row
        if row == 0:
            ax_qv.legend(title="Depth", fontsize=8)
            ax_thv.legend(title="Depth", fontsize=8)
    
    # Shared x label
    axes[-1, 0].set_xlabel("x (km)")
    axes[-1, 1].set_xlabel("x (km)")
    
    # Add vertical SBF lines to all axes if needed
    AddVLines(fig)
    
    fig.tight_layout()
    return fig


In [None]:
def PlotAll_SBFLocation(outputDictionary):

    fig, axes = MakeAxis_All()

    parcel_order = ["CL", "nonCL", "SBF", "nonSBF"]
    depth_order  = ["ALL", "SHALLOW", "DEEP"]

    colors = {
        "ALL":     "black",
        "SHALLOW": "green",
        "DEEP":    "blue",
    }

    # get SBF-relative bins
    x_edges_SBF, x_centers_SBF = GetBins()   # MUST return both

    for row, ptype in enumerate(parcel_order):

        ax_qv  = axes[row, 0]   # qv subplot
        ax_thv = axes[row, 1]   # th_v subplot

        for depth in depth_order:

            Dictionary = outputDictionary[ptype][depth]

            qv_array       = Dictionary["qv_SBFLocation"]
            th_v_array     = Dictionary["thv_SBFLocation"]
            count_array    = Dictionary["count_SBFLocation"]

            # 2D averaged fields
            qv_field  = TakeAverage(qv_array, count_array)
            th_v_field = TakeAverage(th_v_array, count_array)

            # 1D profiles (mean over y dim)
            qv_profile  = np.nanmean(qv_field, axis=0)
            thv_profile = np.nanmean(th_v_field, axis=0)

            # ---- plot QV mean profile ----
            ax_qv.plot(
                x_centers_SBF,
                qv_profile * 1e3,
                label=depth,
                color=colors[depth],
                linewidth=2
            )

            # ---- plot THv mean profile ----
            ax_thv.plot(
                x_centers_SBF,
                thv_profile,
                label=depth,
                color=colors[depth],
                linewidth=2
            )

        # ---- axis settings ----
        ax_qv.set_ylabel(f"{ptype}\nqv (g/kg)")
        ax_thv.set_ylabel(f"{ptype}\nth_v (K)")

        ax_qv.grid(alpha=0.3)
        ax_thv.grid(alpha=0.3)

        # Legend only on first row
        if row == 0:
            ax_qv.legend(title="Depth", fontsize=8)
            ax_thv.legend(title="Depth", fontsize=8)

    # shared x labels
    axes[-1, 0].set_xlabel("SBF-relative x (km)")
    axes[-1, 1].set_xlabel("SBF-relative x (km)")

    # add vertical SBF center line at x=0
    for ax in fig.get_axes():
        ax.axvline(0, color='k', linestyle='--', linewidth=1)

    fig.tight_layout()
    return fig


In [None]:
def DeepMinusShallow_SBFLocation(outputDictionary):

    fig, axes = MakeAxis_All()   # 4 rows × 2 columns

    parcel_order = ["CL", "nonCL", "SBF", "nonSBF"]

    # get SBF bins
    x_edges_SBF, x_centers_SBF = GetBins()

    for row, ptype in enumerate(parcel_order):

        ax_qv  = axes[row, 0]
        ax_thv = axes[row, 1]

        # --- extract arrays for SHALLOW and DEEP ---
        dict_sh = outputDictionary[ptype]["SHALLOW"]
        dict_dp = outputDictionary[ptype]["DEEP"]

        # QV fields
        qv_sh = TakeAverage(dict_sh["qv_SBFLocation"], dict_sh["count_SBFLocation"])
        qv_dp = TakeAverage(dict_dp["qv_SBFLocation"], dict_dp["count_SBFLocation"])

        # THv fields
        th_sh = TakeAverage(dict_sh["thv_SBFLocation"], dict_sh["count_SBFLocation"])
        th_dp = TakeAverage(dict_dp["thv_SBFLocation"], dict_dp["count_SBFLocation"])

        # ---- compute mean profiles ----
        qv_sh_prof = np.nanmean(qv_sh, axis=0)
        qv_dp_prof = np.nanmean(qv_dp, axis=0)

        th_sh_prof = np.nanmean(th_sh, axis=0)
        th_dp_prof = np.nanmean(th_dp, axis=0)

        # ---- compute differences ----
        qv_diff = (qv_dp_prof - qv_sh_prof) * 1e3   # g/kg
        thv_diff = th_dp_prof - th_sh_prof          # K

        # ---- plot differences ----
        ax_qv.plot(x_centers_SBF, qv_diff, color="k", linewidth=2)
        ax_thv.plot(x_centers_SBF, thv_diff, color="k", linewidth=2)

        # ---- labels ----
        ax_qv.set_ylabel(f"{ptype}\n(DEEP − SHALLOW)\nqv diff (g/kg)")
        ax_thv.set_ylabel(f"{ptype}\n(DEEP − SHALLOW)\nth_v diff (K)")

        ax_qv.grid(alpha=0.3)
        ax_thv.grid(alpha=0.3)

        # Add zero-line
        ax_qv.axhline(0, color='gray', linestyle='--')
        ax_thv.axhline(0, color='gray', linestyle='--')

        # Vertical SBF line at x=0
        ax_qv.axvline(0, color='k', linestyle='--')
        ax_thv.axvline(0, color='k', linestyle='--')

    # shared x-labels
    axes[-1, 0].set_xlabel("SBF-relative x (km)")
    axes[-1, 1].set_xlabel("SBF-relative x (km)")

    fig.tight_layout()
    return fig


In [None]:
##################
#PLOTTING 

In [None]:
# fig = PlotSingle_XLocation(parcel_type='CL',parcel_depth='DEEP')

In [None]:
fig = PlotAll_XLocation(outputDictionary)
SaveFigure(
    fig,
    plotType="Project_Algorithms/Tracking_Algorithms/Tracked_Histograms",
    fileName="Tracked_Histograms_2DAverageFields_XLocation"
)

In [None]:
# fig = PlotSingle_SBFLocation(parcel_type='CL',parcel_depth='DEEP')

In [None]:
fig = PlotAll_SBFLocation(outputDictionary)

SaveFigure(
    fig,
    plotType="Project_Algorithms/Tracking_Algorithms/Tracked_Histograms",
    fileName="Tracked_Histograms_2DAverageFields_SBFLocation"
)

In [None]:
DeepMinusShallow_SBFLocation(outputDictionary)
SaveFigure(
    fig,
    plotType="Project_Algorithms/Tracking_Algorithms/Tracked_Histograms",
    fileName="Tracked_Histograms_2DAverageFields_SBFLocation_DEEPminusSHALLOW"
)