# Visual Comparison: SemViT (Mixed) vs. Digital Cliff

This notebook generates a visual comparison grid to demonstrate the **"Digital Cliff" effect**:
1.  **Original**: Ground truth images.
2.  **Proposed (SemViT Mixed)**: Reconstruction shows graceful degradation (blurry but recognizable) at low SNR.
3.  **Conventional (Digital)**: Fails completely ("Cliff Effect") below a certain SNR threshold (typically < 8dB).

### Instructions
1.  **Run Setup**: Run the first cell to install dependencies.
2.  **Upload Weights**: Ensure `weights/experiment_mixed_sat_138.weights.h5` is present.


In [None]:
# --- SETUP ---
!pip install sionna==0.14.0 tensorflow-compression tensorflow-addons mitsuba==3.2.1

import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import io
from PIL import Image

# Auto-navigate to repo if needed
repo_name = 'Semmantic-Communication-Geo-Leo-Channels'
if os.path.exists(repo_name):
    os.chdir(repo_name)
elif os.path.exists('/content/' + repo_name):
    os.chdir('/content/' + repo_name)

from models.model import SemViT
from utils.datasets import dataset_generator

tf.config.run_functions_eagerly(False)

In [None]:
# --- CONFIGURATION ---

# Model Settings
MODEL_PATH = "weights/experiment_mixed_sat_138.weights.h5"

# SNR Selection suitable for demonstrating the Cliff Effect
# High (>15dB): Both good
# Medium (8-14dB): JPEG artifacts visible
# Low (<8dB): Digital Cliff (Complete Failure)
SNR_DB = 5                  # Set to 5dB to show the Cliff
JPEG_QUALITY = 10           # Quality factor for the pre-cliff blocky region
BATCH_SIZE = 64             # Number of images to show in the grid

# Architecture
BLOCK_TYPES = ['C', 'C', 'V', 'V', 'C', 'C']
FILTERS = [256, 256, 256, 256, 256, 256]
NUM_BLOCKS = [1, 1, 3, 3, 1, 1]
HAS_GDN = True
DATA_SIZE = 512

In [None]:
# --- HELPER: Image Grid ---
def create_image_grid(images, grid_size=(8, 8)):
    """
    Creates a single tiled image from a batch of images.
    images: numpy array of shape (B, H, W, C)
    grid_size: tuple (rows, cols)
    """
    h, w, c = images.shape[1:]
    rows, cols = grid_size
    
    # Ensure we have enough images
    if len(images) < rows * cols:
        raise ValueError("Not enough images for grid size")
        
    grid_img = np.zeros((rows * h, cols * w, c), dtype=images.dtype)
    
    for r in range(rows):
        for c_idx in range(cols):
            idx = r * cols + c_idx
            if idx < len(images):
                grid_img[r*h:(r+1)*h, c_idx*w:(c_idx+1)*w, :] = images[idx]
                
    return grid_img

def simulate_digital_transmission(images_np, snr, quality=10):
    """
    Simulates Digital Transmission behavior:
    - SNR >= 15: Perfect Reconstruction
    - 8 <= SNR < 15: Blocky Artifacts (JPEG)
    - SNR < 8: Digital Cliff (Complete Failure / Gray Screen)
    """
    compressed_batch = []
    
    for img in images_np:
        # 1. High SNR: Perfect
        if snr >= 15:
            compressed_batch.append(img)
            continue
            
        # 2. Medium SNR: Blocky Artifacts (JPEG)
        if snr >= 8:
            # Convert to uint8
            img_uint8 = (np.clip(img, 0, 1) * 255).astype(np.uint8)
            pil_img = Image.fromarray(img_uint8)
            
            buffer = io.BytesIO()
            pil_img.save(buffer, format="JPEG", quality=quality)
            buffer.seek(0)
            decoded = Image.open(buffer)
            compressed_batch.append(np.array(decoded).astype(np.float32) / 255.0)
            continue
            
        # 3. Low SNR: Digital Cliff (Failure)
        # Create gray image with static noise to simulate decoding failure
        gray = np.ones_like(img) * 0.5
        noise = np.random.normal(0, 0.1, img.shape)
        corrupted = np.clip(gray + noise, 0, 1)
        compressed_batch.append(corrupted)
    
    # CRITICAL FIX: Ensure output is float32 to match TensorFlow inputs
    return np.array(compressed_batch).astype(np.float32)

