In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import glob
import os
import pathlib
import re
import time
from collections import OrderedDict

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml
from cfd_model.interpolator.torch_interpolator import (
    interpolate,
    interpolate_time_series,
)
from IPython.display import display
from src.ssim import SSIM
from src.utils import read_pickle, set_seeds
from tqdm.notebook import tqdm

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)
plt.rcParams["font.family"] = "serif"

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
ROOT_DIR

In [None]:
DEVICE = "cuda:1"
if not torch.cuda.is_available():
    raise Exception("No GPU. CPU is used.")

In [None]:
TMP_DATA_DIR = "./data"
os.makedirs(TMP_DATA_DIR, exist_ok=True)

In [None]:
CSV_DATA_DIR = "./csv"
os.makedirs(CSV_DATA_DIR, exist_ok=True)

In [None]:
# FIG_DIR = f"{ROOT_DIR}/doc/james_1st_rev/fig"
# os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
CONFIG_DIR = f"{ROOT_DIR}/pytorch/config/paper_experiment_06"
CONFIG_PATHS = sorted(glob.glob(f"{CONFIG_DIR}/*.yml"))

In [None]:
ENKF_ROOT_DIR = f"{ROOT_DIR}/pytorch/notebook/paper_experiment_02/data/EnKF"

In [None]:
ASSIMILATION_PERIOD = 4
FORECAST_SPAN = 4
START_TIME_INDEX = 0
MAX_TIME_INDEX_FOR_INTEGRATION = 96

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

Y0_MEAN = np.pi / 2.0
SIGMA_MEAN = 0.4
TAU0_MEAN = 0.3

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
HR_COEFF_DIFFUSION = 1e-5
LR_COEFF_DIFFUSION = 5e-5

DT = LR_DT * LR_NT
T0 = START_TIME_INDEX * LR_DT * LR_NT

N_ENS_PER_CHUNK = 125

In [None]:
CONFIGS = OrderedDict()

for config_path in CONFIG_PATHS:
    with open(config_path) as file:
        config = yaml.safe_load(file)

    config_name = os.path.basename(config_path).split(".")[0]
    assert config_name not in CONFIGS

    experiment_name = config_path.split("/")[-2]
    _dir = f"{ROOT_DIR}/data/pytorch/DL_results/{experiment_name}/{config_name}"

    CONFIGS[config_name] = {
        "config": config,
        "model_name": config["model"]["model_name"],
        "experiment_name": experiment_name,
        "weight_path": f"{_dir}/weights.pth",
        "learning_history_path": f"{_dir}/learning_history.csv",
    }

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

# Define methods

In [None]:
def plot(
    dict_data: dict,
    t: float,
    obs: np.ndarray,
    figsize: list = [20, 2],
    write_out: bool = False,
    ttl_header: str = "",
    fig_file_name: str = "",
    use_hr_space: bool = True,
    vmin_omega: float = -9,
    vmax_omega: float = 9,
):

    xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=HR_NY, endpoint=True)
    hr_x, hr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=LR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=LR_NY, endpoint=True)
    lr_x, lr_y = np.meshgrid(xs, ys, indexing="ij")

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 18
    fig, axes = plt.subplots(
        1, len(dict_data), figsize=figsize, sharex=True, sharey=False
    )

    gt = None
    for ax, (label, data) in zip(axes, dict_data.items()):
        if "LR" in label:
            if use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=HR_NX, ny=HR_NY, mode="nearest"
                ).numpy()
        else:
            if not use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=LR_NX, ny=LR_NY
                ).numpy()

        if use_hr_space:
            x, y = hr_x, hr_y
        else:
            x, y = lr_x, lr_y

        d = np.squeeze(data)
        if label == "HR":
            gt = d
            ttl = "HR Ground Truth"
        else:
            maer = np.mean(np.abs(gt - d)) / np.mean(np.abs(gt))
            ttl = label
            ttl = f"{label}\nMAER={maer:.2f}"

        if use_hr_space:
            assert d.shape == (HR_NX, HR_NY)
        else:
            assert d.shape == (LR_NX, LR_NY)

        cnts = ax.pcolormesh(
            x, y, d, cmap="twilight_shifted", vmin=vmin_omega, vmax=vmax_omega
        )
        ax.set_title(ttl)

        fig.colorbar(
            cnts,
            ax=ax,
            ticks=[vmin_omega, vmin_omega / 2, 0, vmax_omega / 2, vmax_omega],
            extend="both",
        )

        ax.set_xlim([0, 2 * np.pi])
        ax.set_ylim([0, np.pi])

        if label == "HR" and use_hr_space:
            obs = np.squeeze(obs).flatten()
            obs_x = x.flatten()[~np.isnan(obs)]
            obs_y = y.flatten()[~np.isnan(obs)]
            ax.scatter(obs_x, obs_y, marker=".", s=1, c="k")

        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    plt.suptitle(f"{ttl_header}Time = {t}")
    plt.tight_layout()

    # if write_out:
    #     fig.savefig(f"{FIG_DIR}/{fig_file_name}.jpg")

    plt.show()


