In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import libraries

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import torch
from src.four_dim_srda.config.config_loader import load_config
from src.four_dim_srda.config.experiment_config import CFDConfig
from src.four_dim_srda.utils.calc_statistics import (
    calc_maer,
    calc_maer_averaging_over_selected_iz,
    calc_mssim,
)
from src.four_dim_srda.utils.io_pickle import write_pickle
from src.four_dim_srda.utils.ssim import MSSIM
from src.four_dim_srda.utils.torch_interpolator import (
    interpolate_2d,
    interpolate_along_z,
)
from src.qg_model.qg_model import QGModel

plt.rcParams["font.family"] = "serif"

# Define constants

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"]).parent.resolve()

In [None]:
experiment_name = "experiment7"

In [None]:
# model_names = ["ConvTransNetVer01", "UNetMaxVitVer01", "UNetVitVer02"]
model_names = ["ConvTransNetVer01", "UNetMaxVitVer01"]

CFG_DIR = f"{ROOT_DIR}/python/configs/four_dim_srda/{experiment_name}"

# srda
dict_cfg_srda_name = {}
for m_name in model_names:
    if m_name == "ConvTransNetVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_bias1_bs12_lr1e-04"
    #
    elif m_name == "UNetMaxVitVer01":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nmb6_bias0_bs12_lr1e-04"
    #
    elif m_name == "UNetVitVer02":
        dict_cfg_srda_name[
            m_name
        ] = "bea2_bed2_dspe360_nsls100_ogx08_ogy08_n3drb3_nvb4_bias0_vits0_bs12_lr1e-04"

# config srda
CFG_SRDA_PATH = f"{CFG_DIR}/perform_4D_SRDA/{model_names[0]}/{dict_cfg_srda_name[model_names[0]]}.yml"


cfg_srda = load_config(model_name=model_names[0], config_path=CFG_SRDA_PATH)

# config cfd
CFG_CFD_PATH = f"{CFG_DIR}/cfd_simulation/qg_model/gpu_evaluation_config.yml"

cfg_cfd = CFDConfig.load(pathlib.Path(CFG_CFD_PATH))

DEVICE_CPU = "cpu"
DEVICE_GPU = torch.device("cuda") if torch.cuda.is_available() else None

cfg_cfd.lr_base_config.device = (
    cfg_cfd.hr_base_config.device
) = cfg_cfd.uhr_base_config.device = DEVICE_CPU

dict_cfg_srda_name

In [None]:
cfg_letkf_name = (
    "na3e-03_letkf_cfg_ogx08_ogy08_ne100_ch16e-04_cr6e+00_if12e-01_lr57e-01_bs6"
)

In [None]:
uhr_model = QGModel(cfg_cfd.uhr_base_config, show_input_cfg_info=False)

In [None]:
DATA_DIR = f"{ROOT_DIR}/data/four_dim_srda"

LR_DATA_DIR = f"{DATA_DIR}/{experiment_name}/cfd_simulation/qg_model/lr_pv_narrow_jet"
UHR_DATA_DIR = f"{DATA_DIR}/{experiment_name}/cfd_simulation/qg_model/uhr_pv_narrow_jet"

In [None]:
_result_dir = f"{ROOT_DIR}/python/results/four_dim_srda/{experiment_name}"
RESULT_DIR = f"{_result_dir}/analysis/use_narrow_jet/store_only_forecast"

In [None]:
ASSIMILATION_PERIOD = (
    cfg_cfd.da_config.segment_length - cfg_cfd.da_config.forecast_span - 1
)
FORECAST_SPAN = cfg_cfd.da_config.forecast_span
# In this research, FORECAST_SPAN_LR = FORECAST_SPAN // 4
FORECAST_SPAN_LR = FORECAST_SPAN // 4

NUM_TIMES = (
    cfg_srda.dataset_config.max_start_time_index + ASSIMILATION_PERIOD + FORECAST_SPAN
)
NUM_TIMES_LR = cfg_cfd.time_config.end_time

# Define methods

In [None]:
def _preprocess(
    data: torch.Tensor, pv_min: float, pv_max: float, use_clipping: bool = False
) -> torch.Tensor:
    #
    # batch, time, z, y, x dims
    assert data.ndim == 5

    # normalization
    data = (data - pv_min) / (pv_max - pv_min)

    if use_clipping:
        data = torch.clamp(data, min=0.0, max=1.0)

    return data

# Prepare data

In [None]:
# Set it to use only data at last forecast time
t_slice = FORECAST_SPAN

