In [None]:
from einops import rearrange
import yaml
import os
from pathlib import Path

import torch
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
import torch.nn.functional as F

from safetensors import safe_open

from gphyt.model.transformer.model import get_model as get_gphyt_model
from gphyt.model.unet import get_model as get_unet_model
from gphyt.data.phys_dataset import PhysicsDataset as GPhyTDataset

from scOT.model import ScOT, ScOTConfig
from scOT.problems.well_ds import PhysicsDataset as PoseidonDataset

from dpot.models.dpot import DPOTNet
from dpot.well_ds import PhysicsDataset as DPOTDataset

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

def load_yaml(file_path: Path):
    with open(file_path, 'r') as file:
        return yaml.load(file, Loader=yaml.FullLoader)

base_path = Path("/hpcwork/rwth1802/coding/General-Physics-Transformer/results")

def unet_model(path: Path) -> torch.nn.Module:

    config = {
        "model_size": "UNet_M"
    }
    model = get_unet_model(config)
    cp = torch.load(path / "best_model.pth", map_location='cpu')
    model_state_dict = cp["model_state_dict"]
    consume_prefix_in_state_dict_if_present(model_state_dict, "module.")
    consume_prefix_in_state_dict_if_present(model_state_dict, "_orig_mod.")
    model.load_state_dict(model_state_dict, strict=True)
    return model

def gphyt_model(path: Path) -> torch.nn.Module:
    config = load_yaml(path / "config_eval.yaml")
    model = get_gphyt_model(config["model"])
    cp = torch.load(path / "best_model.pth", map_location='cpu')
    model_state_dict = cp["model_state_dict"]
    consume_prefix_in_state_dict_if_present(model_state_dict, "module.")
    consume_prefix_in_state_dict_if_present(model_state_dict, "_orig_mod.")
    model.load_state_dict(model_state_dict, strict=True)
    return model

def poseidon_model(path: Path) -> ScOT:
    def get_model() -> ScOT:
        m_config = {
            "num_heads": [3, 6, 12, 24],
            "skip_connections": [2, 2, 2, 0],
            "window_size": 16,
            "patch_size": 4,
            "mlp_ratio": 4.0,
            "depths": [8, 8, 8, 8],
            "embed_dim": 96,
        }
        config = ScOTConfig(
            image_size=128,
            patch_size=m_config["patch_size"],
            num_channels=5,
            num_out_channels=5,
            embed_dim=m_config["embed_dim"],
            depths=m_config["depths"],
            num_heads=m_config["num_heads"],
            skip_connections=m_config["skip_connections"],
            window_size=m_config["window_size"],
            mlp_ratio=m_config["mlp_ratio"],
            qkv_bias=True,
            hidden_dropout_prob=0.0,  # default
            attention_probs_dropout_prob=0.0,  # default
            drop_path_rate=0.0,
            hidden_act="gelu",
            use_absolute_embeddings=False,
            initializer_range=0.02,
            layer_norm_eps=1e-5,
            p=1,
            channel_slice_list_normalized_loss=None,
            residual_model="convnext",
            use_conditioning=True,
            learn_residual=False,
        )
        model = ScOT(config)
        return model

    model = get_model()

    weights = {}
    with safe_open(path / "model.safetensors", framework="pt", device="cpu") as f:
        for key in f.keys():
            weights[key] = f.get_tensor(key)
    consume_prefix_in_state_dict_if_present(weights, "module.")
    consume_prefix_in_state_dict_if_present(weights, "_orig_mod.")
    model.load_state_dict(weights, strict=True)
    return model

def dpot_model(path: Path) -> torch.nn.Module:

    config = load_yaml(Path("/hpcwork/rwth1802/coding/DPOT/configs/eval_medium.yaml"))

    def get_model(config: dict) -> torch.nn.Module:
        model = DPOTNet(
            img_size=config["res"],
            patch_size=config["patch_size"],
            in_channels=config["num_channels"],
            in_timesteps=config["T_in"],
            out_timesteps=1,
            out_channels=config["num_channels"],
            normalize=config["normalize"],
            embed_dim=config["width"],
            depth=config["n_layers"],
            n_blocks=config["n_blocks"],
            mlp_ratio=config["mlp_ratio"],
            out_layer_dim=config["out_layer_dim"],
            act=config["act"],
            n_cls=12,
        )
        return model
    model = get_model(config)

    data = torch.load(path / "model_6.pth", map_location="cpu", weights_only=False)
    model_dict = data["model"]
    consume_prefix_in_state_dict_if_present(model_dict, "module.")
    consume_prefix_in_state_dict_if_present(model_dict, "_orig_mod.")
    model.load_state_dict(model_dict, strict=True)
    return model

