In [None]:
####################################
#ENVIRONMENT SETUP

In [None]:
#Importing Libraries
import sys; import os; import time; from datetime import timedelta

import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.ticker as ticker
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import ScalarFormatter
import matplotlib.gridspec as gridspec
import matplotlib.lines as mlines
import xarray as xr

import pickle
import h5py

from tqdm import tqdm

In [None]:
#MAIN DIRECTORIES
def GetDirectories():
    mainDirectory='/mnt/lustre/koa/koastore/torri_group/air_directory/Projects/DCI-Project/'
    mainCodeDirectory=os.path.join(mainDirectory,"Code/CodeFiles/")
    scratchDirectory='/mnt/lustre/koa/scratch/air673/'
    codeDirectory=os.getcwd()
    return mainDirectory,mainCodeDirectory,scratchDirectory,codeDirectory

[mainDirectory,mainCodeDirectory,scratchDirectory,codeDirectory] = GetDirectories()

In [None]:
def GetPlottingDirectory(plotFileName, plotType):
    plottingDirectory = mainCodeDirectory=os.path.join(mainDirectory,"Code","PLOTTING")
    
    specificPlottingDirectory = os.path.join(plottingDirectory, plotType, 
                                             f"{ModelData.res}_{ModelData.t_res}_{ModelData.Nz_str}nz")
    os.makedirs(specificPlottingDirectory, exist_ok=True)

    plottingFileName=os.path.join(specificPlottingDirectory, plotFileName)

    return plottingFileName

def SaveFigure(fig,plotType, fileName):
    plotFileName = f"{fileName}_{ModelData.res}_{ModelData.t_res}_{ModelData.Np_str}.jpg"
    plottingFileName = GetPlottingDirectory(plotFileName, plotType)
    print(f"Saving figure to {plottingFileName}")
    fig.savefig(plottingFileName, dpi=300, bbox_inches='tight')

In [None]:
#IMPORT CLASSES
sys.path.append(os.path.join(mainCodeDirectory,"2_Variable_Calculation"))
from CLASSES_Variable_Calculation import ModelData_Class, SlurmJobArray_Class, DataManager_Class

In [None]:
#IMPORT CLASSES
sys.path.append(os.path.join(mainCodeDirectory,"3_Project_Algorithms","1_Domain_Profiles"))
from CLASSES_DomainProfiles import DomainProfiles_Class, DomainProfiles_DataLoading_Class

In [None]:
#IMPORT CLASSES
sys.path.append(os.path.join(mainCodeDirectory,"3_Project_Algorithms","2_Tracking_Algorithms"))
from CLASSES_TrackingAlgorithms import TrackingAlgorithms_DataLoading_Class

In [None]:
#IMPORT FUNCTIONS
sys.path.append(os.path.join(mainCodeDirectory,"2_Variable_Calculation"))
import FUNCTIONS_Variable_Calculation
from FUNCTIONS_Variable_Calculation import *

In [None]:
#data loading class
ModelData = ModelData_Class(mainDirectory, scratchDirectory, simulationNumber=2)
#data manager class
DataManager = DataManager_Class(mainDirectory, scratchDirectory, ModelData.res, ModelData.t_res, ModelData.Nz_str,
                                ModelData.Np_str, dataType="Tracking_Algorithms", dataName="Eulerian_CLTracking",
                                dtype='float32',codeSection = "Project_Algorithms")

In [None]:
import sys
path=os.path.join(mainCodeDirectory,'Functions/')
sys.path.append(path)

import NumericalFunctions
from NumericalFunctions import * # import NumericalFunctions 
import PlottingFunctions
from PlottingFunctions import * # import PlottingFunctions

# # Get all functions in NumericalFunctions
# import inspect
# functions = [f[0] for f in inspect.getmembers(NumericalFunctions, inspect.isfunction)]
# functions

In [None]:
#############################
#SETUP

In [None]:
#JOB ARRAY SETUP
UsingJobArray=True

def GetNumJobs(res,t_res):
    if res=='1km':
        if t_res=='5min':
            num_jobs=20
        elif t_res=='1min':
            num_jobs=20
    elif res=='250m': 
        if t_res=='1min':
            num_jobs=100
    return num_jobs
num_jobs = GetNumJobs(ModelData.res,ModelData.t_res)
SlurmJobArray = SlurmJobArray_Class(total_elements=ModelData.Ntime, num_jobs=num_jobs, UsingJobArray=UsingJobArray)
start_job = SlurmJobArray.start_job; end_job = SlurmJobArray.end_job

