In [None]:
import torch
from bubbleformer.models import get_model
from bubbleformer.data import BubblemlForecast

In [None]:
test_path = ["/share/crsp/lab/ai4ts/share/BubbleML_f32/PoolBoiling-Saturated-FC72-2D-0.1/Twall-92.hdf5"]
test_dataset = BubblemlForecast(
    filenames=test_path,
    fields=["dfun", "temperature", "velx", "vely"],
    norm="none",
    time_window=5,
    start_time=95
)

In [None]:
model_name = "unet_modern"
model_kwargs = {
    "hidden_channels": 32,
    "ch_mults": [1, 2, 2, 4, 4],
    "norm": True
}
model = get_model(model_name, **model_kwargs)

In [None]:
from collections import OrderedDict
weights_path = "/pub/sheikhh1/bubbleformer_logs/unet_modern_poolboiling_saturated_36055598/hpc_ckpt_8.ckpt"
model_data = torch.load(weights_path, weights_only=False)
print(model_data.keys())
diff_term, div_term = model_data['hyper_parameters']['normalization_constants']
diff_term = torch.tensor(diff_term)
div_term = torch.tensor(div_term)
weight_state_dict = OrderedDict()
for key, val in model_data["state_dict"].items():
    name = key[6:]
    weight_state_dict[name] = val
del model_data

In [None]:
model.load_state_dict(weight_state_dict)

In [None]:
from bubbleformer.utils.losses import LpLoss
_, _ = test_dataset.normalize(diff_term, div_term)
criterion = LpLoss(d=2, p=2, reduce_dims=[0,1], reductions=["mean", "mean"])
model.eval()
start_time = test_dataset.start_time
skip_itrs = test_dataset.time_window
model_preds = []
model_targets = []
timesteps = []
for itr in range(0, 500, skip_itrs):
    inp, tgt = test_dataset[itr]
    print(f"Autoreg pred {itr}, inp tw [{start_time+itr}, {start_time+itr+skip_itrs}], tgt tw [{start_time+itr+skip_itrs}, {start_time+itr+2*skip_itrs}]")
    if len(model_preds) > 0:
        inp = model_preds[-1] # T, C, H, W
    inp = inp.float().unsqueeze(0)
    pred = model(inp)
    pred = pred.squeeze(0).detach().cpu()
    tgt = tgt.detach().cpu()

    model_preds.append(pred)
    model_targets.append(tgt)
    timesteps.append(torch.arange(start_time+itr+skip_itrs, start_time+itr+2*skip_itrs))
    print(criterion(pred, tgt))

In [None]:
from bubbleformer.utils.plot_utils import plot_bubbleml

# model_preds = torch.cat(model_preds, dim=0)         # T, C, H, W
# model_targets = torch.cat(model_targets, dim=0)     # T, C, H, W
# timesteps = torch.cat(timesteps, dim=0)             # T,
# num_var = len(test_dataset.fields)                  # C

# preds = model_preds * div_term.view(1, num_var, 1, 1) + diff_term.view(1, num_var, 1, 1)     # denormalize
# targets = model_targets * div_term.view(1, num_var, 1, 1) + diff_term.view(1, num_var, 1, 1) # denormalize



In [None]:
import os
save_dir = "/pub/sheikhh1/bubbleformer_logs/unet_modern_poolboiling_saturated_36055598/epoch_187_outputs/sat_92"
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "predictions.pt")

torch.save({"preds": preds, "targets": targets, "timesteps": timesteps}, save_path)
plot_bubbleml(preds, targets, timesteps, save_dir)

In [None]:
def test_eikonal_loss(phi):
    """
    phi = predicted sdf torch.Tensor(T,H,W)
    """
    dx = 1/32
    grad_x = (phi[:, :, 2:] - phi[:, :, :-2]) / (2 * dx)
    grad_y = (phi[:, 2:, :] - phi[:, :-2, :]) / (2 * dx)

    grad_x = torch.nn.functional.pad(grad_x, (1, 1), mode="replicate")
    grad_y = torch.nn.functional.pad(grad_y, (0, 0, 1, 1), mode="replicate")

    grad_magnitude = torch.sqrt(grad_x**2 + grad_y**2)
    loss_map = torch.abs(grad_magnitude - 1)
    mean_loss = torch.mean(loss_map, dim=(1, 2))
    return mean_loss