In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
!pip install torch torchvision pillow tqdm numpy matplotlib torchmetrics

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ====================================
# Encoder
# ====================================
class StegoEncoder(nn.Module):
    def __init__(self):
        super(StegoEncoder, self).__init__()

        # secret buffering: project secret (128x256) -> (256x256)
        self.secret_up = nn.Upsample(size=(256, 256), mode="bilinear", align_corners=False)

        # conv layers for mixing cover + secret
        self.conv1 = nn.Conv2d(6, 64, kernel_size=3, stride=1, padding=1)  # cover+secret
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)  # output stego

        self.relu = nn.ReLU(inplace=True)

    def forward(self, cover, secret_half):
        # Upsample secret to match cover size
        secret_resized = self.secret_up(secret_half)

        # Concatenate along channel dimension
        x = torch.cat([cover, secret_resized], dim=1)  # [B,6,256,256]

        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        stego = torch.sigmoid(self.conv4(x))  # constrain to [0,1]

        return stego


# ====================================
# Decoder
# ====================================
class StegoDecoder(nn.Module):
    def __init__(self):
        super(StegoDecoder, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  # downsample
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # upsample
        self.conv5 = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, stego):
        x = self.relu(self.conv1(stego))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))  # back to 256x256
        secret_recon = torch.sigmoid(self.conv5(x))

        # crop back to half-height (128x256)
        secret_half = secret_recon[:, :, 0:128, :]
        return secret_half



In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.models import vgg16
import torch.nn.functional as F
from torchmetrics.functional import structural_similarity_index_measure as ssim

class VGGPerceptualLoss(nn.Module):
    def __init__(self, device="cuda"):
        super().__init__()
        vgg = vgg16(pretrained=True).features[:16].eval().to(device)  # up to relu3_3
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg

    def forward(self, x, y):
        return F.mse_loss(self.vgg(x), self.vgg(y))

# ====================================
# Training Loop
# ====================================
def train_steganography(
    encoder,
    decoder,
    dataloader,
    start=0,
    num_epochs=20,
    device="cuda",
    save_dir="./checkpoints"
):
    os.makedirs(save_dir, exist_ok=True)

    encoder = encoder.to(device)
    decoder = decoder.to(device)

    mse_loss = nn.MSELoss()
    perceptual_loss = VGGPerceptualLoss(device)

    best_loss = float("inf")
    best_file = os.path.join(save_dir, "best_loss.pth")
    if os.path.exists(best_file):
        best_loss = torch.load(best_file)  # reload previous best loss
        print(f"Resuming training. Previous best loss = {best_loss:.6f}")

    # Optimizer
    optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=1e-4
    )

    for epoch in range(start,num_epochs):
        encoder.train()
        decoder.train()
        epoch_loss = 0.0

        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

        for (covers, secret) in pbar:
            cover1, cover2 = covers
            cover1, cover2, secret = cover1.to(device), cover2.to(device), secret.to(device)
            secret_top = secret[:, :, 0:128, :]     # [B,3,128,256]
            secret_bottom = secret[:, :, 128:256, :]  # [B,3,128,256]

            assert secret_top.size(2) == 128 and secret_bottom.size(2) == 128, \
                f"Split failed: {secret_top.shape}, {secret_bottom.shape}"
            # Forward pass
            stego1 = encoder(cover1, secret_top)      # [B,3,256,256]
            stego2 = encoder(cover2, secret_bottom)   # [B,3,256,256]

            rec_top= decoder(stego1)       # [B,3,128,256]
            rec_bottom= decoder(stego2)                # [B,3,128,256]
            rec_secret = torch.cat([rec_top, rec_bottom], dim=2)

            # Loss terms
            loss_cover_mse = mse_loss(stego1, cover1) + mse_loss(stego2, cover2)
            loss_cover_perc = perceptual_loss(stego1, cover1) + perceptual_loss(stego2, cover2)
            loss_secret_perc = perceptual_loss(rec_secret, secret)
            loss_secret_mse = mse_loss(rec_secret, secret)
            loss_secret_ssim = 1 - ssim(rec_secret, secret)

            loss_cover = loss_cover_mse + 0.2 * loss_cover_perc
            loss_secret = loss_secret_mse + 0.5 * loss_secret_ssim + 0.2 * loss_secret_perc
            loss = loss_cover + 2 * loss_secret

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            pbar.set_postfix({"loss": loss.item()})

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Loss: {avg_loss:.6f}")
        torch.save(encoder.state_dict(), os.path.join(save_dir, f"encoder_epoch{epoch+1}.pth"))
        torch.save(decoder.state_dict(), os.path.join(save_dir, f"decoder_epoch{epoch+1}.pth"))

        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(encoder.state_dict(), os.path.join(save_dir, "encoder_best.pth"))
            torch.save(decoder.state_dict(), os.path.join(save_dir, "decoder_best.pth"))
            torch.save(best_loss, best_file)
            print(f"✅ Saved best model at epoch {epoch+1}")

    print("🎉 Training complete!")

