In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not any(["StreamHandler" in str(handler) for handler in logger.handlers]):
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import glob
import os
import pathlib
import random
from collections import OrderedDict

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml
from cfd_model.interpolator.torch_interpolator import (
    interpolate,
    interpolate_time_series,
)
from IPython.display import display
from src.dataloader import (
    make_dataloaders_vorticity_making_observation_inside_time_series_splitted,
)
from src.model_maker import make_model
from src.ssim import SSIM
from src.utils import AverageMeter, set_seeds
from tqdm.notebook import tqdm

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
ROOT_DIR

In [None]:
TMP_DATA_DIR = "./data"
os.makedirs(TMP_DATA_DIR, exist_ok=True)

In [None]:
CSV_DATA_DIR = "./csv"
os.makedirs(CSV_DATA_DIR, exist_ok=True)

In [None]:
CONFIG_DIR = f"{ROOT_DIR}/pytorch/config/paper_experiment_06"
CONFIG_PATHS = sorted(glob.glob(f"{CONFIG_DIR}/*.yml"))

In [None]:
ASSIMILATION_PERIOD = 4
FORECAST_SPAN = 4
START_TIME_INDEX = 16
MAX_START_TIME_INDEX = 92

LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = 128
HR_NY = 65

Y0_MEAN = np.pi / 2.0
SIGMA_MEAN = 0.4
TAU0_MEAN = 0.3

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
HR_COEFF_DIFFUSION = 1e-5
LR_COEFF_DIFFUSION = 5e-5

DT = LR_DT * LR_NT
T0 = START_TIME_INDEX * LR_DT * LR_NT

N_ENS_PER_CHUNK = 125

In [None]:
DEVICE = "cuda:1"
if not torch.cuda.is_available():
    raise Exception("No GPU. CPU is used.")

In [None]:
CONFIGS = OrderedDict()

for num, config_path in enumerate(sorted(CONFIG_PATHS)):
    with open(config_path) as file:
        config = yaml.safe_load(file)

    config_name = os.path.basename(config_path).split(".")[0]
    assert config_name not in CONFIGS

    experiment_name = config_path.split("/")[-2]
    _dir = f"{ROOT_DIR}/data/pytorch/DL_results/{experiment_name}/{config_name}"

    CONFIGS[config_name] = {
        "config": config,
        "model_name": config["model"]["model_name"],
        "experiment_name": experiment_name,
        "weight_path": f"{_dir}/weights.pth",
        "learning_history_path": f"{_dir}/learning_history.csv",
        "number": num,
    }

In [None]:
OBS_GRID_RATIO = {
    0: 0.0,
    4: 0.06250000093132257,
    5: 0.03999999910593033,
    6: 0.027777777363856632,
    7: 0.02040816326530612,
    8: 0.015625000116415322,
    9: 0.012345679127323775,
    10: 0.010000000149011612,
    11: 0.008264463206306716,
    12: 0.006944444625534945,
    13: 0.005917159876284691,
    14: 0.005102040977882487,
    15: 0.004444444572759999,
    16: 0.003906250014551915,
}

# Define methods

