In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, StreamHandler, getLogger

logger = getLogger()
log_handler = StreamHandler(sys.stdout)
logger.addHandler(log_handler)
logger.setLevel(DEBUG)

# Import libraries

In [None]:
ROOT_DIR = "/workspace"

import matplotlib.pyplot as plt
import numpy as np
import yaml
from IPython.display import HTML, display

sys.path.append(f"{ROOT_DIR}/pytorch/src")

from utils import set_seeds
from velocity_dataloader import (
    make_velocity_dataloaders_for_barotropic_instability_spectral_nudging,
)
from vortex_dataloader import make_vortex_dataloaders_for_barotropic_instability_spectral_nudging

# Define constants

In [None]:
VORTICITY_DATA_DIR = (
    f"{ROOT_DIR}/data/pytorch/barotropic_instability_spectral_nudging/DL_data/vorticity"
)
VELOCITY_DATA_DIR = (
    f"{ROOT_DIR}/data/pytorch/barotropic_instability_spectral_nudging/DL_data/velocity"
)
CONFIG_DIR_NUDGE = (
    f"{ROOT_DIR}/pytorch/config/barotropic_instability_spectral_nudging/spectral_nudging"
)
CONFIG_DIR_AVERAGE = f"{ROOT_DIR}/pytorch/config/barotropic_instability_spectral_nudging/average"

# Deinfe methods

In [None]:
def read_yaml(path: str) -> dict:
    with open(path) as file:
        return yaml.safe_load(file)


def check_vortex_data(dataloader):
    set_seeds()
    for j, (Xs, ys) in enumerate(dataloader):
        assert Xs.shape == ys.shape
        for i, (X, y) in enumerate(zip(Xs, ys)):
            print(X.shape, y.shape)
            assert X.shape == y.shape
            X, y = X.numpy().squeeze(), y.numpy().squeeze()
            ax = plt.subplot(121)
            ax.imshow(X.transpose())
            ax = plt.subplot(122)
            ax.imshow(y.transpose())
            plt.show()
            if i >= 0:
                break
        if j + 1 >= 2:
            break


def check_velocity_data(dataloader):
    set_seeds()
    for j, (Xs, ys) in enumerate(dataloader):
        assert Xs.shape == ys.shape
        for i, (X, y) in enumerate(zip(Xs, ys)):
            print(X.shape, y.shape)
            assert X.shape == y.shape
            X, y = X.numpy(), y.numpy()
            ax = plt.subplot(221)
            ax.imshow(X[0, :, :].transpose())
            ax = plt.subplot(222)
            ax.imshow(y[0, :, :].transpose())
            ax = plt.subplot(223)
            ax.imshow(X[1, :, :].transpose())
            ax = plt.subplot(224)
            ax.imshow(y[1, :, :].transpose())
            plt.show()
            if i >= 0:
                break
        if j + 1 >= 2:
            break

# Check vorticity

## Spectral nudging

In [None]:
config_name = "dscms_Z.yml"
config = read_yaml(f"{CONFIG_DIR_NUDGE}/{config_name}")

In [None]:
dataloaders = make_vortex_dataloaders_for_barotropic_instability_spectral_nudging(
    VORTICITY_DATA_DIR, config
)

In [None]:
logger.setLevel(INFO)
# logger.setLevel(DEBUG)
for data_kind in ["train", "valid", "test"]:
    display(HTML(f"<h3>{data_kind}</h3>"))
    check_vortex_data(dataloaders[data_kind])

## Average

In [None]:
config_name = "dscms_Z.yml"
config = read_yaml(f"{CONFIG_DIR_AVERAGE}/{config_name}")

In [None]:
dataloaders = make_vortex_dataloaders_for_barotropic_instability_spectral_nudging(
    VORTICITY_DATA_DIR, config
)

In [None]:
logger.setLevel(INFO)
# logger.setLevel(DEBUG)
for data_kind in ["train", "valid", "test"]:
    display(HTML(f"<h3>{data_kind}</h3>"))
    check_vortex_data(dataloaders[data_kind])

# Check velocity

## Spectral nudging

In [None]:
config_name = "dscms_V.yml"
config = read_yaml(f"{CONFIG_DIR_NUDGE}/{config_name}")

In [None]:
dataloaders = make_velocity_dataloaders_for_barotropic_instability_spectral_nudging(
    VELOCITY_DATA_DIR, config
)

In [None]:
# logger.setLevel(DEBUG)
logger.setLevel(INFO)
for data_kind in ["train", "valid", "test"]:
    display(HTML(f"<h3>{data_kind}</h3>"))
    check_velocity_data(dataloaders[data_kind])

## Average

In [None]:
config_name = "dscms_V.yml"
config = read_yaml(f"{CONFIG_DIR_AVERAGE}/{config_name}")

In [None]:
dataloaders = make_velocity_dataloaders_for_barotropic_instability_spectral_nudging(
    VELOCITY_DATA_DIR, config
)

In [None]:
# logger.setLevel(DEBUG)
logger.setLevel(INFO)
for data_kind in ["train", "valid", "test"]:
    display(HTML(f"<h3>{data_kind}</h3>"))
    check_velocity_data(dataloaders[data_kind])