In [2]:
!pip install torch torchvision pillow tqdm numpy matplotlib torchmetrics



In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models


# ---------------------------
# CBAM (Channel + Spatial)
# ---------------------------
class CBAM(nn.Module):
    def __init__(self, channels, reduction=16, kernel_size=7):
        super().__init__()
        # Channel attention
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )
        self.sigmoid_channel = nn.Sigmoid()

        # Spatial attention
        self.conv_spatial = nn.Conv2d(2, 1, kernel_size, padding=(kernel_size//2), bias=False)
        self.sigmoid_spatial = nn.Sigmoid()

    def forward(self, x):
        # Channel
        avg = self.mlp(self.avg_pool(x))
        max_ = self.mlp(self.max_pool(x))
        ch_attn = self.sigmoid_channel(avg + max_)
        x = x * ch_attn

        # Spatial
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        max_pool, _ = torch.max(x, dim=1, keepdim=True)
        spat = torch.cat([avg_pool, max_pool], dim=1)
        spat_attn = self.sigmoid_spatial(self.conv_spatial(spat))

        return x * spat_attn


# ---------------------------
# Residual Block (small)
# ---------------------------
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(x + self.conv(x))
        

class ResNetCBAMEncoder(nn.Module):
    def __init__(self, pretrained=True, use_cbam=True):
        super().__init__()
        resnet = models.resnet34(pretrained=pretrained)
        self.conv6 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
        nn.init.kaiming_normal_(self.conv6.weight, mode='fan_out', nonlinearity='relu')
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        self.use_cbam = use_cbam
        if use_cbam:
            self.cbam2 = CBAM(128)
            self.cbam3 = CBAM(256)
            self.cbam4 = CBAM(512)

        # Upsample head
        self.stego_head = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),   # 8→16
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),   # 16→32
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),   # 32→64
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),   # 64→128
            nn.Conv2d(32, 16, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),   # 128→256 ✅
            nn.Conv2d(16, 3, 1),
            nn.Sigmoid()
        )

    def forward(self, cover, secret_half):
        secret_resized = F.interpolate(secret_half, size=cover.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([cover, secret_resized], dim=1)
        x = self.conv6(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        if self.use_cbam: x = self.cbam2(x)
        x = self.layer3(x)
        if self.use_cbam: x = self.cbam3(x)
        x = self.layer4(x)
        if self.use_cbam: x = self.cbam4(x)
        stego = self.stego_head(x)
        return stego


class ResNetCBAMDecoder(nn.Module):
    def __init__(self, use_cbam=True):
        super().__init__()
        self.use_cbam = use_cbam
        self.conv_in = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.res1 = ResidualBlock(64)
        self.res2 = ResidualBlock(64)
        self.res3 = ResidualBlock(64)
        if use_cbam:
            self.cbam = CBAM(64)

        self.secret_head = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(size=(128, 256), mode='bilinear', align_corners=False),
            nn.Conv2d(32, 16, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 3, 1),
            nn.Sigmoid()
        )

    def forward(self, stego):
        x = self.conv_in(stego)
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        if self.use_cbam:
            x = self.cbam(x)
        secret_half = self.secret_head(x)
        return secret_half



# ---------------------------
# Quick utility: count params
# ---------------------------
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



    rec_half = decoder(stego)                  # [B,3,128,256]
    print("stego.shape", stego.shape, "rec_half.shape", rec_half.shape)


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchmetrics.functional import structural_similarity_index_measure as ssim
from tqdm import tqdm

# Optional perceptual loss (VGG-based)
from torchvision.models import vgg16
class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg16(weights="IMAGENET1K_V1").features[:16].eval()  # up to conv3_3
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg

    def forward(self, pred, target):
        return nn.functional.l1_loss(self.vgg(pred), self.vgg(target))


def train(encoder, decoder, dataloader, device,start=0, epochs=20, lr=1e-4,save_dir="/kaggle/working"):
    # Optimizer
    optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)
    encoder.to(device)
    decoder.to(device)
    # Losses
    mse_loss = nn.MSELoss()
    perc_loss = PerceptualLoss().to(device)

    best_loss = float("inf")
    best_file = os.path.join(save_dir, "best_loss.pth")
    if os.path.exists(best_file):
        best_loss = torch.load(best_file)  # reload previous best loss
        print(f"Resuming training. Previous best loss = {best_loss:.6f}")

    for epoch in range(start,epochs):
        encoder.train()
        decoder.train()

        total_loss = 0.0
        total_perc = 0.0
        total_ssim = 0.0

        loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")

        for covers, secret in loop:
            cover1, cover2 = covers
            cover1, cover2, secret = cover1.to(device), cover2.to(device), secret.to(device)

            # Split secret into halves
            secret_top = secret[:, :, 0:128, :]
            secret_bottom = secret[:, :, 128:256, :]

            # Forward pass
            stego1 = encoder(cover1, secret_top)
            stego2 = encoder(cover2, secret_bottom)

            rec_top = decoder(stego1)
            rec_bottom = decoder(stego2)
            rec_secret = torch.cat([rec_top, rec_bottom], dim=2)               # decoder output

            if stego1.shape[2:] != cover1.shape[2:]:
                print("Miss match")
                stego1 = F.interpolate(stego1, size=cover1.shape[2:], mode="bilinear", align_corners=False)
            if stego2.shape[2:] != cover2.shape[2:]:
                stego2 = F.interpolate(stego2, size=cover2.shape[2:], mode="bilinear", align_corners=False)

            # Cover imperceptibility
            loss_cover_mse = mse_loss(stego1, cover1) + mse_loss(stego2, cover2)
            loss_cover_perc = perc_loss(stego1, cover1) + perc_loss(stego2, cover2)

            # Secret recoverability
            loss_secret_mse = mse_loss(rec_secret, secret)
            loss_secret_ssim = 1 - ssim(rec_secret, secret)
            loss_secret_perc = perc_loss(rec_secret, secret)

            # Hybrid Loss
            loss_cover = loss_cover_mse + 0.2 * loss_cover_perc
            loss_secret = loss_secret_mse + 0.5 * loss_secret_ssim + 0.2 * loss_secret_perc
            loss = loss_cover + 2*loss_secret


            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            batch_perc = perc_loss(rec_secret, secret)
            batch_ssim = ssim(rec_secret, secret)

            total_loss += loss.item()
            total_perc += batch_perc.item()
            total_ssim += batch_ssim.item()

            loop.set_postfix(loss=loss.item(),
                             perc=batch_perc.item(),
                             ssim=batch_ssim.item())

        # Epoch averages
        avg_loss = total_loss / len(dataloader)
        avg_perc = total_perc / len(dataloader)
        avg_ssim = total_ssim / len(dataloader)

        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, PERC={avg_perc:.2f}, SSIM={avg_ssim:.4f}")

        torch.save(encoder.state_dict(), os.path.join(save_dir, f"encoder_epoch{epoch+1}.pth"))
        torch.save(decoder.state_dict(), os.path.join(save_dir, f"decoder_epoch{epoch+1}.pth"))

        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(encoder.state_dict(), os.path.join(save_dir, "encoder_best.pth"))
            torch.save(decoder.state_dict(), os.path.join(save_dir, "decoder_best.pth"))
            torch.save(best_loss, best_file)
            print("✅ Saved best model")

    print("Training finished!")


