In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import os
import pathlib
from collections import OrderedDict

import numpy as np
import pandas as pd
import torch
import yaml
from cfd_model.filter.low_pass_periodic_channel_domain import LowPassFilter
from cfd_model.initialization.periodic_channel_jet_initializer import calc_jet_forcing
from ml_model.conv2d_cvae import Conv2dCvae
from ml_model.conv2d_sr_net import ConvSrNet
from src.sr_da_helper_2 import (
    get_observation_with_noise,
    get_testdataset,
    initialize_models,
    make_cfd_models,
    make_invprocessed_sr,
    make_preprocessed_lr,
    make_preprocessed_obs,
    read_all_hr_omegas_with_combining,
)
from src.utils import set_seeds
from tqdm.notebook import tqdm

In [None]:
pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())

# Define constants

In [None]:
DEVICE = "cuda:0"

In [None]:
CONFIG_NAME = "default_neural_nets"

In [None]:
with open(f"{ROOT_DIR}/pytorch/config/{CONFIG_NAME}.yml") as file:
    config = yaml.safe_load(file)

_dir = f"{ROOT_DIR}/data/ModelWeights/{CONFIG_NAME}"

RESULT_AND_CONFIG = {
    "config": config,
    "prior_weight_path": f"{_dir}/prior_model_weight.pth",
    "prior_learning_history_path": f"{_dir}/prior_model_loss.csv",
    "cvae_weight_path": f"{_dir}/cvae_weight.pth",
    "cvae_learning_history_path": f"{_dir}/cvae_loss.csv",
    "log_path": f"{_dir}/log.txt",
}

In [None]:
SRDA_DATA_DIR = f"{ROOT_DIR}/data/SRDA"
os.makedirs(SRDA_DATA_DIR, exist_ok=True)

In [None]:
ASSIMILATION_PERIOD = 4
START_TIME_INDEX = 16

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

Y0_MEAN = np.pi / 2.0
SIGMA_MEAN = 0.4
TAU0_MEAN = 0.3

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
HR_COEFF_DIFFUSION = 1e-5
LR_COEFF_DIFFUSION = 5e-5

DT = LR_DT * LR_NT
T0 = START_TIME_INDEX * LR_DT * LR_NT

N_ENS_PER_CHUNK = 100

In [None]:
LR_CFD_CONFIG = {
    "nx": LR_NX,
    "ny": LR_NY,
    "coeff_linear_drag": COEFF_LINEAR_DRAG,
    "coeff_diffusion": LR_COEFF_DIFFUSION,
    "order_diffusion": ORDER_DIFFUSION,
    "beta": BETA,
    "device": DEVICE,
}

INDEX_CONFIG = {
    "assimilation_period": ASSIMILATION_PERIOD,
    "n_ens": N_ENS_PER_CHUNK,
    "lr_nx": LR_NX,
    "lr_ny": LR_NY,
    "hr_nx": HR_NX,
    "hr_ny": HR_NY,
    "device": DEVICE,
}

low_pass_filter = LowPassFilter(
    nx_lr=LR_NX, ny_lr=LR_NY, nx_hr=HR_NX, ny_hr=HR_NY, device="cpu"
)

# Define methods

In [None]:
def get_all_result_paths(config_name: str = CONFIG_NAME):
    sr_prior_file_path = f"{SRDA_DATA_DIR}/sr_prior_{config_name}.npy"
    sr_analysis_file_path = f"{SRDA_DATA_DIR}/sr_analysis_{config_name}.npy"
    lr_omega_file_path = f"{SRDA_DATA_DIR}/lr_omega_{config_name}.npy"
    hr_omega_file_path = f"{SRDA_DATA_DIR}/hr_omega_{config_name}.npy"
    hr_obsrv_file_path = f"{SRDA_DATA_DIR}/hr_obsrv_{config_name}.npy"

    return (
        sr_prior_file_path,
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    )


def read_all_result_files(config_name: str = CONFIG_NAME):
    (
        sr_prior_file_path,
        sr_analysis_file_path,
        lr_omega_file_path,
        hr_omega_file_path,
        hr_obsrv_file_path,
    ) = get_all_result_paths(config_name)

    return (
        np.load(sr_prior_file_path),
        np.load(sr_analysis_file_path),
        np.load(lr_omega_file_path),
        np.load(hr_omega_file_path),
        np.load(hr_obsrv_file_path),
    )

