In [None]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1"
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
import pickle

seed=123
key = jax.random.PRNGKey(seed)
case = 'case_a1_FiLM'
case_path = os.path.join("./results/", case)
os.makedirs(case_path, exist_ok=True)

In [None]:
latent_dim = 128
num_epochs = 2000
batch_size= 512
learning_rate=1e-3

# Data Preprocessing

In [None]:
def compute_stats(dataset):
    input_all = np.concatenate([sample["input"] for sample in dataset], axis=0)
    output_all = np.concatenate([sample["output"] for sample in dataset], axis=0)

    input_mean = input_all.mean(axis=0)
    input_std = input_all.std(axis=0)

    output_mean = output_all.mean(axis=0)
    output_std = output_all.std(axis=0)

    return {
        "input_mean": input_mean,
        "input_std": input_std,
        "output_mean": output_mean,
        "output_std": output_std
    }

def normalize_dataset(dataset, stats):
    for sample in dataset:
        sample["input"] = jnp.array((sample["input"] - stats["input_mean"]) / stats["input_std"])
        sample["output"] = jnp.array(((sample["output"] - stats["output_mean"]) / stats["output_std"])[:,1:])
    return dataset

def reverse_input(arr, stats):
    return arr * stats["input_std"] + stats["input_mean"]

def reverse_output(arr, stats):
    return arr * stats["output_std"][1:] + stats["output_mean"][1:]

def flatten_dataset(dataset_jax, batch_size=512):
    image_list = [sample["image"] for sample in dataset_jax]
    input_list = [sample["input"] for sample in dataset_jax]
    output_list = [sample["output"] for sample in dataset_jax]
    return {
        "image": jnp.stack(image_list, axis=0),
        "input": jnp.stack(input_list, axis=0),
        "output": jnp.stack(output_list, axis=0)
    }

In [None]:
train_pkl_path1 = "dataset/train_circular_inclusion1.pkl"
train_pkl_path2 = "dataset/train_circular_inclusion2.pkl"

with open(train_pkl_path1, "rb") as f:
        dataset1 = pickle.load(f)
with open(train_pkl_path2, "rb") as f:
        dataset2 = pickle.load(f)
dataset = dataset1 + dataset2

stats = compute_stats(dataset)
dataset_jax = normalize_dataset(dataset, stats)
dataset = flatten_dataset(dataset_jax)

first_input = dataset_jax[0]['input']
first_output = dataset_jax[0]['output']
seq_len, input_dim = first_input.shape
_, output_dim = first_output.shape
metadata = {
    "seq_len": seq_len,
    "input_dim": input_dim,
    "latent_dim":latent_dim,
    "output_dim": output_dim,
    "data_size": len(dataset_jax)
}
print("Metadata:", metadata)

In [None]:
# Train and test dataset
num_samples = metadata['data_size']
indices = jax.random.permutation(key, num_samples)

n_test = 4000
test_indices = indices[:n_test]
train_indices = indices[n_test:]

train_input = dataset['input'][train_indices]
train_output = dataset['output'][train_indices]
train_image = dataset['image'][train_indices]

test_input = dataset['input'][test_indices]
test_output = dataset['output'][test_indices]
test_image = dataset['image'][test_indices]

train_set = {"input": train_input, "output": train_output, "image": train_image}
test_set = {"input": test_input, "output": test_output, "image": test_image}

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
# ==============================================================
# Global Functions
# ==============================================================
mpl.rcParams['figure.dpi'] = 1000
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['savefig.dpi'] = 1000
plt.rcParams['figure.figsize'] = [6.0, 4.0]
plt.rcParams['font.size'] = 9.0
legend_font = FontProperties(size=9)

# Training
## VAE Training

In [None]:
import os
import pickle
from typing import Dict, Any, Tuple, List
import jax
import jax.numpy as jnp
from jax import random
import optax

