In [None]:
import numpy as np
import tensorflow as tf
from pathlib import Path

from data_loader_saver import load_data
from model_mheight import (
    SUPPORTED_COMBOS,
    ColumnWiseSetEncoder,
    RowWiseEncoder,
    MHeightRegressor,
    TransformerBlock,
    symmetric_relative_loss_from_log,
    heteroscedastic_gaussian_nll_from_log,
    symmetric_relative_loss_from_log_params,
    unpack_log_prediction,
    from_log2_height,
)


In [None]:
ground_truth_m_h_datapath = "../data/project/test_samples_m_h.pkl"
input_n_k_m_datapath = "../data/project/test_samples_n_k_m_G.pkl"

models_dir = Path("trained_models")

inference_sample_num = 1000

In [None]:
n_k_m_P, m_heights = load_data(input_n_k_m_datapath, ground_truth_m_h_datapath)

custom_objects = {
    "TransformerBlock": TransformerBlock,
    "ColumnWiseSetEncoder": ColumnWiseSetEncoder,
    "RowWiseEncoder": RowWiseEncoder,
    "MHeightRegressor": MHeightRegressor,
    "heteroscedastic_gaussian_nll_from_log": heteroscedastic_gaussian_nll_from_log,
    "symmetric_relative_loss_from_log_params": symmetric_relative_loss_from_log_params,
    "symmetric_relative_loss_from_log": symmetric_relative_loss_from_log,
}

def load_model_cache(models_directory: Path, combos=SUPPORTED_COMBOS) -> dict:
    'Load the pre-trained model for each supported (n, k, m) combo.'
    model_cache = {}
    for combo in combos:
        model_path = models_directory / f"{combo[0]}_{combo[1]}_{combo[2]}_model"
        model_cache[combo] = tf.keras.models.load_model(
            model_path,
            compile=False,
            custom_objects=custom_objects,
        )
    return model_cache

def predict_m_heights(samples, model_cache) -> list:
    'Infer m-height for each sample in n_k_m_P.'
    predictions = []
    for sample in samples:
        combo = tuple(sample[:3])
        if combo not in model_cache:
            raise ValueError(f"No trained model for combo {combo}.")
        matrix = np.asarray(sample[3], dtype=np.float32)[None, ...]
        log_pred_params = model_cache[combo].predict(matrix, verbose=0)
        log_mean, _ = unpack_log_prediction(log_pred_params)
        pred = float(from_log2_height(log_mean)[0])
        predictions.append(pred)
    return predictions

def analyze_prediction_errors(predictions, targets, samples, combos=SUPPORTED_COMBOS) -> dict:
    'Compute absolute and symmetric ratio errors overall and per combo.'
    abs_errors = []
    ratio_errors = []
    combo_abs = {combo: [] for combo in combos}
    combo_ratio = {combo: [] for combo in combos}
    for pred, target, sample in zip(predictions, targets, samples):
        combo = tuple(sample[:3])
        abs_err = float(abs(pred - target))
        abs_errors.append(abs_err)
        if combo in combo_abs:
            combo_abs[combo].append(abs_err)
        if pred > 0 and target > 0:
            ratio_err = float(max(pred / target, target / pred))
        else:
            ratio_err = float('nan')
        ratio_errors.append(ratio_err)
        if combo in combo_ratio:
            combo_ratio[combo].append(ratio_err)
    overall_abs = float(np.mean(abs_errors)) if abs_errors else float('nan')
    overall_ratio = float(np.nanmean(ratio_errors)) if ratio_errors else float('nan')
    per_combo = {}
    for combo in combos:
        abs_vals = combo_abs[combo]
        ratio_vals = combo_ratio[combo]
        per_combo[combo] = {
            'count': len(abs_vals),
            'mean_abs': float(np.mean(abs_vals)) if abs_vals else float('nan'),
            'mean_ratio': float(np.nanmean(ratio_vals)) if ratio_vals else float('nan'),
        }
    return {
        'overall_mean_abs': overall_abs,
        'overall_mean_ratio': overall_ratio,
        'per_combo': per_combo,
    }

model_cache = load_model_cache(models_dir)


In [None]:
rng = np.random.default_rng(42)
num_samples = min(inference_sample_num, len(n_k_m_P))
sampled_indices = rng.choice(len(n_k_m_P), size=num_samples, replace=False)
sampled_features = [n_k_m_P[idx] for idx in sampled_indices]
sampled_targets = [m_heights[idx] for idx in sampled_indices]

predicted_m_heights = predict_m_heights(sampled_features, model_cache)

print(f"Predicted m-heights for {num_samples} samples:")
print(f"{predicted_m_heights[0:10]} ...")

In [None]:
error_metrics = analyze_prediction_errors(predicted_m_heights, sampled_targets, sampled_features)

print(f"Random sample size: {num_samples}")
print(f"Overall mean absolute error: {error_metrics['overall_mean_abs']:.6f}")
print(f"Overall symmetric ratio error: {error_metrics['overall_mean_ratio']:.6f}")
for combo, stats in error_metrics['per_combo'].items():
    count = stats['count']
    if count == 0:
        continue
    print(
        f"{combo}: mean |pred - gt| = {stats['mean_abs']:.6f}, "
        f"mean ratio = {stats['mean_ratio']:.6f} (samples={count})"
    )

### Legacy result

Random sample size: 1000

Overall mean absolute error: 9779.482215

Overall symmetric ratio error: 2.196944

(9, 4, 2): mean |pred - gt| = 45.958449, mean ratio = 1.402891 (samples=107)

(9, 4, 3): mean |pred - gt| = 88.956220, mean ratio = 1.405114 (samples=108)

(9, 4, 4): mean |pred - gt| = 678.325993, mean ratio = 1.894131 (samples=100)

(9, 4, 5): mean |pred - gt| = 20602.184997, mean ratio = 2.831138 (samples=118)

(9, 5, 2): mean |pred - gt| = 95.846664, mean ratio = 1.381756 (samples=111)

(9, 5, 3): mean |pred - gt| = 844.654270, mean ratio = 2.024824 (samples=114)

(9, 5, 4): mean |pred - gt| = 44776.815177, mean ratio = 3.338351 (samples=110)

(9, 6, 2): mean |pred - gt| = 532.782312, mean ratio = 2.630890 (samples=114)

(9, 6, 3): mean |pred - gt| = 18414.835091, mean ratio = 2.713982 (samples=118)