# Perform SR-DA

In [None]:
set_seeds(42, use_deterministic=True)

(
    sr_prior_file_path,
    sr_analysis_file_path,
    lr_omega_file_path,
    hr_omega_file_path,
    hr_obsrv_file_path,
) = get_all_result_paths(CONFIG_NAME)

if (
    os.path.exists(sr_prior_file_path)
    or os.path.exists(sr_analysis_file_path)
    or os.path.exists(lr_omega_file_path)
    or os.path.exists(hr_omega_file_path)
    or os.path.exists(hr_obsrv_file_path)
):
    raise Exception("Results already exist.")

In [None]:
test_dataset = get_testdataset(ROOT_DIR, RESULT_AND_CONFIG["config"])
assert test_dataset.obs_time_interval == ASSIMILATION_PERIOD

In [None]:
all_hr_omegas = read_all_hr_omegas_with_combining(test_dataset.hr_file_paths)
assert all_hr_omegas.shape == (500, 81, HR_NX, HR_NY)  # (500, 81, HR_NX, HR_NY)
assert all_hr_omegas.shape[0] % N_ENS_PER_CHUNK == 0

In [None]:
prior = ConvSrNet(**RESULT_AND_CONFIG["config"]["model"]["prior_model"]).to(DEVICE)
prior.load_state_dict(
    torch.load(RESULT_AND_CONFIG["prior_weight_path"], map_location=DEVICE)
)
_ = prior.eval()

cvae = Conv2dCvae(**RESULT_AND_CONFIG["config"]["model"]["vae_model"]).to(DEVICE)
cvae.load_state_dict(
    torch.load(RESULT_AND_CONFIG["cvae_weight_path"], map_location=DEVICE)
)
_ = cvae.eval()

In [None]:
_, lr_forcing = calc_jet_forcing(
    nx=LR_NX,
    ny=LR_NY,
    ne=1,
    y0=Y0_MEAN,
    sigma=SIGMA_MEAN,
    tau0=TAU0_MEAN,
)

all_hr_obsrv, all_lr_omega = [], []
all_sr_analysis, all_sr_prior = [], []
all_lr_prior = []

