# **Utils**

In [1]:
import cv2
import numpy as np
import torch

In [2]:
def extract_frames(video_path, max_frames=None, resize_dim=(128, 128)):
    """
    Extracts frames from a video file, resizes them, and normalizes them.
    Returns a PyTorch tensor of shape (C, T, H, W) where T is the number of frames.
    """
    cap = cv2.VideoCapture(video_path)
    frames = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        if resize_dim is not None:
            frame = cv2.resize(frame, resize_dim)

        frames.append(frame)
        if max_frames is not None and len(frames) >= max_frames:
            break

    cap.release()
    if not frames:
        raise ValueError(f"Could not extract any frames from {video_path}")

    frames_np = np.array(frames).astype(np.float32) / 255.0
    tensor_frames = torch.from_numpy(frames_np).permute(3, 0, 1, 2)
    return tensor_frames

In [None]:
def compile_video(frames_tensor, output_path, fps=30):
    """
    Reconstructs a video from a PyTorch tensor of shape (C, T, H, W).
    """
    if frames_tensor.requires_grad:
        frames_tensor = frames_tensor.detach()
    frames_tensor = frames_tensor.cpu()

    frames_np = frames_tensor.permute(1, 2, 3, 0).numpy()
    frames_np = np.clip(frames_np * 255.0, 0, 255).astype(np.uint8)

    T, H, W, C = frames_np.shape
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))

    for i in range(T):
        frame = cv2.cvtColor(frames_np[i], cv2.COLOR_RGB2BGR)
        out.write(frame)

    out.release()

# **Video auto encoder**


In [None]:
import torch
import torch.nn as nn

In [None]:
class VideoEncoder(nn.Module):
    def __init__(self, in_channels=3, latent_dim=256):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels, 32, 3, 2, 1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.Conv3d(32, 64, 3, 2, 1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.Conv3d(64, 128, 3, 2, 1),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.Conv3d(128, 256, 3, 2, 1),
            nn.BatchNorm3d(256),
            nn.ReLU(),
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(4096, latent_dim)

    def forward(self, x):
        return self.fc(self.flatten(self.encoder(x)))

In [None]:
class VideoDecoder(nn.Module):
    def __init__(self, out_channels=3, latent_dim=256):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 4096)
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(256, 128, 3, 2, 1, 1),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 64, 3, 2, 1, 1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, 3, 2, 1, 1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.ConvTranspose3d(32, out_channels, 3, 2, 1, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.decoder(self.fc(x).view(-1, 256, 1, 4, 4))

In [None]:
class VideoAutoencoder(nn.Module):
    def __init__(self, in_channels=3, latent_dim=256):
        super().__init__()
        self.encoder = VideoEncoder(in_channels, latent_dim)
        self.decoder = VideoDecoder(in_channels, latent_dim)

    def forward(self, x):
        latent = self.encoder(x)
        return self.decoder(latent), latent

# **Encryption**

In [None]:
import os
import torch
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend

In [None]:
def generate_key():
    return os.urandom(32)

In [None]:
def generate_iv():
    return os.urandom(12)

In [None]:
def tensor_to_bytes(tensor):
    tensor_np = tensor.cpu().detach().numpy().astype("float32")
    return tensor_np.tobytes(), tensor_np.shape

In [None]:
def bytes_to_tensor(byte_data, shape, device="cpu"):
    import numpy as np

    tensor_np = np.frombuffer(byte_data, dtype="float32").copy().reshape(shape)
    return torch.from_numpy(tensor_np).to(device)

In [None]:
def encrypt_data(data_bytes, key, iv):
    encryptor = Cipher(
        algorithms.AES(key), modes.GCM(iv), backend=default_backend()
    ).encryptor()
    return encryptor.update(data_bytes) + encryptor.finalize(), encryptor.tag



In [None]:
def decrypt_data(ciphertext, tag, key, iv):
    decryptor = Cipher(
        algorithms.AES(key), modes.GCM(iv, tag), backend=default_backend()
    ).decryptor()
    return decryptor.update(ciphertext) + decryptor.finalize()

In [None]:
class LatentEncryptor:
    def __init__(self, key=None):
        self.key = key if key else generate_key()

    def encrypt(self, latent_tensor):
        data_bytes, shape = tensor_to_bytes(latent_tensor)
        iv = generate_iv()
        ciphertext, tag = encrypt_data(data_bytes, self.key, iv)
        return ciphertext, {"iv": iv, "tag": tag, "shape": shape}

    def decrypt(self, ciphertext, metadata, device="cpu"):
        plaintext = decrypt_data(ciphertext, metadata["tag"], self.key, metadata["iv"])
        return bytes_to_tensor(plaintext, metadata["shape"], device=device)

In [None]:
def ciphertext_to_bits(ciphertext, max_len=None):
    import numpy as np

    byte_array = np.frombuffer(ciphertext, dtype=np.uint8)
    bit_array = np.unpackbits(byte_array).astype(np.float32)
    if max_len is not None:
        padded = np.zeros(max_len, dtype=np.float32)
        padded[: len(bit_array)] = bit_array
        bit_array = padded
    return torch.from_numpy(bit_array)

In [None]:
def bits_to_ciphertext(bit_tensor, original_byte_len):
    import numpy as np

    bit_array = (bit_tensor.cpu().numpy() >= 0.5).astype(np.uint8)[
        : original_byte_len * 8
    ]
    return np.packbits(bit_array).tobytes()

# **# Image Generation**

In [None]:
import torch
import numpy as np

In [None]:
class ImageGenerator:
    """Wrapper for AI image generation."""

    def __init__(self, device="cpu", use_dummy=True):
        self.device = device
        self.use_dummy = use_dummy
        self.pipeline = None

        if not self.use_dummy:
            try:
                from diffusers import StableDiffusionPipeline

                self.pipeline = StableDiffusionPipeline.from_pretrained(
                    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
                )
                self.pipeline = self.pipeline.to(self.device)
            except ImportError:
                print("Diffusers not installed. Falling back to dummy generator.")
                self.use_dummy = True

    def generate_cover(
        self,
        prompt="A beautiful realistic landscape photo, 4k resolution",
        size=(256, 256),
    ):
        if self.use_dummy or self.pipeline is None:
            img_tensor = (
                torch.rand((3, size[0], size[1]), dtype=torch.float32)
                .to(self.device)
                .unsqueeze(0)
            )
            import torch.nn.functional as F

            img_tensor = F.avg_pool2d(
                img_tensor, kernel_size=5, stride=1, padding=2
            ).squeeze(0)
            return (img_tensor - img_tensor.min()) / (
                img_tensor.max() - img_tensor.min() + 1e-8
            )
        else:
            image = self.pipeline(
                prompt, height=size[0], width=size[1], num_inference_steps=20
            ).images[0]
            image_np = np.array(image).astype(np.float32) / 255.0
            return torch.from_numpy(image_np).permute(2, 0, 1).to(self.device)




# **Stego Networks**

In [None]:
import torch
import torch.nn as nn

In [None]:
class HiderNetwork(nn.Module):
    def __init__(self, cover_channels=3, secret_channels=1, hidden_channels=64):
        super().__init__()
        in_channels = cover_channels + secret_channels
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, cover_channels, 3, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, cover, secret):
        return self.net(torch.cat([cover, secret], dim=1))

