In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not any(["StreamHandler" in str(handler) for handler in logger.handlers]):
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import glob
import os
import pathlib
import re
import time
from collections import OrderedDict

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml
from cfd_model.initialization.periodic_channel_jet_initializer import calc_jet_forcing
from cfd_model.interpolator.torch_interpolator import (
    interpolate,
    interpolate_time_series,
)
from IPython.display import display
from src.sr_da_helper import (
    get_observation_with_noise,
    get_testdataset,
    initialize_models,
    make_invprocessed_sr,
    make_models,
    make_preprocessed_lr,
    make_preprocessed_obs,
    read_all_hr_omegas_with_combining,
)
from src.utils import set_seeds
from tqdm.notebook import tqdm

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
ROOT_DIR

In [None]:
TMP_DATA_DIR = "/workspace/all_data/notebook/paper_experiment_01/data"
if not os.path.exists(TMP_DATA_DIR):
    TMP_DATA_DIR = "./data"
os.makedirs(TMP_DATA_DIR, exist_ok=True)

In [None]:
CSV_DATA_DIR = "./csv"
os.makedirs(CSV_DATA_DIR, exist_ok=True)

In [None]:
FIG_DIR = "./fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
CONFIG_DIR = f"{ROOT_DIR}/pytorch/config/paper_experiment_01"
CONFIG_PATHS = sorted(glob.glob(f"{CONFIG_DIR}/*.yml"))

In [None]:
CFD_DIR_NAME = "jet02"

ASSIMILATION_PERIOD = 4
START_TIME_INDEX = 16

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

Y0_MEAN = np.pi / 2.0
SIGMA_MEAN = 0.4
TAU0_MEAN = 0.3

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
HR_COEFF_DIFFUSION = 1e-5
LR_COEFF_DIFFUSION = 5e-5

DT = LR_DT * LR_NT
T0 = START_TIME_INDEX * LR_DT * LR_NT

N_ENS_PER_CHUNK = 125

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
    logger.info("GPU is used.")
else:
    logger.error("No GPU. CPU is used.")
    raise Exception("No GPU. CPU is used.")

In [None]:
LR_CFD_CONFIG = {
    "nx": LR_NX,
    "ny": LR_NY,
    "coeff_linear_drag": COEFF_LINEAR_DRAG,
    "coeff_diffusion": LR_COEFF_DIFFUSION,
    "order_diffusion": ORDER_DIFFUSION,
    "beta": BETA,
    "device": DEVICE,
}

INDEX_CONFIG = {
    "assimilation_period": ASSIMILATION_PERIOD,
    "n_ens": N_ENS_PER_CHUNK,
    "lr_nx": LR_NX,
    "lr_ny": LR_NY,
    "hr_nx": HR_NX,
    "hr_ny": HR_NY,
    "device": DEVICE,
}

In [None]:
CONFIGS = OrderedDict()

for config_path in CONFIG_PATHS:
    with open(config_path) as file:
        config = yaml.safe_load(file)

    config_name = os.path.basename(config_path).split(".")[0]
    assert config_name not in CONFIGS

    experiment_name = config_path.split("/")[-2]
    _dir = (
        f"/workspace/all_data/data/pytorch/DL_results/{experiment_name}/{config_name}"
    )
    if not os.path.exists(_dir):
        _dir = f"{ROOT_DIR}/data/pytorch/DL_results/{experiment_name}/{config_name}"

    CONFIGS[config_name] = {
        "config": config,
        "model_name": config["model"]["model_name"],
        "experiment_name": experiment_name,
        "weight_path": f"{_dir}/weights.pth",
        "learning_history_path": f"{_dir}/learning_history.csv",
    }

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

# Define methods