In [None]:
# SRDA's results have NaN at it = 0, 1, ..., FORECAST_SPAN, so skip them
# And only the last forecast in one cycle is needed
# So, we set 2 * FORECAST_SPAN

all_gt = []
for i_seed_uhr in range(
    cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
):
    uhr_data_path = f"{UHR_DATA_DIR}/seed{i_seed_uhr:05}/seed{i_seed_uhr:05}_start000_end800_uhr_pv.npy"
    gt = np.load(f"{uhr_data_path}")
    gt = torch.from_numpy(gt[2 * FORECAST_SPAN : NUM_TIMES : t_slice])
    all_gt.append(gt)

all_gt = torch.stack(all_gt, dim=0)
all_gt.shape, all_gt.dtype

In [None]:
all_lr_fcst = []
for i_seed_uhr in range(
    cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
):
    lr_data_path = f"{LR_DATA_DIR}/seed{i_seed_uhr:05}/seed{i_seed_uhr:05}_start000_end200_lr_pv.npy"
    lr_fcst = np.load(f"{lr_data_path}")
    lr_fcst = torch.from_numpy(lr_fcst[2 * FORECAST_SPAN_LR : NUM_TIMES_LR])
    all_lr_fcst.append(lr_fcst)

all_lr_fcst = torch.stack(all_lr_fcst, dim=0)
all_lr_fcst.shape, all_lr_fcst.dtype

In [None]:
dict_srda_fcsts = {}
for m_name in dict_cfg_srda_name.keys():
    srda_hr_fcst = []
    #
    for i_seed_uhr in range(
        cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
    ):
        _path = f"{_result_dir}/srda/{m_name}/use_narrow_jet/store_only_forecast/{dict_cfg_srda_name[m_name]}/UHR_seed_{i_seed_uhr:05}.npz"
        _result_npz = np.load(_path)
        _srda_hr_fcst = torch.from_numpy(
            _result_npz["srda_forecast"][2 * FORECAST_SPAN : NUM_TIMES : t_slice]
        )
        srda_hr_fcst.append(_srda_hr_fcst)
    #
    srda_hr_fcst = torch.stack(srda_hr_fcst, dim=0)
    dict_srda_fcsts[m_name] = srda_hr_fcst

dict_srda_fcsts[m_name].shape, dict_srda_fcsts[m_name].dtype, dict_srda_fcsts.keys()

In [None]:
dict_srda_lr = {}
for m_name in dict_cfg_srda_name.keys():
    srda_lr = []
    #
    for i_seed_uhr in range(
        cfg_cfd.seed_config.uhr_seed_start, cfg_cfd.seed_config.uhr_seed_end + 1
    ):
        # The path is intentionally modified
        # because the srda_lr's result at current time in one inference cycle is needed
        _path = f"{_result_dir}/srda/{m_name}/use_narrow_jet/{dict_cfg_srda_name[m_name]}/UHR_seed_{i_seed_uhr:05}.npz"
        _result_npz = np.load(_path)
        _srda_lr = torch.from_numpy(
            _result_npz["all_lr_forecast"][2 * FORECAST_SPAN_LR :]
        ).to(torch.float32)
        srda_lr.append(_srda_lr)
    #
    srda_lr = torch.stack(srda_lr, dim=0)
    dict_srda_lr[m_name] = srda_lr

dict_srda_lr[m_name].shape, dict_srda_lr[m_name].dtype, dict_srda_lr.keys()

In [None]:
_path = f"{_result_dir}/letkf/perform_letkf_hr_using_uhr/use_narrow_jet/store_only_forecast/{cfg_letkf_name}/all_letkf_fcst.npy"
letkf_hr_fcsts = np.load(_path)
letkf_hr_fcsts = torch.from_numpy(letkf_hr_fcsts[:, 2 * FORECAST_SPAN :: t_slice, ...])

letkf_hr_fcsts.shape

In [None]:
# interpolation

all_lr_fcst_uhr = []

for ib in range(len(all_lr_fcst)):
    _tmp = interpolate_2d(
        data=all_lr_fcst[ib],
        nx=cfg_cfd.uhr_base_config.nx,
        ny=cfg_cfd.uhr_base_config.ny,
        mode="nearest-exact",
    )
    _lr_fcst_uhr = interpolate_along_z(
        data=_tmp,
        nz=cfg_cfd.uhr_base_config.nz,
        mode="nearest-exact",
    )
    all_lr_fcst_uhr.append(_lr_fcst_uhr)
#
all_lr_fcst_uhr = torch.stack(all_lr_fcst_uhr, dim=0)
assert all_gt.shape == all_lr_fcst_uhr.shape

all_lr_fcst_uhr.shape