def read_all_tmp_files(config_name):
    npz_file_path = f"{TMP_DATA_DIR}/{config_name}_with_init_noise.npz"
    data = np.load(npz_file_path)

    return data["sr_frcst"], data["lr_omega"], data["hr_omega"], data["hr_obsrv"]

# Vorticity snapshots

In [None]:
obs_grid_interval = 8

In [None]:
target_config_name = (
    f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd832180"
)

(
    all_sr_forecast,
    all_lr_omega,
    all_hr_omega,
    all_hr_obsrv,
) = read_all_tmp_files(target_config_name)

enkf_dir = f"{ENKF_ROOT_DIR}/{target_config_name}_with_init_noise"
assert os.path.exists(enkf_dir)

In [None]:
for i_ensemble in [185]:

    enkf_lr = np.load(f"{enkf_dir}/ens_all_lr_{i_ensemble:04}_with_init_noise.npy")
    enkf_lr = np.mean(enkf_lr, axis=0)  # ensemble mean
    assert enkf_lr.shape == (MAX_TIME_INDEX_FOR_INTEGRATION, 32, 17)

    enkf_hr = read_pickle(
        f"{enkf_dir}/ens_mean_hr_{i_ensemble:04}_with_init_noise.pickle"
    )

    for i_cycle in [28, 56, 92]:
        t = (i_cycle + START_TIME_INDEX) * DT

        dict_data = {
            "HR": all_hr_omega[i_ensemble, i_cycle],
            "LR": all_lr_omega[i_ensemble, i_cycle],
            "EnKF(LR)": enkf_lr[i_cycle],
            "EnKF(HR)": enkf_hr[i_cycle],
            "SRDA": all_sr_forecast[i_ensemble, i_cycle],
        }

        plot(
            dict_data,
            t,
            obs=all_hr_obsrv[i_ensemble, i_cycle],
            figsize=[24, 4.0],
            ttl_header=f"i_ens = {i_ensemble}, ",
            fig_file_name=f"snapshot_srda_vs_enkf_t{int(t):03}",
            write_out=True,
            use_hr_space=False,
            vmin_omega=-10,
            vmax_omega=10,
        )

# Calculate time series of errors

In [None]:
ssim_func_gauss = SSIM(size_average=False, use_gauss=True)
ssim_func_unfrm = SSIM(size_average=False, use_gauss=False)

In [None]:
intervals = np.arange(4, 14, 2)