In [5]:
import os
import random
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms

class StegoDataset(Dataset):
    def __init__(self, dataset_dir, image_size=256):
        self.dataset_dir = dataset_dir
        self.images = sorted(os.listdir(dataset_dir))

        # image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # pick secret image
        secret_path = os.path.join(self.dataset_dir, self.images[idx])
        secret = Image.open(secret_path).convert("RGB")

        # pick two random cover images (not equal to secret)
        cover_choices = list(range(len(self.images)))
        cover_choices.remove(idx)
        cover1_idx, cover2_idx = random.sample(cover_choices, 2)

        cover1_path = os.path.join(self.dataset_dir, self.images[cover1_idx])
        cover2_path = os.path.join(self.dataset_dir, self.images[cover2_idx])

        cover1 = Image.open(cover1_path).convert("RGB")
        cover2 = Image.open(cover2_path).convert("RGB")

        # apply transforms
        cover1 = self.transform(cover1)
        cover2 = self.transform(cover2)
        secret = self.transform(secret)

        return (cover1, cover2), secret


def get_dataloader(dataset_dir, batch_size=8, image_size=256, shuffle=True):
    dataset = StegoDataset(dataset_dir, image_size=image_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader

def get_half_dataloader(dataset_dir, batch_size=8, image_size=256, shuffle=True, first_half=True):
    """
    Returns a dataloader for either the first half or second half of the dataset.
    
    Args:
        dataset_dir (str): Path to dataset
        batch_size (int): Batch size
        image_size (int): Image size
        shuffle (bool): Whether to shuffle
        first_half (bool): If True, use first half; else use second half
    
    Returns:
        DataLoader
    """
    dataset = StegoDataset(dataset_dir, image_size=image_size)
    half_len = len(dataset) // 25

    if first_half:
        indices = list(range(half_len))
    else:
        indices = list(range(half_len, len(dataset)))

    subset = Subset(dataset, indices)
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle)
    return dataloader

In [6]:
dataset_path = "/kaggle/input/pimagenet/AllData"   # your dataset folder path
train_loader = get_half_dataloader(dataset_path, batch_size=3, image_size=256)
encoder_model=ResNetCBAMEncoder()
decoder_model=ResNetCBAMDecoder()

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 212MB/s]


In [7]:
start=10

In [8]:
encoder_model.load_state_dict(torch.load(f"/kaggle/working/encoder_epoch{start}.pth"))
decoder_model.load_state_dict(torch.load(f"/kaggle/working/decoder_epoch{start}.pth"))

<All keys matched successfully>

In [10]:
device=device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train(encoder_model, decoder_model, train_loader, device=device, start=start, epochs=100,save_dir="/kaggle/working")

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 242MB/s] 


Resuming training. Previous best loss = 0.840534


Epoch 11/100: 100%|██████████| 7198/7198 [44:23<00:00,  2.70it/s, loss=1.09, perc=0.678, ssim=0.483] 


Epoch 11: Loss=0.8301, PERC=0.46, SSIM=0.6328
✅ Saved best model


Epoch 12/100:   3%|▎         | 240/7198 [01:24<41:00,  2.83it/s, loss=0.697, perc=0.383, ssim=0.691]