# Attack Simulation Blocks
## Add these blocks AFTER Block 9 in your existing notebook
These blocks simulate real-world attacks on the stego image and measure how well the hidden message survives.

**Run order:** Block 10 → Block 11 → Block 12

In [None]:
# ============================================================
# BLOCK 10: Attack functions
# Each function takes the stego image path, applies an attack,
# saves the attacked image, and returns the saved path.
# ============================================================
import cv2
import numpy as np
import os

# --- Helper: calculate PSNR between original stego and attacked version ---
def calc_psnr(img1_path, img2_path):
    """
    Peak Signal-to-Noise Ratio between two images.
    Higher = less visual distortion.
    > 40 dB = imperceptible to humans.
    """
    a = cv2.imread(img1_path).astype(np.float64)
    b = cv2.imread(img2_path).astype(np.float64)

    # Crop both to same size in case of size mismatch after attack
    h = min(a.shape[0], b.shape[0])
    w = min(a.shape[1], b.shape[1])
    a, b = a[:h, :w], b[:h, :w]

    mse = np.mean((a - b) ** 2)
    if mse == 0:
        return float('inf')
    return 10 * np.log10((255 ** 2) / mse)


# --- Helper: compare extracted message to original ---
def evaluate(extracted, original):
    """
    Returns (survived, corrupted_chars, total_chars)
    survived       : True if message extracted perfectly
    corrupted_chars: number of wrong characters
    total_chars    : length of original message
    """
    total = len(original)
    if not extracted:
        return False, total, total
    corrupted = sum(1 for a, b in zip(original, extracted) if a != b)
    corrupted += abs(len(original) - len(extracted))  # length difference counts too
    survived = (corrupted == 0)
    return survived, corrupted, total


# ============================================================
# ATTACK 1: JPEG Compression
# Saves the stego image as JPEG at different quality levels
# then reloads it as PNG. This is the most destructive attack
# because JPEG re-quantizes DCT coefficients.
# ============================================================
def attack_jpeg(stego_path, quality, output_path):
    img = cv2.imread(stego_path)
    # Save as JPEG with the given quality (100=best, 1=worst)
    temp_jpg = output_path.replace('.png', '_temp.jpg')
    cv2.imwrite(temp_jpg, img, [cv2.IMWRITE_JPEG_QUALITY, quality])
    # Reload the JPEG and save back as PNG for extraction
    reloaded = cv2.imread(temp_jpg)
    cv2.imwrite(output_path, reloaded)
    os.remove(temp_jpg)  # clean up temp file
    return output_path


# ============================================================
# ATTACK 2: Gaussian Noise
# Adds random pixel noise sampled from a Gaussian distribution.
# Simulates a noisy transmission channel.
# ============================================================
def attack_gaussian_noise(stego_path, std, output_path):
    img = cv2.imread(stego_path).astype(np.float64)
    # Generate noise with mean=0 and the given standard deviation
    noise = np.random.normal(0, std, img.shape)
    noisy = np.clip(img + noise, 0, 255).astype(np.uint8)
    cv2.imwrite(output_path, noisy)
    return output_path


# ============================================================
# ATTACK 3: Resizing
# Scales the image down then back up to original size.
# Interpolation during resize changes pixel values.
# ============================================================
def attack_resize(stego_path, scale_factor, output_path):
    img = cv2.imread(stego_path)
    h, w = img.shape[:2]
    # Scale down
    small_w = int(w * scale_factor)
    small_h = int(h * scale_factor)
    small = cv2.resize(img, (small_w, small_h), interpolation=cv2.INTER_LINEAR)
    # Scale back up to original size
    restored = cv2.resize(small, (w, h), interpolation=cv2.INTER_LINEAR)
    cv2.imwrite(output_path, restored)
    return output_path