In [None]:
class RevealerNetwork(nn.Module):
    def __init__(self, stego_channels=3, secret_channels=1, hidden_channels=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(stego_channels, hidden_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, secret_channels, 3, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, stego):
        return self.net(stego)

In [None]:
def format_secret_for_hiding(secret_bits, target_shape):
    B, C, H, W = target_shape
    total_elements = C * H * W
    padded = torch.zeros(B, total_elements, device=secret_bits.device)
    for i in range(B):
        seq = secret_bits[i] if secret_bits.dim() > 1 else secret_bits
        length = min(len(seq), total_elements)
        padded[i, :length] = seq[:length]
    return padded.view(B, C, H, W)

In [None]:
def extract_secret_from_prediction(secret_pred_spatial, original_length):
    return secret_pred_spatial.view(secret_pred_spatial.shape[0], -1)[
        :, :original_length
    ]

# **Training**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os

In [None]:
from video_autoencoder import VideoAutoencoder
from stego_networks import (
    HiderNetwork,
    RevealerNetwork,
    format_secret_for_hiding,
    extract_secret_from_prediction,
)
from image_generator import ImageGenerator

In [None]:
class DummyVideoDataset(Dataset):
    def __init__(self, num_samples=100, frames=16, height=64, width=64):
        self.num_samples = num_samples
        self.frames = frames
        self.height = height
        self.width = width

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return torch.rand(
            (3, self.frames, self.height, self.width), dtype=torch.float32
        )

In [None]:
import cv2
import glob
from torchvision import transforms
from PIL import Image

In [None]:
class RealVideoDataset(Dataset):
    """Loads actual .mp4 or .avi videos from a directory for autoencoder training."""
    def __init__(self, directory, frames=16, height=64, width=64):
        self.video_paths = glob.glob(os.path.join(directory, "**", "*.avi"), recursive=True) + \
                           glob.glob(os.path.join(directory, "**", "*.mp4"), recursive=True)
        self.frames = frames
        self.height = height
        self.width = width
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((self.height, self.width)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        cap = cv2.VideoCapture(self.video_paths[idx])
        frames = []
        while len(frames) < self.frames:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame_tensor = self.transform(frame) # shape (3, H, W)
            frames.append(frame_tensor)
        cap.release()
        
        # If video is too short, pad it with the last frame
        while len(frames) < self.frames and len(frames) > 0:
            frames.append(frames[-1])
            
        # If video couldn't be loaded at all, return zeros (edge case fallback)
        if len(frames) == 0:
            return torch.zeros((3, self.frames, self.height, self.width), dtype=torch.float32)
            
        # Stack into (C, F, H, W)
        video_tensor = torch.stack(frames, dim=1)
        return video_tensor

In [None]:
def train_video_autoencoder(model, dataloader, epochs=5, device="cpu"):
    print("--- Training Video Autoencoder ---")
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for batch in dataloader:
            batch = batch.to(device)
            optimizer.zero_grad()
            reconstructed, _ = model(batch)
            loss = criterion(reconstructed, batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(dataloader):.4f}")
    os.makedirs("../models", exist_ok=True)
    torch.save(model.state_dict(), "../models/video_autoencoder.pth")

In [None]:
def train_stego_networks(
    hider, revealer, image_generator, epochs=5, device="cpu", secret_dim=4096
):
    print("\\n--- Training Steganography Networks ---")
    hider.to(device)
    revealer.to(device)
    criterion_mse = nn.MSELoss()
    criterion_bce = nn.BCELoss()
    optimizer = optim.Adam(
        list(hider.parameters()) + list(revealer.parameters()), lr=1e-3
    )
    batch_size = 4
    iterations = 20
    for epoch in range(epochs):
        hider.train()
        revealer.train()
        img_loss = 0.0
        bit_loss = 0.0
        for _ in range(iterations):
            covers = torch.stack(
                [image_generator.generate_cover(size=(256, 256)) for _ in range(batch_size)]
            ).to(device)
            secret_bits = (
                torch.randint(0, 2, (batch_size, secret_dim)).float().to(device)
            )
            spatial_secret = format_secret_for_hiding(
                secret_bits, (batch_size, 1, 256, 256)
            )
            stego = hider(covers, spatial_secret)
            secret_pred = extract_secret_from_prediction(revealer(stego), secret_dim)
            secret_pred = torch.clamp(secret_pred, 1e-7, 1.0 - 1e-7)
            l_img = criterion_mse(stego, covers)
            l_bit = criterion_bce(secret_pred, secret_bits)
            loss = (10.0 * l_img) + l_bit
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            img_loss += l_img.item()
            bit_loss += l_bit.item()
        print(
            f"Epoch {epoch+1}/{epochs}, Img Loss: {img_loss/iterations:.4f}, Bit Loss: {bit_loss/iterations:.4f}"
        )
    torch.save(hider.state_dict(), "../models/hider.pth")
    torch.save(revealer.state_dict(), "../models/revealer.pth")

In [None]:
if __name__ == "__main__":
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps") # Uses Apple Silicon GPU!
    else:
        device = torch.device("cpu")
        
    print(f"Using device: {device}")
    
    # -------------------------------------------------------------
    # 1. SETUP: CHOOSE DUMMY DATA OR REAL DATA
    # -------------------------------------------------------------
    # Change this to True when you have downloaded the datasets!
    USE_REAL_DATA = True 
    
    # Define paths to your downloaded folders
    VIDEO_DATASET_PATH = "dataset" # Update this path on Kaggle!
    
    if USE_REAL_DATA:
        print("Loading REAL Video Dataset... (This might take a moment)")
        # Load your real video dataset and create a dataloader
        video_dataset = RealVideoDataset(directory=VIDEO_DATASET_PATH, frames=16)
        video_loader = DataLoader(video_dataset, batch_size=4, shuffle=True)
    else:
        print("Loading DUMMY Video Dataset... (For quick testing)")
        video_dataset = DummyVideoDataset(num_samples=20)
        video_loader = DataLoader(video_dataset, batch_size=4)

    # -------------------------------------------------------------
    # 2. RUN AUTOENCODER TRAINING
    # -------------------------------------------------------------
    ae = VideoAutoencoder(3, 256)
    train_video_autoencoder(ae, video_loader, epochs=2, device=device)

    # -------------------------------------------------------------
    # 3. RUN STEGANOGRAPHY NETWORKS TRAINING
    # -------------------------------------------------------------
    # For Stego networks, Stable Diffusion `ImageGenerator(use_dummy=not USE_REAL_DATA)` 
    # will handle real cover images if USE_REAL_DATA=True!
    hider = HiderNetwork(3, 1, 32)
    revealer = RevealerNetwork(3, 1, 32)
    img_gen = ImageGenerator(device, use_dummy=not USE_REAL_DATA)
    
    train_stego_networks(
        hider,
        revealer,
        img_gen,
        epochs=2,
        device=device,
        secret_dim=8416, 
    )
    print("Training Complete!")

# **Pipeline**

In [None]:
import torch
import os
import cv2
import numpy as np

In [None]:
from utils import extract_frames, compile_video
from video_autoencoder import VideoAutoencoder
from encryption import LatentEncryptor, ciphertext_to_bits, bits_to_ciphertext
from image_generator import ImageGenerator
from stego_networks import (
    HiderNetwork,
    RevealerNetwork,
    format_secret_for_hiding,
    extract_secret_from_prediction,
)

In [None]:
def save_image(tensor, path):
    import torchvision

    torchvision.utils.save_image(tensor, path)

In [None]:
class SteganoPipeline:
    def __init__(self, device="cpu"):
        self.device = device
        self.autoencoder = VideoAutoencoder(in_channels=3, latent_dim=256).to(device)
        self.encryptor = LatentEncryptor()
        self.generator = ImageGenerator(device=device, use_dummy=True)
        self.hider = HiderNetwork(cover_channels=3, secret_channels=1).to(device)
        self.revealer = RevealerNetwork(stego_channels=3, secret_channels=1).to(device)

    def hide_video(self, video_path, output_image_path):
        print(f"1. Extracting frames from {video_path}...")
        frames = (
            extract_frames(video_path, max_frames=16, resize_dim=(64, 64))
            .unsqueeze(0)
            .to(self.device)
        )
        print("2. Compressing video into latent vector...")
        with torch.no_grad():
            _, latent = self.autoencoder(frames)
        print("3. Encrypting latent vector using AES...")
        ciphertext, metadata = self.encryptor.encrypt(latent[0])
        print("4. Generating Cover Image (256x256)...")
        cover_image = (
            self.generator.generate_cover(size=(256, 256)).unsqueeze(0).to(self.device)
        )
        print("5. Packing encrypted data into spatial tensor...")
        spatial_secret = format_secret_for_hiding(
            ciphertext_to_bits(ciphertext, 65536).unsqueeze(0).to(self.device),
            (1, 1, 256, 256),
        )
        print("6. Embedding secret into Cover Image...")
        with torch.no_grad():
            stego_image = self.hider(cover_image, spatial_secret)
        print(f"7. Saving Stego Image to {output_image_path}...")
        save_image(stego_image[0], output_image_path)
        return metadata, len(ciphertext)

    def extract_video(self, stego_image_path, metadata, cipher_len, output_video_path):
        from torchvision.io import read_image

        print(f"1. Loading Stego Image from {stego_image_path}...")
        stego_image = (
            (read_image(stego_image_path).float() / 255.0).unsqueeze(0).to(self.device)
        )
        print("2. Extracting spatial data...")
        with torch.no_grad():
            secret_pred_spatial = self.revealer(stego_image)
        print("3. Reconstructing bit stream...")
        bit_tensor = extract_secret_from_prediction(secret_pred_spatial, cipher_len * 8)
        print("4. Repacking bits to ciphertext...")
        recovered_ciphertext = bits_to_ciphertext(bit_tensor[0], cipher_len)
        print("5. Decrypting latent vector using AES...")
        try:
            recovered_latent = self.encryptor.decrypt(
                recovered_ciphertext, metadata, self.device
            )
            print("Decryption successful!")
        except Exception:
            print(
                "Decryption failed! Networks untrained (Invalid AES Tag error expected during testing)."
            )
            print(
                "Using a random fallback latent vector to demonstrate the rest of the pipeline..."
            )
            recovered_latent = torch.rand(metadata["shape"]).to(self.device)

        print("6. Reconstructing video frames...")
        with torch.no_grad():
            reconstructed_frames = self.autoencoder.decoder(
                recovered_latent.unsqueeze(0)
            )
        print(f"7. Saving Reconstructed Video to {output_video_path}...")
        compile_video(reconstructed_frames[0], output_video_path, fps=15)
        print("--- Decode Complete ---")

In [None]:
if __name__ == "__main__":
    os.makedirs("../data", exist_ok=True)
    os.makedirs("../models", exist_ok=True)
    dummy_video_path = "../data/dummy_video.mp4"
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(dummy_video_path, fourcc, 15, (64, 64))
    for _ in range(16):
        out.write(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8))
    out.release()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pipeline = SteganoPipeline(device=device)

    stego_img_path = "../data/stego_output.png"
    recon_video_path = "../data/reconstructed_video.mp4"

    metadata, cipher_length = pipeline.hide_video(dummy_video_path, stego_img_path)
    pipeline.extract_video(stego_img_path, metadata, cipher_length, recon_video_path)

