In [None]:
####################################
#ENVIRONMENT SETUP

In [None]:
#Importing Libraries
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.ticker as ticker
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import ScalarFormatter
import matplotlib.gridspec as gridspec
import xarray as xr

import sys; import os; import time; from datetime import timedelta
import pickle
import h5py

from tqdm import tqdm

from glob import glob

In [None]:
#MAIN DIRECTORIES
def GetDirectories():
    mainDirectory='/mnt/lustre/koa/koastore/torri_group/air_directory/Projects/DCI-Project/'
    mainCodeDirectory=os.path.join(mainDirectory,"Code/CodeFiles/")
    scratchDirectory='/mnt/lustre/koa/scratch/air673/'
    codeDirectory=os.getcwd()
    return mainDirectory,mainCodeDirectory,scratchDirectory,codeDirectory

[mainDirectory,mainCodeDirectory,scratchDirectory,codeDirectory] = GetDirectories()

In [None]:
#IMPORT CLASSES
sys.path.append(os.path.join(mainCodeDirectory,"2_Variable_Calculation"))
from CLASSES_Variable_Calculation import ModelData_Class, SlurmJobArray_Class, DataManager_Class

In [None]:
#data loading class
ModelData = ModelData_Class(mainDirectory, scratchDirectory, simulationNumber=1)
#data manager class
DataManager = DataManager_Class(mainDirectory, scratchDirectory, ModelData.res, ModelData.t_res, ModelData.Nz_str,
                                ModelData.Np_str, dataType="Tracking_Algorithms", dataName="Lagrangian_UpdraftTracking",
                                dtype='float32',codeSection = "Project_Algorithms")

In [None]:
#IMPORT FUNCTIONS
sys.path.append(os.path.join(mainCodeDirectory,"2_Variable_Calculation"))
import FUNCTIONS_Variable_Calculation
from FUNCTIONS_Variable_Calculation import *

In [None]:
#IMPORT CLASSES
sys.path.append(os.path.join(mainCodeDirectory,"3_Project_Algorithms","2_Tracking_Algorithms"))
from CLASSES_TrackingAlgorithms import TrackingAlgorithms_DataLoading_Class, SlurmJobArray_Class, Results_InputOutput_Class, TrackedParcel_Loading_Class

In [None]:
##############################################
#SETUP

In [None]:
################################
#JOB ARRAY SETUP
################################
# how many total jobs are being run? i.e. array=1-100 ==> num_jobs=100
if '1e6' in ModelData.Np_str:
    num_jobs=60 #1M parcels
    num_slurm_jobs=10
if '50e6' in ModelData.Np_str:
    num_jobs=200 #50M parcels
    num_slurm_jobs=60
##############################

In [None]:
##############################################
#Data Loading Functions

In [None]:
# def GetDensityPotentialTemperature():
def CallVariables(ModelData, DataManager, timeString, varNames,zInterpolate=None):
    varDictionary = {}
    for varName in varNames:
        varDictionary[varName] = CallVariable(ModelData, DataManager, timeString, 
                                              variableName=varName, zInterpolate=zInterpolate)
    return varDictionary

def GetVariables(t):
    timeString = ModelData.timeStrings[t]
    varNames = ["theta_v"]
    VARs = CallVariables(ModelData, DataManager, timeString, varNames)
    theta_v=VARs["theta_v"]
    qr=ModelData.GetVariable(varName='qr',isel={'time': t})
    buoyancy=ModelData.GetVariable(varName='buoyancy',isel={'time': t})
    return theta_v,qr,buoyancy

def SelectZLevel(array, zLevel_meters):
    zIndex = (np.abs(ModelData.zh - zLevel_meters/1e3)).argmin()
    return array[zIndex]

def GetPerturbation(array):
    return array - np.mean(array)
    