def GetNumElements():
    loop_elements = np.arange(ModelData.Ntime)[start_job:end_job]
    return loop_elements
loop_elements = GetNumElements()

In [None]:
##############################################
#MODEL AND ALGORITHM NUMERICAL PARAMETERS

In [None]:
dx=int(ModelData.xh[1]-ModelData.xh[0]) #grid resolution (in km)
dy=dx
xh = ModelData.xh - ModelData.xh[0]
yh = ModelData.yh - ModelData.yh[0]

In [None]:
#########################################
#LOADING DATA

In [None]:
def GetConvergence(t,z):
    timeString = ModelData.timeStrings[t]
    varName = 'convergence'
    convergence = CallVariable(ModelData, DataManager, timeString, varName)
    return convergence[z]
    
def LoadTrackedData(t):
    timeString = ModelData.timeStrings[t]
    Dictionary = TrackingAlgorithms_DataLoading_Class.LoadData(ModelData,DataManager, timeString)
    maxConvergence_X = Dictionary["maxConvergence_X"]
    return maxConvergence_X

In [None]:
#########################################
#PLOTTING FUNCTIONS

In [None]:
def CL_plotting(axis, t, zlev, 
                clim,
                font_size=12, index_adjust=0, ocean_fraction=2/8):

    #loading convergence data
    conv_z = GetConvergence(t, zlev)

    #plotting convergence
    if clim is not None:
        levels = np.linspace(clim[0] * 1000, clim[1] * 1000, 40)
        contour = axis.contourf(
            xh, yh, conv_z * 1000,
            levels=levels,
            # cmap="RdBu_r"
    )
    else:
        contour = axis.contourf(xh, yh, conv_z * 1000, levels=40)
    cbar = plt.colorbar(contour, ax=axis, pad=0)
    cbar.set_label(r'$-\nabla \cdot \vec{V}_H\ (s^{-1})$', fontsize=font_size)
    cbar.ax.tick_params(labelsize=font_size)
    cbar.ax.yaxis.label.set_size(font_size)

    axis.set_xlabel('x (km)', fontsize=font_size)
    axis.set_ylabel('y (km)', fontsize=font_size)
    axis.tick_params(axis='both', which='major', labelsize=font_size)

    #loading tracked CL data
    maxConvergence_X = LoadTrackedData(t)

    # Scatter max convergence points
    for yind in range(ModelData.Nyh):
        local_maxes = maxConvergence_X[zlev, yind]
        local_maxes = local_maxes[local_maxes != -1]
        local_maxes = local_maxes.astype(int)
        axis.scatter(xh[local_maxes], [yh[yind]] * len(local_maxes), color='red', s=1)


    # Coastline
    axis.axvline(x=(ModelData.Nxf) * ocean_fraction, color='black', linewidth=1.5, label='Coastline')

    # Legend
    ([days, hours, mins], _) = get_time(ModelData.time, t, (0, 6, 0))
    handle_pts = mlines.Line2D([], [], color='red', marker='o', linestyle='None', markersize=6, label='Convergence Local Y-Maxima')
    handle_time = mlines.Line2D([], [], color='none', label=f't = {t + index_adjust} = {days}:{hours}:{mins}')
    handle_z = mlines.Line2D([], [], color='none', label=f'z = {ModelData.zh[zlev]*1000:.0f} m')
    handle_c = mlines.Line2D([], [], color='black', lw=3, label='Coastline')
    legend = axis.legend(handles=[handle_pts, handle_c, handle_time, handle_z], loc='upper left', fontsize=font_size)
    for text in legend.get_texts():
        text.set_fontsize(font_size)

def RunTrackedPlot(t,zlev, 
                   clim=None,
                   SAVING=False,CLOSE=False):
    channel_aspect_ratio = 5
    figwidth = 20
    dpi = 72
    fig, axis = plt.subplots(nrows=1, ncols=1, figsize=(figwidth, figwidth / channel_aspect_ratio), dpi=dpi)

    CL_plotting(axis, t,zlev,
                clim=clim)

    # SAVING PLOT
    if SAVING:
        fileName=f"Eulerian_CLTracking_{ModelData.timeStrings[t]}" 
        SaveFigure(fig,plotType="Project_Algorithms/Tracking_Algorithms/Eulerian_CLTracking",fileName=fileName)

    
    if CLOSE==True:
        plt.close(fig)


In [None]:
#########################################
#PLOTTING FUNCTIONS

In [None]:
import os
import pickle
import numpy as np
from tqdm import tqdm

