In [None]:
import logging
import os
from pathlib import Path

import pandas as pd
import xarray as xr
from cm1.input.sounding import Sounding, era5_aws, era5_model_level
from cm1.run import PBS, CM1Run
from cm1.utils import animate_cm1out_nc, skewt
from metpy.units import units

# Configure logging
logging.basicConfig(
    level=logging.WARNING, format="%(asctime)s - %(levelname)s: %(message)s", force=True
)

In [None]:
# Path to the CM1 repository
scratch_dir = Path(os.getenv("SCRATCH", "/Users/ahijevyc/sysdisk1"))
cm1_path = scratch_dir / "CM1"
runlocal = False  # set True to run locally as on cmd line

In [None]:
testcase = "supercell"
run_dir = cm1_path / f"run_{testcase}"

if runlocal:
    serial = True
    pbs_config = None
else:
    pbs_config = PBS(
        name=testcase,
        account=os.getenv("PBS_ACCOUNT", "NMMM0001"),
        walltime="01:00:00",
        nodes=1,
    )

# If the testcase is 1 or 2-D use cm1 compiled for single processor
oneortwod = testcase.startswith("scm_") or testcase == "nh_mountain_waves"
if oneortwod or runlocal:
    executable_path = cm1_path / "run/cm1.single.gnu.exe"
else:
    executable_path = cm1_path / "run/cm1.exe"
pbs_config

In [None]:
cm1_run = CM1Run(
    cm1_path=cm1_path,
    run_dir=run_dir,
    executable_path=executable_path,
    pbs_config=pbs_config,
    serial=runlocal,
)
# Modify values in namelist
namelist = cm1_run.namelist
namelist["param2"]["isnd"] = 7
namelist["param9"]["output_format"] = 2
# Combine output in one netCDF file.
namelist["param9"]["output_filetype"] = 1

valid_time = pd.to_datetime("20240525")
lon = -95 * units.degree_E
lat = 31 * units.degree_N
sndfile = scratch_dir / "tmp" / f"{testcase}.{valid_time:%Y%m%d%H}.{lat.m}N.{lon.m}E.nc"
if sndfile.exists():
    print(f"open {sndfile}")
    input_sounding_ds = Sounding(sndfile)
else:
    era5_func = era5_model_level if os.path.exists("/glade/campaign") else era5_aws

    input_sounding_ds = era5_func(valid_time, lon=lon, lat=lat)
    input_sounding_ds.metpy.dequantify().to_netcdf(sndfile)
    print(f"wrote {sndfile}")

cm1_run.sounding = input_sounding_ds
cm1_run

In [None]:
input_sounding_ds

In [None]:
cm1_run.run()
print(cm1_run.readme)

In [None]:
cm1_run.sounding.plot()

In [None]:
# Open NetCDF file
output_ds = xr.open_dataset(run_dir / "cm1out.nc", decode_timedelta=True)

animate_cm1out_nc(output_ds.winterp.sel(zh=2, method='nearest'))

In [None]:
help(CM1Run)

In [None]:
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from IPython.display import HTML


def animate_cm1out_nc(
    data: xr.DataArray,
    interval: int = 200,
    **kwargs,
):
    """
    Create an animation of a user-specified 2D field over its time dimension.

    The user is responsible for selecting any other dimensions (e.g., vertical)
    before passing the data to this function.

    Parameters:
    - data: xr.DataArray. A DataArray with a 'time' coordinate, ready to be plotted.
      Example: ds['cref'] or ds['u'].sel(zh=1.0, method='nearest')
    - interval: int, Interval between frames in milliseconds (default: 200ms).
    - **kwargs: Additional keyword arguments passed to the plot function.
    """
    if "time" not in data.dims:
        raise ValueError("Input DataArray must have a 'time' dimension.")

    time = pd.to_timedelta(data.time)

    # --- SIMPLIFIED LOGIC ---
    # The title is built from the DataArray's attributes.
    title_parts = []
    if data.name:
        title_parts.append(data.name)

    # Check for a vertical coordinate to add its value to the title.
    for coord_name in ["zh", "z", "level", "pressure"]:
        if coord_name in data.coords and data[coord_name].size == 1:
            level_val = data[coord_name].item()
            units = data[coord_name].attrs.get("units", "")
            title_parts.append(f"at {level_val:.2f} {units}")
            break

    title_parts.append("Time: {time}")
    title_template = ", ".join(title_parts)

    # --- PLOTTING LOGIC (largely unchanged) ---
    img = data.isel(time=0).plot.imshow(origin="lower", **kwargs)

    # Animation function
    def update(frame):
        img.set_array(data.isel(time=frame))

        # Use the appropriate title template.
        current_time = time[frame]
        img.axes.set_title(title_template.format(time=current_time))
        return [img]

    # Create animation
    ani = animation.FuncAnimation(
        img.figure, update, frames=len(time), interval=interval, blit=False
    )

    # Display in notebook
    anim_html = HTML(ani.to_jshtml())
    plt.close(img.figure)  # close figure to prevent static image display
    return anim_html