def train_vae(
    vae_model,
    train_set: Dict[str, jnp.ndarray],
    rng: jax.Array,
    num_epochs: int = 50,
    batch_size: int = 64,
    learning_rate: float = 1e-3,
    patience: int = 10,
    min_delta: float = 1e-4,
) -> Tuple[Dict, Dict, Dict[str, List[float]]]:
    """
    return: vae_params, vae_batch_stats, loss_history
          loss_history = {'total': [...], 'recon': [...], 'kl': [...]}
    """
    images = train_set["image"]
    num_samples = images.shape[0]
    image_shape = images.shape[1:]

    rng, init_rng, step_base = random.split(rng, 3)
    dummy = jnp.ones((1,) + image_shape)
    vae_variables = vae_model.init(init_rng, dummy, init_rng)  # (x, key)
    vae_params = vae_variables["params"]
    vae_batch_stats = vae_variables.get("batch_stats", {})

    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(vae_params)

    def loss_fn(vae_params, vae_batch_stats, batch_x, step_key):
        (x_recon, mu, log_var), mutable_out = vae_model.apply(
            {"params": vae_params, "batch_stats": vae_batch_stats},
            batch_x,
            step_key,
            mutable=["batch_stats"],
            use_running_average=False,
        )
        new_batch_stats = mutable_out["batch_stats"]
        recon = jnp.mean(jnp.sum((batch_x - x_recon) ** 2, axis=(1, 2, 3)))
        kl_div = -0.5 * jnp.sum(1 + log_var - mu**2 - jnp.exp(log_var), axis=1)
        kl = jnp.mean(kl_div)
        total = recon + kl
        return total, (recon, kl, new_batch_stats)

    @jax.jit
    def update_step(vae_params, vae_batch_stats, opt_state, batch_x, step_key):
        (loss, (recon, kl, new_batch_stats)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
            vae_params, vae_batch_stats, batch_x, step_key
        )
        updates, opt_state = optimizer.update(grads, opt_state, vae_params)
        vae_params = optax.apply_updates(vae_params, updates)
        return vae_params, new_batch_stats, opt_state, loss, recon, kl

    def _should_early_stop(curve: List[float], patience: int, min_delta: float) -> bool:
        if len(curve) < patience + 1:
            return False
        recent = curve[-(patience + 1):]
        deltas = [abs(recent[i + 1] - recent[i]) for i in range(patience)]
        return all(d < min_delta for d in deltas)

    loss_history = {"total": [], "recon": [], "kl": []}
    for epoch in range(1, num_epochs + 1):
        perm = jax.random.permutation(step_base, num_samples)
        step_base, = random.split(step_base, 1)
        epoch_total = 0.0
        epoch_recon = 0.0
        epoch_kl = 0.0
        steps = 0

        for i in range(0, num_samples, batch_size):
            idx = perm[i:i + batch_size]
            batch_x = images[idx]
            step_key, = random.split(step_base, 1)
            # update
            vae_params, vae_batch_stats, opt_state, loss, recon, kl = update_step(
                vae_params, vae_batch_stats, opt_state, batch_x, step_key
            )
            epoch_total += float(loss)
            epoch_recon += float(recon)
            epoch_kl += float(kl)
            steps += 1

        avg_total = epoch_total / max(steps, 1)
        avg_recon = epoch_recon / max(steps, 1)
        avg_kl = epoch_kl / max(steps, 1)
        print(f"[VAE] Epoch {epoch:03d} | total: {avg_total:.6f} | recon: {avg_recon:.6f} | kl: {avg_kl:.6f}")

        loss_history["total"].append(avg_total)
        loss_history["recon"].append(avg_recon)
        loss_history["kl"].append(avg_kl)

        if _should_early_stop(loss_history["total"], patience, min_delta):
            print(f"[VAE] Early stopping at epoch {epoch}")
            break

    return vae_params, vae_batch_stats, loss_history

In [None]:
from model import VAE, StackedGRUFiLM
vae_model = VAE(latent_dim=latent_dim)
gru_model = StackedGRUFiLM(
    input_dim=metadata["input_dim"],
    output_dim=metadata["output_dim"],
    latent_dim=latent_dim,
    hidden_dim=512,
    num_layers=3
)
key = jax.random.PRNGKey(0)
min_delta = 1e-10