In [None]:
for hr_omegas in tqdm(
    torch.split(all_hr_omegas, N_ENS_PER_CHUNK),
    total=(all_hr_omegas.shape[0] // N_ENS_PER_CHUNK),
):
    lr_model, srda_model = make_cfd_models(cfd_config=LR_CFD_CONFIG)

    initialize_models(
        t0=T0,
        hr_omega0=hr_omegas[:, 0],  # first time index
        lr_forcing=lr_forcing,
        lr_model=lr_model,
        srda_model=srda_model,
        **INDEX_CONFIG,
    )

    hr_obsrv = get_observation_with_noise(hr_omegas, test_dataset, **INDEX_CONFIG)
    assert hr_obsrv.shape == hr_omegas.shape == (N_ENS_PER_CHUNK, 81, HR_NX, HR_NY)

    ts, hr_obs, lr_omega, lr_forecast = [], [], [], []
    sr_analysis, sr_prior = [], []
    lr_prior = []
    last_omega0 = None
    max_time_index = hr_omegas.shape[1]

    for i_cycle in tqdm(range(max_time_index)):
        logger.debug(f"Start: i_cycle = {i_cycle}, t = {lr_model.t:.2f}")

        ts.append(lr_model.t)
        lr_omega.append(lr_model.omega.cpu().clone())
        lr_forecast.append(srda_model.omega.cpu().clone())

        if (
            RESULT_AND_CONFIG["config"]["data"]["use_observation"]
            and i_cycle % ASSIMILATION_PERIOD == 0
        ):
            hr_obs.append(hr_obsrv[:, i_cycle])
        else:
            hr_obs.append(torch.full_like(hr_obsrv[:, i_cycle], torch.nan))

        lr_prior.append(srda_model.omega.cpu().clone())

        # Calc prior HR vorticity
        x = srda_model.omega.clone()
        x = x[:, None, :, :-1].to(torch.float32)
        x = (x - test_dataset.vorticity_bias) / test_dataset.vorticity_scale
        x = torch.clamp(x, min=test_dataset.clamp_min, max=test_dataset.clamp_max)
        x = x.permute(0, 1, 3, 2).contiguous()
        assert x.shape == (N_ENS_PER_CHUNK, 1, 16, 32)

        with torch.no_grad():
            pred = prior(x.to(DEVICE)).detach().cpu().clone()

        assert pred.shape == (N_ENS_PER_CHUNK, 1, 64, 128)
        pred = pred.squeeze().permute(0, 2, 1)
        pred = pred * test_dataset.vorticity_scale
        pred = pred + test_dataset.vorticity_bias
        p = torch.zeros((N_ENS_PER_CHUNK, 128, 65), dtype=torch.float32)
        p[..., :-1] = pred

        sr_prior.append(p)

        if i_cycle > 0 and i_cycle % ASSIMILATION_PERIOD == 0:
            x = make_preprocessed_lr(
                lr_forecast,
                last_omega0,
                test_dataset,
                **INDEX_CONFIG,
            )

            o = make_preprocessed_obs(
                hr_obs,
                test_dataset,
                **INDEX_CONFIG,
            )

            # last time step
            x = x[:, -1]
            o = o[:, -1]

            with torch.no_grad():
                _, sr, _ = cvae(x, o)  # return mu
                sr = sr.detach().cpu().clone()

            assert sr.shape == (N_ENS_PER_CHUNK, 1, 64, 128)

            sr = torch.broadcast_to(
                sr[:, None, ...],
                (N_ENS_PER_CHUNK, ASSIMILATION_PERIOD + 1, 1, 64, 128),
            )
            sr = make_invprocessed_sr(
                sr,
                test_dataset,
                **INDEX_CONFIG,
            )

            assert sr[-1].shape == (N_ENS_PER_CHUNK, 128, 65)
            last_omega0 = low_pass_filter.apply(sr[-1])

            srda_model.initialize(t0=srda_model.t, omega0=last_omega0)
            srda_model.calc_grid_data()

            i_start = 0 if len(sr_analysis) == 0 else 1
            for it in range(i_start, sr.shape[0]):
                sr_analysis.append(sr[it])

            logger.debug(f"Assimilation at i = {i_cycle}")

        lr_model.time_integrate(dt=LR_DT, nt=LR_NT, hide_progress_bar=True)
        lr_model.calc_grid_data()

        srda_model.time_integrate(dt=LR_DT, nt=LR_NT, hide_progress_bar=True)
        srda_model.calc_grid_data()

    # Stack along time dim
    all_hr_obsrv.append(torch.stack(hr_obs, dim=1))
    all_lr_omega.append(torch.stack(lr_omega, dim=1))
    all_sr_analysis.append(torch.stack(sr_analysis, dim=1))
    all_sr_prior.append(torch.stack(sr_prior, dim=1))
    all_lr_prior.append(torch.stack(lr_prior, dim=1))

    del lr_model, srda_model
    torch.cuda.empty_cache()
    _ = gc.collect()

In [None]:
# Stack along ensemble (batch) dim
all_hr_obsrv = torch.cat(all_hr_obsrv, dim=0).to(torch.float32).numpy()
all_lr_omega = torch.cat(all_lr_omega, dim=0).to(torch.float32).numpy()
all_sr_analysis = torch.cat(all_sr_analysis, dim=0).to(torch.float32).numpy()
all_hr_omegas = all_hr_omegas.to(torch.float32).numpy()
all_sr_prior = torch.cat(all_sr_prior, dim=0).to(torch.float32).numpy()
all_lr_prior = torch.cat(all_lr_prior, dim=0).to(torch.float32).numpy()

assert (
    all_hr_obsrv.shape
    == all_sr_prior.shape
    == all_sr_analysis.shape
    == all_hr_omegas.shape
    == (500, 81, HR_NX, HR_NY)
)

assert all_lr_omega.shape == all_lr_prior.shape == (500, 81, LR_NX, LR_NY)

np.save(sr_prior_file_path, all_sr_prior)
np.save(hr_obsrv_file_path, all_hr_obsrv)
np.save(lr_omega_file_path, all_lr_omega)
np.save(sr_analysis_file_path, all_sr_analysis)
np.save(hr_omega_file_path, all_hr_omegas)

lr_prior_file_path = f"{SRDA_DATA_DIR}/lr_prior_{CONFIG_NAME}.npy"
np.save(lr_prior_file_path, all_lr_prior)