In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not any(["StreamHandler" in str(handler) for handler in logger.handlers]):
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import copy
import gc
import glob
import os
import pathlib
from collections import OrderedDict

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import yaml
from numpy.testing import assert_array_equal
from scipy import signal
from src.dataloader import make_evaluation_dataloader_without_random_cropping
from src.loss_maker import (
    AbsDiffDivergence,
    AbsDiffTemperature,
    ChannelwiseMse,
    DiffOmegaVectorNorm,
    DiffVelocityVectorNorm,
    MaskedL1Loss,
    MaskedL1LossNearWall,
    MaskedL2Loss,
    MaskedL2LossNearWall,
    MyL1Loss,
    MyL2Loss,
    ResidualContinuity,
    Ssim3dLoss,
    calc_mask_near_build_wall,
)
from src.model_maker import make_model
from src.optim_helper import evaluate
from src.utils import calc_early_stopping_patience, set_seeds
from tqdm.notebook import tqdm

pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 500)

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic
set_seeds(42, use_deterministic=True)

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
DL_DATA_DIR = pathlib.Path(f"{ROOT_DIR}/data/DL_data")
DL_INFERENCE_DIR = pathlib.Path(f"{ROOT_DIR}/data/DL_inferences")
EXPERIMENT_NAME = "unet_model"

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
    logger.info("GPU is used.")
else:
    logger.error("No GPU. CPU is used.")
    # raise Exception("No GPU. CPU is used.")

In [None]:
CONFIG_PATHS = sorted(glob.glob(f"{ROOT_DIR}/pytorch/config/*.yml"))

In [None]:
CONFIGS = OrderedDict()

for config_path in CONFIG_PATHS:
    if "tutorial" in config_path:
        continue
    with open(config_path) as file:
        config = yaml.safe_load(file)

    config_name = os.path.basename(config_path).split(".")[0]
    assert config_name not in CONFIGS

    _dir = f"{ROOT_DIR}/data/DL_results/{EXPERIMENT_NAME}/{config_name}"

    CONFIGS[config_name] = {
        "config": config,
        "model_name": config["model"]["model_name"],
        "experiment_name": EXPERIMENT_NAME,
        "weight_path": f"{_dir}/weights.pth",
        "learning_history_path": f"{_dir}/learning_history.csv",
    }

# Define methods

In [None]:
def dimensionalize(data, means, scales):
    assert data.ndim == means.ndim == scales.ndim
    assert data.shape[1] == 4
    return data * scales + means

# Make inference

In [None]:
for config_name, config_info in tqdm(CONFIGS.items(), total=len(CONFIGS)):

    inference_dir = DL_INFERENCE_DIR / config_info["experiment_name"] / config_name
    os.makedirs(inference_dir, exist_ok=True)

    logger.info(f"\n{config_name} is being evaluated.")

    config = config_info["config"]

    test_loader = make_evaluation_dataloader_without_random_cropping(
        config, DL_DATA_DIR, batch_size=1
    )

    model = make_model(config).to(DEVICE)
    model.load_state_dict(torch.load(config_info["weight_path"], map_location=DEVICE))
    _ = model.eval()

    means = torch.Tensor(config["data"]["means"])[None, :, None, None, None]
    scales = torch.Tensor(config["data"]["stds"])[None, :, None, None, None]
    test_file_paths = test_loader.dataset.lr_files

    assert len(test_file_paths) == len(test_loader)

    for path, (Xs, bs, ys) in tqdm(
        zip(test_file_paths, test_loader), total=len(test_loader)
    ):
        out_file_name = os.path.basename(path).replace("LR", "SR")
        out_file_path = str(inference_dir / out_file_name)

        if os.path.exists(out_file_path):
            continue

        bs = bs.unsqueeze(1)  # add channel dim
        assert Xs.shape[1:] == (4, 8, 80, 80)
        assert bs.shape[1:] == (1, 32, 320, 320)
        assert ys.shape[1:] == (4, 32, 320, 320)

        with torch.no_grad():
            preds = model(Xs.to(DEVICE), bs.to(DEVICE)).cpu()
            preds = dimensionalize(preds, means, scales)

        if not os.path.exists(out_file_path):
            np.save(out_file_path, preds.numpy())

    del model, test_loader, test_file_paths
    _ = gc.collect()