In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
log_handler = StreamHandler(sys.stdout)
logger.addHandler(log_handler)
logger.setLevel(INFO)

# Import libraries

In [None]:
import matplotlib.pyplot as plt
import torch
import yaml
from IPython.display import HTML, display

ROOT_DIR = "/workspace"

sys.path.append(f"{ROOT_DIR}/pytorch/src")
from velocity_dataloader import make_velocity_dataloaders_for_decaying_turbulence

# Plot dataset

## Subsample

In [None]:
CONFIG_PATH = f"{ROOT_DIR}/pytorch/config/decaying_turbulence/subsample/dscms_scale09_V.yml"
with open(CONFIG_PATH) as file:
    CONFIG = yaml.safe_load(file)

dict_dataloaders = make_velocity_dataloaders_for_decaying_turbulence(
    data_dir=f"{ROOT_DIR}/data/pytorch/decaying_turbulence/DL_data/velocity",
    config=CONFIG,
    num_workers=2,
    seed=0,
)

In [None]:
NX, NY = 128, 128
for kind, loader in dict_dataloaders.items():
    display(HTML(f"<h3>{kind}</h3>"))
    for Xs, ys in loader:
        for j in range(2):
            for i, (X, y) in enumerate(zip(Xs, ys)):
                assert X.squeeze().shape == (2, NX, NY)
                assert y.squeeze().shape == (2, NX, NY)
                X, y = X[j, :, :], y[j, :, :]
                vmin, vmax = torch.min(y), torch.max(y)
                ax = plt.subplot(121)
                ax.imshow(X.numpy().squeeze(), vmin=vmin, vmax=vmax, interpolation=None)
                ax = plt.subplot(122)
                ax.imshow(y.numpy().squeeze(), vmin=vmin, vmax=vmax, interpolation=None)
                plt.show()
                if i + 1 >= 1:
                    break
        break

## Average

In [None]:
CONFIG_PATH = f"{ROOT_DIR}/pytorch/config/decaying_turbulence/average/dscms_scale09_V.yml"
with open(CONFIG_PATH) as file:
    CONFIG = yaml.safe_load(file)

dict_dataloaders = make_velocity_dataloaders_for_decaying_turbulence(
    data_dir=f"{ROOT_DIR}/data/pytorch/decaying_turbulence/DL_data/velocity",
    config=CONFIG,
    num_workers=2,
    seed=0,
)

In [None]:
NX, NY = 128, 128
for kind, loader in dict_dataloaders.items():
    display(HTML(f"<h3>{kind}</h3>"))
    for Xs, ys in loader:
        for j in range(2):
            for i, (X, y) in enumerate(zip(Xs, ys)):
                assert X.squeeze().shape == (2, NX, NY)
                assert y.squeeze().shape == (2, NX, NY)
                X, y = X[j, :, :], y[j, :, :]
                vmin, vmax = torch.min(y), torch.max(y)
                ax = plt.subplot(121)
                ax.imshow(X.numpy().squeeze(), vmin=vmin, vmax=vmax, interpolation=None)
                ax = plt.subplot(122)
                ax.imshow(y.numpy().squeeze(), vmin=vmin, vmax=vmax, interpolation=None)
                plt.show()
                if i + 1 >= 1:
                    break
        break