# Generalization Matrix (Cross-Channel Validation)

This notebook generates a **4x3 Matrix** showing the robustness of different models across different channel environments.

**The Narrative to Prove:**
> "While the Specialist models win in their specific domains, the Mixed Model provides the best trade-off, losing only ~0.5dB while working universally across orbit types. The AWGN model fails catastrophically in real fading channels."

In [None]:
import os
import tensorflow as tf
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from models.model import SemViT
from models.metrics import psnr
from utils.datasets import dataset_generator

# Ensure TF is eager/ready
tf.config.run_functions_eagerly(False)

In [None]:
# --- CONFIGURATION ---
FIXED_SNR_DB = 10
TEST_CHANNELS = ['LEO', 'GEO', 'AWGN']
DATA_SIZE = 512

MODELS_TO_TEST = {
    "Mixed (LEO+GEO)": "weights/experiment_mixed_sat_138.weights.h5",
    "LEO Only":        "weights/experiment_leo_sat_128.weights.h5",
    "GEO Only":        "weights/experiment_geo_sat_34.weights.h5",
    "AWGN (Baseline)": "weights/CCVVCC_512_10dB_599.weights.h5"
}

# Model Architecture Config
BLOCK_TYPES = ['C', 'C', 'V', 'V', 'C', 'C']
FILTERS = [256, 256, 256, 256, 256, 256]
NUM_BLOCKS = [1, 1, 3, 3, 1, 1]
HAS_GDN = True

In [None]:
# --- DATASET ---
print("Loading Dataset...")
def normalize_img(x, y):
    return (tf.cast(x, tf.float32) / 255.0, tf.cast(x, tf.float32) / 255.0)

# Using 200 images for speed on CPU
test_ds_raw = dataset_generator('./dataset/CIFAR10/test/', shuffle=False)
test_ds = test_ds_raw.map(normalize_img).take(200).prefetch(tf.data.AUTOTUNE)

In [None]:
def evaluate_single_run(weights_path, channel_type, snr_db):
    # Instantiate Model with specific Channel
    model = SemViT(
        block_types=BLOCK_TYPES,
        filters=FILTERS,
        num_blocks=NUM_BLOCKS,
        has_gdn=HAS_GDN,
        num_symbols=DATA_SIZE,
        snrdB=snr_db,
        channel=channel_type 
    )
    model.compile(optimizer='adam', loss='mse', metrics=[psnr])
    model(tf.zeros((1, 32, 32, 3))) 
    
    if not os.path.exists(weights_path):
        return None

    try:
        # Fix for Keras 3 compatibility
        model.load_weights(weights_path, skip_mismatch=True)
    except Exception as e:
        print(f"Error loading {weights_path}: {e}")
        return 0.0

    results = model.evaluate(test_ds, verbose=0)
    tf.keras.backend.clear_session()
    return results[1] # PSNR

In [None]:
# --- MAIN LOOP ---
results_matrix = pd.DataFrame(index=MODELS_TO_TEST.keys(), columns=TEST_CHANNELS)

print(f"Starting Generalization Test at SNR={FIXED_SNR_DB}dB...")

for model_name, weights_path in MODELS_TO_TEST.items():
    print(f"Evaluating Model: {model_name}")
    for channel in TEST_CHANNELS:
        print(f"  -> on Channel: {channel}")
        score = evaluate_single_run(weights_path, channel, FIXED_SNR_DB)
        if score is None:
            score = 0.0
        results_matrix.loc[model_name, channel] = score
        print(f"     PSNR: {score:.2f} dB")

print("\nDone!")

In [None]:
# --- VISUALIZATION ---
print("\n--- Generalization Matrix (Average PSNR @ 10dB) ---")
print(results_matrix)

# Convert to float for plotting
results_matrix = results_matrix.astype(float)

plt.figure(figsize=(10, 6))
sns.heatmap(results_matrix, annot=True, fmt=".1f", cmap="RdYlGn", linewidths=.5, cbar_kws={'label': 'PSNR (dB)'})
plt.title(f"Model Generalization Matrix (SNR={FIXED_SNR_DB}dB)")
plt.ylabel("Trained Model")
plt.xlabel("Test Channel Environment")
plt.yticks(rotation=0)

plt.savefig("generalization_matrix.png", dpi=300, bbox_inches='tight')
plt.show()