# Rewritten Blocks 3, 4, 5 — True DCT Steganography (JSteg)
Replace your existing Block 3, 4, 5 with these cells.
After replacing, run: Block 3 → Block 4 → Block 5 → Block 7 → Block 5b → Block 9

In [None]:
# ============================================================
# BLOCK 3 (JSteg): JPEG quantization table + helpers
# ============================================================
import numpy as np
import cv2
from scipy.fftpack import dct, idct

# Standard JPEG luminance quantization table (quality=50, ITU-T T.81 standard)
# Each value is the quantization step for that DCT position.
# JPEG uses this exact table internally — so we must use it too.
JPEG_QUANT_TABLE = np.array([
    [16, 11, 10, 16, 24,  40,  51,  61],
    [12, 12, 14, 19, 26,  58,  60,  55],
    [14, 13, 16, 24, 40,  57,  69,  56],
    [14, 17, 22, 29, 51,  87,  80,  62],
    [18, 22, 37, 56, 68,  109, 103, 77],
    [24, 35, 55, 64, 81,  104, 113, 92],
    [49, 64, 78, 87, 103, 121, 120, 101],
    [72, 92, 95, 98, 112, 100, 103, 99]
], dtype=np.float64)

# Select which positions to embed into:
# - Skip DC coefficient (0,0) — carries overall brightness, very visible if changed
# - Use only mid-frequency positions where JPEG Q step is between 10 and 50
# - High Q (>50) means JPEG quantizes too coarsely — our bit gets destroyed
# - Low Q (<10) positions are near DC — changes are visible
EMBED_POSITIONS = [
    (r, c)
    for r in range(8)
    for c in range(8)
    if (r, c) != (0, 0)
    and 10 <= JPEG_QUANT_TABLE[r, c] <= 50
]
print(f'Embedding positions per 8x8 block: {len(EMBED_POSITIONS)}')
print(f'Positions: {EMBED_POSITIONS}')


def text_to_bits(text):
    bits = []
    for char in text:
        byte = ord(char)
        for i in range(7, -1, -1):
            bits.append((byte >> i) & 1)
    return bits


def bits_to_text(bits):
    chars = []
    for i in range(0, len(bits), 8):
        byte = bits[i:i+8]
        if len(byte) < 8:
            break
        value = int(''.join(str(b) for b in byte), 2)
        chars.append(chr(value))
    return ''.join(chars)


def apply_dct_block(block):
    return dct(dct(block.T, norm='ortho').T, norm='ortho')


def apply_idct_block(block):
    return idct(idct(block.T, norm='ortho').T, norm='ortho')


def embed_bit_jsteg(coeff, bit, Q):
    """
    Embed one bit into a single DCT coefficient using JSteg method.

    The logic:
    1. Divide coeff by Q to get an integer quantization index
    2. Make that index ODD  for bit=1
                       EVEN for bit=0
    3. Multiply back by Q

    WHY this survives JPEG:
    JPEG quantizes using the same Q table. So JPEG re-quantizes
    our coefficient to the same index we set. Our odd/even bit survives.
    """
    quantized = int(round(coeff / Q))

    # Avoid zero — JPEG often discards zero coefficients entirely
    if quantized == 0:
        quantized = 1

    current_bit = abs(quantized) % 2   # 1 if odd, 0 if even

    if current_bit != bit:
        # Flip parity by adding 1 in the direction away from zero
        if quantized > 0:
            quantized += 1
        else:
            quantized -= 1

    # Final zero check after adjustment
    if quantized == 0:
        quantized = 2

    return float(quantized * Q)


def extract_bit_jsteg(coeff, Q):
    """
    Extract one bit from a DCT coefficient.
    Odd quantization index = bit 1
    Even quantization index = bit 0
    """
    quantized = int(round(coeff / Q))
    if quantized == 0:
        return 0
    return abs(quantized) % 2


print('Block 3 JSteg ready.')

In [None]:
# ============================================================
# BLOCK 4 (JSteg): Embed message
# ============================================================