def GetCLim(ModelData):
    # Choose output pickle file path
    output_pkl = f"convergence_stats_{ModelData.res}_{ModelData.t_res}.pkl"
    
    # ------------------------------------------------------------
    # If file exists â†’ load and skip calculations
    # ------------------------------------------------------------
    if os.path.exists(output_pkl):
        with open(output_pkl, "rb") as f:
            conv_stats = pickle.load(f)
        clim = (np.min(conv_stats['min']),np.max(conv_stats['max']))
        print(f"Loaded existing convergence stats from: {output_pkl}")
    
    # ------------------------------------------------------------
    # Otherwise compute and save
    # ------------------------------------------------------------
    else:
        print("Computing convergence stats...")
    
        # Find level closest to 350 m
        zlev = int(np.abs(ModelData.zh - 350/1e3).argmin())
    
        # Initialize dictionary
        conv_stats = {"min": [], "max": []}
    
        # Loop over all timesteps
        for t in tqdm(range(ModelData.Ntime), desc="Computing convergence stats"):
            conv_z = GetConvergence(t, zlev)  # 2D field
            conv_min = np.nanmin(conv_z)
            conv_max = np.nanmax(conv_z)
            conv_stats["min"].append(conv_min)
            conv_stats["max"].append(conv_max)
    
        # Optionally store associated times
        conv_stats["time"] = getattr(ModelData, "time", None)
    
        # Convert lists to numpy arrays
        conv_stats["min"] = np.array(conv_stats["min"])
        conv_stats["max"] = np.array(conv_stats["max"])
    
        # Save to pickle
        with open(output_pkl, "wb") as f:
            pickle.dump(conv_stats, f)
    
        clim = (np.min(conv_stats['min']),np.max(conv_stats['max']))
        print(f"Saved convergence stats to: {output_pkl}")

    return clim
clim = GetCLim(ModelData)

In [None]:
# #TESTING INDIVIDUAL PLOTS
# #########################
# #setting time
# t = 100 if ModelData.t_res == '5min' else 100 * 5
# # t+=20

# #getting z level
# zlev = int(np.abs(ModelData.zh - 350/1e3).argmin())

# #running
# RunTrackedPlot(t=t,zlev=zlev,
#                clim=clim,
#                SAVING=False,CLOSE=False)

In [None]:
#OUTPUTTING ALL PLOTS
#########################
#getting z level
zlev = int(np.abs(ModelData.zh - 350/1e3).argmin())

#running
for t in tqdm(loop_elements, desc="Processing timesteps"): #uses "job_array"
    RunTrackedPlot(t=t,zlev=zlev,
                   clim=clim,
                   SAVING=True,CLOSE=True)

In [None]:
########################################
#MAKING ANIMATION
animating = False #keep false when job array is running
# animating = True

In [None]:
#IMPORT FUNCTIONS
sys.path.append(os.path.join(mainCodeDirectory,"1_Initial_Figures","Animations"))
import CLASSES_AnimationPlotting
from CLASSES_AnimationPlotting import AnimationPlotting_Class

In [None]:
def GetPNGFileList(folder):
    import glob, re
    imageFiles = (
        glob.glob(f"{folder}/*.jpg") +
        glob.glob(f"{folder}/*.jpeg") +
        glob.glob(f"{folder}/*.png")
    )
    
    time_re = re.compile(r'_(\d+)-(\d+)-(\d+)(?:_|\.|$)')
    
    def extract_hms_seconds(path):
        """
        Returns total seconds from H-M-S encoded in the filename.
        If not found, returns a large number so such files go last.
        """
        m = time_re.search(path)
        if not m:
            return float("inf")
        h, mnt, s = map(int, m.groups())
        return h*3600 + mnt*60 + s
    
    # Sort by actual time, not string order
    imageFiles_sorted = sorted(imageFiles, key=extract_hms_seconds)
    
    # (Optional) inspect the order
    # for f in imageFiles_sorted:
    #     print(f)
    return imageFiles_sorted

In [None]:
# Grab both JPG and PNG (add JPEG too if needed)
filePath = GetPlottingDirectory(plotFileName = "temporary.png",plotType="Project_Algorithms/Tracking_Algorithms/Eulerian_CLTracking")
folder = os.path.dirname(filePath)
imageFiles = GetPNGFileList(folder)

fileName=f"Eulerian_CLTracking.mp4" 
outputFile = os.path.join(folder,fileName)

In [None]:
if animating:
    fps = AnimationPlotting_Class.CalculateFPS(num_frames=ModelData.Ntime, time_interval_minutes=5, desired_duration_min=1)
    AnimationPlotting_Class.PNGsToMP4(imageFiles, outputFile, fps=fps)