def calculate_metrics(original, reconstructed):
    """
    Calculates average PSNR and SSIM over the batch.
    """
    psnr_val = tf.image.psnr(original, reconstructed, max_val=1.0)
    ssim_val = tf.image.ssim(original, reconstructed, max_val=1.0)
    # Handle cliff cases where metric might be low or NaN
    return tf.reduce_mean(psnr_val).numpy(), tf.reduce_mean(ssim_val).numpy()

In [None]:
# --- LOAD MODEL & DATA ---
print("Loading Dataset...")
def normalize(x, y):
    return (tf.cast(x, tf.float32) / 255.0, tf.cast(x, tf.float32) / 255.0)

try:
    test_ds_raw = dataset_generator('./dataset/CIFAR10/test/', shuffle=False)
except:
    print("Dataset not found locally. Downloading CIFAR-10...")
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    test_ds_raw = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(1)

# Get a clean batch of BATCH_SIZE images
ds_batched = test_ds_raw.map(normalize).unbatch().batch(BATCH_SIZE).take(1)
original_batch = next(iter(ds_batched))[0]
print(f"Loaded batch shape: {original_batch.shape}")

print(f"Loading Model from {MODEL_PATH}...")
model = SemViT(
    block_types=BLOCK_TYPES,
    filters=FILTERS,
    num_blocks=NUM_BLOCKS,
    has_gdn=HAS_GDN,
    num_symbols=DATA_SIZE,
    snrdB=SNR_DB,
    channel='AWGN' # Using AWGN to match the standard digital comparison assumption
)
model.compile(optimizer='adam', loss='mse')
model(tf.zeros((1, 32, 32, 3))) # Build

if os.path.exists(MODEL_PATH):
    try:
        model.load_weights(MODEL_PATH, skip_mismatch=True)
        print("Weights loaded successfully.")
    except Exception as e:
        print(f"Error loading weights: {e}")
else:
    print(f"WARNING: Weights file not found at {MODEL_PATH}. Using random weights.")

In [None]:
# --- INFERENCE & COMPARISON ---

# 1. Proposed (Semantic Model)
print(f"Running Semantic Model Inference at {SNR_DB}dB...")
proposed_batch = model(original_batch, training=False)
psnr_prop, ssim_prop = calculate_metrics(original_batch, proposed_batch)

# 2. Conventional (Digital Simulation)
print(f"Simulating Digital Transmission at {SNR_DB}dB...")
digital_batch = simulate_digital_transmission(original_batch.numpy(), snr=SNR_DB, quality=JPEG_QUALITY)
psnr_dig, ssim_dig = calculate_metrics(original_batch, digital_batch)

print(f"Proposed (Semantic): PSNR={psnr_prop:.2f}, SSIM={ssim_prop:.2f}")
print(f"Conventional (Digital): PSNR={psnr_dig:.2f}, SSIM={ssim_dig:.2f}")

In [None]:
# --- VISUALIZATION ---
rows = int(np.sqrt(BATCH_SIZE))
cols = BATCH_SIZE // rows

# Create Grids
grid_orig = create_image_grid(original_batch.numpy(), (rows, cols))
grid_prop = create_image_grid(proposed_batch.numpy(), (rows, cols))
grid_dig = create_image_grid(digital_batch, (rows, cols))

fig, axes = plt.subplots(1, 3, figsize=(20, 8))

# Original
axes[0].imshow(np.clip(grid_orig, 0, 1))
axes[0].set_title("Original", fontsize=16)
axes[0].axis('off')

# Proposed
axes[1].imshow(np.clip(grid_prop, 0, 1))
axes[1].set_title(f"Proposed (Graceful Degradation)\nPSNR: {psnr_prop:.2f} dB\nSSIM: {ssim_prop:.2f}", fontsize=16)
axes[1].axis('off')

# Conventional
axes[2].imshow(np.clip(grid_dig, 0, 1))
dig_title = "Conventional (Digital Cliff)" if SNR_DB < 8 else "Conventional (JPEG)"
axes[2].set_title(f"{dig_title}\nPSNR: {psnr_dig:.2f} dB\nSSIM: {ssim_dig:.2f}", fontsize=16)
axes[2].axis('off')

plt.tight_layout()
plt.savefig("visual_comparison_cliff.png", dpi=300)
plt.show()