In [None]:
vae_params, vae_batch_stats, vae_hist = train_vae(
    vae_model,
    train_set=train_set,
    rng=key,
    num_epochs=num_epochs,
    batch_size=batch_size,
    learning_rate=1e-3
)
vae_state = {
    "vae_params": vae_params,
    "vae_batch_stats": vae_batch_stats,
    "vae_hist": vae_hist,
}
with open(os.path.join(case_path ,"vae_training.pkl"), "wb") as f:
    pickle.dump(vae_state, f)

In [None]:
with open(os.path.join(case_path ,"vae_training.pkl"), "rb") as f:
    vae_state = pickle.load(f)
vae_params = vae_state["vae_params"]
vae_batch_stats = vae_state["vae_batch_stats"]
vae_hist = vae_state["vae_hist"]

In [None]:
keys_order = ["total", "recon", "kl"]
legend_titles = {
    "total": "Total Loss",
    "recon": "Reconstruction Loss",
    "kl": "KL Divergence"
}

epochs = np.arange(1, len(vae_hist[keys_order[0]]) + 1)
curves = [np.asarray(vae_hist[k], dtype=float) for k in keys_order]
loss_data = np.column_stack([epochs] + curves)  # shape: (E, 1+num_curves)
plt.figure()
for i, k in enumerate(keys_order):
    plt.plot(loss_data[:, 0].astype(int),
             loss_data[:, i + 1],
             '--', linewidth=2, label=legend_titles[k])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.yscale('log')
plt.title("Training Loss")
# plt.grid(True)
plt.legend()
plt.savefig(os.path.join(case_path,"vae_loss_curve.png"), dpi=300)
# plt.show()
plt.close()

## GRU Training

In [None]:
def train_gru_with_mu(
    gru_model,                          
    vae_model,                          
    vae_params: Dict,                   
    vae_batch_stats: Dict,              
    train_set: Dict[str, jnp.ndarray], 
    rng: jax.Array,
    num_epochs: int = 100,
    batch_size: int = 64,
    learning_rate: float = 1e-3,
    patience: int = 10,
    min_delta: float = 1e-5,
) -> Tuple[Dict, Dict[str, List[float]]]:
    """
    return: gru_params, loss_history
          loss_history = {'pred': [...]}
    """
    images = train_set["image"]    # (N, H, W, C)
    inputs = train_set["input"]    # (N, T, input_dim)
    targets = train_set["output"]  # (N, T, output_dim)
    num_samples = inputs.shape[0]

    params = gru_model.params
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(params)

    def encode_z(batch_imgs, key):
        (_, mu, _) = vae_model.apply(
            {"params": vae_params, "batch_stats": vae_batch_stats},
            batch_imgs,
            key,
            method=vae_model.encode_only,
            mutable=False,
            use_running_average=True,
        )
        return mu  # (B, latent_dim)

    def loss_fn(params, batch_imgs, batch_inp, batch_tgt, key):
        z_latent = encode_z(batch_imgs, key)                  # (B, latent_dim)
        preds = gru_model.forward(batch_inp, z_latent, params)  # (B, T, output_dim)
        pred_mse = jnp.mean((preds - batch_tgt) ** 2)
        return pred_mse

    @jax.jit
    def update_step(params, opt_state, batch_imgs, batch_inp, batch_tgt, key):
        loss, grads = jax.value_and_grad(loss_fn)(
            params, batch_imgs, batch_inp, batch_tgt, key
        )
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss

    def _should_early_stop(curve: List[float], patience: int, min_delta: float) -> bool:
        if len(curve) < patience + 1:
            return False
        recent = curve[-(patience + 1):]
        deltas = [abs(recent[i + 1] - recent[i]) for i in range(patience)]
        return all(d < min_delta for d in deltas)

    loss_history = {"pred": []}
    base_key = rng
    for epoch in range(1, num_epochs + 1):
        perm = jax.random.permutation(base_key, num_samples)
        base_key, = random.split(base_key, 1)
        epoch_loss = 0.0
        steps = 0

        for i in range(0, num_samples, batch_size):
            idx = perm[i:i + batch_size]
            batch_imgs = images[idx]
            batch_inp = inputs[idx]
            batch_tgt = targets[idx]
            step_key, = random.split(base_key, 1)

            params, opt_state, loss = update_step(
                params, opt_state, batch_imgs, batch_inp, batch_tgt, step_key
            )
            epoch_loss += float(loss)
            steps += 1

        avg = epoch_loss / max(steps, 1)
        print(f"[GRU] Epoch {epoch:03d} | pred_mse: {avg:.6e}")
        loss_history["pred"].append(avg)

        if _should_early_stop(loss_history["pred"], patience, min_delta):
            print(f"[GRU] Early stopping at epoch {epoch}")
            break

    return params, loss_history