In [None]:
def plot(
    dict_data: dict,
    t: float,
    obs: np.ndarray,
    figsize: list = [20, 2],
    write_out: bool = False,
    ttl_header: str = "",
    fig_file_name: str = "",
    use_hr_space: bool = True,
    vmin_omega: float = -9,
    vmax_omega: float = 9,
):

    xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=HR_NY, endpoint=True)
    hr_x, hr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=LR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=LR_NY, endpoint=True)
    lr_x, lr_y = np.meshgrid(xs, ys, indexing="ij")

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 18
    fig, axes = plt.subplots(
        1, len(dict_data), figsize=figsize, sharex=True, sharey=False
    )

    gt = None
    for ax, (label, data) in zip(axes, dict_data.items()):
        if label == "LR":
            if use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=HR_NX, ny=HR_NY, mode="nearest"
                ).numpy()
        else:
            if not use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=LR_NX, ny=LR_NY
                ).numpy()

        if use_hr_space:
            x, y = hr_x, hr_y
        else:
            x, y = lr_x, lr_y

        d = np.squeeze(data)
        if label == "HR":
            gt = d
            ttl = "HR Ground Truth"
        else:
            mae = np.mean(np.abs(gt - d))
            ttl = label
            ttl = f"{label}\nMAE={mae:.2f}"

        if use_hr_space:
            assert d.shape == (HR_NX, HR_NY)
        else:
            assert d.shape == (LR_NX, LR_NY)

        cnts = ax.pcolormesh(
            x, y, d, cmap="twilight_shifted", vmin=vmin_omega, vmax=vmax_omega
        )
        ax.set_title(ttl)

        fig.colorbar(
            cnts,
            ax=ax,
            ticks=[vmin_omega, vmin_omega / 2, 0, vmax_omega / 2, vmax_omega],
            extend="both",
        )

        ax.set_xlim([0, 2 * np.pi])
        ax.set_ylim([0, np.pi])

        if label == "HR" and use_hr_space:
            obs = np.squeeze(obs).flatten()
            obs_x = x.flatten()[~np.isnan(obs)]
            obs_y = y.flatten()[~np.isnan(obs)]
            ax.scatter(obs_x, obs_y, marker=".", s=1, c="k")

        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    plt.suptitle(f"{ttl_header}Time = {t}")
    plt.tight_layout()

    if write_out:
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.jpg", dpi=300)
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.eps", dpi=300)

    plt.show()


def get_all_tmp_file_paths(config_name):
    sr_analysis_file_path = f"{TMP_DATA_DIR}/sr_analysis_{config_name}.npy"
    lr_omega_file_path = f"{TMP_DATA_DIR}/lr_omega_{config_name}.npy"
    hr_omega_file_path = f"{TMP_DATA_DIR}/hr_omega_{config_name}.npy"
    hr_obsrv_file_path = f"{TMP_DATA_DIR}/hr_obsrv_{config_name}.npy"

    return (
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    )


def read_all_tmp_files(config_name):
    (
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    ) = get_all_tmp_file_paths(config_name)

    return (
        np.load(sr_analysis_file_path),
        np.load(lr_omega_file_path),
        np.load(hr_omega_file_path),
        np.load(hr_obsrv_file_path),
    )

# Plot learning curves

In [None]:
for config_name, config_info in CONFIGS.items():
    try:
        df = pd.read_csv(config_info["learning_history_path"])
        assert len(df) == config_info["config"]["train"]["num_epochs"]

        continue
        if config["train"]["seed"] != 221958:
            continue

        plt.rcParams["font.size"] = 15
        fig = plt.figure(figsize=[7, 5])
        ax = plt.subplot(111)

        df.plot(
            ax=ax,
            xlabel="Epochs",
            ylabel=config_info["config"]["train"]["loss"]["name"],
        )
        ax.set_title(config_name)
        plt.yscale("log")

        # fig.savefig(f"{FIG_DIR}/{config_name}_learning_curve.jpg")
        plt.show()
    except Exception as e:
        print(config_name)
        print(e)

# Perform SR-DA

