In [None]:
import os
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
from tqdm import tqdm

# Output directories
EMBED_OUTPUT_DIR = 'output/stegnographed'
DECODE_OUTPUT_DIR = 'output/stegnograph_decoded'

# Device selection: CUDA → MPS → CPU
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

device = get_device()
print(f"🖥️ Using device: {device}")

def message_to_bits(message):
    return ''.join(f'{ord(c):08b}' for c in message)

def bits_to_message(bits):
    chars = [chr(int(bits[i:i+8], 2)) for i in range(0, len(bits), 8)]
    return ''.join(chars)

def get_texture_map(gray_tensor):
    gray_tensor = gray_tensor.unsqueeze(0).unsqueeze(0).to(device).float() / 255.0
    laplacian_kernel = torch.tensor([[0, 1, 0],
                                     [1, -4, 1],
                                     [0, 1, 0]], dtype=torch.float32).to(device).unsqueeze(0).unsqueeze(0)
    lap = torch.nn.functional.conv2d(gray_tensor, laplacian_kernel, padding=1)
    texture = torch.abs(lap)
    blurred = torch.nn.functional.avg_pool2d(texture, kernel_size=7, stride=1, padding=3)
    return blurred.squeeze().flatten()

def embed_message(img_np, message):
    img_rgb = torch.tensor(img_np, dtype=torch.uint8).to(device)
    gray = (0.2989 * img_rgb[..., 0] + 0.5870 * img_rgb[..., 1] + 0.1140 * img_rgb[..., 2]).byte()
    texture_map = get_texture_map(gray)

    flat_img = img_rgb.view(-1, 3)
    indices = torch.argsort(texture_map)

    binary_message = message_to_bits(message)
    header = f'{len(binary_message):032b}'
    final_bits = header + binary_message

    if len(final_bits) > flat_img.shape[0]:
        return None

    for i, bit in enumerate(final_bits):
        idx = indices[i]
        flat_img[idx][2] = (flat_img[idx][2] & 0xFE) | int(bit)

    return flat_img.view(img_np.shape).cpu().numpy()

def extract_message(img_np):
    img_rgb = torch.tensor(img_np, dtype=torch.uint8).to(device)
    gray = (0.2989 * img_rgb[..., 0] + 0.5870 * img_rgb[..., 1] + 0.1140 * img_rgb[..., 2]).byte()
    texture_map = get_texture_map(gray)

    flat_img = img_rgb.view(-1, 3)
    indices = torch.argsort(texture_map)

    header_bits = ''.join(str(int(flat_img[idx][2].item()) & 1) for idx in indices[:32])
    message_len = int(header_bits, 2)

    bits = ''.join(str(int(flat_img[indices[i + 32]][2].item()) & 1) for i in range(message_len))
    return bits_to_message(bits)

def batch_embed(input_folder, message):
    os.makedirs(EMBED_OUTPUT_DIR, exist_ok=True)
    supported_ext = ('.png', '.jpg', '.jpeg', '.bmp')
    files = [f for f in os.listdir(input_folder) if f.lower().endswith(supported_ext)]

    total, success, failed = 0, 0, 0

    print(f"\n🚀 Embedding message into {len(files)} images...")
    for filename in tqdm(files, desc="Embedding"):
        total += 1
        input_path = os.path.join(input_folder, filename)
        try:
            img = Image.open(input_path).convert('RGB')
            img_np = np.array(img)
            stego_np = embed_message(img_np, message)
            if stego_np is None:
                failed += 1
                continue
            Image.fromarray(stego_np.astype(np.uint8)).save(os.path.join(EMBED_OUTPUT_DIR, filename))
            success += 1
        except:
            failed += 1

    print("\n📊 Embed Summary")
    print(f"  Total:   {total}")
    print(f"  Success: {success}")
    print(f"  Failed:  {failed}")
    print(f"  Output:  {EMBED_OUTPUT_DIR}")

def batch_decode(input_folder):
    os.makedirs(DECODE_OUTPUT_DIR, exist_ok=True)
    supported_ext = ('.png', '.jpg', '.jpeg', '.bmp')
    files = [f for f in os.listdir(input_folder) if f.lower().endswith(supported_ext)]

    total, success, failed = 0, 0, 0

    print(f"\n🔍 Decoding messages from {len(files)} images...")
    for filename in tqdm(files, desc="Decoding"):
        total += 1
        input_path = os.path.join(input_folder, filename)
        try:
            img = Image.open(input_path).convert('RGB')
            img_np = np.array(img)
            message = extract_message(img_np)
            out_txt = os.path.join(DECODE_OUTPUT_DIR, f"{os.path.splitext(filename)[0]}.txt")
            with open(out_txt, 'w', encoding='utf-8') as f:
                f.write(message)
            success += 1
        except:
            failed += 1

    print("\n📊 Decode Summary")
    print(f"  Total:   {total}")
    print(f"  Success: {success}")
    print(f"  Failed:  {failed}")
    print(f"  Output:  {DECODE_OUTPUT_DIR}")

if __name__ == "__main__":
    print("🧠 GPU Steganography Tool")

    mode = input("Do you want to (e)mbed or (d)ecode? ").strip().lower()
    if mode == 'e':
        input_folder = input("Enter folder path containing original images: ").strip()
        message = input("Enter the message to embed (emojis supported): ").strip()
        batch_embed(input_folder, message)
    elif mode == 'd':
        input_folder = input("Enter folder path containing stego images: ").strip()
        batch_decode(input_folder)
    else:
        print("❌ Invalid option. Please enter 'e' or 'd'.")

🖥️ Using device: mps
🧠 GPU Steganography Tool

🔍 Decoding messages from 3289 images...


Decoding:   0%|          | 3/3289 [01:30<24:03:47, 26.36s/it]