@torch.inference_mode()
def gphyt_forward(model, sample, device) -> torch.Tensor:

    xx, target = sample

    xx = xx.to(device).unsqueeze(0)  # add batch dim
    target = target.to(device).unsqueeze(0)  # add batch dim
    predictions = []

    # Perform autoregressive prediction
    ar_steps = target.shape[1]  # num of timesteps
    output = torch.tensor(0.0, device=device)  # Initialize for linter
    for _ar_step in range(ar_steps):
        if _ar_step == 0:
            x = xx
        else:
            x = torch.cat(
                (x[:, 1:, ...], output),
                dim=1,
            )  # remove first input step, append output step
        output = model(x)
        predictions.append(output)
    predictions = torch.cat(predictions, dim=1)  # concat along time dimension
    return predictions.squeeze(0).cpu() # T, H, W, C

@torch.inference_mode()
def poseidon_forward(model, sample, device) -> torch.Tensor:

    xx = sample["pixel_values"].to(device)  # (C, H, W)
    target = sample["labels"].to(device)  # (T, C, H, W)
    times = sample["time"].to(device)
    xx = xx.unsqueeze(0)  # add batch dim
    target = target.unsqueeze(0)  # add batch dim
    times = times.unsqueeze(0)  # add batch dim
    predictions = []

    ar_steps = target.shape[1]  # num of timesteps
    output = torch.tensor(0.0, device=device)  # Initialize for linter
    for _ar_step in range(ar_steps):
        if _ar_step == 0:
            x = F.interpolate(
                xx, size=(128, 128), mode="bilinear", align_corners=False
            )
        else:
            x = output

        input = {
            "pixel_values": x,
            "time": times,
        }
        output = model(**input).output # (B, C, H, W)

        real_output = F.interpolate(
            output,
            size=xx.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )
        predictions.append(real_output)
    predictions = torch.stack(predictions, dim=1)  # B, T, C, H, W
    return predictions.squeeze(0).cpu()  # T, C, H, W


@torch.inference_mode()
def dpot_forward(model, sample, device) -> torch.Tensor:

    xx, target, _ = sample  # h, w, t, c
    xx = xx.to(device)  # h, w, t, c
    target = target.to(device)  # h, w, t, c
    xx = xx.unsqueeze(0)  # add batch dim
    target = target.unsqueeze(0)  # add batch dim

    predictions = []

    # Perform autoregressive prediction
    ar_steps = target.shape[-2]  # num of timesteps
    output = torch.tensor(0.0, device=device)  # Initialize for linter
    for _ar_step in range(ar_steps):
        if _ar_step == 0:
            x = xx
            x = rearrange(x, "B H W T C -> (B T) C H W")  # (B*T, C, H, W)
            x = F.interpolate(
                x, size=(128, 128), mode="bilinear", align_corners=False
            )
            x = rearrange(x, "(B T) C H W -> B H W T C", B=1)  # (B, H, W, T, C)
        else:
            x = torch.cat(
                (x[..., 1:, :], output),
                dim=-2,
            )  # remove first input step, append output step
        output, _ = model(x)  # (B, H, W, 1, C)
        predictions.append(output)
    predictions = torch.cat(predictions, dim=-2)  # concat along time dimension
    # reverse interpolation
    predictions = rearrange(predictions, "B H W T C -> (B T) C H W")  # (B*T, C, H, W)
    predictions = F.interpolate(
        predictions,
        size=(xx.shape[1], xx.shape[2]),
        mode="bilinear",
        align_corners=False,
    )
    predictions = rearrange(predictions, "(B T) C H W -> B H W T C", B=1)  # (B, H, W, T, C)

    return predictions.squeeze(0).cpu()  # H, W, T, C

In [None]:
# Set parameters
data_dir = Path("/hpcwork/rwth1802/coding/General-Physics-Transformer/data/datasets")
name = "euler_multi_quadrants_periodicBC"
ar_steps = 24
stride = 1
sample_idx = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
%%capture
# UNet
unet = unet_model(base_path / "unet-m-04")
unet.eval()
unet.to(device)

In [None]:
# GPHYT
gphyt = gphyt_model(base_path / "m-main-03")
gphyt.eval()
gphyt.to(device)

dataset_gphyt = GPhyTDataset(
    data_dir / name / "data/test",
    n_steps_input=4,
    n_steps_output=ar_steps,
    dt_stride=stride,
)

In [None]:
# Poseidon
poseidon = poseidon_model(base_path / "poseidon")
poseidon.eval()
poseidon.to(device)
dataset_poseidon = PoseidonDataset(
    data_dir / name / "data/test",
    n_output_steps=ar_steps,
    dt_stride=stride,
    train=False,
)

