# TerminatorModel End-to-End Demo

This notebook demonstrates KMR's TerminatorModel for advanced hierarchical feature processing with context. It includes:

- Data generation with input and context features
- Model creation, training, and evaluation
- Plotly visualizations
- Model serialization and loading validation

## 1. Setup and Imports


In [1]:
import os
import tempfile

import numpy as np
import tensorflow as tf
import keras
from keras.optimizers import Adam
from keras.losses import MeanSquaredError
from keras.metrics import MeanAbsoluteError

import plotly.graph_objects as go
from plotly.subplots import make_subplots

# KMR imports
from kmr.models.TerminatorModel import TerminatorModel

print("✅ All imports successful!")
print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")


✅ All imports successful!
TensorFlow version: 2.18.0
Keras version: 3.8.0


## 2. Generate Synthetic Data

We'll create a synthetic dataset with:
- Input features (primary features)
- Context features (additional contextual information)
- A continuous target that depends on both inputs


In [2]:
print("📦 Generating synthetic data...")

# Reproducibility
np.random.seed(42)

num_samples = 2000
input_dim = 16  # Primary features
context_dim = 8  # Context features

# Generate input features
X_input = np.random.randn(num_samples, input_dim).astype(np.float32)
X_input[:, 1] = 0.3 * X_input[:, 0] + 0.7 * X_input[:, 1]  # introduce correlation

# Generate context features
X_context = np.random.randn(num_samples, context_dim).astype(np.float32)
X_context[:, 0] = 0.5 * X_input[:, 0] + 0.3 * np.random.randn(num_samples).astype(np.float32)

# Create target that depends on both input and context
input_weights = np.linspace(1.0, 0.3, input_dim)
context_weights = np.linspace(0.8, 0.2, context_dim)

y_raw = (
    2.0 * np.sin(X_input[:, 0])
    + 0.5 * X_input[:, 1] ** 2
    - 1.2 * X_context[:, 0]
    + (X_input @ input_weights) * 0.5
    + (X_context @ context_weights) * 0.4
    + 0.3 * (X_input[:, 2] * X_context[:, 1])  # interaction term
    + 0.4 * np.random.randn(num_samples)
)

# Scale to [0, 1] range for sigmoid output (TerminatorModel uses sigmoid)
y_min, y_max = y_raw.min(), y_raw.max()
y = ((y_raw - y_min) / (y_max - y_min + 1e-8)).astype(np.float32)

print(f"Target range: [{y.min():.4f}, {y.max():.4f}] (scaled for sigmoid output)")

# Train/Val/Test split
train_size = int(0.7 * num_samples)
val_size = int(0.15 * num_samples)

X_input_train = X_input[:train_size]
X_context_train = X_context[:train_size]
y_train = y[:train_size]

X_input_val = X_input[train_size:train_size + val_size]
X_context_val = X_context[train_size:train_size + val_size]
y_val = y[train_size:train_size + val_size]

X_input_test = X_input[train_size + val_size:]
X_context_test = X_context[train_size + val_size:]
y_test = y[train_size + val_size:]

print(f"✅ Data shapes:")
print(f"  Input features -> Train: {X_input_train.shape}, Val: {X_input_val.shape}, Test: {X_input_test.shape}")
print(f"  Context features -> Train: {X_context_train.shape}, Val: {X_context_val.shape}, Test: {X_context_test.shape}")
print(f"  Targets -> Train: {y_train.shape}, Val: {y_val.shape}, Test: {y_test.shape}")


📦 Generating synthetic data...
Target range: [0.0000, 1.0000] (scaled for sigmoid output)
✅ Data shapes:
  Input features -> Train: (1400, 16), Val: (300, 16), Test: (300, 16)
  Context features -> Train: (1400, 8), Val: (300, 8), Test: (300, 8)
  Targets -> Train: (1400,), Val: (300,), Test: (300,)


## 3. Build TerminatorModel

TerminatorModel stacks multiple SFNE blocks for hierarchical feature processing and uses context features to guide the processing.


In [3]:
# Create TerminatorModel
model = TerminatorModel(
    input_dim=input_dim,
    context_dim=context_dim,
    output_dim=1,
    hidden_dim=64,
    num_layers=2,
    num_blocks=3,
    slow_network_layers=3,
    slow_network_units=128,
    name='terminator_demo'
)

# Compile model
model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss=MeanSquaredError(),
    metrics=[MeanAbsoluteError()],
)