def embed_message(image_path, message, output_path):

    img = cv2.imread(image_path)
    if img is None:
        raise FileNotFoundError(f'Cannot open: {image_path}')

    img = np.clip(img, 0, 255).astype(np.uint8)

    # Crop to multiple of 8 — block alignment is critical
    h, w = img.shape[:2]
    new_h = (h // 8) * 8
    new_w = (w // 8) * 8
    img = img[0:new_h, 0:new_w]
    print(f'Cropped: {w}x{h} -> {new_w}x{new_h}')

    # Work on Y (luminance) channel — least visible to human eye
    img_ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
    Y, Cr, Cb = cv2.split(img_ycrcb)
    Cr = Cr.astype(np.uint8)
    Cb = Cb.astype(np.uint8)
    Y  = Y.astype(np.float64)

    # Prepare message bits with null terminator
    full_message = message + '\x00'
    bits = text_to_bits(full_message)
    total_bits = len(bits)

    height, width = Y.shape
    blocks_v = height // 8
    blocks_h = width // 8
    total_capacity = blocks_v * blocks_h * len(EMBED_POSITIONS)

    print(f'Total blocks      : {blocks_v * blocks_h}')
    print(f'Bits per block    : {len(EMBED_POSITIONS)}')
    print(f'Total capacity    : {total_capacity} bits = ~{total_capacity // 8} chars')
    print(f'Bits to embed     : {total_bits}')

    if total_bits > total_capacity:
        raise ValueError(f'Message too long! Max ~{total_capacity // 8} characters.')

    bit_index = 0

    for row in range(blocks_v):
        for col in range(blocks_h):
            if bit_index >= total_bits:
                break

            r = row * 8
            c = col * 8
            block = Y[r:r+8, c:c+8].copy()
            dct_block = apply_dct_block(block)
            modified = False

            for (pr, pc) in EMBED_POSITIONS:
                if bit_index >= total_bits:
                    break
                Q = JPEG_QUANT_TABLE[pr, pc]
                dct_block[pr, pc] = embed_bit_jsteg(dct_block[pr, pc], bits[bit_index], Q)
                bit_index += 1
                modified = True

            if modified:
                Y[r:r+8, c:c+8] = apply_idct_block(dct_block)

    Y_uint8 = np.clip(Y, 0, 255).astype(np.uint8)
    stego_ycrcb = cv2.merge([Y_uint8, Cr, Cb])
    stego_bgr   = cv2.cvtColor(stego_ycrcb, cv2.COLOR_YCrCb2BGR)

    ok = cv2.imwrite(output_path, stego_bgr)
    print(f'Saved: {output_path} | Success: {ok}')
    return stego_bgr


print('embed_message() JSteg defined.')

In [None]:
# ============================================================
# BLOCK 5 (JSteg): Extract message
# ============================================================

def extract_message(stego_path):
    """
    Extract hidden message from stego image.
    Works on both PNG and JPEG-compressed versions of the stego image.
    """
    img = cv2.imread(stego_path)
    if img is None:
        raise FileNotFoundError(f'Cannot open: {stego_path}')

    img_ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
    Y, _, _ = cv2.split(img_ycrcb)
    Y = Y.astype(np.float64)

    height, width = Y.shape
    blocks_v = height // 8
    blocks_h = width // 8

    bits = []

    for row in range(blocks_v):
        for col in range(blocks_h):
            r = row * 8
            c = col * 8
            block = Y[r:r+8, c:c+8].copy()
            dct_block = apply_dct_block(block)

            for (pr, pc) in EMBED_POSITIONS:
                Q = JPEG_QUANT_TABLE[pr, pc]
                bit = extract_bit_jsteg(dct_block[pr, pc], Q)
                bits.append(bit)

                if len(bits) % 8 == 0:
                    last_byte = int(''.join(str(b) for b in bits[-8:]), 2)
                    if last_byte == 0:  # null terminator found
                        chars = []
                        for i in range(0, len(bits) - 8, 8):
                            val = int(''.join(str(b) for b in bits[i:i+8]), 2)
                            chars.append(chr(val))
                        return ''.join(chars)

    return bits_to_text(bits)


print('extract_message() JSteg defined.')

In [None]:
# ============================================================
# BLOCK 5b: JPEG survival test
# Run this AFTER Block 7 (embedding) to verify JPEG survival
# ============================================================
import os

print('Testing JPEG survival...\n')
print(f'Original: "{SECRET_MESSAGE}"\n')
print(f'{"Quality":<12} {"Survived":<12} {"Corrupted":<15} {"Extracted"}')
print('-' * 80)

for quality in [90, 75, 50, 30]:
    temp_jpg = f'test_q{quality}.jpg'
    temp_png = f'test_q{quality}.png'

    # Compress stego as JPEG then reload as PNG
    img = cv2.imread(STEGO_IMAGE_PATH)
    cv2.imwrite(temp_jpg, img, [cv2.IMWRITE_JPEG_QUALITY, quality])
    reloaded = cv2.imread(temp_jpg)
    cv2.imwrite(temp_png, reloaded)

    # Extract
    extracted = extract_message(temp_png)
    survived  = (extracted == SECRET_MESSAGE)
    corrupted = sum(1 for a, b in zip(SECRET_MESSAGE, extracted) if a != b)
    corrupted += abs(len(SECRET_MESSAGE) - len(extracted))

    status   = 'YES ✓' if survived else 'NO  ✗'
    corrupt_str = f'{corrupted}/{len(SECRET_MESSAGE)}'
    preview  = extracted[:40] + '...' if len(extracted) > 40 else extracted
    print(f'Q={quality:<9} {status:<12} {corrupt_str:<15} "{preview}"')

    os.remove(temp_jpg)
    os.remove(temp_png)