In [None]:
# interpolation

dict_srda_uhr_fcsts = {}
for m_model in dict_srda_fcsts.keys():
    srda_uhr_fcsts = []
    for ib in range(len(dict_srda_fcsts[m_model])):
        _tmp = interpolate_2d(
            data=dict_srda_fcsts[m_model][ib],
            nx=cfg_cfd.uhr_base_config.nx,
            ny=cfg_cfd.uhr_base_config.ny,
            mode="nearest-exact",
        )
        _srda_uhr_fcst = interpolate_along_z(
            data=_tmp,
            nz=cfg_cfd.uhr_base_config.nz,
            mode="nearest-exact",
        )
        srda_uhr_fcsts.append(_srda_uhr_fcst)
    #
    srda_uhr_fcsts = torch.stack(srda_uhr_fcsts, dim=0)
    assert all_gt.shape == srda_uhr_fcsts.shape

    dict_srda_uhr_fcsts[m_model] = srda_uhr_fcsts

dict_srda_uhr_fcsts[m_model].shape

In [None]:
# interpolation

dict_srda_lr_uhr = {}
for m_model in dict_srda_lr.keys():
    srda_lr_uhr = []
    for ib in range(len(dict_srda_lr[m_model])):
        _tmp = interpolate_2d(
            data=dict_srda_lr[m_model][ib],
            nx=cfg_cfd.uhr_base_config.nx,
            ny=cfg_cfd.uhr_base_config.ny,
            mode="nearest-exact",
        )
        _srda_lr_uhr = interpolate_along_z(
            data=_tmp,
            nz=cfg_cfd.uhr_base_config.nz,
            mode="nearest-exact",
        )
        srda_lr_uhr.append(_srda_lr_uhr)
    #
    srda_lr_uhr = torch.stack(srda_lr_uhr, dim=0)
    assert all_gt.shape == srda_lr_uhr.shape

    dict_srda_lr_uhr[m_model] = srda_lr_uhr

dict_srda_lr_uhr[m_model].shape

In [None]:
# interpolation

letkf_uhr_fcsts = []
for ib in range(len(letkf_hr_fcsts)):
    _tmp = interpolate_2d(
        data=letkf_hr_fcsts[ib],
        nx=cfg_cfd.uhr_base_config.nx,
        ny=cfg_cfd.uhr_base_config.ny,
        mode="nearest-exact",
    )
    _letkf_uhr_fcsts = interpolate_along_z(
        data=_tmp,
        nz=cfg_cfd.uhr_base_config.nz,
        mode="nearest-exact",
    )
    letkf_uhr_fcsts.append(_letkf_uhr_fcsts)

letkf_uhr_fcsts = torch.stack(letkf_uhr_fcsts, dim=0)
assert all_gt.shape == letkf_uhr_fcsts.shape

letkf_uhr_fcsts.shape

# Calc mae ratio

In [None]:
selected_iz = [18, 19, 20, 21, 22]

dict_all_maer = {
    "maer": {},
    "maer_selected_iz": {},
    "maer_time_avg": {},
}

In [None]:
maer = calc_maer(all_gt=all_gt, all_fcst=all_lr_fcst_uhr)
maer_selected_iz = calc_maer_averaging_over_selected_iz(
    all_gt=all_gt,
    all_fcst=all_lr_fcst_uhr,
    selected_iz=selected_iz,
)
maer_time_avg = torch.mean(maer, dim=0)

dict_all_maer["maer"]["lr_fcst"] = maer
dict_all_maer["maer_selected_iz"]["lr_fcst"] = maer_selected_iz
dict_all_maer["maer_time_avg"]["lr_fcst"] = maer_time_avg

dict_all_maer["maer"]["lr_fcst"].shape, dict_all_maer["maer_time_avg"]["lr_fcst"]

In [None]:
dict_all_maer["maer"]["srda"] = {}
dict_all_maer["maer_selected_iz"]["srda"] = {}
dict_all_maer["maer_time_avg"]["srda"] = {}

for m_model in dict_srda_uhr_fcsts.keys():
    #
    assert all_gt.shape == dict_srda_uhr_fcsts[m_model].shape

    #
    maer = calc_maer(all_gt=all_gt, all_fcst=dict_srda_uhr_fcsts[m_model])
    maer_selected_iz = calc_maer_averaging_over_selected_iz(
        all_gt=all_gt,
        all_fcst=dict_srda_uhr_fcsts[m_model],
        selected_iz=selected_iz,
    )
    maer_time_avg = torch.mean(maer, dim=0)

    dict_all_maer["maer"]["srda"][m_model] = maer
    dict_all_maer["maer_selected_iz"]["srda"][m_model] = maer_selected_iz
    dict_all_maer["maer_time_avg"]["srda"][m_model] = maer_time_avg