In [None]:
%%capture
# DPOT
dpot = dpot_model(base_path / "dpot")
dataset_dpot = DPOTDataset(
    data_dir / name / "data/test",
    T_in=4,
    T_out=ar_steps,
    dt_stride=stride,
    train=False,
    use_normalization=True
)
dpot.eval()
dpot.to(device)

In [None]:
# get sample:
sample_gphyt = dataset_gphyt[sample_idx]
sample_poseidon = dataset_poseidon[sample_idx]
sample_dpot = dataset_dpot[sample_idx] # xx (H, W, T, C), target, _

In [None]:
# Prediction loop
pred_unet = gphyt_forward(unet, sample_gphyt, device)  # T, H, W, C
pred_gphyt = gphyt_forward(gphyt, sample_gphyt, device)  # T, H, W, C
pred_poseidon = poseidon_forward(poseidon, sample_poseidon, device)  # T, C, H, W
pred_dpot = dpot_forward(dpot, sample_dpot, device)  # H, W, T, C

In [None]:
ground_truth = sample_gphyt[1]  # T, H, W, C
# rearrange poseidon prediction to T, H, W, C
pred_poseidon = rearrange(pred_poseidon, "T C H W -> T H W C")
# rearrange dpot prediction to T, H, W, C
pred_dpot = rearrange(pred_dpot, "H W T C -> T H W C")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def visualize_rollout(
    gt: torch.Tensor,
    pred: torch.Tensor,
    save_path: Path
) -> None:
    """Visualize the model predictions for a trajectory.
    """

    # Convert to numpy and transpose to match visualization format
    predictions = pred.cpu().numpy()
    ground_truth = gt.cpu().numpy()

    # Transpose to match visualization format (T, H, W, C) -> (T, W, H, C)
    predictions = predictions.transpose(0, 2, 1, 3)
    ground_truth = ground_truth.transpose(0, 2, 1, 3)

    # Calculate velocity magnitude
    vel_mag_pred = np.linalg.norm(predictions[..., -2:], axis=-1)
    vel_mag_gt = np.linalg.norm(ground_truth[..., -2:], axis=-1)

    # Add velocity magnitude as a new channel
    predictions = np.concatenate([predictions, vel_mag_pred[..., None]], axis=-1)
    ground_truth = np.concatenate([ground_truth, vel_mag_gt[..., None]], axis=-1)

    # Field names and colormaps
    field_names = [
        ("pressure", "inferno"),
        ("density", "viridis"),
        ("temperature", "magma"),
        ("velocity_x", "viridis"),
        ("velocity_y", "viridis"),
        ("velocity_mag", "viridis"),
    ]

    # Create save directory if needed
    save_path.mkdir(parents=True, exist_ok=True)

    # Visualize each field
    for i, (field, colormap) in enumerate(field_names):
        # Get min and max values for consistent color scaling
        vmin = min(np.nanmin(predictions[..., i]), np.nanmin(ground_truth[..., i]))
        vmax = max(np.nanmax(predictions[..., i]), np.nanmax(ground_truth[..., i]))

        for t in range(predictions.shape[0]):
            # Normalize the data to 0-1 range for colormap
            pred_norm = (predictions[t, ..., i] - vmin) / (vmax - vmin)
            gt_norm = (ground_truth[t, ..., i] - vmin) / (vmax - vmin)

            # Apply viridis colormap
            colormap = plt.get_cmap(colormap)
            pred_rgb = colormap(pred_norm)[..., :3]  # Get RGB channels
            gt_rgb = colormap(gt_norm)[..., :3]  # Get RGB channels

            # Convert to uint8 for PIL
            pred_rgb = (pred_rgb * 255).astype(np.uint8)
            gt_rgb = (gt_rgb * 255).astype(np.uint8)

            # Create PIL images
            pred_img = Image.fromarray(pred_rgb)
            gt_img = Image.fromarray(gt_rgb)

            # Save prediction
            pred_path = save_path / f"{field}_pred_t{t}.png"
            pred_img.save(pred_path)

            # Save ground truth
            gt_path = save_path / f"{field}_gt_t{t}.png"
            gt_img.save(gt_path)


img_path = Path(
    "/hpcwork/rwth1802/coding/General-Physics-Transformer/results/01_new_plots/visualizations"
)

visualize_rollout(
    ground_truth,
    pred_gphyt,
    save_path=img_path / name / "gphyt",
)
visualize_rollout(
    ground_truth,
    pred_poseidon,
    save_path=img_path / name / "poseidon",
)
visualize_rollout(
    ground_truth,
    pred_dpot,
    save_path=img_path / name / "dpot",
)
visualize_rollout(
    ground_truth,
    pred_unet,
    save_path=img_path / name / "unet",
)