In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import libraries

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import io
import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from PIL import Image
from src.four_dim_srda.config.experiment_config import CFDConfig
from src.qg_model.qg_model import QGModel
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"

# Define constants

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"]).parent.resolve()

In [None]:
experiment_name = "experiment7"

In [None]:
CFG_DIR = f"{ROOT_DIR}/python/configs/four_dim_srda/{experiment_name}"
CFG_CFD_PATH = f"{CFG_DIR}/cfd_simulation/qg_model/gpu_evaluation_config.yml"

cfg_cfd = CFDConfig.load(pathlib.Path(CFG_CFD_PATH))

DEVICE_CPU = "cpu"

cfg_cfd.lr_base_config.device = (
    cfg_cfd.hr_base_config.device
) = cfg_cfd.uhr_base_config.device = DEVICE_CPU

In [None]:
DATA_DIR = f"{ROOT_DIR}/data/four_dim_srda/{experiment_name}/cfd_simulation/qg_model"

FIG_DIR = f"{DATA_DIR}/animation/hr_pv"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
i_seed = 0
i_b = 0
hr_pv = np.load(
    f"{DATA_DIR}/hr_pv/seed{i_seed:05}/seed{i_seed:05}_start000_end800_hr_pv_{i_b:02}.npy"
)

hr_pv.shape

In [None]:
time = np.arange(
    cfg_cfd.time_config.start_time,
    cfg_cfd.time_config.end_time + cfg_cfd.time_config.output_hr_dt,
    cfg_cfd.time_config.output_hr_dt,
)
time.shape

In [None]:
t_slice = 8

hr_pv = hr_pv[::t_slice]
time = time[::t_slice]

hr_pv.shape, time.shape

In [None]:
model = QGModel(cfg_cfd.hr_base_config, show_input_cfg_info=False)
xs, ys, zs = model.get_grids()
xs.shape

# Make 3D animation

## Setting

In [None]:
# Data-related
val_slice = 3

lower_percentile = 20
upper_percentile = 80

# Retrieve clipping range from hr_pv
clip_min = np.percentile(hr_pv[:, :, ::val_slice, ::val_slice], lower_percentile)
clip_max = np.percentile(hr_pv[:, :, ::val_slice, ::val_slice], upper_percentile)

# Apply clipping
clipped_data = np.clip(hr_pv[:, :, ::val_slice, ::val_slice], clip_min, clip_max)

# Scaling after clipping
scaled_data = (
    2
    * (clipped_data - np.min(clipped_data))
    / (np.max(clipped_data) - np.min(clipped_data))
    - 1
)

data = scaled_data.flatten()
x = xs[:, ::val_slice, ::val_slice].numpy().flatten()
y = ys[:, ::val_slice, ::val_slice].numpy().flatten()
z = zs[:, ::val_slice, ::val_slice].numpy().flatten()

vmin = np.min(data)
vmax = np.max(data)
_max = min(abs(vmin), abs(vmax))
vmin, vmax = -_max, _max

colorscale = [
    [0.0, "rgb(30, 60, 150)"],  # Deep blue (minimum value)
    [0.3, "rgb(70, 130, 220)"],  # Dark cyan
    [0.4, "rgb(170, 210, 240)"],  # Light cyan
    [0.5, "rgb(255, 245, 200)"],  # Soft cream (center)
    [0.6, "rgb(255, 220, 160)"],  # Bright orange
    [0.8, "rgb(220, 90, 90)"],  # Vivid red
    [1.0, "rgb(120, 30, 30)"],  # Deep red (maximum value)
]

# Layout
layout = go.Layout(
    width=512,
    height=300,
    margin_b=0,
    margin_t=0,
    margin_r=0,
    margin_l=0,
    font_size=16,
    scene=dict(
        aspectmode="manual",  # Set the aspect ratio mode
        aspectratio=dict(x=2, y=1, z=0.8),  # Specify the ratio for each axis
        xaxis=dict(
            showticklabels=False,  # Hide tick labels on the x-axis
            title="x",  # Label for the x-axis
            titlefont=dict(family="Times New Roman"),
        ),
        yaxis=dict(
            showticklabels=False,  # Hide tick labels on the y-axis
            title="y",  # Label for the y-axis
            titlefont=dict(family="Times New Roman"),
        ),
        zaxis=dict(
            showticklabels=False,  # Hide tick labels on the z-axis
            title="z",  # Label for the z-axis
            titlefont=dict(family="Times New Roman"),
        ),
    ),
)

axeslabel_3d_standard = dict(
    xaxis_title="x",
    xaxis_title_font_family="Times New Roman",
    yaxis_title="y",
    yaxis_title_font_family="Times New Roman",
    zaxis_title="z",
    zaxis_title_font_family="Times New Roman",
)

camera_3d_standard = dict(
    up=dict(x=0, y=0, z=1),
    center=dict(x=0, y=0, z=-0.18),
    eye=dict(x=-2.0, y=-1.15, z=0.8),
)

## Plot

In [None]:
fig = go.Figure(
    data=go.Volume(
        x=x,
        y=y,
        z=z,
        value=scaled_data[0].flatten(),
        isomin=-1.0,
        isomax=1.0,
        opacity=0.5,
        surface_count=15,
        colorbar=dict(
            thickness=5,
            x=0.8,
        ),
        colorscale=colorscale,
        cmin=-1.1,
        cmax=1.1,
        reversescale=True,
    )
)

# Update layout with additional settings
fig.update_layout(
    layout,
    scene=axeslabel_3d_standard,
    scene_camera=camera_3d_standard,
    title=dict(
        text=f"Time = {time[0]:.1f}",
        font=dict(size=16),
        xref="paper",
        x=0.45,
        y=0.92,
        xanchor="center",
    ),
)

fig.show()

## Animation

In [None]:
# This code needs kaleido
# This singularity container doesn't include kaleido
# Therefore, please run this code in the environment where kaleido exists

frames = []
for it in tqdm(range(len(time))):
    fig = go.Figure(
        data=go.Volume(
            x=x,
            y=y,
            z=z,
            value=scaled_data[it].flatten(),
            isomin=-1.0,
            isomax=1.0,
            opacity=0.5,
            surface_count=15,
            colorbar=dict(
                thickness=5,
                x=0.8,
            ),
            colorscale=colorscale,
            cmin=-1.1,
            cmax=1.1,
            reversescale=True,
        )
    )

    # Update layout with additional settings
    fig.update_layout(
        layout,
        scene=axeslabel_3d_standard,
        scene_camera=camera_3d_standard,
        title=dict(
            text=f"Time = {time[it]:.1f}",
            font=dict(size=16),
            xref="paper",
            x=0.45,
            y=0.92,
            xanchor="center",
        ),
    )

    img = fig.to_image(format="png")
    frames.append(Image.open(io.BytesIO(img)))

frames[0].save(
    f"{FIG_DIR}/pv_3d_animation.gif",
    save_all=True,
    append_images=frames[1:],
    duration=150,
    loop=0,
)