dict_all_maer["maer"]["srda"].keys(), dict_all_maer["maer"]["srda"][
    m_model
].shape, dict_all_maer["maer_time_avg"]["srda"][m_model]

In [None]:
dict_all_maer["maer"]["srda_lr"] = {}
dict_all_maer["maer_selected_iz"]["srda_lr"] = {}
dict_all_maer["maer_time_avg"]["srda_lr"] = {}

for m_model in dict_srda_lr_uhr.keys():
    #
    assert all_gt.shape == dict_srda_lr_uhr[m_model].shape

    #
    maer = calc_maer(all_gt=all_gt, all_fcst=dict_srda_lr_uhr[m_model])
    maer_selected_iz = calc_maer_averaging_over_selected_iz(
        all_gt=all_gt,
        all_fcst=dict_srda_lr_uhr[m_model],
        selected_iz=selected_iz,
    )
    maer_time_avg = torch.mean(maer, dim=0)

    dict_all_maer["maer"]["srda_lr"][m_model] = maer
    dict_all_maer["maer_selected_iz"]["srda_lr"][m_model] = maer_selected_iz
    dict_all_maer["maer_time_avg"]["srda_lr"][m_model] = maer_time_avg

dict_all_maer["maer"]["srda_lr"].keys(), dict_all_maer["maer"]["srda_lr"][
    m_model
].shape, dict_all_maer["maer_time_avg"]["srda_lr"][m_model]

In [None]:
maer = calc_maer(all_gt=all_gt, all_fcst=letkf_uhr_fcsts)
maer_selected_iz = calc_maer_averaging_over_selected_iz(
    all_gt=all_gt,
    all_fcst=letkf_uhr_fcsts,
    selected_iz=selected_iz,
)
maer_time_avg = torch.mean(maer, dim=0)

dict_all_maer["maer"]["letkf"] = maer
dict_all_maer["maer_selected_iz"]["letkf"] = maer_selected_iz
dict_all_maer["maer_time_avg"]["letkf"] = maer_time_avg

#
dict_all_maer["maer"]["letkf"].shape, dict_all_maer["maer_time_avg"]["letkf"]

# Save mae ratio result

In [None]:
write_pickle(
    data=dict_all_maer, file_path=f"{RESULT_DIR}/all_maer_only_forecast_result.pkl"
)

# Calc mssim loss and save

## Normalization

In [None]:
pv_miv = cfg_srda.dataset_config.pv_min
pv_max = cfg_srda.dataset_config.pv_max

all_gt_norm = _preprocess(data=all_gt, pv_min=pv_miv, pv_max=pv_max, use_clipping=False)
assert all_gt_norm.min() >= 0 and all_gt_norm.max() <= 1

all_lr_fcst_norm = _preprocess(
    data=all_lr_fcst_uhr, pv_min=pv_miv, pv_max=pv_max, use_clipping=False
)
assert all_lr_fcst_norm.min() >= 0 and all_lr_fcst_norm.max() <= 1

dict_srda_uhr_fcsts_norm = {}

for m_model in dict_srda_uhr_fcsts.keys():
    dict_srda_uhr_fcsts_norm[m_model] = _preprocess(
        dict_srda_uhr_fcsts[m_model], pv_min=pv_miv, pv_max=pv_max, use_clipping=False
    )

    assert (
        dict_srda_uhr_fcsts_norm[m_model].min() >= 0
        and dict_srda_uhr_fcsts_norm[m_model].max() <= 1
    )

dict_srda_lr_uhr_norm = {}

for m_model in dict_srda_lr_uhr.keys():
    dict_srda_lr_uhr_norm[m_model] = _preprocess(
        dict_srda_lr_uhr[m_model], pv_min=pv_miv, pv_max=pv_max, use_clipping=False
    )

    assert (
        dict_srda_lr_uhr_norm[m_model].min() >= 0
        and dict_srda_lr_uhr_norm[m_model].max() <= 1
    )

letkf_uhr_fcsts_norm = _preprocess(
    data=letkf_uhr_fcsts,
    pv_min=pv_miv,
    pv_max=pv_max,
    use_clipping=False,
)
assert letkf_uhr_fcsts_norm.min() >= 0 and letkf_uhr_fcsts_norm.max() <= 1

## Uniform window

In [None]:
num_batch = len(all_gt_norm)

In [None]:
#
dict_all_mssim_loss_uniform_wsz5 = {
    "mssim_loss": {},
    "mssim_loss_selected_iz": {},
}