# ============================================================
# ATTACK 4: Rotation
# Rotates the image by a small angle then rotates back.
# Even tiny rotations desync the 8x8 block grid.
# ============================================================
def attack_rotation(stego_path, angle, output_path):
    img = cv2.imread(stego_path)
    h, w = img.shape[:2]
    center = (w // 2, h // 2)
    # Rotate by +angle
    M_fwd = cv2.getRotationMatrix2D(center, angle, 1.0)
    rotated = cv2.warpAffine(img, M_fwd, (w, h),
                              flags=cv2.INTER_LINEAR,
                              borderMode=cv2.BORDER_REFLECT)
    # Rotate back by -angle
    M_back = cv2.getRotationMatrix2D(center, -angle, 1.0)
    restored = cv2.warpAffine(rotated, M_back, (w, h),
                               flags=cv2.INTER_LINEAR,
                               borderMode=cv2.BORDER_REFLECT)
    cv2.imwrite(output_path, restored)
    return output_path


# ============================================================
# ATTACK 5: Histogram Equalization
# Redistributes pixel intensity values to improve contrast.
# Drastically changes the Y channel — very destructive.
# ============================================================
def attack_histogram_eq(stego_path, output_path):
    img = cv2.imread(stego_path)
    img_ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
    Y, Cr, Cb = cv2.split(img_ycrcb)
    # Equalize only the Y channel (same channel we embed into)
    Y_eq = cv2.equalizeHist(Y)
    merged = cv2.merge([Y_eq, Cr, Cb])
    result = cv2.cvtColor(merged, cv2.COLOR_YCrCb2BGR)
    cv2.imwrite(output_path, result)
    return output_path


# ============================================================
# ATTACK 6: Median Filter
# Replaces each pixel with the median of its neighbors.
# Commonly used to remove noise — also removes hidden data.
# ============================================================
def attack_median_filter(stego_path, kernel_size, output_path):
    img = cv2.imread(stego_path)
    # kernel_size must be odd: 3, 5, 7...
    filtered = cv2.medianBlur(img, kernel_size)
    cv2.imwrite(output_path, filtered)
    return output_path


print('All attack functions defined.')

In [None]:
# ============================================================
# BLOCK 11: Run all attacks and collect results
# ============================================================
import os

# Make a folder to store attacked images
ATTACK_DIR = 'attacked_images'
os.makedirs(ATTACK_DIR, exist_ok=True)

# All attacks to run with their parameters
# Format: (attack_name, function_call_lambda, output_filename)
attacks = [
    # JPEG at different quality levels
    ('JPEG Quality=90',  lambda: attack_jpeg(STEGO_IMAGE_PATH, 90,  f'{ATTACK_DIR}/jpeg_q90.png')),
    ('JPEG Quality=75',  lambda: attack_jpeg(STEGO_IMAGE_PATH, 75,  f'{ATTACK_DIR}/jpeg_q75.png')),
    ('JPEG Quality=50',  lambda: attack_jpeg(STEGO_IMAGE_PATH, 50,  f'{ATTACK_DIR}/jpeg_q50.png')),

    # Gaussian noise at different intensities
    ('Gaussian Noise std=5',  lambda: attack_gaussian_noise(STEGO_IMAGE_PATH, 5,  f'{ATTACK_DIR}/noise_5.png')),
    ('Gaussian Noise std=15', lambda: attack_gaussian_noise(STEGO_IMAGE_PATH, 15, f'{ATTACK_DIR}/noise_15.png')),
    ('Gaussian Noise std=30', lambda: attack_gaussian_noise(STEGO_IMAGE_PATH, 30, f'{ATTACK_DIR}/noise_30.png')),

    # Resizing at different scale factors
    ('Resize scale=0.9', lambda: attack_resize(STEGO_IMAGE_PATH, 0.9, f'{ATTACK_DIR}/resize_90.png')),
    ('Resize scale=0.75',lambda: attack_resize(STEGO_IMAGE_PATH, 0.75,f'{ATTACK_DIR}/resize_75.png')),
    ('Resize scale=0.5', lambda: attack_resize(STEGO_IMAGE_PATH, 0.5, f'{ATTACK_DIR}/resize_50.png')),

    # Rotation at different angles
    ('Rotation 1 deg',   lambda: attack_rotation(STEGO_IMAGE_PATH, 1,  f'{ATTACK_DIR}/rot_1.png')),
    ('Rotation 5 deg',   lambda: attack_rotation(STEGO_IMAGE_PATH, 5,  f'{ATTACK_DIR}/rot_5.png')),
    ('Rotation 10 deg',  lambda: attack_rotation(STEGO_IMAGE_PATH, 10, f'{ATTACK_DIR}/rot_10.png')),

    # Histogram equalization
    ('Histogram EQ',     lambda: attack_histogram_eq(STEGO_IMAGE_PATH, f'{ATTACK_DIR}/histeq.png')),

    # Median filter at different kernel sizes
    ('Median Filter k=3',lambda: attack_median_filter(STEGO_IMAGE_PATH, 3, f'{ATTACK_DIR}/median_3.png')),
    ('Median Filter k=5',lambda: attack_median_filter(STEGO_IMAGE_PATH, 5, f'{ATTACK_DIR}/median_5.png')),
]

# Run each attack and collect results
results = []

print(f'Running {len(attacks)} attacks...\n')
print(f'{"Attack":<25} {"Survived":<10} {"Corrupted":<15} {"PSNR (dB)":<12}')
print('-' * 65)

for attack_name, attack_fn in attacks:
    try:
        # Apply the attack
        attacked_path = attack_fn()

        # Try to extract the message from the attacked image
        extracted = extract_message(attacked_path)

        # Evaluate how well the message survived
        survived, corrupted, total = evaluate(extracted, SECRET_MESSAGE)

        # Calculate PSNR between original stego and attacked version
        psnr = calc_psnr(STEGO_IMAGE_PATH, attacked_path)
        psnr_str = f'{psnr:.2f}' if psnr != float('inf') else 'inf'

        # Store result
        results.append({
            'name': attack_name,
            'survived': survived,
            'corrupted': corrupted,
            'total': total,
            'psnr': psnr,
            'psnr_str': psnr_str,
            'extracted': extracted,
            'path': attacked_path
        })

        status = 'YES ✓' if survived else 'NO  ✗'
        corrupted_str = f'{corrupted}/{total} chars'
        print(f'{attack_name:<25} {status:<10} {corrupted_str:<15} {psnr_str:<12}')

    except Exception as e:
        print(f'{attack_name:<25} ERROR: {e}')
        results.append({'name': attack_name, 'survived': False,
                        'corrupted': -1, 'total': len(SECRET_MESSAGE),
                        'psnr': 0, 'psnr_str': 'N/A', 'extracted': '', 'path': ''})

print('-' * 65)
survived_count = sum(1 for r in results if r['survived'])
print(f'\nMessage survived {survived_count}/{len(results)} attacks.')

In [None]:
# ============================================================
# BLOCK 12: Visualize results
# ============================================================
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np

# --- Chart 1: Corrupted characters per attack ---
names      = [r['name'] for r in results]
corrupted  = [r['corrupted'] for r in results]
survived   = [r['survived'] for r in results]
psnr_vals  = [r['psnr'] if r['psnr'] != float('inf') else 60 for r in results]
colors     = ['#2ecc71' if s else '#e74c3c' for s in survived]

fig, axes = plt.subplots(2, 1, figsize=(14, 10))

# Bar chart — corrupted characters
bars = axes[0].bar(names, corrupted, color=colors, edgecolor='black', linewidth=0.5)
axes[0].set_title('Characters Corrupted per Attack', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Corrupted Characters')
axes[0].set_xticks(range(len(names)))
axes[0].set_xticklabels(names, rotation=45, ha='right', fontsize=9)
axes[0].axhline(y=0, color='black', linewidth=0.8)
axes[0].set_ylim(0, len(SECRET_MESSAGE) + 5)

# Add value labels on bars
for bar, val in zip(bars, corrupted):
    if val > 0:
        axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
                     str(val), ha='center', va='bottom', fontsize=8)

# Legend
green_patch = mpatches.Patch(color='#2ecc71', label='Message survived')
red_patch   = mpatches.Patch(color='#e74c3c', label='Message corrupted')
axes[0].legend(handles=[green_patch, red_patch])

# Bar chart — PSNR values
psnr_colors = ['#3498db' if p >= 30 else '#e67e22' if p >= 20 else '#e74c3c'
               for p in psnr_vals]
bars2 = axes[1].bar(names, psnr_vals, color=psnr_colors, edgecolor='black', linewidth=0.5)
axes[1].set_title('PSNR After Each Attack (higher = less visual damage)', fontsize=14, fontweight='bold')
axes[1].set_ylabel('PSNR (dB)')
axes[1].set_xticks(range(len(names)))
axes[1].set_xticklabels(names, rotation=45, ha='right', fontsize=9)
axes[1].axhline(y=40, color='green', linewidth=1.5, linestyle='--', label='40 dB (imperceptible threshold)')
axes[1].axhline(y=30, color='orange', linewidth=1.5, linestyle='--', label='30 dB (slight distortion)')
axes[1].legend(fontsize=9)

# Add PSNR value labels
for bar, val, r in zip(bars2, psnr_vals, results):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
                 r['psnr_str'], ha='center', va='bottom', fontsize=7)