In [4]:
import os
import random
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms

class StegoDataset(Dataset):
    def __init__(self, dataset_dir, image_size=256):
        self.dataset_dir = dataset_dir
        self.images = sorted(os.listdir(dataset_dir))

        # image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # pick secret image
        secret_path = os.path.join(self.dataset_dir, self.images[idx])
        secret = Image.open(secret_path).convert("RGB")

        # pick two random cover images (not equal to secret)
        cover_choices = list(range(len(self.images)))
        cover_choices.remove(idx)
        cover1_idx, cover2_idx = random.sample(cover_choices, 2)

        cover1_path = os.path.join(self.dataset_dir, self.images[cover1_idx])
        cover2_path = os.path.join(self.dataset_dir, self.images[cover2_idx])

        cover1 = Image.open(cover1_path).convert("RGB")
        cover2 = Image.open(cover2_path).convert("RGB")

        # apply transforms
        cover1 = self.transform(cover1)
        cover2 = self.transform(cover2)
        secret = self.transform(secret)

        return (cover1, cover2), secret


def get_dataloader(dataset_dir, batch_size=8, image_size=256, shuffle=True):
    dataset = StegoDataset(dataset_dir, image_size=image_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader

def get_half_dataloader(dataset_dir, batch_size=8, image_size=256, shuffle=True, first_half=True):
    """
    Returns a dataloader for either the first half or second half of the dataset.
    
    Args:
        dataset_dir (str): Path to dataset
        batch_size (int): Batch size
        image_size (int): Image size
        shuffle (bool): Whether to shuffle
        first_half (bool): If True, use first half; else use second half
    
    Returns:
        DataLoader
    """
    dataset = StegoDataset(dataset_dir, image_size=image_size)
    half_len = len(dataset) // 25

    if first_half:
        indices = list(range(half_len))
    else:
        indices = list(range(half_len, len(dataset)))

    subset = Subset(dataset, indices)
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle)
    return dataloader

In [5]:
dataset_path = "/kaggle/input/pimagenet/AllData"   # your dataset folder path
train_loader = get_half_dataloader(dataset_path, batch_size=4, image_size=256)
encoder_model=StegoEncoder()
decoder_model=StegoDecoder()

In [7]:
train_steganography(encoder_model, decoder_model, train_loader,start=0, num_epochs=20, device="cpu",save_dir="/kaggle/working")

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:04<00:00, 123MB/s] 
Epoch 1/20:   0%|          | 2/5399 [00:59<44:21:05, 29.58s/it, loss=3.83]


KeyboardInterrupt: 

In [None]:
encoder_model.load_state_dict(torch.load("/kaggle/working/encoder_epoch1.pth"))
decoder_model.load_state_dict(torch.load("/kaggle/working/decoder_epoch1.pth"))