for is_half in [False]:
    for seed in tqdm([221958, 771155, 832180, 465838, 359178]):
        for obs_grid_interval in tqdm(intervals, total=len(intervals)):

            config_name = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd{seed:06}"

            csv_path = f"{CSV_DATA_DIR}/enkf_srda_err_time_series_{config_name}_with_init_noise.csv"
            if is_half:
                csv_path = f"{CSV_DATA_DIR}/enkf_srda_err_time_series_{config_name}_with_init_noise_half_test_data.csv"

            if os.path.exists(csv_path):
                logger.info(f"CSV already exists. {config_name}")
                continue

            (
                all_sr_forecast,
                all_lr_omega,
                all_hr_omega,
                _,
            ) = read_all_tmp_files(config_name)

            num_batch_indices = all_hr_omega.shape[0]
            if is_half:
                num_batch_indices = num_batch_indices // 2
            logger.info(f"num_batch_indices = {num_batch_indices}")

            num_time_indices = MAX_TIME_INDEX_FOR_INTEGRATION
            logger.info(f"num_time_indices = {num_time_indices}")

            enkf_dir = f"{ENKF_ROOT_DIR}/{config_name.replace(f'_sd{seed:06}', '_sd832180')}_with_init_noise"

            all_enkf_mean = [
                np.load(f"{enkf_dir}/ens_all_lr_{i:04}_with_init_noise.npy")
                for i in range(num_batch_indices)
            ]
            all_enkf_mean = np.stack(all_enkf_mean, axis=0)  # stack along batch dim
            all_enkf_mean = np.mean(all_enkf_mean, axis=1)  # mean over ensemble dim

            ts = [LR_DT * LR_NT * i_cycle + T0 for i_cycle in range(num_time_indices)]

            df_results = pd.DataFrame()
            df_results["Time"] = ts[:num_time_indices]

            for kind in ["HR", "LR"]:
                gt = all_hr_omega[:num_batch_indices, :num_time_indices]
                lr = all_lr_omega[:num_batch_indices, :num_time_indices]
                srda = all_sr_forecast[:num_batch_indices, :num_time_indices]
                enkf = all_enkf_mean[:num_batch_indices, :num_time_indices]

                if kind == "HR":
                    lr = interpolate_time_series(
                        torch.from_numpy(lr), HR_NX, HR_NY
                    ).numpy()
                    enkf = interpolate_time_series(
                        torch.from_numpy(enkf), HR_NX, HR_NY
                    ).numpy()
                elif kind == "LR":
                    gt = interpolate_time_series(
                        torch.from_numpy(gt), LR_NX, LR_NY
                    ).numpy()
                    srda = interpolate_time_series(
                        torch.from_numpy(srda), LR_NX, LR_NY
                    ).numpy()
                else:
                    raise NotImplementedError()

                for label, data in zip(["LR", "SRDA", "EnKF"], [lr, srda, enkf]):
                    assert gt.shape == data.shape

                    errs = np.mean(
                        np.abs(gt - data), axis=(-2, -1)
                    )  # mean over x and y
                    nrms = np.mean(np.abs(gt), axis=(-2, -1))

                    mae = np.mean(errs, axis=0)  # mean over batch
                    maer = np.mean(errs / nrms, axis=0)

                    # Only time dim remains.
                    assert mae.shape == maer.shape == (num_time_indices,)

                    df_results[f"MAE_{kind}_{label}"] = mae
                    df_results[f"MAER_{kind}_{label}"] = maer

                    for ssim_kind, ssim_func in zip(
                        ["Gauss", "Uniform"], [ssim_func_gauss, ssim_func_unfrm]
                    ):

                        ssim = ssim_func(
                            img1=torch.from_numpy(gt).to(DEVICE),
                            img2=torch.from_numpy(data).to(DEVICE),
                        )
                        ssim = torch.mean(
                            ssim, dim=(0, 2, 3)
                        )  # mean over batch, x, and y
                        ssim = ssim.detach().cpu()

                        # Only time dim remains.
                        assert ssim.shape == (num_time_indices,)

                        df_results[f"SSIM_{ssim_kind}_{kind}_{label}"] = 1 - ssim

            df_results.to_csv(csv_path, index=False)

# Plot error time series

