# Comprehensive Comparison: SemViT Models vs BPG+LDPC Baseline

This notebook provides a direct comparison between the 4 Neural SemViT models and the Traditional Digital Baseline (BPG Source Coding + Ideal/Capacity-Achieving Channel Coding).

**Test Conditions (Matched to `bpg-ldpc.ipynb`):**
- **Channel**: AWGN
- **Bandwidth Ratio**: 1/6 (512 Symbols / 3072 input dimensions)
- **Metrics**: PSNR, SSIM
- **SNR Range**: [0, 2, 5, 7, 10, 12, 15] dB


In [None]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from models.model import SemViT
from models.metrics import psnr
from utils.datasets import dataset_generator

# Ensure TF is eager/ready
tf.config.run_functions_eagerly(False)

In [None]:
# --- CONFIGURATION ---
SNR_RANGE = [0, 2, 5, 7, 10, 12, 15]
TEST_CHANNEL = 'AWGN'  # Strict comparison baseline
DATA_SIZE = 512

# Models to Evaluate
MODELS_TO_TEST = {
    "Mixed (LEO+GEO)": "weights/experiment_mixed_sat_138.weights.h5",
    "LEO Only":        "weights/experiment_leo_sat_128.weights.h5",
    "GEO Only":        "weights/experiment_geo_sat_34.weights.h5",
    "AWGN (Baseline)": "weights/CCVVCC_512_10dB_599.weights.h5"
}

# BPG + Capacity Baseline Data (Extracted from bpg-ldpc.ipynb logs)
# This represents BPG compression scaled to the Channel Capacity limit at each SNR with 1/6 BW.
BPG_BASELINE = {
    "PSNR": [24.13, 25.65, 28.24, 29.72, 31.59, 32.83, 34.54],
    "SSIM": [0.82,  0.87,  0.92,  0.94,  0.96,  0.97,  0.98]
}

# Model Architecture Config
BLOCK_TYPES = ['C', 'C', 'V', 'V', 'C', 'C']
FILTERS = [256, 256, 256, 256, 256, 256]
NUM_BLOCKS = [1, 1, 3, 3, 1, 1]
HAS_GDN = True

In [None]:
# --- DATASET ---
print("Loading Dataset...")
def normalize_img(x, y):
    return (tf.cast(x, tf.float32) / 255.0, tf.cast(x, tf.float32) / 255.0)

# Use a larger subset for this rigorous comparison if metrics are crucial, 
# but keep at 200 for CPU speed unless on GPU.
test_ds_raw = dataset_generator('./dataset/CIFAR10/test/', shuffle=False)
test_ds = test_ds_raw.map(normalize_img).take(200).prefetch(tf.data.AUTOTUNE)

In [None]:
def evaluate_metrics(weights_path, snr_db):
    # Instantiate
    model = SemViT(
        block_types=BLOCK_TYPES,
        filters=FILTERS,
        num_blocks=NUM_BLOCKS,
        has_gdn=HAS_GDN,
        num_symbols=DATA_SIZE,
        snrdB=snr_db,
        channel=TEST_CHANNEL 
    )
    # Dummy build
    model.compile(optimizer='adam', loss='mse')
    model(tf.zeros((1, 32, 32, 3))) 
    
    # Load Weights
    if not os.path.exists(weights_path):
        print(f"  [ERROR] Weights file NOT FOUND at: {weights_path}")
        return 0.0, 0.0
    try:
        # Fix for Keras 3: 'by_name' is not supported in load_weights anymore
        model.load_weights(weights_path, skip_mismatch=True)
    except Exception as e:
        print(f"  [ERROR] Exception loading weights: {e}")
        return 0.0, 0.0

    # Eval Loop
    psnr_sum = 0.0
    ssim_sum = 0.0
    count = 0
    
    for x, y in test_ds:
        # Predict
        recon = model(x, training=False)
        
        # Calculate Metrics (Batch Average)
        # Images are [0,1] float
        batch_psnr = tf.reduce_mean(tf.image.psnr(x, recon, max_val=1.0))
        batch_ssim = tf.reduce_mean(tf.image.ssim(x, recon, max_val=1.0))
        
        psnr_sum += float(batch_psnr)
        ssim_sum += float(batch_ssim)
        count += 1
    
    tf.keras.backend.clear_session()
    if count == 0:
        print("  [ERROR] Test Dataset yielded 0 batches. Check dataset path.")
        return 0.0, 0.0
    return (psnr_sum / count), (ssim_sum / count)

In [None]:
# --- MAIN EVALUATION LOOP ---
results = {name: {"PSNR": [], "SSIM": []} for name in MODELS_TO_TEST}

for model_name, path in MODELS_TO_TEST.items():
    print(f"Evaluating {model_name}...")
    for snr in SNR_RANGE:
        p, s = evaluate_metrics(path, snr)
        results[model_name]["PSNR"].append(p)
        results[model_name]["SSIM"].append(s)
        print(f"  SNR={snr}: PSNR={p:.2f}, SSIM={s:.3f}")

In [None]:
# --- PLOTTING ---
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# 1. PSNR Plot
ax = axes[0]
# Plot Baseline
ax.plot(SNR_RANGE, BPG_BASELINE["PSNR"], label="BPG + Capacity (Baseline)", 
        color='black', linestyle='--', linewidth=2.5, marker='*')

# Plot Models
for name, metrics in results.items():
    ax.plot(SNR_RANGE, metrics["PSNR"], label=name, marker='o')

ax.set_title("PSNR Comparison (AWGN Channel)")
ax.set_xlabel("SNR (dB)")
ax.set_ylabel("PSNR (dB)")
ax.grid(True, alpha=0.5)
ax.legend()

# 2. SSIM Plot
ax = axes[1]
# Plot Baseline
ax.plot(SNR_RANGE, BPG_BASELINE["SSIM"], label="BPG + Capacity (Baseline)", 
        color='black', linestyle='--', linewidth=2.5, marker='*')

# Plot Models
for name, metrics in results.items():
    ax.plot(SNR_RANGE, metrics["SSIM"], label=name, marker='o')

ax.set_title("SSIM Comparison (AWGN Channel)")
ax.set_xlabel("SNR (dB)")
ax.set_ylabel("SSIM")
ax.grid(True, alpha=0.5)
ax.legend()

plt.savefig("full_comparison_awgn.png", dpi=300)
plt.show()