In [None]:
key = jax.random.PRNGKey(1)
gru_params, gru_hist = train_gru_with_mu(
    gru_model,
    vae_model,
    vae_params,
    vae_batch_stats,
    train_set=train_set,
    rng=key,
    num_epochs=num_epochs,
    batch_size=batch_size,
    learning_rate=1e-3,
)
gru_state = {
    "gru_params": gru_params,
    "gru_hist": gru_hist,
}
with open(os.path.join(case_path ,"gru_training.pkl"), "wb") as f:
    pickle.dump(gru_state, f)

# 68m
# [GRU] Epoch 1000 | pred_mse: 9.452973e-05

In [None]:
with open(os.path.join(case_path ,"gru_training.pkl"), "rb") as f:
    gru_state = pickle.load(f)
gru_params = gru_state["gru_params"]
gru_hist = gru_state["gru_hist"]

In [None]:
loss_history = gru_hist["pred"]
epochs = np.arange(1, len(loss_history) + 1)
loss_data = np.column_stack((epochs, loss_history))
epochs = loss_data[:, 0].astype(int)            
loss_curve = loss_data[:, 1]                    
plt.plot(epochs, loss_curve, '--', linewidth=2, label='GRU MSE')
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.yscale('log')
plt.title("Training Loss")
# plt.grid(True)
plt.legend()
plt.savefig(os.path.join(case_path,"gru_loss_curve.png"))
# plt.show()
plt.close()

## Joint Training

In [None]:
from model import JointModel, train_joint_model_mu
joint_model = JointModel(vae_model=vae_model, gru_model=gru_model, learning_rate=1e-3)
params_dict, opt_states_dict, batch_stats_dict, rng = joint_model.init(key)
total = joint_model.count_params(params_dict)

In [None]:
final_params, final_opt_states, final_batch_stats, loss_history = train_joint_model_mu(
    model=joint_model, 
    train_set = train_set, 
    params_dict = params_dict, 
    opt_states_dict = opt_states_dict, 
    batch_stats_dict = batch_stats_dict, 
    rng = key, 
    num_epochs = num_epochs, 
    batch_size = batch_size,
    min_delta=min_delta)

model_state = {
    "params": final_params,
    "opt_states": final_opt_states,
    "batch_stats": final_batch_stats,
    "loss_history": loss_history,
}

with open(os.path.join(case_path ,"initial_training.pkl"), "wb") as f:
    pickle.dump(model_state, f)

In [None]:
with open(os.path.join(case_path ,"initial_training.pkl"), "rb") as f:
    model_state = pickle.load(f)

params_dict = model_state["params"]
batch_stats_dict = model_state["batch_stats"]
loss_history = model_state["loss_history"]

In [None]:
epochs = np.arange(1, len(loss_history['total']) + 1)
plt.figure(figsize=(8,6))
plt.plot(epochs, loss_history['total'], '--', linewidth=2, label='Total Loss')
plt.plot(epochs, loss_history['recon'], '-', linewidth=2, label='Recon Loss')
plt.plot(epochs, loss_history['kl'], '-', linewidth=2, label='KL Loss')
plt.plot(epochs, loss_history['pde'], '-', linewidth=2, label='PDE Loss')
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.yscale('log')
plt.title("Training Loss Curve")
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(case_path,"joint_train_loss.png"), dpi=300)
# plt.show()
plt.close()

# Validation