In [None]:
for resolution in ["HR", "LR"]:
    intervals = np.arange(4, 14, 2)
    for obs_grid_interval in tqdm(intervals, total=len(intervals)):

        dict_dfs = {}
        for seed in [221958, 771155, 832180, 465838, 359178]:
            config_name = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd{seed:06}"
            csv_path = f"{CSV_DATA_DIR}/enkf_srda_err_time_series_{config_name}_with_init_noise.csv"
            dict_dfs[seed] = pd.read_csv(csv_path)

        grid_ratio = OBS_GRID_RATIO[obs_grid_interval] * 100

        fig, axes = plt.subplots(1, 3, sharex=True, figsize=[16, 3])

        for ax, ycol in zip(axes, ["MAER", "SSIM_Gauss", "SSIM_Uniform"]):
            xs = None

            for ls, label in zip(["-.", "-", "--"], ["LR", "SRDA", "EnKF"]):
                all_ys = []
                for df in dict_dfs.values():
                    all_ys.append(df[f"{ycol}_{resolution}_{label}"].values)
                    if xs is None:
                        xs = df["Time"].values
                    else:
                        np.testing.assert_array_equal(xs, df["Time"].values)

                all_ys = np.stack(all_ys)
                assert all_ys.shape == (5, MAX_TIME_INDEX_FOR_INTEGRATION)
                # axis = batch, time

                ys = np.mean(all_ys, axis=0)
                diffs = all_ys - ys
                min_diffs = -np.min(diffs, axis=0)
                max_diffs = np.max(diffs, axis=0)
                errs = np.stack([min_diffs, max_diffs])
                assert errs.shape == (2, MAX_TIME_INDEX_FOR_INTEGRATION)
                assert np.nanmax(errs) >= 0

                if label == "SRDA":
                    ax.errorbar(xs, ys, yerr=errs, label=label, ls=ls, capsize=2)
                    # ax.plot(xs, ys, label=label, ls=ls)
                else:
                    ax.plot(xs, ys, label=label, ls=ls)

            ax.set_xlabel("Time")
            ax.set_ylabel(f"{ycol} ({resolution})")
            ax.set_title(f"Observation grid ratio: {grid_ratio:.2f} %")
            ax.legend()
            ax.set_xticks(np.linspace(4, 24, 6))

            if ycol == "MAER":
                ax.set_ylim(0, 1)
            else:
                ax.set_ylim(0.0, 0.4)

        plt.tight_layout()

        # if obs_grid_interval == 8:
        #     fig.savefig(f"{FIG_DIR}/err_time_series_enkf_vs_srda_{resolution}.jpg")

        plt.show()

# Dependency on observation grid ratio

In [None]:
for resolution in ["HR", "LR"]:
    intervals = np.arange(4, 14, 2)
    xs = []
    ys = {
        "MAER": {"SRDA": [], "EnKF": []},
        "SSIM_Gauss": {"SRDA": [], "EnKF": []},
    }
    y_min_errs = {
        "MAER": {"SRDA": [], "EnKF": []},
        "SSIM_Gauss": {"SRDA": [], "EnKF": []},
    }
    y_max_errs = {
        "MAER": {"SRDA": [], "EnKF": []},
        "SSIM_Gauss": {"SRDA": [], "EnKF": []},
    }

    for obs_grid_interval in tqdm(intervals, total=len(intervals)):
        grid_ratio = OBS_GRID_RATIO[obs_grid_interval] * 100
        xs.append(grid_ratio)

        dict_dfs = {}
        for seed in [221958, 771155, 832180, 465838, 359178]:
            config_name = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd{seed:06}"
            csv_path = f"{CSV_DATA_DIR}/enkf_srda_err_time_series_{config_name}_with_init_noise.csv"
            dict_dfs[seed] = pd.read_csv(csv_path)

        for ycol in ["MAER", "SSIM_Gauss"]:
            for label in ["SRDA", "EnKF"]:
                _ys = []
                for df in dict_dfs.values():
                    _ys.append(df[f"{ycol}_{resolution}_{label}"].mean())
                _ys = np.array(_ys)
                assert _ys.shape == (5,)

                y = np.mean(_ys)
                ys[ycol][label].append(y)
                y_min_errs[ycol][label].append(-np.min(_ys - y))
                y_max_errs[ycol][label].append(np.max(_ys - y))

    fig, axes = plt.subplots(1, 2, sharex=True, figsize=[12, 4])

    for ax, ycol in zip(axes, ["MAER", "SSIM_Gauss"]):
        for label, _ys in ys[ycol].items():
            if label == "SRDA":
                errs = np.array([y_min_errs[ycol][label], y_max_errs[ycol][label]])
                ax.errorbar(xs, _ys, yerr=errs, fmt="o-", label=label, capsize=5)
            else:
                ax.plot(xs, _ys, "o-", label=label)

        ax.set_xlabel("Observation point ratio [%]")
        ax.set_ylabel(f"{ycol} ({resolution})")
        ax.legend()

        if ycol == "MAER":
            ax.set_ylim(0.08, 0.4)
            ax.set_yticks(np.linspace(0.08, 0.4, 5))
        else:
            ax.set_ylim(0.00, 0.16)
            ax.set_yticks(np.linspace(0.00, 0.16, 5))

    plt.tight_layout()
    # fig.savefig(f"{FIG_DIR}/obs_point_vs_error_for_enkf_srda_{resolution}.jpg")
    plt.show()

