# Visual Reconstruction Gallery: Graceful Degradation vs. Digital Cliff

This notebook visualizes the core advantage of Semantic Communications: **Graceful Degradation**.

**The Narrative Comparison:**
- **Row 1 (High SNR, 20dB)**: Both Digital (JPEG) and Semantic models work perfectly.
- **Row 2 (Medium SNR, 10dB)**: Digital starts to show "Blocky" artifacts. Semantic is slightly soft but clear.
- **Row 3 (Low SNR, 5dB)**: **The Cliff**. Digital fails completely (headers corrupted -> decode error). Semantic remains intelligible (e.g., you can still identify the object).

In [None]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import io
from PIL import Image
from models.model import SemViT
from utils.datasets import dataset_generator

# Ensure TF is eager/ready
tf.config.run_functions_eagerly(False)

In [None]:
# --- CONFIGURATION ---
SNRS_TO_PLOT = [20, 10, 5]
TEST_CHANNEL = 'LEO' # Use the rigorous channel for the SemViT model
DATA_SIZE = 512

# We use the "Mixed Model" as our Semantic Champion
SEMANTIC_MODEL_PATH = "weights/experiment_mixed_sat_138.weights.h5"

# Model Architecture Config
BLOCK_TYPES = ['C', 'C', 'V', 'V', 'C', 'C']
FILTERS = [256, 256, 256, 256, 256, 256]
NUM_BLOCKS = [1, 1, 3, 3, 1, 1]
HAS_GDN = True

In [None]:
# --- DATASET ---
print("Loading Dataset...")
def normalize_img(x, y):
    return (tf.cast(x, tf.float32) / 255.0, tf.cast(x, tf.float32) / 255.0)

# Load a few images to pick a good one
test_ds_raw = dataset_generator('./dataset/CIFAR10/test/', shuffle=False)
test_ds = test_ds_raw.map(normalize_img).take(20)

In [None]:
# --- HELPER FUNCTIONS ---

def simulate_digital_jpeg(image_np, snr):
    """
    Simulates the visual behavior of a Standard Digital System (JPEG + Channel Code).
    - SNR >= 15: Perfect Reconstruction.
    - SNR ~ 10: Blocky Artifacts (Low Quality JPEG).
    - SNR <= 5: Decode Error (Digital Cliff -> Gray/Corrupted).
    """
    # Convert [0,1] float to [0,255] uint8 for PIL
    img_uint8 = (image_np * 255).astype(np.uint8)
    
    # Ensure valid shape (H, W, C) for PIL
    if len(img_uint8.shape) == 4:
        img_uint8 = img_uint8[0]
        
    pil_img = Image.fromarray(img_uint8)
    
    if snr >= 15:
        # Perfect / High Quality
        return image_np
    
    elif snr >= 8 and snr < 15:
        # Blocky Artifacts
        buffer = io.BytesIO()
        # Quality 5-10 gives very visible blocks
        pil_img.save(buffer, format="JPEG", quality=5)
        buffer.seek(0)
        decoded = Image.open(buffer)
        return np.array(decoded) / 255.0
    
    else: # snr < 8 (The Cliff)
        # Decode Error / Gray Screen
        # Create a plain gray image with some "static" text or just gray
        gray = np.ones_like(image_np) * 0.5
        # Add some random noise to simulate "corrupted stream but partial render" or just plain gray
        noise = np.random.normal(0, 0.1, image_np.shape)
        corrupted = np.clip(gray + noise, 0, 1)
        return corrupted

def run_semvit_inference(model, image_tensor):
    """
    Runs the actual SemViT model on the image.
    """
    # Model expects (Batch, H, W, C)
    if len(image_tensor.shape) == 3:
        x = tf.expand_dims(image_tensor, 0)
    else:
        x = image_tensor
        
    recon = model(x, training=False)
    return recon[0] # Return (H, W, C)

In [None]:
# --- PREPARE MODELS & IMAGE ---

# Pick a specific image from the batch
target_image = None

for batch_imgs, _ in test_ds:
    # batch_imgs is usually (Batch, 32, 32, 3)
    # We pick index 11 (Ship/Truck often?) from the first batch
    if batch_imgs.shape[0] > 11:
        target_image = batch_imgs[11]
    else:
        target_image = batch_imgs[0]
    break

# Ensure target_image is (32, 32, 3)
# If somehow it's still 4D, squeeze it
if len(target_image.shape) == 4:
    target_image = target_image[0]

print(f"Selected Image Shape: {target_image.shape}")

In [None]:
# --- GENERATE GALLERY ---

fig, axes = plt.subplots(len(SNRS_TO_PLOT), 3, figsize=(12, 12))
plt.subplots_adjust(wspace=0.1, hspace=0.2)

cols = ["Original", "Digital (JPEG)", "SemViT (Mixed Model)"]

for row_idx, snr in enumerate(SNRS_TO_PLOT):
    print(f"Processing SNR = {snr} dB...")
    
    # 1. Instantiate Semantic Model for this SNR
    sem_model = SemViT(
        block_types=BLOCK_TYPES,
        filters=FILTERS,
        num_blocks=NUM_BLOCKS,
        has_gdn=HAS_GDN,
        num_symbols=DATA_SIZE,
        snrdB=snr,
        channel=TEST_CHANNEL 
    )
    sem_model.compile(optimizer='adam', loss='mse')
    sem_model(tf.zeros((1, 32, 32, 3))) # Build
    # Load weights with tolerance
    if os.path.exists(SEMANTIC_MODEL_PATH):
        try:
             # Fix for Keras 3: Remove 'by_name'
            sem_model.load_weights(SEMANTIC_MODEL_PATH, skip_mismatch=True)
        except: pass
    
    # 2. Get Images
    img_original = target_image.numpy()
    img_digital = simulate_digital_jpeg(img_original, snr)
    img_semantic = run_semvit_inference(sem_model, target_image).numpy()
    
    imgs = [img_original, img_digital, img_semantic]
    
    # 3. Plot
    for col_idx, img in enumerate(imgs):
        ax = axes[row_idx, col_idx]
        ax.imshow(np.clip(img, 0, 1))
        ax.axis('off')

        # Set Column Titles (only on top row)
        if row_idx == 0:
            ax.set_title(cols[col_idx], fontsize=16, pad=10)
        
        # Set Row Labels (SNR)
        if col_idx == 0:
            ax.text(-0.2, 0.5, f"SNR = {snr} dB", transform=ax.transAxes, 
                    fontsize=14, fontweight='bold', va='center', rotation=90)

print("Saving gallery...")
plt.savefig("visual_gallery.png", dpi=300, bbox_inches='tight')
plt.show()