In [None]:
for config_name in tqdm(CONFIGS.keys(), total=len(CONFIGS)):
    set_seeds(42, use_deterministic=True)

    (
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    ) = get_all_tmp_file_paths(config_name)

    if (
        os.path.exists(sr_analysis_file_path)
        and os.path.exists(lr_omega_file_path)
        and os.path.exists(hr_omega_file_path)
        and os.path.exists(hr_obsrv_file_path)
    ):
        logger.info(f"Results already exist. So skip {config_name}")
        continue

    config_info = CONFIGS[config_name]
    weight_path = config_info["weight_path"]

    if not os.path.exists(weight_path):
        logger.info(f"No weight file for {config_name}")
        continue
    logger.info(f"{config_name} is being evaluated")

    config = config_info["config"]

    test_dataset = get_testdataset(ROOT_DIR, config)
    assert test_dataset.obs_time_interval == ASSIMILATION_PERIOD

    all_hr_omegas = read_all_hr_omegas_with_combining(test_dataset.hr_file_paths)
    assert all_hr_omegas.shape[1:] == (81, HR_NX, HR_NY)
    assert all_hr_omegas.shape[0] % N_ENS_PER_CHUNK == 0

    _, lr_forcing = calc_jet_forcing(
        nx=LR_NX,
        ny=LR_NY,
        ne=1,
        y0=Y0_MEAN,
        sigma=SIGMA_MEAN,
        tau0=TAU0_MEAN,
    )

    all_hr_obsrv, all_sr_analysis, all_lr_omega = [], [], []

    for hr_omegas in tqdm(
        torch.split(all_hr_omegas, N_ENS_PER_CHUNK),
        total=(all_hr_omegas.shape[0] // N_ENS_PER_CHUNK),
    ):
        sr_model, lr_model, srda_model = make_models(config, weight_path, LR_CFD_CONFIG)

        initialize_models(
            t0=T0,
            hr_omega0=hr_omegas[:, 0],  # first time index
            lr_forcing=lr_forcing,
            lr_model=lr_model,
            srda_model=srda_model,
            **INDEX_CONFIG,
        )

        hr_obsrv = get_observation_with_noise(hr_omegas, test_dataset, **INDEX_CONFIG)
        assert hr_obsrv.shape == hr_omegas.shape

        ts, hr_obs, lr_omega, lr_forecast, sr_analysis = [], [], [], [], []
        last_omega0 = None
        max_time_index = hr_omegas.shape[1]

        for i_cycle in tqdm(range(max_time_index)):
            logger.debug(f"Start: i_cycle = {i_cycle}, t = {lr_model.t:.2f}")

            ts.append(lr_model.t)
            lr_omega.append(lr_model.omega.cpu().clone())
            lr_forecast.append(srda_model.omega.cpu().clone())

            if config["data"]["use_observation"] and i_cycle % ASSIMILATION_PERIOD == 0:
                hr_obs.append(hr_obsrv[:, i_cycle])
            else:
                hr_obs.append(torch.full_like(hr_obsrv[:, i_cycle], torch.nan))

            if i_cycle > 0 and i_cycle % ASSIMILATION_PERIOD == 0:
                x = make_preprocessed_lr(
                    lr_forecast,
                    last_omega0,
                    test_dataset,
                    **INDEX_CONFIG,
                )
                o = make_preprocessed_obs(
                    hr_obs,
                    test_dataset,
                    **INDEX_CONFIG,
                )

                if not config["data"].get("use_lr_forecast", True):
                    x = torch.full_like(x, test_dataset.missing_value)

                with torch.no_grad():
                    sr = sr_model(x, o).detach().cpu().clone()
                sr = make_invprocessed_sr(
                    sr,
                    test_dataset,
                    **INDEX_CONFIG,
                )

                last_omega0 = interpolate(sr[-1, :], nx=LR_NX, ny=LR_NY, mode="bicubic")
                srda_model.initialize(t0=srda_model.t, omega0=last_omega0)

                i_start = 0 if len(sr_analysis) == 0 else 1
                for it in range(i_start, sr.shape[0]):
                    sr_analysis.append(sr[it])

                logger.debug(f"Assimilation at i = {i_cycle}")

            lr_model.time_integrate(dt=LR_DT, nt=LR_NT, hide_progress_bar=True)
            lr_model.calc_grid_data()

            srda_model.time_integrate(dt=LR_DT, nt=LR_NT, hide_progress_bar=True)
            srda_model.calc_grid_data()

        # Stack along time dim
        all_hr_obsrv.append(torch.stack(hr_obs, dim=1))
        all_lr_omega.append(torch.stack(lr_omega, dim=1))
        all_sr_analysis.append(torch.stack(sr_analysis, dim=1))

        del sr_model, lr_model, srda_model
        torch.cuda.empty_cache()
        _ = gc.collect()

    # Stack along ensemble (batch) dim
    all_hr_obsrv = torch.cat(all_hr_obsrv, dim=0).to(torch.float32).numpy()
    all_lr_omega = torch.cat(all_lr_omega, dim=0).to(torch.float32).numpy()
    all_sr_analysis = torch.cat(all_sr_analysis, dim=0).to(torch.float32).numpy()
    all_hr_omegas = all_hr_omegas.to(torch.float32).numpy()

    np.save(hr_obsrv_file_path, all_hr_obsrv)
    np.save(lr_omega_file_path, all_lr_omega)
    np.save(sr_analysis_file_path, all_sr_analysis)
    np.save(hr_omega_file_path, all_hr_omegas)

# Anlyze errors

## Make csv

In [None]:
for config_name in tqdm(CONFIGS.keys(), total=len(CONFIGS)):
    try:
        csv_file = f"{CSV_DATA_DIR}/hr_err_time_series_{config_name}_with_mae_ratio.csv"

        if os.path.exists(csv_file):
            continue

        (
            all_sr_analysis,
            all_lr_omega,
            all_hr_omega,
            _,
        ) = read_all_tmp_files(config_name)

        max_time_index = all_hr_omega.shape[1]
        ts = [LR_DT * LR_NT * i_cycle + T0 for i_cycle in range(max_time_index)]

        _lr = interpolate_time_series(
            torch.from_numpy(all_lr_omega), HR_NX, HR_NY
        ).numpy()
        assert _lr.shape == all_hr_omega.shape == (500, 81, HR_NX, HR_NY)

        # mean over x and y dims
        _lr_err1 = np.mean(np.abs(_lr - all_hr_omega), axis=(-2, -1))
        _lr_err2 = np.mean(np.abs(all_hr_omega), axis=(-2, -1))
        # mean over ensemble dim
        lr = np.mean(_lr_err1, axis=(0,))
        lr_ratio = np.mean(_lr_err1 / _lr_err2, axis=(0,))

        assert all_sr_analysis.shape == all_hr_omega.shape == (500, 81, HR_NX, HR_NY)

        _sr_err1 = np.mean(np.abs(all_sr_analysis - all_hr_omega), axis=(-2, -1))
        _sr_err2 = np.mean(np.abs(all_hr_omega), axis=(-2, -1))
        sr = np.mean(_sr_err1, axis=(0,))
        sr_ratio = np.mean(_sr_err1 / _sr_err2, axis=(0,))

        df = pd.DataFrame()
        df["Time"] = ts
        df["ErrLR(bicubic)"] = lr
        df["ErrRatioLR(bicubic)"] = lr_ratio
        df["ErrSR"] = sr
        df["ErrRatioSR"] = sr_ratio
        df.to_csv(csv_file, index=False)
    except Exception as e:
        logger.error(e)

## Analyze csv

In [None]:
df_results = pd.DataFrame(index=CONFIGS.keys())

for config_name in tqdm(CONFIGS.keys(), total=len(CONFIGS)):
    csv_file = f"{CSV_DATA_DIR}/hr_err_time_series_{config_name}_with_mae_ratio.csv"
    if not os.path.exists(csv_file):
        logger.error(f"{csv_file} does not exist!")
        continue

    df = pd.read_csv(csv_file)

    config = CONFIGS[config_name]["config"]

    df_results.loc[config_name, "AveErrSR"] = df["ErrSR"].mean()
    df_results.loc[config_name, "MaxErrSR"] = df["ErrSR"].max()
    df_results.loc[config_name, "MinErrSR"] = df["ErrSR"].min()
    df_results.loc[config_name, "StdErrSR"] = df["ErrSR"].std()
    df_results.loc[config_name, "Max-MinErrSR"] = (
        df_results.loc[config_name, "MaxErrSR"]
        - df_results.loc[config_name, "MinErrSR"]
    )

    df_results.loc[config_name, "UseObs"] = config["data"]["use_observation"]
    df_results.loc[config_name, "ObsGridInterval"] = config["data"]["obs_grid_interval"]
    df_results.loc[config_name, "ObsGridRatio"] = (
        OBS_GRID_RATIO[config["data"]["obs_grid_interval"]] * 100
    )
    df_results.loc[config_name, "ObsNoiseStd"] = config["data"]["obs_noise_std"]

    df_results.loc[config_name, "LrTimeInterval"] = config["data"]["lr_time_interval"]
    df_results.loc[config_name, "UseSkipConn"] = config["model"][
        "use_global_skip_connection"
    ]
    df_results.loc[config_name, "UseMixup"] = config["data"]["use_mixup"]
    df_results.loc[config_name, "alpha"] = config["data"]["beta_dist_alpha"]
    df_results.loc[config_name, "beta"] = config["data"]["beta_dist_beta"]
    df_results.loc[config_name, "UseLrForecast"] = config["data"].get(
        "use_lr_forecast", True
    )
    df_results.loc[config_name, "Seed"] = config["train"]["seed"]

In [None]:
for grid_ratio, grp in df_results.groupby("ObsGridRatio"):
    display(f"Grid Ratio = {grid_ratio*100}")
    display(grp.sort_values("MaxErrSR"))

# Make figures

## Obs grid ratio, groupby use mixup

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 20
ycol = "AveErrSR"

fig = plt.figure()
ax = plt.subplot(111)

df = df_results.sort_values("ObsGridRatio")
df = df[(df["UseObs"] == True) & (df["UseLrForecast"] == True)]

data = df[(df["UseMixup"] == False)]

data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
ys = data["mean"]
errs = np.array([ys - data["min"], data["max"] - ys])
assert len(ys) == 5 and set(data["count"]) == {5}

ax.errorbar(data.index, ys, yerr=errs, fmt="o-", label="No Mixup", capsize=5)


data = df[(df["UseMixup"] == True) & (df["beta"] == 2)]

data = data.groupby("ObsGridRatio")[ycol].agg(func=["mean", "count", "min", "max"])
ys = data["mean"]
errs = np.array([ys - data["min"], data["max"] - ys])
assert len(ys) == 5 and set(data["count"]) == {5}

ax.errorbar(data.index, ys, yerr=errs, fmt="o-", label="Use Mixup", capsize=5)


ax.set_ylabel("Ave MAE")
ax.set_xlabel("Observation point ratio [%]")
ax.legend()

plt.show()

## Using only Observation or not (no LR or not)

In [None]:
ycol = "AveErrSR"

for seed in [221958, 771155, 832180, 465838, 359178]:
    xs, ys, ys_no_lr = [], [], []
    for obs_grid_interval in range(4, 14, 2):
        try:
            config_name_using_lr = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd{seed:06}"
            err = df_results.loc[config_name_using_lr, ycol]

            config_name_no_lr = config_name_using_lr.replace("_muT_", "_muF_") + "_noLR"
            err_no_lr = df_results.loc[config_name_no_lr, ycol]

            xs.append(OBS_GRID_RATIO[obs_grid_interval])
            ys.append(err)
            ys_no_lr.append(err_no_lr)
        except Exception as e:
            print(e)

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 20

    fig = plt.figure()
    ax = plt.subplot(111)

    ax.plot(xs, ys, "o-", label="Using Obs and LR")
    ax.plot(xs, ys_no_lr, "o-", label="Using Obs but No LR")

    ax.set_ylabel("MAE (time ave.)")
    ax.set_xlabel("Observation grid ratio [%]")
    ax.set_title(f"Seed = {seed}")

    ax.legend()
    plt.show()

## Error time series

In [None]:
obs_grid_interval = 12

for seed in [221958, 771155, 832180, 465838, 359178]:
    config_name_without_obs = (
        f"lt4og00_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd{seed:06}"
    )
    config_name_using_obs = config_name_without_obs.replace(
        "og00", f"og{obs_grid_interval:02}"
    )
    config_name_only_obs = config_name_using_obs.replace("muT", "muF") + "_noLR"

    df_without_obs = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name_without_obs}_with_mae_ratio.csv"
    )
    df_using_obs = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name_using_obs}_with_mae_ratio.csv"
    )
    df_only_obs = pd.read_csv(
        f"{CSV_DATA_DIR}/hr_err_time_series_{config_name_only_obs}_with_mae_ratio.csv"
    )

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 20

    fig = plt.figure(figsize=[8, 4])
    ax = plt.subplot(111)

    ax.plot(df_using_obs["Time"], df_using_obs["ErrSR"], "-", label="SRDA")
    ax.plot(df_without_obs["Time"], df_without_obs["ErrSR"], "--", label="Only SR")
    ax.plot(df_only_obs["Time"], df_only_obs["ErrSR"], "--", label="Only Obs")

    ax.plot(
        df_without_obs["Time"],
        df_without_obs["ErrLR(bicubic)"],
        "-.",
        label="No SR or DA",
    )

    ax.set_title(
        f"Observation point ratio = {OBS_GRID_RATIO[obs_grid_interval]*100:.2f} %"
    )
    ax.set_xlabel("Time")
    ax.set_ylabel("MAE")
    ax.set_title(f"Seed = {seed}")

    lg = ax.legend(
        bbox_to_anchor=(1.05, 1.0),
        loc="upper left",
        ncol=1,
        fontsize=16,
        framealpha=1,
        edgecolor="k",
    )
    plt.tight_layout()
    plt.show()