# Calculate time series of errors for all simulations

In [None]:
ssim_func_gauss = SSIM(size_average=False, use_gauss=True)
ssim_func_unfrm = SSIM(size_average=False, use_gauss=False)

In [None]:
intervals = np.arange(4, 14, 2)

for is_half in [False]:
    for seed in tqdm([221958, 771155, 832180, 465838, 359178]):
        for obs_grid_interval in tqdm(intervals, total=len(intervals)):

            config_name = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd{seed:06}"

            csv_path = f"{CSV_DATA_DIR}/enkf_srda_err_time_series_{config_name}_with_init_noise_all_simulations.csv"
            if is_half:
                csv_path = f"{CSV_DATA_DIR}/enkf_srda_err_time_series_{config_name}_with_init_noise_half_test_data_all_simulations.csv"

            if os.path.exists(csv_path):
                logger.info(f"CSV already exists. {config_name}")
                continue

            (
                all_sr_forecast,
                _,
                all_hr_omega,
                _,
            ) = read_all_tmp_files(config_name)
            del _
            gc.collect()

            num_batch_indices = all_hr_omega.shape[0]
            if is_half:
                num_batch_indices = num_batch_indices // 2

            num_time_indices = MAX_TIME_INDEX_FOR_INTEGRATION

            enkf_dir = f"{ENKF_ROOT_DIR}/{config_name.replace(f'_sd{seed:06}', '_sd832180')}_with_init_noise"

            all_enkf_mean = [
                np.load(f"{enkf_dir}/ens_all_lr_{i:04}_with_init_noise.npy")
                for i in range(num_batch_indices)
            ]
            all_enkf_mean = np.stack(all_enkf_mean, axis=0)  # stack along batch dim
            all_enkf_mean = np.mean(all_enkf_mean, axis=1)  # mean over ensemble dim
            gc.collect()

            ts = [LR_DT * LR_NT * i_cycle + T0 for i_cycle in range(num_time_indices)]
            logger.info("All data have been read.")

            df_all_results = []
            df_times = pd.DataFrame()
            df_times["time"] = ts[:num_time_indices]
            df_all_results.append(df_times)

            for kind in ["HR", "LR"]:
                gt = all_hr_omega[:num_batch_indices, :num_time_indices]
                srda = all_sr_forecast[:num_batch_indices, :num_time_indices]
                enkf = all_enkf_mean[:num_batch_indices, :num_time_indices]

                if kind == "HR":
                    enkf = interpolate_time_series(
                        torch.from_numpy(enkf), HR_NX, HR_NY
                    ).numpy()
                elif kind == "LR":
                    gt = interpolate_time_series(
                        torch.from_numpy(gt), LR_NX, LR_NY
                    ).numpy()
                    srda = interpolate_time_series(
                        torch.from_numpy(srda), LR_NX, LR_NY
                    ).numpy()
                else:
                    raise NotImplementedError()

                for label, data in zip(["SRDA", "EnKF"], [srda, enkf]):
                    logger.info(f"{kind}:{label} is being evaluated")
                    assert gt.shape == data.shape

                    mae = np.mean(np.abs(gt - data), axis=(-2, -1))  # mean over x and y
                    nrms = np.mean(np.abs(gt), axis=(-2, -1))
                    maer = mae / nrms

                    # batch and time dims remain.
                    assert (
                        mae.shape == maer.shape == (num_batch_indices, num_time_indices)
                    )

                    cols = [
                        f"MAE_{kind}_{label}_{i:03}" for i in range(num_batch_indices)
                    ]
                    df_all_results.append(
                        pd.DataFrame(data=mae.transpose(), columns=cols)
                    )

                    cols = [
                        f"MAER_{kind}_{label}_{i:03}" for i in range(num_batch_indices)
                    ]
                    df_all_results.append(
                        pd.DataFrame(data=maer.transpose(), columns=cols)
                    )

                    for ssim_kind, ssim_func in zip(
                        ["Gauss", "Uniform"], [ssim_func_gauss, ssim_func_unfrm]
                    ):
                        logger.info(f"SSIM: {ssim_kind}")
                        ssim = ssim_func(
                            img1=torch.from_numpy(gt).to(DEVICE),
                            img2=torch.from_numpy(data).to(DEVICE),
                        )
                        ssim = torch.mean(ssim, dim=(2, 3))  # mean over x, and y
                        ssim = ssim.detach().cpu().numpy()

                        # Only time dim remains.
                        assert ssim.shape == (num_batch_indices, num_time_indices)

                        cols = [
                            f"SSIM_{ssim_kind}_{kind}_{label}_{i:03}"
                            for i in range(num_batch_indices)
                        ]
                        df_all_results.append(
                            pd.DataFrame(data=ssim.transpose(), columns=cols)
                        )

            df_all_results = pd.concat(df_all_results, axis=1)
            df_all_results.to_csv(csv_path, index=False)
            logger.info("Finished\n")