plt.tight_layout(pad=3.0)
plt.savefig('attack_results.png', dpi=150, bbox_inches='tight')
plt.show()
print('Chart saved as attack_results.png')

# --- Summary table ---
print('\n===== FULL RESULTS SUMMARY =====')
print(f'{"Attack":<25} {"Survived":<10} {"Corrupted":<15} {"PSNR":<10} {"Extracted message"}')
print('=' * 100)
for r in results:
    status   = 'YES' if r['survived'] else 'NO'
    corrupt  = f"{r['corrupted']}/{r['total']}" if r['corrupted'] >= 0 else 'ERROR'
    preview  = r['extracted'][:40] + '...' if len(r['extracted']) > 40 else r['extracted']
    print(f"{r['name']:<25} {status:<10} {corrupt:<15} {r['psnr_str']:<10} {repr(preview)}")

# --- What each attack means ---
print('''
===== WHAT THE RESULTS MEAN =====

JPEG Compression:
  JPEG re-quantizes DCT blocks — the most destructive attack for DCT steganography.
  Our method embeds using coefficient COMPARISON (A vs B), not exact values,
  so it has some resistance but JPEG at low quality will likely still break it.

Gaussian Noise:
  Random pixel changes. Low noise (std=5) may survive because our DELTA gap
  is larger than the noise. High noise (std=30) flips too many coefficients.

Resizing:
  Interpolation completely rearranges pixel values — block grid is destroyed.
  Even 90% scale is usually fatal for block-based steganography.

Rotation:
  Even 1 degree misaligns all 8x8 blocks — almost always fatal.
  This is a known fundamental weakness of DCT steganography.

Histogram Equalization:
  Stretches pixel intensities — massively changes the Y channel we embed into.
  Almost always destroys the hidden message.

Median Filter:
  Smoothing filter removes high and mid frequency changes — our embedded
  coefficients are mid-frequency so small kernels (3x3) may partially survive.
''')