## Vorticity snapshots

In [None]:
target_config_name = "lt4og12_on1e-01_ep1000_lr1e-04_scT_muT_a02_b02_sd221958"

(
    all_sr_analysis_with_noise,
    all_lr_omega,
    all_hr_omega,
    all_hr_obsrv,
) = read_all_tmp_files(target_config_name)

target_config_name = "lt4og12_on1e-01_ep1000_lr1e-04_scT_muF_a02_b02_sd221958"

all_sr_analysis_without_noise, _, _, _ = read_all_tmp_files(target_config_name)

In [None]:
for i_ensemble in [43, 46, 48, 70, 91, 107, 154, 156, 163, 185]:
    for i_cycle in [40, 80]:

        t = (i_cycle + START_TIME_INDEX) * DT

        dict_data = {
            "HR": all_hr_omega[i_ensemble, i_cycle],
            "LR": all_lr_omega[i_ensemble, i_cycle],
            "SRDA(noise)": all_sr_analysis_with_noise[i_ensemble, i_cycle],
            "SRDA(no noise)": all_sr_analysis_without_noise[i_ensemble, i_cycle],
        }

        plot(
            dict_data,
            t,
            obs=all_hr_obsrv[i_ensemble, i_cycle],
            figsize=[12, 4.0],
            ttl_header=f"i_ens = {i_ensemble}, ",
            fig_file_name=f"snapshot_use_obs_{i_ensemble:03}_t_{int(t):03}",
            write_out=False,
            use_hr_space=True,
            vmin_omega=-10,
            vmax_omega=10,
        )