In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import glob
import os
import pathlib
import time
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import yaml
from cfd_model.filter.low_pass_periodic_channel_domain import LowPassFilter
from cfd_model.initialization.periodic_channel_jet_initializer import (
    calc_init_omega,
    calc_init_perturbation_hr_omegas,
    calc_jet_forcing,
)
from cfd_model.interpolator.torch_interpolator import (
    interpolate,
    interpolate_time_series,
)
from IPython.display import display
from src.sr_da_helper_2 import (
    get_observation_with_noise,
    get_testdataset,
    initialize_and_itegrate_srda_cfd_model_for_forecast,
    initialize_models,
    make_invprocessed_sr_for_forecast,
    make_models,
    make_preprocessed_lr_for_forecast,
    make_preprocessed_obs_for_forecast,
    read_all_hr_omegas_with_combining_for_forecast,
)
from src.ssim import SSIM
from src.utils import set_seeds
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"
pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
ROOT_DIR

In [None]:
ASSIMILATION_PERIOD = 4
FORECAST_SPAN = 4
NUM_SIMULATIONS = 1

MIN_START_TIME_INDEX = -1
MAX_START_TIME_INDEX = 88
START_TIME_INDEX = 0
NUM_TIMES = MAX_START_TIME_INDEX + ASSIMILATION_PERIOD + FORECAST_SPAN

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

UHR_NX = 1024
UHR_NY = 513

Y0 = np.pi / 2.0
SIGMA = 0.4
U0 = 3.0
TAU0 = 0.3
PERTUB_NOISE = 0.0025

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
HR_COEFF_DIFFUSION = 1e-5
LR_COEFF_DIFFUSION = 5e-5

DT = LR_DT * LR_NT
T0 = START_TIME_INDEX * LR_DT * LR_NT

In [None]:
DEVICE = "cpu"

if not torch.cuda.is_available():
    raise Exception("No GPU. CPU is used.")

In [None]:
LR_CFD_CONFIG = {
    "nx": LR_NX,
    "ny": LR_NY,
    "coeff_linear_drag": COEFF_LINEAR_DRAG,
    "coeff_diffusion": LR_COEFF_DIFFUSION,
    "order_diffusion": ORDER_DIFFUSION,
    "beta": BETA,
    "device": DEVICE,
    "dt": LR_DT,
    "nt": LR_NT,
}

INDEX_CONFIG = {
    "assimilation_period": ASSIMILATION_PERIOD,
    "forecast_span": FORECAST_SPAN,
    "n_ens": 1,
    "lr_nx": LR_NX,
    "lr_ny": LR_NY,
    "hr_nx": HR_NX,
    "hr_ny": HR_NY,
    "device": DEVICE,
}

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

# Define methods

In [None]:
def get_initial_hr_omega(seed: int):
    hr_jet, _ = calc_jet_forcing(
        nx=HR_NX,
        ny=HR_NY,
        ne=NUM_SIMULATIONS,
        y0=Y0,
        sigma=SIGMA,
        tau0=TAU0,
    )

    hr_perturb = calc_init_perturbation_hr_omegas(
        nx=HR_NX, ny=HR_NY, ne=NUM_SIMULATIONS, noise_amp=PERTUB_NOISE, seed=seed
    )

    hr_omega0 = calc_init_omega(
        perturb_omega=hr_perturb,
        jet=hr_jet,
        u0=U0,
    )

    return hr_omega0

In [None]:
def get_uhr_and_hr_omegas(uhr_result_dir: str):
    all_uhr_omegas = []
    for path in sorted(glob.glob(f"{uhr_result_dir}/*.npy")):
        uhr = torch.from_numpy(np.load(path)).squeeze()
        assert uhr.shape == (UHR_NX, UHR_NY)
        all_uhr_omegas.append(uhr)
    # Stack along time dim
    all_uhr_omegas = torch.stack(all_uhr_omegas)[:NUM_TIMES]
    assert all_uhr_omegas.shape == (NUM_TIMES, UHR_NX, UHR_NY)

    tmp = all_uhr_omegas[:, None, :, 1:]
    _omegas = F.avg_pool2d(tmp, kernel_size=8).squeeze()

    all_hr_omegas = torch.zeros((NUM_TIMES, HR_NX, HR_NY), dtype=_omegas.dtype)
    all_hr_omegas[:, :, 1:] = _omegas

    return all_uhr_omegas, all_hr_omegas

# Perform SR-DA

In [None]:
GRID_INTERVAL = 8
I_SEED_UHR = 9999

In [None]:
CONFIG_DIR = f"{ROOT_DIR}/pytorch/config/paper_experiment_06"

In [None]:
UHR_RESULT_DIR = f"{ROOT_DIR}/data/pytorch/CFD/jet27/seed{I_SEED_UHR:05}"

# 221958, 771155, 832180, 465838, 359178
CONFIG_NAME = (
    f"lt4og{GRID_INTERVAL:02}_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd771155"
)
CONFIG_PATH = f"{CONFIG_DIR}/{CONFIG_NAME}.yml"

with open(CONFIG_PATH) as file:
    CONFIG = yaml.safe_load(file)

experiment_name = CONFIG_PATH.split("/")[-2]
_dir = f"{ROOT_DIR}/data/pytorch/DL_results/{experiment_name}/{CONFIG_NAME}"

CONFIG_INFO = {
    "config": CONFIG,
    "model_name": CONFIG["model"]["model_name"],
    "experiment_name": experiment_name,
    "weight_path": f"{_dir}/weights.pth",
    "learning_history_path": f"{_dir}/learning_history.csv",
    "log": f"{_dir}/log.txt",
}

