# Model Robustness Comparison

This notebook evaluates user-specified models on the **LEO Satellite Channel** across a range of SNR values.
It assumes the weights are available in the `weights/` directory (and AWGN is converted to .h5).

In [None]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from models.model import SemViT
from models.metrics import psnr
from utils.datasets import dataset_generator

# Ensure TF is eager/ready
tf.config.run_functions_eagerly(False)

In [None]:
# --- CONFIGURATION ---
# List of SNRs to test (X-Axis)
SNR_RANGE = [0, 5, 10, 15, 20]

# Define your 4 Models here
# Format: "Label Name": "Path to .h5 weights file"
MODELS_TO_TEST = {
    "Mixed (LEO+GEO)": "weights/experiment_mixed_sat_138.weights.h5",
    "LEO Only":        "weights/experiment_leo_sat_128.weights.h5",
    "GEO Only":        "weights/experiment_geo_sat_34.weights.h5",
    "AWGN (Baseline)": "weights/CCVVCC_512_10dB_599.weights.h5"
}

# The "Rigorous" Channel to test everyone on
# We use 'LEO' because it has fading + noise (hardest test)
TEST_CHANNEL = 'LEO' 
DATA_SIZE = 512  # Updated to 512 as per filename and typical config

# Model Architecture Config (Common to all)
BLOCK_TYPES = ['C', 'C', 'V', 'V', 'C', 'C']
FILTERS = [256, 256, 256, 256, 256, 256]
NUM_BLOCKS = [1, 1, 3, 3, 1, 1]
HAS_GDN = True

In [None]:
def evaluate_model(model_weights_path, snr_db):
    """
    Instantiates a fresh model with specific SNR, loads weights, 
    and calculates average PSNR on the test set.
    """
    # 1. Instantiate Model with specific Test SNR and Test Channel
    model = SemViT(
        block_types=BLOCK_TYPES,
        filters=FILTERS,
        num_blocks=NUM_BLOCKS,
        has_gdn=HAS_GDN,
        num_symbols=DATA_SIZE,
        snrdB=snr_db,
        channel=TEST_CHANNEL 
    )
    
    # 2. Compile (Required to load weights properly)
    model.compile(optimizer='adam', loss='mse', metrics=[psnr])
    
    # 3. Build model with dummy input to initialize shapes
    model(tf.zeros((1, 32, 32, 3))) # Assumes CIFAR input shape
    
    # 4. Load Weights
    if not os.path.exists(model_weights_path):
        print(f"Warning: Weight file not found: {model_weights_path}")
        return 0.0
        
    try:
        model.load_weights(model_weights_path)
    except Exception as e:
        print(f"Error loading {model_weights_path}: {e}")
        return 0.0

    # 5. Evaluate
    # Evaluate on the test dataset
    # Note: test_ds is global or passed in
    results = model.evaluate(test_ds, verbose=0)
    # results[1] is the PSNR metric because we defined metrics=[psnr]
    
    # Clear session to free memory between iterations
    tf.keras.backend.clear_session()
    
    return results[1]

In [None]:
# --- PREPARE DATASET ---
print("Loading Dataset...")
# Adapting dataset loading to work in standard env
test_ds = dataset_generator('dataset/CIFAR10/test/') # Ensure path is correct or will fallback to download
# dataset_generator returns already batched/mapped dataset usually, but let's check.
# Checking utils/datasets.py: It returns a dataset. If mode is None, it returns normalized/augmented?? 
# Wait, dataset_generator in this repo seems to return a raw dataset if not passed specific mode logic?
# Let's re-read dataset_generator in utils/datasets.py
# It calls `image_dataset_from_directory` with `color_mode='rgb'`. This returns [0, 255] int or float?
# Usually int if not specified or 'int'.
# The user snippet did: test_ds.map(lambda x, y: (x / 255.0, x / 255.0))
# The train_dist.py does: test_ds.map(lambda x, y: (normalize(x), y))

# We will use a safe normalization approach assuming 0-255 input from generator fallback
def normalize_img(x, y):
    return (tf.cast(x, tf.float32) / 255.0, tf.cast(x, tf.float32) / 255.0)

# Re-load with our specific pipeline if needed, or trust generator.
# To be safe, let's use the CIFAR10 loader directly if directory is missing, or wrap generator
# Use the dataset_generator which handles fallback
test_ds_raw = dataset_generator('./dataset/CIFAR10/test/', shuffle=False)
test_ds = test_ds_raw.map(normalize_img).take(200)

# Prefetch
test_ds = test_ds.prefetch(tf.data.AUTOTUNE)

In [None]:
# --- MAIN LOOP ---
results_data = {}

for model_name, weights_path in MODELS_TO_TEST.items():
    print(f"--- Evaluating {model_name} ---")
    psnr_values = []
    
    for snr in SNR_RANGE:
        print(f"  Testing at SNR = {snr} dB...")
        score = evaluate_model(weights_path, snr)
        psnr_values.append(score)
        print(f"    -> PSNR: {score:.2f} dB")
    
    results_data[model_name] = psnr_values

In [None]:
# --- PLOTTING ---
plt.figure(figsize=(10, 6))

# Define styles for each line
styles = {
    "Mixed (LEO+GEO)": {"color": "green", "marker": "o", "linewidth": 2.5},
    "LEO Only":        {"color": "blue", "marker": "s", "linewidth": 1.5, "linestyle": "--"},
    "GEO Only":        {"color": "orange", "marker": "^", "linewidth": 1.5, "linestyle": "--"},
    "AWGN (Baseline)": {"color": "red", "marker": "x", "linewidth": 1.5, "linestyle": ":"}
}

for name, psnrs in results_data.items():
    style = styles.get(name, {})
    plt.plot(SNR_RANGE, psnrs, label=name, **style)

plt.title(f"Model Robustness Comparison (Tested on {TEST_CHANNEL} Channel)", fontsize=14)
plt.xlabel("Signal-to-Noise Ratio (dB)", fontsize=12)
plt.ylabel("PSNR (dB)", fontsize=12)
plt.grid(True, which='both', linestyle='--', alpha=0.7)
plt.legend(fontsize=11)
plt.xticks(SNR_RANGE)

# Save the plot
plt.savefig("psnr_vs_snr_comparison.png", dpi=300)
print("\nSuccess! Plot saved to 'psnr_vs_snr_comparison.png'")
plt.show()