def GetNecessaryData(t):

    # Load timestep
    theta_v, qr, buoyancy = GetVariables(t)

    zLevels = [100, 250]

    dataDictionary = {}

    for zLevel in zLevels:
        theta_v_level = SelectZLevel(theta_v, zLevel)

        dataDictionary[zLevel] = {
            "theta_v_prime": GetPerturbation(theta_v_level),
            "qr": SelectZLevel(qr, zLevel),
            "buoyancy": SelectZLevel(buoyancy, zLevel)
        }

    return dataDictionary

In [None]:
##############################################
#Calculation Functions

In [None]:
def GetMasks(dataDictionary_zLevel,
             threshold_1=-1,threshold_2=1e-6,threshold_3=0.005):
    mask_0 = np.ones_like(dataDictionary_zLevel["theta_v_prime"], dtype=int)
    mask_1 = (dataDictionary_zLevel["theta_v_prime"] < threshold_1).astype(int)
    mask_2 = (dataDictionary_zLevel["qr"] > threshold_2).astype(int)
    mask_3 = (dataDictionary_zLevel["buoyancy"] < threshold_3).astype(int)
    return mask_0, mask_1,mask_2,mask_3

In [None]:
##############################################
#Plotting Functions

In [None]:
def MakePlot(dataDictionary, plotMode="mask"):
    zLevels = [100, 250]

    fig, Axes = plt.subplots(
        nrows=2,
        ncols=4,
        figsize=(12, 7),
        constrained_layout=True
    )

    Mappable = None  # <-- store last contourf handle

    for rowIndex, zLevel in enumerate(zLevels):

        dataDictionary_zLevel = dataDictionary[zLevel]

        
        mask_0, mask_1, mask_2, mask_3 = GetMasks(
            dataDictionary_zLevel,
            threshold_1,threshold_2,threshold_3)

        MaskList = [
            mask_0,
            mask_1,
            mask_1 & mask_2,
            mask_1 & mask_2 & mask_3
        ]

        theta_v_prime = dataDictionary_zLevel["theta_v_prime"]

        for colIndex, mask in enumerate(MaskList):

            Axes[rowIndex, colIndex].set_xticks([])
            Axes[rowIndex, colIndex].set_yticks([])

            if plotMode == "mask":
                FieldToPlot = mask.astype(int)
                Levels = [-0.5, 0.5, 1.5]

            elif plotMode == "variable":
                FieldToPlot = theta_v_prime.copy()
                FieldToPlot[mask == 0] = np.nan
                Levels = None

            else:
                raise ValueError("plotMode must be 'mask' or 'variable'")

            Mappable = Axes[rowIndex, colIndex].contourf(
                FieldToPlot,
                levels=Levels
            )

            if rowIndex == 0:
                Axes[rowIndex, colIndex].set_title(
                    [
                        f"All Mask",
                        f"{threshold_1_string}",
                        f"{threshold_1_string}\n& {threshold_2_string}",
                        f"{threshold_1_string}\n& {threshold_2_string}\n& {threshold_3_string}"
                    ][colIndex]
                )

        Axes[rowIndex, 0].set_ylabel(f"z = {zLevel} m")

    # ---- Shared colorbar ----
    Cbar = fig.colorbar(
        Mappable,
        ax=Axes,
        orientation="vertical",
        fraction=0.03,
        pad=0.02
    )

    if plotMode == "mask":
        Cbar.set_ticks([0, 1])
        Cbar.set_ticklabels(["False", "True"])
        Cbar.set_label("Mask")

    else:
        Cbar.set_label("θᵥ′ (K)")

    plt.show()


In [None]:
##############################################
#Plotting

In [None]:
threshold_1 = -1.5
threshold_2 = 1e-6
threshold_3 = -0

threshold_1_string = f"(θv′ < {threshold_1})"
threshold_2_string = f"(qr > {threshold_2})"
threshold_3_string = f"(B < {threshold_3})"

In [None]:
dataDictionary = GetNecessaryData(t=100)    
MakePlot(dataDictionary,plotMode="variable")