In [7]:
import numpy as np
import tensorflow as tf
from pathlib import Path

from data_loader_saver import load_data
from model_mheight import (
    SUPPORTED_COMBOS,
    ColumnWiseSetEncoder,
    MHeightRegressor,
    TransformerBlock,
    symmetric_relative_loss,
)

In [8]:
m_h_datapath = "../data/project/test_samples_m_h.pkl"
n_k_m_datapath = "../data/project/test_samples_n_k_m_G.pkl"

models_dir = Path("trained_models")

inference_sample_num = 100

In [9]:
n_k_m_P, m_heights = load_data(n_k_m_datapath, m_h_datapath)

custom_objects = {
    "TransformerBlock": TransformerBlock,
    "ColumnWiseSetEncoder": ColumnWiseSetEncoder,
    "MHeightRegressor": MHeightRegressor,
    "symmetric_relative_loss": symmetric_relative_loss,
}

def load_model_cache(models_directory: Path, combos=SUPPORTED_COMBOS) -> dict:
    """Load the pre-trained model for each supported (n, k, m) combo."""
    model_cache = {}
    for combo in combos:
        model_path = models_directory / f"{combo[0]}_{combo[1]}_{combo[2]}_model"
        model_cache[combo] = tf.keras.models.load_model(
            model_path,
            compile=False,
            custom_objects=custom_objects,
        )
    return model_cache

def predict_m_heights(samples, model_cache) -> list:
    """Infer m-height for each sample in n_k_m_P."""
    predictions = []
    for sample in samples:
        combo = tuple(sample[:3])
        if combo not in model_cache:
            raise ValueError(f"No trained model for combo {combo}.")
        matrix = np.asarray(sample[3], dtype=np.float32)[None, ...]
        pred = float(model_cache[combo].predict(matrix, verbose=0)[0])
        predictions.append(pred)
    return predictions

def analyze_prediction_errors(predictions, targets, samples, combos=SUPPORTED_COMBOS) -> dict:
    """Compute absolute and symmetric ratio errors overall and per combo."""
    abs_errors = []
    ratio_errors = []
    combo_abs = {combo: [] for combo in combos}
    combo_ratio = {combo: [] for combo in combos}
    for pred, target, sample in zip(predictions, targets, samples):
        combo = tuple(sample[:3])
        abs_err = float(abs(pred - target))
        abs_errors.append(abs_err)
        if combo in combo_abs:
            combo_abs[combo].append(abs_err)
        if pred > 0 and target > 0:
            ratio_err = float(max(pred / target, target / pred))
        else:
            ratio_err = float('nan')
        ratio_errors.append(ratio_err)
        if combo in combo_ratio:
            combo_ratio[combo].append(ratio_err)
    overall_abs = float(np.mean(abs_errors)) if abs_errors else float('nan')
    overall_ratio = float(np.nanmean(ratio_errors)) if ratio_errors else float('nan')
    per_combo = {}
    for combo in combos:
        abs_vals = combo_abs[combo]
        ratio_vals = combo_ratio[combo]
        per_combo[combo] = {
            'count': len(abs_vals),
            'mean_abs': float(np.mean(abs_vals)) if abs_vals else float('nan'),
            'mean_ratio': float(np.nanmean(ratio_vals)) if ratio_vals else float('nan'),
        }
    return {
        'overall_mean_abs': overall_abs,
        'overall_mean_ratio': overall_ratio,
        'per_combo': per_combo,
    }

model_cache = load_model_cache(models_dir)



In [13]:
rng = np.random.default_rng(42)
num_samples = min(inference_sample_num, len(n_k_m_P))
sampled_indices = rng.choice(len(n_k_m_P), size=num_samples, replace=False)
sampled_features = [n_k_m_P[idx] for idx in sampled_indices]
sampled_targets = [m_heights[idx] for idx in sampled_indices]

predicted_m_heights = predict_m_heights(sampled_features, model_cache)

print(f"Predicted m-heights for {num_samples} samples:")
print(f"{predicted_m_heights[0:10]} ...")

Predicted m-heights for 100 samples:
[4908.98876953125, 8153.56982421875, 359.34722900390625, 110.38915252685547, 4913.275390625, 8046.72900390625, 62.6585807800293, 269.9351501464844, 56.98933029174805, 517.0590209960938] ...


In [14]:
error_metrics = analyze_prediction_errors(predicted_m_heights, sampled_targets, sampled_features)

print(f"Random sample size: {num_samples}")
print(f"Overall mean absolute error: {error_metrics['overall_mean_abs']:.6f}")
print(f"Overall symmetric ratio error: {error_metrics['overall_mean_ratio']:.6f}")
for combo, stats in error_metrics['per_combo'].items():
    count = stats['count']
    if count == 0:
        continue
    print(
        f"{combo}: mean |pred - gt| = {stats['mean_abs']:.6f}, "
        f"mean ratio = {stats['mean_ratio']:.6f} (samples={count})"
    )

Random sample size: 100
Overall mean absolute error: 25992.666347
Overall symmetric ratio error: 5.666596
(9, 4, 2): mean |pred - gt| = 106.577901, mean ratio = 4.178997 (samples=9)
(9, 4, 3): mean |pred - gt| = 193.558822, mean ratio = 4.408251 (samples=11)
(9, 4, 4): mean |pred - gt| = 571.337664, mean ratio = 3.021044 (samples=7)
(9, 4, 5): mean |pred - gt| = 18369.599148, mean ratio = 4.739743 (samples=16)
(9, 5, 2): mean |pred - gt| = 203.017941, mean ratio = 2.553555 (samples=11)
(9, 5, 3): mean |pred - gt| = 684.752366, mean ratio = 2.753337 (samples=11)
(9, 5, 4): mean |pred - gt| = 149482.311848, mean ratio = 19.617391 (samples=15)
(9, 6, 2): mean |pred - gt| = 230.287249, mean ratio = 1.391927 (samples=12)
(9, 6, 3): mean |pred - gt| = 5437.717310, mean ratio = 1.779362 (samples=8)