In [None]:
def plot_reconstruction(x_real, x_recon, n_rows=2, n_cols=5, save=False, path="reconstruction.png"):
    assert x_real.shape == x_recon.shape, "Input shapes must match"
    num_samples = x_real.shape[0]
    indices = np.random.choice(num_samples, size=n_cols, replace=False)

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(12, 5))

    for i, idx in enumerate(indices):
        real_img = np.squeeze(x_real[idx])
        recon_img = np.squeeze(x_recon[idx])

        axs[0, i].imshow(real_img, cmap='gray')
        axs[0, i].axis('off')
        axs[0, i].set_title("Original")

        axs[1, i].imshow(recon_img, cmap='gray')
        axs[1, i].axis('off')
        axs[1, i].set_title("Reconstruction")
    plt.tight_layout()
    if save:
        plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

def visualize_vaemodel_reconstruction(vae_model, test_set, params_dict, batch_stats_dict, rng, path):
    image_test = test_set["image"]
    indices = np.random.choice(image_test.shape[0], size=5, replace=False)
    x_real = image_test[indices]

    vae_vars = {"params": params_dict, "batch_stats": batch_stats_dict}
    x_recon, _, _ = vae_model.apply(vae_vars, x_real, rng, mutable=False, use_running_average=True)

    plot_reconstruction(x_real, x_recon, n_rows=2, n_cols=5, save=True, path=os.path.join(case_path ,path))

def evaluate_vae_reconstruction(vae_model, test_set, params_dict, batch_stats_dict, rng_key):
    """
    Evaluate VAE reconstruction quality inside a JointModel.
    
    Args:
        joint_model: JointModel instance (contains a .vae submodule)
        test_set: dict, must contain key "image" with shape (N, H, W)
        params_dict: dict, must contain "vae"
        batch_stats_dict: dict, must contain "vae"
        rng_key: jax.random.PRNGKey
        
    Returns:
        mse, mae, psnr
    """
    images = test_set["image"]  # shape (N, H, W)
    vae_vars = {"params": params_dict, "batch_stats": batch_stats_dict}
    # Run forward pass: full batch inference    
    x_recon, _, _ = vae_model.apply(
        vae_vars,
        images,
        rng_key,
        mutable=False,
        use_running_average=True
    )

    # Flatten to vectors
    x_true = np.array(images).reshape((images.shape[0], -1))
    x_pred = np.array(x_recon).reshape((x_recon.shape[0], -1))

    # === Metrics ===
    mse = np.mean((x_true - x_pred) ** 2)
    mae = np.mean(np.abs(x_true - x_pred))
    if mse == 0:
        psnr = float("inf")
    else:
        psnr = 20 * np.log10(1.0) - 10 * np.log10(mse)

    print(f"[VAE Reconstruction on Test Set]")
    print(f"  ➤ MSE  : {mse:.6f}")
    print(f"  ➤ MAE  : {mae:.6f}")
    print(f"  ➤ PSNR : {psnr:.2f} dB")

    return

In [None]:
visualize_vaemodel_reconstruction(vae_model, test_set, vae_params, vae_batch_stats, key, "validation_reconstruction_vae.png")
evaluate_vae_reconstruction(vae_model, test_set, vae_params, vae_batch_stats, key)

In [None]:
def visualize_jointmodel_reconstruction(joint_model, test_set, params_dict, batch_stats_dict, rng, path):
    image_test = test_set["image"]
    indices = np.random.choice(image_test.shape[0], size=5, replace=False)
    x_real = image_test[indices]

    vae_vars = {"params": params_dict["vae"], "batch_stats": batch_stats_dict["vae"]}
    x_recon, _, _ = joint_model.vae.apply(vae_vars, x_real, rng, mutable=False, use_running_average=True)

    plot_reconstruction(x_real, x_recon, n_rows=2, n_cols=5, save=True, path=os.path.join(case_path ,path))