# Plot error time series for all simulations

In [None]:
resolution = "HR"
max_simulations = 250

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 18

intervals = np.arange(4, 14, 2)

for seed in tqdm([221958, 771155, 832180, 465838, 359178]):
    for obs_grid_interval in tqdm(intervals, total=len(intervals)):

        config_name = f"lt4og{obs_grid_interval:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd{seed:06}"
        csv_path = f"{CSV_DATA_DIR}/enkf_srda_err_time_series_{config_name}_with_init_noise_all_simulations.csv"
        df_all_results = pd.read_csv(csv_path)

        fig, axes = plt.subplots(1, 2, sharex=True, figsize=[16, 4])

        for ax, ycol in zip(axes, ["MAER", "SSIM_Gauss"]):
            all_diffs = []
            for i in range(max_simulations):
                srda = df_all_results[f"{ycol}_{resolution}_SRDA_{i:03}"].values
                enkf = df_all_results[f"{ycol}_{resolution}_EnKF_{i:03}"].values
                all_diffs.append(srda - enkf)
            all_diffs = np.stack(all_diffs, axis=0)

            xs = df_all_results["time"].values
            ys = np.median(all_diffs, axis=0)
            ys_low = np.quantile(all_diffs, 0.1, axis=0)
            ys_high = np.quantile(all_diffs, 0.9, axis=0)

            ax.plot(xs, ys_high, "-", label="90%tile")
            ax.plot(xs, ys, "-", label="50%tile")
            ax.plot(xs, ys_low, "-", label="10%tile")

            ax.axhline(0, ls="--", color="k")
            ax.set_xlabel("Time")
            ax.set_ylabel(f"Diff. {ycol}")
            ax.set_title(f"Difference of {ycol}, SRDA minus EnKF")
            ax.legend()

        r = OBS_GRID_RATIO[obs_grid_interval] * 100
        plt.suptitle(f"Seed = {seed}, Obs Point Ratio = {r:.2f} %")
        plt.tight_layout()
        plt.show()