In [None]:
def plot(
    dict_data: dict,
    t: float,
    obs: np.ndarray,
    figsize: list = [20, 2],
    write_out: bool = False,
    ttl_header: str = "",
    fig_file_name: str = "",
    use_hr_space: bool = True,
    vmin_omega: float = -9,
    vmax_omega: float = 9,
):

    xs = np.linspace(0, 2 * np.pi, num=HR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=HR_NY - 1, endpoint=False)
    hr_x, hr_y = np.meshgrid(xs, ys, indexing="ij")

    xs = np.linspace(0, 2 * np.pi, num=LR_NX, endpoint=False)
    ys = np.linspace(0, np.pi, num=LR_NY - 1, endpoint=False)
    lr_x, lr_y = np.meshgrid(xs, ys, indexing="ij")

    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.size"] = 18
    fig, axes = plt.subplots(
        1, len(dict_data), figsize=figsize, sharex=True, sharey=False
    )

    gt = None
    for ax, (label, data) in zip(axes, dict_data.items()):
        if label == "LR":
            if use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]),
                    nx=HR_NX,
                    ny=HR_NY - 1,
                    mode="nearest",
                ).numpy()
        else:
            if not use_hr_space:
                data = interpolate(
                    torch.from_numpy(data[None, :]), nx=LR_NX, ny=LR_NY - 1
                ).numpy()

        if use_hr_space:
            x, y = hr_x, hr_y
        else:
            x, y = lr_x, lr_y

        d = np.squeeze(data)
        if label == "HR":
            gt = d
            ttl = "HR Ground Truth"
        else:
            mae = np.mean(np.abs(gt - d))
            ttl = label
            ttl = f"{label}\nMAE={mae:.2f}"

        if use_hr_space:
            assert d.shape == (HR_NX, HR_NY - 1)
        else:
            assert d.shape == (LR_NX, LR_NY - 1)

        cnts = ax.pcolormesh(
            x, y, d, cmap="twilight_shifted", vmin=vmin_omega, vmax=vmax_omega
        )
        ax.set_title(ttl)

        fig.colorbar(
            cnts,
            ax=ax,
            ticks=[vmin_omega, vmin_omega / 2, 0, vmax_omega / 2, vmax_omega],
            extend="both",
        )

        ax.set_xlim([0, 2 * np.pi])
        ax.set_ylim([0, np.pi])

        if label == "HR" and use_hr_space:
            obs = np.squeeze(obs).flatten()
            obs_x = x.flatten()[~np.isnan(obs)]
            obs_y = y.flatten()[~np.isnan(obs)]
            ax.scatter(obs_x, obs_y, marker=".", s=1, c="k")

        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)

    # plt.suptitle(f"{ttl_header}Time = {t}")
    plt.tight_layout()

    if write_out:
        fig.savefig(f"{FIG_DIR}/{fig_file_name}.jpg")

    plt.show()

# Examine observation features

In [None]:
config_name = "lt4og08_on1e-01_ep1000_lr1e-04_scT_bT_muT_a02_b02_sd359178"

In [None]:
model = make_model(CONFIGS[config_name]["config"]).to(DEVICE)
model.load_state_dict(
    torch.load(CONFIGS[config_name]["weight_path"], map_location=DEVICE)
)
_ = model.eval()

In [None]:
(
    dict_dataloaders,
    _,
) = make_dataloaders_vorticity_making_observation_inside_time_series_splitted(
    root_dir=ROOT_DIR,
    config=CONFIGS[config_name]["config"],
    train_valid_test_kinds=["test"],
)

In [None]:
encoder_block = 2
all_obs, all_feat, all_latent, all_gt = [], [], [], []
for Xs, obs, gt in dict_dataloaders["test"]:
    feat, latent = model.get_obs_feature(
        Xs.to(DEVICE), obs.to(DEVICE), encoder_block=encoder_block
    )
    feat = feat.reshape(-1, 3, 64, 64, 128).detach().cpu()
    latent = latent.reshape((-1, 3) + latent.shape[-3:]).detach().cpu()

    all_obs.append(obs[:, 0].squeeze())
    all_feat.append(feat[:, 0].squeeze())
    all_latent.append(latent[:, 0].squeeze())
    all_gt.append(gt[:, 0].squeeze())

all_obs = torch.concat(all_obs, axis=0).numpy()
all_feat = torch.concat(all_feat, axis=0).numpy()
all_latent = torch.concat(all_latent, axis=0).numpy()
all_gt = torch.concat(all_gt, axis=0).numpy()

In [None]:
all_feat.shape, all_latent.shape

In [None]:
for j in range(64):
    for i in [6]:
        fig, axes = plt.subplots(1, 4, figsize=[20, 4])

        obs = all_obs[i]
        ax = axes[0]
        ax.pcolormesh(obs)
        ax.set_title("Observation value")

        feat = all_feat[i]
        vmin = np.quantile(feat.flatten(), 0.05)
        vmax = np.quantile(feat.flatten(), 0.95)
        ax = axes[1]
        ax.pcolormesh(feat[j], vmin=vmin, vmax=vmax)
        ax.set_title("Observation feature")

        latent = all_latent[i]
        vmin = np.quantile(latent.flatten(), 0.05)
        vmax = np.quantile(latent.flatten(), 0.95)
        ax = axes[2]
        ax.pcolormesh(latent[j], vmin=vmin, vmax=vmax)
        ax.set_title("Observation latent feature")

        gt = all_gt[i]
        ax = axes[3]
        ax.pcolormesh(gt)
        ax.set_title("Ground truth")

        plt.suptitle(f"i = {i}, j = {j}")
        plt.show()