def evaluate_joint_vae_reconstruction(joint_model, test_set, params_dict, batch_stats_dict, rng_key):
    """
    Evaluate VAE reconstruction quality inside a JointModel.
    
    Args:
        joint_model: JointModel instance (contains a .vae submodule)
        test_set: dict, must contain key "image" with shape (N, H, W)
        params_dict: dict, must contain "vae"
        batch_stats_dict: dict, must contain "vae"
        rng_key: jax.random.PRNGKey
        
    Returns:
        mse, mae, psnr
    """
    images = test_set["image"]  # shape (N, H, W)
    vae_vars = {"params": params_dict["vae"], "batch_stats": batch_stats_dict["vae"]}
    # Run forward pass: full batch inference    
    x_recon, _, _ = joint_model.vae.apply(
        vae_vars,
        images,
        rng_key,
        mutable=False,
        use_running_average=True
    )

    # Flatten to vectors
    x_true = np.array(images).reshape((images.shape[0], -1))
    x_pred = np.array(x_recon).reshape((x_recon.shape[0], -1))

    # === Metrics ===
    mse = np.mean((x_true - x_pred) ** 2)
    mae = np.mean(np.abs(x_true - x_pred))
    if mse == 0:
        psnr = float("inf")
    else:
        psnr = 20 * np.log10(1.0) - 10 * np.log10(mse)

    print(f"[VAE Reconstruction on Test Set]")
    print(f"  ➤ MSE  : {mse:.6f}")
    print(f"  ➤ MAE  : {mae:.6f}")
    print(f"  ➤ PSNR : {psnr:.2f} dB")

    return

In [None]:
params_dict = model_state["params"]
batch_stats_dict = model_state["batch_stats"]
visualize_jointmodel_reconstruction(joint_model, test_set, params_dict, batch_stats_dict, rng, "validation_reconstruction_joint.png")
evaluate_joint_vae_reconstruction(joint_model, test_set, params_dict, batch_stats_dict, rng)

In [None]:
from typing import Dict, Any, Tuple 
def evaluate_stress_prediction(model, test_set: Dict[str, jnp.ndarray], params_dict: Dict[str, Any], rng) -> Tuple[float, float]:
    input_path = test_set["input"]    # shape: (N, T, 3)
    output_gt = test_set["output"]    # shape: (N, T, 5)
    micro_image = test_set["image"]   # shape: (N, H, W)

    _, mu_latent, _ = model.vae.apply(
        {"params": params_dict["vae"], "batch_stats": params_dict["batch_stats"]},
        micro_image,
        rng,
        method=model.vae.encode_only,
        mutable=False,
        use_running_average=True
    )

    preds = model.gru.forward(input_path, mu_latent, params=params_dict["gru"])

    preds_np = np.array(preds)
    output_np = np.array(output_gt)

    mse_all = np.mean((preds_np - output_np) ** 2)
    mae_all = np.mean(np.abs(preds_np - output_np))

    mse_energy = np.mean((preds_np[:, :, 0] - output_np[:, :, 0]) ** 2)
    mae_energy = np.mean(np.abs(preds_np[:, :, 0] - output_np[:, :, 0]))

    mse_stress = np.mean((preds_np[:, :, 1:4] - output_np[:, :, 1:4]) ** 2)
    mae_stress = np.mean(np.abs(preds_np[:, :, 1:4] - output_np[:, :, 1:4]))

    per_sample_mae = np.mean(np.abs(preds_np - output_np), axis=(1, 2))  # shape: (N,)
    per_sample_mse = np.mean((preds_np - output_np) ** 2, axis=(1, 2))   # shape: (N,)

    worst_mae_idx = int(np.argmax(per_sample_mae))
    worst_mse_idx = int(np.argmax(per_sample_mse))
    worst_mae = float(np.max(per_sample_mae))
    worst_mse = float(np.max(per_sample_mse))

    print("===== Stress Prediction Summary =====")
    print(f"Total     → MSE: {mse_all:.6f} | MAE: {mae_all:.6f}")
    print(f"Energy    → MSE: {mse_energy:.6f} | MAE: {mae_energy:.6f}")
    print(f"Stress    → MSE: {mse_stress:.6f} | MAE: {mae_stress:.6f}")
    print(f"Worst MAE sample index: {worst_mae_idx} → MAE: {worst_mae:.6f}")
    print(f"Worst MSE sample index: {worst_mse_idx} → MSE: {worst_mse:.6f}")

    return {
        "mse_all": mse_all,
        "mae_all": mae_all,
        "mse_energy": mse_energy,
        "mae_energy": mae_energy,
        "mse_stress": mse_stress,
        "mae_stress": mae_stress,
        "worst_mae_idx": worst_mae_idx,
        "worst_mae": worst_mae,
        "worst_mse_idx": worst_mse_idx,
        "worst_mse": worst_mse
    }