In [None]:
test_dataset = get_testdataset(
    ROOT_DIR,
    CONFIG,
    min_start_time_index=MIN_START_TIME_INDEX,
    max_start_time_index=MAX_START_TIME_INDEX + 1,
)

assert test_dataset.obs_time_interval == ASSIMILATION_PERIOD
assert test_dataset.is_output_only_last == False
assert test_dataset.is_last_obs_missing == True

In [None]:
init_hr_omega = get_initial_hr_omega(seed=I_SEED_UHR + GRID_INTERVAL * 100)
assert init_hr_omega.shape == (1, HR_NX, HR_NY)

_, lr_forcing = calc_jet_forcing(
    nx=LR_NX,
    ny=LR_NY,
    ne=1,
    y0=Y0,
    sigma=SIGMA,
    tau0=TAU0,
)
assert lr_forcing.shape == (1, LR_NX, LR_NY)

low_pass_filter = LowPassFilter(
    nx_lr=LR_NX, ny_lr=LR_NY, nx_hr=HR_NX, ny_hr=HR_NY, device=DEVICE
)

uhr_omegas, hr_omegas = get_uhr_and_hr_omegas(UHR_RESULT_DIR)

hr_obsrvs = get_observation_with_noise(
    hr_omegas[None, ...],  # add ens channel (dummy channel)
    test_dataset,
    **INDEX_CONFIG,
).squeeze()

assert uhr_omegas.shape == (NUM_TIMES, UHR_NX, UHR_NY)
assert hr_omegas.shape == hr_obsrvs.shape == (NUM_TIMES, HR_NX, HR_NY)

In [None]:
for _ in range(5):
    set_seeds(555, use_deterministic=True)

    sr_model, _, _ = make_models(CONFIG, CONFIG_INFO["weight_path"], LR_CFD_CONFIG)
    _ = sr_model.eval()

    last_t0 = T0
    last_hr_omega0 = init_hr_omega

    hr_obs, sr_forecast = [], []

    start_time = time.time()

    for i_cycle in tqdm(range(NUM_TIMES)):

        if i_cycle % ASSIMILATION_PERIOD == 0:
            o = hr_obsrvs[i_cycle]
            hr_obs.append(o[None, ...])  # add channel dim
        else:
            o = hr_obsrvs[i_cycle]
            hr_obs.append(torch.full_like(o[None, ...], torch.nan))

        if i_cycle > 0 and i_cycle % ASSIMILATION_PERIOD == 0:
            lr_forecast = []
            initialize_and_itegrate_srda_cfd_model_for_forecast(
                lr_forecast=lr_forecast,
                num_integrate_steps=ASSIMILATION_PERIOD + FORECAST_SPAN,
                last_t0=last_t0,
                last_hr_omega0=last_hr_omega0,
                lr_ens_forcing=lr_forcing,
                cfd_config=LR_CFD_CONFIG,
                low_pass_filter=low_pass_filter,
            )
            assert len(lr_forecast) == ASSIMILATION_PERIOD + FORECAST_SPAN + 1

            x = make_preprocessed_lr_for_forecast(
                lr_forecast,
                test_dataset,
                **INDEX_CONFIG,
            )
            o = make_preprocessed_obs_for_forecast(
                hr_obs,
                test_dataset,
                **INDEX_CONFIG,
            )
            print(x.shape, o.shape)

            # Check num of time dims
            _sum = ASSIMILATION_PERIOD + FORECAST_SPAN
            _nt = int(_sum / CONFIG["data"]["lr_time_interval"] + 1)
            assert x.shape[1] == _nt
            assert o.shape[1] == _sum + 1

            with torch.no_grad():
                sr = sr_model(x, o).detach().cpu().clone()
            sr = make_invprocessed_sr_for_forecast(
                sr,
                test_dataset,
                **INDEX_CONFIG,
            )

            last_hr_omega0 = sr[ASSIMILATION_PERIOD + 1].clone()
            last_t0 += ASSIMILATION_PERIOD * LR_DT * LR_NT

            # The indices between 0 to ASSIMILATION_PERIOD are past
            # So NaN values are substituted for the forecast.
            if len(sr_forecast) == 0:
                dummy = torch.full(
                    size=(ASSIMILATION_PERIOD,) + sr.shape[1:],
                    fill_value=torch.nan,
                    dtype=sr.dtype,
                )
                sr_forecast += dummy

            i_start = ASSIMILATION_PERIOD
            i_end = ASSIMILATION_PERIOD + FORECAST_SPAN
            sr_forecast += sr[i_start:i_end]

            logger.debug(f"Assimilation at i = {i_cycle}")

    # Stack along time dim
    hr_obs = torch.stack(hr_obs, dim=1).squeeze()
    sr_forecast = torch.stack(sr_forecast, dim=1).squeeze()

    assert (
        hr_obs.shape
        == sr_forecast.shape
        == hr_omegas.shape
        == (NUM_TIMES, HR_NX, HR_NY)
    )

    logger.info(f"Wall time = {time.time() - start_time} sec")

In [None]:
# GPU

(
    53.415621280670166
    + 51.91998839378357
    + 51.53233218193054
    + 50.935877084732056
    + 50.46408128738403
) / 5

In [None]:
# CPU

(
    41.11348748207092
    + 40.52914905548096
    + 40.035480260849
    + 40.37577414512634
    + 40.252901554107666
) / 5

In [None]:
# oni01, CPU
(
    57.561357259750366
    + 56.41942834854126
    + 55.99148988723755
    + 58.471670150756836
    + 57.50953793525696
) / 5

In [None]:
import torchinfo

In [None]:
x.shape, o.shape

In [None]:
torchinfo.summary(sr_model, input_size=[(1, 3, 1, 16, 32), (1, 9, 1, 64, 128)])

In [None]:
sr_model