In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import DEBUG, INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not any(["StreamHandler" in str(handler) for handler in logger.handlers]):
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import libraries

In [None]:
import gc
import os
import pathlib
import sys
import time
import traceback
from logging import INFO, WARNING, FileHandler, StreamHandler, getLogger
from typing import Callable

import numpy as np
import pandas as pd
import torch
from cfd_model.cfd.periodic_channel_domain import TorchSpectralModel2D
from cfd_model.initialization.periodic_channel_jet_initializer import (
    calc_init_omega,
    calc_init_perturbation_hr_omegas,
    calc_jet_forcing,
)
from cfd_model.interpolator.torch_interpolator import interpolate
from scipy.ndimage import sobel
from src.utils import set_seeds

if "ipykernel" in sys.modules:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

os.environ["CUBLAS_WORKSPACE_CONFIG"] = r":4096:8"  # to make calculations deterministic

# Define constants

In [None]:
DEVICE = "cpu"

In [None]:
LR_NX = 32
LR_NY = 17
LR_DT = 5e-4
LR_NT = 500

HR_NX = LR_NX
HR_NY = LR_NY
HR_DT = LR_DT
HR_NT = LR_NT

N_CYCLES = 96
ASSIM_PERIOD = 4
N_ENSEMBLES = 1

assert N_CYCLES % ASSIM_PERIOD == 0

Y0 = np.pi / 2.0
SIGMA = 0.4
U0 = 3.0
TAU0 = 0.3
PERTUB_NOISE = 0.0025

BETA = 0.1
COEFF_LINEAR_DRAG = 1e-2
ORDER_DIFFUSION = 2
HR_COEFF_DIFFUSION = 1e-5
LR_COEFF_DIFFUSION = 1e-5

ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())

DF_SEEDS = pd.read_csv(f"{ROOT_DIR}/pytorch/config/cfd_seeds/seeds01.csv").set_index(
    "SimulationNumber"
)

# Define methods

In [None]:
def make_and_initialize_hr_model(n_ensembles: int, seed: int, t0: float = 0.0):
    logger.setLevel(WARNING)

    hr_jet, hr_forcing = calc_jet_forcing(
        nx=HR_NX,
        ny=HR_NY,
        ne=n_ensembles,
        y0=Y0,
        sigma=SIGMA,
        tau0=TAU0,
    )

    hr_perturb = calc_init_perturbation_hr_omegas(
        nx=HR_NX, ny=HR_NY, ne=n_ensembles, noise_amp=PERTUB_NOISE, seed=seed
    )

    hr_omega0 = calc_init_omega(
        perturb_omega=hr_perturb,
        jet=hr_jet,
        u0=U0,
    )

    hr_model = TorchSpectralModel2D(
        nx=HR_NX,
        ny=HR_NY,
        coeff_linear_drag=COEFF_LINEAR_DRAG,
        coeff_diffusion=HR_COEFF_DIFFUSION,
        order_diffusion=ORDER_DIFFUSION,
        beta=BETA,
        device=DEVICE,
    )
    hr_model.initialize(t0=t0, omega0=hr_omega0, forcing=hr_forcing)
    hr_model.calc_grid_data()

    logger.setLevel(INFO)

    return hr_model

# Measure computation time

In [None]:
start_time = time.time()

for _ in range(5):
    seed = 42
    set_seeds(seed, use_deterministic=True)

    hr_model = make_and_initialize_hr_model(seed=seed, n_ensembles=N_ENSEMBLES)
    start_time = time.time()

    for i_cycle in tqdm(range(N_CYCLES), total=N_CYCLES):
        hr_model.time_integrate(dt=HR_DT, nt=HR_NT, hide_progress_bar=True)
        hr_model.calc_grid_data()

    end_time = time.time()
    logger.info(f"Total elapsed time = {end_time - start_time} sec")

    assert hr_model.omega.shape == (1, 32, 17)

    del hr_model
    gc.collect()
    torch.cuda.empty_cache()
    _ = gc.collect()

In [None]:
# CPU
(
    20.561907291412354
    + 19.555678844451904
    + 19.424622774124146
    + 19.882634162902832
    + 20.068907499313354
) / 5

In [None]:
# GPU
(
    26.677797079086304
    + 26.47084355354309
    + 26.37486433982849
    + 26.40730857849121
    + 26.638855934143066
) / 5

In [None]:
# oni01, cpu
(
    28.967981100082397
    + 28.795180082321167
    + 29.1264705657959
    + 28.72107458114624
    + 29.051819801330566
) / 5