def visualize_jointmodel_prediction(joint_model, test_set, params_dict, batch_stats_dict, index, rng, stats=None, path=None):
    input_path = test_set["input"][index:index+1]    # [1, 101, 3]
    micro_image = test_set["image"][index:index+1]   # [1, 128, 128, 1]
    target_output = test_set["output"][index]        # [101, 5]

    _, mu_latent, _ = joint_model.vae.apply(
        {"params": params_dict["vae"], "batch_stats": batch_stats_dict["vae"]},
        micro_image,
        rng,
        method=joint_model.vae.encode_only,
        use_running_average=True,
        mutable=False
    )

    pred = joint_model.gru.forward(input_path, mu_latent, params=params_dict["gru"])

    if stats:
        input_path = reverse_input(input_path[0], stats)         # shape: [101, 3]
        pred_array = reverse_output(pred[0], stats)              # shape: [101, 5]
        target_output = reverse_output(target_output, stats)
    else:
        input_path = input_path[0]
        pred_array = pred[0]

    fig, axes = plt.subplots(1, 3, figsize=(12, 3))

    # 1. Strain Path
    for idx, (color, label) in enumerate(zip(['r', 'g', 'b'], ['e11', 'e22', 'e12'])):
        axes[0].plot(range(len(input_path)), input_path[:, idx], '--', color=color, label=label)
    axes[0].set_xlabel("Time step")
    axes[0].set_ylabel("Strain")
    axes[0].set_title("Strain Path"); axes[0].legend(); axes[0].grid(True)

    # 2. Stress Path
    for idx, color in enumerate(['r', 'g', 'b']):
        axes[1].plot(pred_array[:, idx+1], 'o', ms=3, color=color, label=f'Predicted s{idx+1}')
    for idx, color in enumerate(['r', 'g', 'b']):
        axes[1].plot(target_output[:, idx+1], '--', linewidth=1.5, color=color, label=f'Target s{idx+1}')
    axes[1].set_xlabel("Time step")
    axes[1].set_ylabel("Stress (MPa)")
    axes[1].set_title("Stress Path"); axes[1].legend(); axes[1].grid(True)

    # 3. Plastic Dissipation
    axes[2].plot(target_output[:, 0], '--', color='tab:orange', label='ALLPD_GT')
    axes[2].plot(pred_array[:, 0], 'o', ms=3, label='ALLPD_pred')
    axes[2].set_xlabel("Time step")
    axes[2].set_ylabel("ALLPD (mJ)")
    axes[2].set_title("Plastic Dissipation"); axes[2].legend(); axes[2].grid(True)

    plt.tight_layout()
    if path is not None:
        plt.savefig(os.path.join(case_path, path), dpi=300, bbox_inches='tight')
    plt.show()

## Seperate Training
### Training Set

In [None]:
params_dict = {
    "vae": vae_params,
    "gru": gru_params,
}
batch_stats_dict = {
    "vae": vae_batch_stats,
}

In [None]:
# train set
metrics_validation = evaluate_stress_prediction(
    model=joint_model,
    test_set=train_set,
    params_dict={
        "vae": params_dict["vae"],
        "gru": params_dict["gru"],
        "batch_stats": batch_stats_dict["vae"]
    },
    rng=rng
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=train_set,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=35,
    rng=rng,
    stats=stats,
    path = "train_predict_seperate.png"
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=train_set,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=metrics_validation["worst_mae_idx"],
    rng=rng,
    stats=stats,
    path = "train_worst_seperate.png"
)

### Validation set

In [None]:
# Validation set
metrics_validation = evaluate_stress_prediction(
    model=joint_model,
    test_set=test_set,
    params_dict={
        "vae": params_dict["vae"],
        "gru": params_dict["gru"],
        "batch_stats": batch_stats_dict["vae"]
    },
    rng=rng
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=test_set,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=35,
    rng=rng,
    stats=stats,
    path = "validation_predict_seperate.png"
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=test_set,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=metrics_validation["worst_mae_idx"],
    rng=rng,
    stats=stats,
    path = "validation_worst_seperate.png"
)