mssim_loss_params = {
    "window_3d_size": (5, 11, 11),
    "sigma_3d": (0.7, 1.5, 1.5),
    "value_magnitude": 1.0,
    "use_gaussian": False,
}

# lr
lr_mssim = []

for ib in range(num_batch):
    ssim = MSSIM(**mssim_loss_params)
    _r = calc_mssim(
        all_gt=all_gt_norm[ib].unsqueeze(0).to(DEVICE_GPU),
        all_fcst=all_lr_fcst_norm[ib].unsqueeze(0).to(DEVICE_GPU),
        mssim=ssim,
    )
    lr_mssim.append(_r.cpu())
    #
    del _r
    torch.cuda.empty_cache()

# mean over batch dim and calc loss
mssim_loss = 1.0 - torch.mean(torch.stack(lr_mssim), dim=0)
dict_all_mssim_loss_uniform_wsz5["mssim_loss"]["lr_fcst"] = mssim_loss
dict_all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz"]["lr_fcst"] = torch.mean(
    mssim_loss[:, selected_iz], dim=1
)


# srda
dict_all_mssim_loss_uniform_wsz5["mssim_loss"]["srda"] = {}
dict_all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz"]["srda"] = {}
for m_model in dict_srda_uhr_fcsts_norm.keys():
    #
    srda_mssim = []
    #
    for ib in range(num_batch):
        ssim = MSSIM(**mssim_loss_params)
        _r = calc_mssim(
            all_gt=all_gt_norm[ib].unsqueeze(0).to(DEVICE_GPU),
            all_fcst=dict_srda_uhr_fcsts_norm[m_model][ib].unsqueeze(0).to(DEVICE_GPU),
            mssim=ssim,
        )
        srda_mssim.append(_r.cpu())
        #
        del _r
        torch.cuda.empty_cache()

    # mean over batch dim and calc loss
    mssim_loss = 1.0 - torch.mean(torch.stack(srda_mssim), dim=0)
    dict_all_mssim_loss_uniform_wsz5["mssim_loss"]["srda"][m_model] = mssim_loss
    dict_all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz"]["srda"][
        m_model
    ] = torch.mean(mssim_loss[:, selected_iz], dim=1)

# srda_lr
dict_all_mssim_loss_uniform_wsz5["mssim_loss"]["srda_lr"] = {}
dict_all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz"]["srda_lr"] = {}
for m_model in dict_srda_lr_uhr_norm.keys():
    #
    srda_lr_mssim = []
    #
    for ib in range(num_batch):
        ssim = MSSIM(**mssim_loss_params)
        _r = calc_mssim(
            all_gt=all_gt_norm[ib].unsqueeze(0).to(DEVICE_GPU),
            all_fcst=dict_srda_lr_uhr_norm[m_model][ib].unsqueeze(0).to(DEVICE_GPU),
            mssim=ssim,
        )
        srda_lr_mssim.append(_r.cpu())
        #
        del _r
        torch.cuda.empty_cache()

    # mean over batch dim and calc loss
    mssim_loss = 1.0 - torch.mean(torch.stack(srda_lr_mssim), dim=0)
    dict_all_mssim_loss_uniform_wsz5["mssim_loss"]["srda_lr"][m_model] = mssim_loss
    dict_all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz"]["srda_lr"][
        m_model
    ] = torch.mean(mssim_loss[:, selected_iz], dim=1)

# letkf
letkf_mssim = []
for ib in range(num_batch):
    ssim = MSSIM(**mssim_loss_params)
    _r = calc_mssim(
        all_gt=all_gt_norm[ib].unsqueeze(0).to(DEVICE_GPU),
        all_fcst=letkf_uhr_fcsts_norm[ib].unsqueeze(0).to(DEVICE_GPU),
        mssim=ssim,
    )
    letkf_mssim.append(_r.cpu())
    #
    del _r
    torch.cuda.empty_cache()

# mean over batch dim and calc loss
mssim_loss = 1.0 - torch.mean(torch.stack(letkf_mssim), dim=0)
dict_all_mssim_loss_uniform_wsz5["mssim_loss"]["letkf"] = mssim_loss
dict_all_mssim_loss_uniform_wsz5["mssim_loss_selected_iz"]["letkf"] = torch.mean(
    mssim_loss[:, selected_iz], dim=1
)

# save
write_pickle(
    data=dict_all_mssim_loss_uniform_wsz5,
    file_path=f"{RESULT_DIR}/all_mssim_loss_uniform_wsz5_only_forecast_result.pkl",
)