print("✅ TerminatorModel created and compiled!")
print(f"Input dimension: {model.input_dim}")
print(f"Context dimension: {model.context_dim}")
print(f"Output dimension: {model.output_dim}")
print(f"Number of SFNE blocks: {model.num_blocks}")


[32m2025-10-30 17:21:48.347[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized SlowNetwork with parameters: {'name': 'slow_network', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'input_dim': 8, 'num_layers': 3, 'units': 128}[0m
[32m2025-10-30 17:21:48.351[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized HyperZZWOperator with parameters: {'name': 'hyper_zzw_operator', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'input_dim': 16, 'context_dim': 8}[0m
[32m2025-10-30 17:21:48.354[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized SlowNetwork with parameters: {'name': 'slow_network_1', 't

✅ TerminatorModel created and compiled!
Input dimension: 16
Context dimension: 8
Output dimension: 1
Number of SFNE blocks: 3


## 4. Train and Evaluate

TerminatorModel expects inputs as a list: `[input_features, context_features]`.


In [4]:
print("🚀 Starting training...")

callbacks = [
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3)
]

history = model.fit(
    [X_input_train, X_context_train], y_train,
    validation_data=([X_input_val, X_context_val], y_val),
    epochs=30,
    batch_size=32,
    callbacks=callbacks,
    verbose=1
)

print("✅ Training completed!")

# Evaluate
val_loss, val_mae = model.evaluate([X_input_val, X_context_val], y_val, verbose=0)
print(f"Validation - Loss: {val_loss:.4f}, MAE: {val_mae:.4f}")

test_loss, test_mae = model.evaluate([X_input_test, X_context_test], y_test, verbose=0)
print(f"Test - Loss: {test_loss:.4f}, MAE: {test_mae:.4f}")

# Predictions for later visualizations
y_pred_val = model.predict([X_input_val, X_context_val], verbose=0).squeeze()
y_pred_test = model.predict([X_input_test, X_context_test], verbose=0).squeeze()


🚀 Starting training...
Epoch 1/30


[32m2025-10-30 17:21:48.479[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m205[0m - [34m[1mTerminatorModel input shape: (None, 16), context shape: (None, 8)[0m
[32m2025-10-30 17:21:48.491[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m211[0m - [34m[1mTerminatorModel input_layer output shape: (None, 16)[0m
[32m2025-10-30 17:21:48.494[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mbuild[0m:[36m106[0m - [34m[1mSlowNetwork built with input_dim=8, num_layers=3, units=128[0m
[32m2025-10-30 17:21:48.495[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m122[0m - [34m[1mSlowNetwork input shape: (None, 8)[0m
[32m2025-10-30 17:21:48.504[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 0 output shape: (None, 128)[0m
[32m2025-10-30 17:21:48.513[0m | [34m[1mDEBUG   [0m | [36mkm

[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.0456 - mean_absolute_error: 0.1735

[32m2025-10-30 17:21:51.767[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m205[0m - [34m[1mTerminatorModel input shape: (None, 16), context shape: (None, 8)[0m
[32m2025-10-30 17:21:51.770[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m211[0m - [34m[1mTerminatorModel input_layer output shape: (None, 16)[0m
[32m2025-10-30 17:21:51.770[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m122[0m - [34m[1mSlowNetwork input shape: (None, 8)[0m
[32m2025-10-30 17:21:51.773[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 0 output shape: (None, 128)[0m
[32m2025-10-30 17:21:51.775[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 1 output shape: (None, 128)[0m
[32m2025-10-30 17:21:51.777[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNe

[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 10ms/step - loss: 0.0453 - mean_absolute_error: 0.1729 - val_loss: 0.0208 - val_mean_absolute_error: 0.1165 - learning_rate: 0.0010
Epoch 2/30
[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0145 - mean_absolute_error: 0.0965 - val_loss: 0.0141 - val_mean_absolute_error: 0.0944 - learning_rate: 0.0010
Epoch 3/30
[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0110 - mean_absolute_error: 0.0836 - val_loss: 0.0112 - val_mean_absolute_error: 0.0834 - learning_rate: 0.0010
Epoch 4/30
[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0082 - mean_absolute_error: 0.0723 - val_loss: 0.0092 - val_mean_absolute_error: 0.0752 - learning_rate: 0.0010
Epoch 5/30
[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0064 - mean_absolute_error: 0.0638 - val_loss: 0.0079 - val_mean_absolute_error: 0.0697 - le

[32m2025-10-30 17:21:57.127[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m205[0m - [34m[1mTerminatorModel input shape: (32, 16), context shape: (32, 8)[0m
[32m2025-10-30 17:21:57.129[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m211[0m - [34m[1mTerminatorModel input_layer output shape: (32, 16)[0m
[32m2025-10-30 17:21:57.130[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m122[0m - [34m[1mSlowNetwork input shape: (32, 8)[0m
[32m2025-10-30 17:21:57.132[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 0 output shape: (32, 128)[0m
[32m2025-10-30 17:21:57.134[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 1 output shape: (32, 128)[0m
[32m2025-10-30 17:21:57.135[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[

Test - Loss: 0.0044, MAE: 0.0515


[32m2025-10-30 17:21:57.157[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m202[0m - [34m[1mSFNEBlock global_output shape: (32, 64, 16)[0m
[32m2025-10-30 17:21:57.160[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m206[0m - [34m[1mSFNEBlock local_output shape: (32, 64, 16)[0m
[32m2025-10-30 17:21:57.160[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m210[0m - [34m[1mSFNEBlock combined_output shape: (32, 64, 32)[0m
[32m2025-10-30 17:21:57.161[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m217[0m - [34m[1mSFNEBlock combined_output_flat shape: (32, 2048)[0m
[32m2025-10-30 17:21:57.162[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m223[0m - [34m[1mSFNEBlock bottleneck_output shape: (32, 16)[0m
[32m2025-10-30 17:21:57.164[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m227[0m - 

## 5. Visualizations

We will plot training curves, predictions vs. ground truth, and residuals.


In [5]:
print("📊 Creating visualizations...")

# Loss curves
hist_loss = history.history.get("loss", [])
hist_val_loss = history.history.get("val_loss", [])

# Unscale predictions for better visualization (convert back from [0,1] range)
# Recalculate scaling parameters from the original data
y_test_raw = y_test * (y_raw.max() - y_raw.min()) + y_raw.min()
y_pred_test_raw = y_pred_test * (y_raw.max() - y_raw.min()) + y_raw.min()

# Residuals (test, unscaled)
residuals = (y_test_raw - y_pred_test_raw)

fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=(
        "Training/Validation Loss",
        "Predictions vs Ground Truth (Test)",
        "Residuals Histogram (Test)",
        "MAE over Bins (Test)"
    ),
    specs=[[{"type": "scatter"}, {"type": "scatter"}],
           [{"type": "histogram"}, {"type": "bar"}]]
)

# Plot 1: Loss curves
fig.add_scatter(y=hist_loss, mode="lines", name="loss", row=1, col=1)
fig.add_scatter(y=hist_val_loss, mode="lines", name="val_loss", row=1, col=1)

# Plot 2: Predictions vs truth (unscaled)
idx = np.random.choice(len(y_test_raw), size=min(400, len(y_test_raw)), replace=False)
fig.add_scatter(
    x=y_test_raw[idx], y=y_pred_test_raw[idx], mode="markers", name="pred vs true",
    row=1, col=2
)
# Reference line (unscaled)
min_v = float(min(y_test_raw.min(), y_pred_test_raw.min()))
max_v = float(max(y_test_raw.max(), y_pred_test_raw.max()))
fig.add_scatter(x=[min_v, max_v], y=[min_v, max_v], mode="lines", name="ideal", row=1, col=2)

# Plot 3: Residuals histogram
fig.add_histogram(x=residuals, nbinsx=40, name="residuals", row=2, col=1)

# Plot 4: MAE over bins
bins = np.linspace(min_v, max_v, 20)
bin_indices = np.digitize(y_test_raw, bins)
mae_per_bin = []
centers = []
for b in range(1, len(bins)):
    mask = bin_indices == b
    if np.any(mask):
        mae_per_bin.append(np.mean(np.abs(y_test_raw[mask] - y_pred_test_raw[mask])))
        centers.append((bins[b] + bins[b-1]) / 2)
    else:
        mae_per_bin.append(0.0)
        centers.append((bins[b] + bins[b-1]) / 2)

fig.add_bar(x=centers, y=mae_per_bin, name="MAE per bin", row=2, col=2)

fig.update_layout(height=800, title_text="TerminatorModel Regression Results", showlegend=True)
fig.update_xaxes(title_text="Epoch", row=1, col=1)
fig.update_yaxes(title_text="Loss", row=1, col=1)
fig.update_xaxes(title_text="True", row=1, col=2)
fig.update_yaxes(title_text="Predicted", row=1, col=2)
fig.update_xaxes(title_text="Residual", row=2, col=1)
fig.update_yaxes(title_text="Frequency", row=2, col=1)
fig.update_xaxes(title_text="Target (bin center)", row=2, col=2)
fig.update_yaxes(title_text="MAE", row=2, col=2)

fig.show()
print("✅ Visualizations created successfully!")


📊 Creating visualizations...


✅ Visualizations created successfully!


## 6. Model Serialization and Loading

We will save the TerminatorModel in Keras format and verify that a loaded model produces consistent predictions.


In [6]:
print("💾 Testing Keras format serialization...")

with tempfile.TemporaryDirectory() as temp_dir:
    keras_path = os.path.join(temp_dir, "terminator_demo.keras")

    # Save
    model.save(keras_path)
    print(f"✅ Model saved to: {keras_path}")

    # Load
    loaded_model = keras.models.load_model(keras_path)
    print("✅ Model loaded successfully!")

    # Compare predictions on a small slice
    sl = slice(0, 64)
    preds_orig = model.predict([X_input_test[sl], X_context_test[sl]], verbose=0)
    preds_loaded = loaded_model.predict([X_input_test[sl], X_context_test[sl]], verbose=0)

    # Report similarity
    diff = np.mean(np.abs(preds_orig - preds_loaded))
    print(f"🔍 Mean absolute difference between original and loaded predictions: {float(diff):.6f}")
    print("✅ Loaded model prediction check completed!")


[32m2025-10-30 17:21:57.587[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized SlowNetwork with parameters: {'name': 'slow_network_4', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'input_dim': 8, 'num_layers': 3, 'units': 128}[0m
[32m2025-10-30 17:21:57.587[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized HyperZZWOperator with parameters: {'name': 'hyper_zzw_operator_4', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None, 'shared_object_id': 13822903840}, 'input_dim': 16, 'context_dim': 8}[0m
[32m2025-10-30 17:21:57.589[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized SlowNetwork with para

💾 Testing Keras format serialization...
✅ Model saved to: /var/folders/v8/4l9cyywn1x970gdc1v67r5480000gn/T/tmpi8mapngd/terminator_demo.keras


[32m2025-10-30 17:21:57.689[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m223[0m - [34m[1mSFNEBlock bottleneck_output shape: (None, 16)[0m
[32m2025-10-30 17:21:57.691[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m227[0m - [34m[1mSFNEBlock output_layer output shape: (None, 16)[0m
[32m2025-10-30 17:21:57.691[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m227[0m - [34m[1mTerminatorModel SFNEBlock 0 output shape: (None, 16)[0m
[32m2025-10-30 17:21:57.691[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m176[0m - [34m[1mSFNEBlock input shape: (None, 16)[0m
[32m2025-10-30 17:21:57.699[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mbuild[0m:[36m106[0m - [34m[1mSlowNetwork built with input_dim=16, num_layers=3, units=128[0m
[32m2025-10-30 17:21:57.699[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[3

✅ Model loaded successfully!
🔍 Mean absolute difference between original and loaded predictions: 0.000000
✅ Loaded model prediction check completed!


## 7. Summary

- Trained `TerminatorModel` on synthetic regression data with input and context features.
- Achieved regression metrics (reported above) on validation and test sets.
- Visualized loss curves, predictions vs. truth, and residuals.
- Saved and loaded model in Keras format; verified prediction consistency.

This notebook mirrors the style of the other KMR demos and showcases TerminatorModel's hierarchical feature processing with context-aware capabilities.


## 8. Alternative: Binary Classification Training (Better Fit)

The current `TerminatorModel` uses a sigmoid output, which is naturally suited for binary classification. Below we derive a binary target from the same features and train the model with a classification-appropriate loss and metrics.


In [7]:
import keras
from keras import layers
from kmr.utils import KMRPlotter

print("🔧 Preparing binary classification targets and model...")

# Build probabilities from input + context for a non-trivial decision boundary
w_inp = np.linspace(0.7, 0.2, input_dim)
w_ctx = np.linspace(0.5, 0.1, context_dim)
logits = (
    1.5 * np.sin(X_input[:, 0])
    + 0.8 * (X_input[:, 1] ** 2)
    - 1.1 * X_context[:, 0]
    + 0.5 * (X_input @ w_inp)
    + 0.4 * (X_context @ w_ctx)
    + 0.35 * (X_input[:, 2] * X_context[:, 1])
    + 0.2 * (X_input[:, 3] * X_context[:, 2])
)
probs = 1 / (1 + np.exp(-logits))

# Ensure balanced classes
y_cls = (probs > np.median(probs)).astype(np.float32)
print(f"Class distribution: {np.bincount(y_cls.astype(int))}")

y_train_cls = y_cls[:train_size]
y_val_cls = y_cls[train_size:train_size + val_size]
y_test_cls = y_cls[train_size + val_size:]

# Fresh model for classification
model_cls = TerminatorModel(
    input_dim=input_dim,
    context_dim=context_dim,
    output_dim=1,
    hidden_dim=64,
    num_layers=2,
    num_blocks=3,
    slow_network_layers=3,
    slow_network_units=128,
    name='terminator_demo_cls'
)

model_cls.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss='binary_crossentropy',
    metrics=[
        keras.metrics.BinaryAccuracy(name='accuracy'),
        keras.metrics.Precision(name='precision'),
        keras.metrics.Recall(name='recall'),
        keras.metrics.AUC(name='auc')
    ]
)

print("🚀 Training classification model...")
cls_callbacks = [
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
]

history_cls = model_cls.fit(
    [X_input_train, X_context_train], y_train_cls,
    validation_data=([X_input_val, X_context_val], y_val_cls),
    epochs=30,
    batch_size=64,
    callbacks=cls_callbacks,
    verbose=1
)

print("✅ Classification training completed!")

# Evaluate and predict
loss_t, acc_t, prec_t, rec_t, auc_t = model_cls.evaluate([X_input_test, X_context_test], y_test_cls, verbose=0)
print(f"Test -> loss {loss_t:.4f}, acc {acc_t:.4f}, prec {prec_t:.4f}, recall {rec_t:.4f}, auc {auc_t:.4f}")

y_pred_proba_cls = model_cls.predict([X_input_test, X_context_test], verbose=0).flatten()
y_pred_cls = (y_pred_proba_cls > 0.5).astype(np.int32)


🔧 Preparing binary classification targets and model...


[32m2025-10-30 17:21:58.164[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized SlowNetwork with parameters: {'name': 'slow_network_8', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'input_dim': 8, 'num_layers': 3, 'units': 128}[0m
[32m2025-10-30 17:21:58.165[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized HyperZZWOperator with parameters: {'name': 'hyper_zzw_operator_8', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'input_dim': 16, 'context_dim': 8}[0m
[32m2025-10-30 17:21:58.169[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized SlowNetwork with parameters: {'name': 'slow_network_9'

Class distribution: [1000 1000]
🚀 Training classification model...
Epoch 1/30


[32m2025-10-30 17:21:58.246[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m205[0m - [34m[1mTerminatorModel input shape: (None, 16), context shape: (None, 8)[0m
[32m2025-10-30 17:21:58.256[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m211[0m - [34m[1mTerminatorModel input_layer output shape: (None, 16)[0m
[32m2025-10-30 17:21:58.258[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mbuild[0m:[36m106[0m - [34m[1mSlowNetwork built with input_dim=8, num_layers=3, units=128[0m
[32m2025-10-30 17:21:58.259[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m122[0m - [34m[1mSlowNetwork input shape: (None, 8)[0m
[32m2025-10-30 17:21:58.269[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 0 output shape: (None, 128)[0m
[32m2025-10-30 17:21:58.277[0m | [34m[1mDEBUG   [0m | [36mkm

[1m12/22[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 5ms/step - accuracy: 0.4961 - auc: 0.5060 - loss: 0.7606 - precision: 0.4967 - recall: 0.7665 

[32m2025-10-30 17:22:01.717[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m205[0m - [34m[1mTerminatorModel input shape: (None, 16), context shape: (None, 8)[0m
[32m2025-10-30 17:22:01.719[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m211[0m - [34m[1mTerminatorModel input_layer output shape: (None, 16)[0m
[32m2025-10-30 17:22:01.719[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m122[0m - [34m[1mSlowNetwork input shape: (None, 8)[0m
[32m2025-10-30 17:22:01.721[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 0 output shape: (None, 128)[0m
[32m2025-10-30 17:22:01.723[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 1 output shape: (None, 128)[0m
[32m2025-10-30 17:22:01.725[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNe

[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 25ms/step - accuracy: 0.5046 - auc: 0.5093 - loss: 0.7555 - precision: 0.5044 - recall: 0.7735 - val_accuracy: 0.5633 - val_auc: 0.5814 - val_loss: 0.6914 - val_precision: 0.5513 - val_recall: 0.5850 - learning_rate: 0.0010
Epoch 2/30
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.5954 - auc: 0.6647 - loss: 0.6519 - precision: 0.6383 - recall: 0.5157 - val_accuracy: 0.6533 - val_auc: 0.7098 - val_loss: 0.6220 - val_precision: 0.6387 - val_recall: 0.6735 - learning_rate: 0.0010
Epoch 3/30
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.7039 - auc: 0.7657 - loss: 0.5859 - precision: 0.6907 - recall: 0.7415 - val_accuracy: 0.7000 - val_auc: 0.7648 - val_loss: 0.5768 - val_precision: 0.6667 - val_recall: 0.7755 - learning_rate: 0.0010
Epoch 4/30
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.7581 - auc: 0.8293

[32m2025-10-30 17:22:03.932[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m205[0m - [34m[1mTerminatorModel input shape: (32, 16), context shape: (32, 8)[0m
[32m2025-10-30 17:22:03.935[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m211[0m - [34m[1mTerminatorModel input_layer output shape: (32, 16)[0m
[32m2025-10-30 17:22:03.935[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m122[0m - [34m[1mSlowNetwork input shape: (32, 8)[0m
[32m2025-10-30 17:22:03.937[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 0 output shape: (32, 128)[0m
[32m2025-10-30 17:22:03.939[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 1 output shape: (32, 128)[0m
[32m2025-10-30 17:22:03.940[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[

Test -> loss 0.4314, acc 0.8200, prec 0.8201, recall 0.7972, auc 0.8955


[32m2025-10-30 17:22:03.950[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m122[0m - [34m[1mSlowNetwork input shape: (32, 64)[0m
[32m2025-10-30 17:22:03.952[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 0 output shape: (32, 128)[0m
[32m2025-10-30 17:22:03.953[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 1 output shape: (32, 128)[0m
[32m2025-10-30 17:22:03.955[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 2 output shape: (32, 128)[0m
[32m2025-10-30 17:22:03.956[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m129[0m - [34m[1mSlowNetwork output shape: (32, 16)[0m
[32m2025-10-30 17:22:03.956[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m188[0m - [34m[1mSFNEBlo

In [8]:
# Classification visualizations
print("📊 Creating classification visualizations...")

fig_cls = KMRPlotter.create_comprehensive_plot(
    'classification',
    y_true=y_test_cls,
    y_pred=y_pred_cls,
    y_scores=y_pred_proba_cls,
    title="TerminatorModel Classification Results"
)
fig_cls.show()
print("✅ Classification visualizations created!")


📊 Creating classification visualizations...


✅ Classification visualizations created!


## 9. Scenario 2: With Preprocessing Model

We now apply a preprocessing model that consumes `input` and `context`, projects them to a common space, combines, and outputs a processed representation that `TerminatorModel` then uses. We'll feed inputs as a dictionary to leverage universal input handling.


In [9]:
# Build a preprocessing model that works with TerminatorModel's standardized input format
# TerminatorModel converts list inputs to {'input_0': input, 'input_1': context}
from keras import Input, Model

print("🧩 Building preprocessing model...")

# Input layers - using the standardized format that TerminatorModel provides
inp_0 = Input(shape=(input_dim,), name='input_0')  # This will be the input features
inp_1 = Input(shape=(context_dim,), name='input_1')  # This will be the context features

# Process each input
h1 = layers.Dense(64, activation='relu', name='preproc_dense1')(inp_0)
h1 = layers.Dropout(0.1)(h1)
h2 = layers.Dense(32, activation='relu', name='preproc_dense2')(inp_1)
h2 = layers.Dropout(0.1)(h2)

# Combine and process
combined = layers.Concatenate(name='preproc_concat')([h1, h2])
combined = layers.Dense(64, activation='relu', name='preproc_dense3')(combined)
combined = layers.Dropout(0.1)(combined)
processed = layers.Dense(input_dim, activation='linear', name='processed')(combined)

# Create preprocessing model with the standardized input format
preproc_model = Model(
    inputs={'input_0': inp_0, 'input_1': inp_1}, 
    outputs=processed, 
    name='terminator_preproc'
)

# Test the preprocessing model with the standardized format
print("Testing preprocessing model...")
test_inp = np.random.randn(1, input_dim).astype(np.float32)
test_ctx = np.random.randn(1, context_dim).astype(np.float32)
test_out = preproc_model({'input_0': test_inp, 'input_1': test_ctx})
print(f"Preprocessing model output shape: {test_out.shape}, expected: (1, {input_dim})")
assert test_out.shape[1] == input_dim, f"Preprocessing output shape mismatch!"

# Model with preprocessing
model_cls_preproc = TerminatorModel(
    input_dim=input_dim,
    context_dim=context_dim,
    output_dim=1,
    hidden_dim=64,
    num_layers=2,
    num_blocks=3,
    slow_network_layers=3,
    slow_network_units=128,
    preprocessing_model=preproc_model,
    name='terminator_demo_cls_preproc'
)

model_cls_preproc.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss='binary_crossentropy',
    metrics=[
        keras.metrics.BinaryAccuracy(name='accuracy'),
        keras.metrics.Precision(name='precision'),
        keras.metrics.Recall(name='recall'),
        keras.metrics.AUC(name='auc')
    ]
)

print("🚀 Training classification model with preprocessing...")

# Test the model with preprocessing first
print("Testing model with preprocessing...")
# Use list format since that's what TerminatorModel expects
test_output = model_cls_preproc([X_input_train[:1], X_context_train[:1]])
print(f"Model with preprocessing test output shape: {test_output.shape}")

# Train with list format (which works with the preprocessing model)
print("Training with list inputs...")
history_cls_preproc = model_cls_preproc.fit(
    [X_input_train, X_context_train], y_train_cls,
    validation_data=([X_input_val, X_context_val], y_val_cls),
    epochs=30,
    batch_size=64,
    callbacks=cls_callbacks,
    verbose=1
)

print("✅ Training with preprocessing completed!")

# Evaluate and visualize (use list format consistently)
loss_p, acc_p, prec_p, rec_p, auc_p = model_cls_preproc.evaluate(
    [X_input_test, X_context_test], y_test_cls, verbose=0
)
y_pred_proba_cls_p = model_cls_preproc.predict(
    [X_input_test, X_context_test], verbose=0
).flatten()

print(f"Test (preproc) -> loss {loss_p:.4f}, acc {acc_p:.4f}, prec {prec_p:.4f}, recall {rec_p:.4f}, auc {auc_p:.4f}")

y_pred_cls_p = (y_pred_proba_cls_p > 0.5).astype(np.int32)

fig_cls_p = KMRPlotter.create_comprehensive_plot(
    'classification',
    y_true=y_test_cls,
    y_pred=y_pred_cls_p,
    y_scores=y_pred_proba_cls_p,
    title="TerminatorModel Classification Results (With Preprocessing)"
)
fig_cls_p.show()
print("✅ Classification visualizations with preprocessing created!")


[32m2025-10-30 17:22:04.381[0m | [34m[1mDEBUG   [0m | [36mkmr.models._base[0m:[36m_setup_preprocessing_model[0m:[36m294[0m - [34m[1mSetting up preprocessing model integration[0m
[32m2025-10-30 17:22:04.382[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized SlowNetwork with parameters: {'name': 'slow_network_12', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'input_dim': 8, 'num_layers': 3, 'units': 128}[0m
[32m2025-10-30 17:22:04.382[0m | [34m[1mDEBUG   [0m | [36mkmr.layers._base_layer[0m:[36m_log_initialization[0m:[36m73[0m - [34m[1mInitialized HyperZZWOperator with parameters: {'name': 'hyper_zzw_operator_12', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'input_dim': 16, 'context_dim': 8}[0m
[32m2025-10-30

🧩 Building preprocessing model...
Testing preprocessing model...
Preprocessing model output shape: (1, 16), expected: (1, 16)
🚀 Training classification model with preprocessing...
Testing model with preprocessing...


[32m2025-10-30 17:22:04.450[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m129[0m - [34m[1mSlowNetwork output shape: (1, 16)[0m
[32m2025-10-30 17:22:04.450[0m | [34m[1mDEBUG   [0m | [36mkmr.models.SFNEBlock[0m:[36mcall[0m:[36m188[0m - [34m[1mSFNEBlock hyper_kernels shape: (1, 16)[0m
[32m2025-10-30 17:22:04.451[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.HyperZZWOperator[0m:[36mbuild[0m:[36m112[0m - [34m[1mHyperZZWOperator built with input_dim=64, context_dim=16[0m
[32m2025-10-30 17:22:04.451[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.HyperZZWOperator[0m:[36mcall[0m:[36m132[0m - [34m[1mHyperZZWOperator input_tensor shape: (1, 64)[0m
[32m2025-10-30 17:22:04.451[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.HyperZZWOperator[0m:[36mcall[0m:[36m133[0m - [34m[1mHyperZZWOperator context_tensor shape: (1, 16)[0m
[32m2025-10-30 17:22:04.454[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.HyperZZWOperator[0m:[3

Model with preprocessing test output shape: (1, 1)
Training with list inputs...
Epoch 1/30



The structure of `inputs` doesn't match the expected structure.
Expected: {'input_0': 'input_0', 'input_1': 'input_1'}
Received: inputs=['Tensor(shape=(None, 16))', 'Tensor(shape=(None, 8))']

[32m2025-10-30 17:22:04.808[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m205[0m - [34m[1mTerminatorModel input shape: (None, 16), context shape: (None, 8)[0m
[32m2025-10-30 17:22:04.810[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m211[0m - [34m[1mTerminatorModel input_layer output shape: (None, 16)[0m
[32m2025-10-30 17:22:04.810[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m122[0m - [34m[1mSlowNetwork input shape: (None, 8)[0m
[32m2025-10-30 17:22:04.812[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 0 output shape: (None, 128)[0m
[32m2025-10-30 17:22:04.814[0m | [34m[1mDEBUG   [0m | [36mkm

[1m12/22[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 5ms/step - accuracy: 0.5801 - auc: 0.6367 - loss: 0.6649 - precision: 0.5552 - recall: 0.7326 

[32m2025-10-30 17:22:08.115[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m205[0m - [34m[1mTerminatorModel input shape: (None, 16), context shape: (None, 8)[0m
[32m2025-10-30 17:22:08.117[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m211[0m - [34m[1mTerminatorModel input_layer output shape: (None, 16)[0m
[32m2025-10-30 17:22:08.117[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m122[0m - [34m[1mSlowNetwork input shape: (None, 8)[0m
[32m2025-10-30 17:22:08.119[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 0 output shape: (None, 128)[0m
[32m2025-10-30 17:22:08.121[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 1 output shape: (None, 128)[0m
[32m2025-10-30 17:22:08.123[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNe

[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 29ms/step - accuracy: 0.5749 - auc: 0.6234 - loss: 0.6729 - precision: 0.5590 - recall: 0.7099 - val_accuracy: 0.6333 - val_auc: 0.7051 - val_loss: 0.6361 - val_precision: 0.6581 - val_recall: 0.5238 - learning_rate: 0.0010
Epoch 2/30
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.6591 - auc: 0.7623 - loss: 0.5904 - precision: 0.7132 - recall: 0.5409 - val_accuracy: 0.7767 - val_auc: 0.8590 - val_loss: 0.5141 - val_precision: 0.7778 - val_recall: 0.7619 - learning_rate: 0.0010
Epoch 3/30
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.7978 - auc: 0.8741 - loss: 0.4541 - precision: 0.8004 - recall: 0.8060 - val_accuracy: 0.8167 - val_auc: 0.8818 - val_loss: 0.4506 - val_precision: 0.8108 - val_recall: 0.8163 - learning_rate: 0.0010
Epoch 4/30
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.8187 - auc: 0.9052


The structure of `inputs` doesn't match the expected structure.
Expected: {'input_0': 'input_0', 'input_1': 'input_1'}
Received: inputs=['Tensor(shape=(32, 16))', 'Tensor(shape=(32, 8))']

[32m2025-10-30 17:22:10.271[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m205[0m - [34m[1mTerminatorModel input shape: (32, 16), context shape: (32, 8)[0m
[32m2025-10-30 17:22:10.273[0m | [34m[1mDEBUG   [0m | [36mkmr.models.TerminatorModel[0m:[36mcall[0m:[36m211[0m - [34m[1mTerminatorModel input_layer output shape: (32, 16)[0m
[32m2025-10-30 17:22:10.273[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m122[0m - [34m[1mSlowNetwork input shape: (32, 8)[0m
[32m2025-10-30 17:22:10.275[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowNetwork[0m:[36mcall[0m:[36m126[0m - [34m[1mSlowNetwork layer 0 output shape: (32, 128)[0m
[32m2025-10-30 17:22:10.277[0m | [34m[1mDEBUG   [0m | [36mkmr.layers.SlowN

Test (preproc) -> loss 0.4059, acc 0.8233, prec 0.8214, recall 0.8042, auc 0.9012


✅ Classification visualizations with preprocessing created!