## Joint Model Validation

In [None]:
params_dict = model_state["params"]
batch_stats_dict = model_state["batch_stats"]

In [None]:
# Validation set
metrics_validation = evaluate_stress_prediction(
    model=joint_model,
    test_set=test_set,
    params_dict={
        "vae": params_dict["vae"],
        "gru": params_dict["gru"],
        "batch_stats": batch_stats_dict["vae"]
    },
    rng=rng
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=test_set,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=35,
    rng=rng,
    stats=stats,
    path = "validation_predict_joint.png"
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=test_set,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=metrics_validation["worst_mae_idx"],
    rng=rng,
    stats=stats,
    path = "validation_worst_joint.png"
)

# Generalization

In [None]:
test_pkl_path1 = "dataset/test_random_inclusion.pkl"
test_pkl_path2 = "dataset/test_unseen_path.pkl"

with open(test_pkl_path1, "rb") as f:
        dataset_test1 = pickle.load(f)

with open(test_pkl_path2, "rb") as f:
        dataset_test2 = pickle.load(f)
dataset_OOD= dataset_test1
dataset_jax_OOD1 = normalize_dataset(dataset_OOD, stats)
dataset_OOD1 = flatten_dataset(dataset_jax_OOD1)

dataset_OOD= dataset_test2
dataset_jax_OOD2 = normalize_dataset(dataset_OOD, stats)
dataset_OOD2 = flatten_dataset(dataset_jax_OOD2)


In [None]:
params_dict = {
    "vae": vae_params,
    "gru": gru_params,
}
batch_stats_dict = {
    "vae": vae_batch_stats,
}

metrics_test = evaluate_stress_prediction(
    model=joint_model,
    test_set=dataset_OOD1,
    params_dict={
        "vae": params_dict["vae"],
        "gru": params_dict["gru"],
        "batch_stats": batch_stats_dict["vae"]
    },
    rng=rng
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=dataset_OOD1,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=35,
    rng=rng,
    stats=stats,
    path = "test1_predict_seperate.png"
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=dataset_OOD1,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=metrics_test["worst_mae_idx"],
    rng=rng,
    stats=stats,
    path = "test1_worst_seperate.png"
)

In [None]:
metrics_test = evaluate_stress_prediction(
    model=joint_model,
    test_set=dataset_OOD2,
    params_dict={
        "vae": params_dict["vae"],
        "gru": params_dict["gru"],
        "batch_stats": batch_stats_dict["vae"]
    },
    rng=rng
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=dataset_OOD2,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=35,
    rng=rng,
    stats=stats,
    path = "test2_predict_seperate.png"
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=dataset_OOD2,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=metrics_test["worst_mae_idx"],
    rng=rng,
    stats=stats,
    path = "test2_worst_seperate.png"
)

In [None]:
params_dict = model_state["params"]
batch_stats_dict = model_state["batch_stats"]

metrics_test = evaluate_stress_prediction(
    model=joint_model,
    test_set=dataset_OOD1,
    params_dict={
        "vae": params_dict["vae"],
        "gru": params_dict["gru"],
        "batch_stats": batch_stats_dict["vae"]
    },
    rng=rng
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=dataset_OOD1,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=35,
    rng=rng,
    stats=stats,
    path = "test1_predict_joint.png"
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=dataset_OOD1,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=metrics_test["worst_mae_idx"],
    rng=rng,
    stats=stats,
    path = "test1_worst_joint.png"
)

In [None]:
metrics_test = evaluate_stress_prediction(
    model=joint_model,
    test_set=dataset_OOD2,
    params_dict={
        "vae": params_dict["vae"],
        "gru": params_dict["gru"],
        "batch_stats": batch_stats_dict["vae"]
    },
    rng=rng
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=dataset_OOD2,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=35,
    rng=rng,
    stats=stats,
    path = "test2_predict_joint.png"
)

visualize_jointmodel_prediction(
    joint_model,
    test_set=dataset_OOD2,
    params_dict=params_dict,
    batch_stats_dict=batch_stats_dict,
    index=metrics_test["worst_mae_idx"],
    rng=rng,
    stats=stats,
    path = "test2_worst_joint.png"
)