In [None]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from itertools import cycle
from PIL import Image
from torchvision.utils import save_image
import numpy as np
import torchvision.utils as vutils
import cv2

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_channels),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, num_res_blocks=6):
        super(Generator, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        )

        res_blocks = [ResidualBlock(256) for _ in range(num_res_blocks)]
        self.residual_blocks = nn.Sequential(*res_blocks)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.residual_blocks(x)
        x = self.decoder(x)
        return x

In [3]:
# PatchGAN Discriminator class
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        return self.layers(x).view(x.size(0), -1)

In [4]:
def pvd_embed(cover_image, secret_image):
    device = cover_image.device
    
    cover_image_np = tensor_to_np(cover_image.cpu())
    secret_image_np = tensor_to_np(secret_image.cpu())
    
    # batch_size, channel, height, width
    _, _, height, width = cover_image_np.shape
    stego_img = np.zeros((1, 3, height, width), dtype=np.uint8)
    

    for c in range(3):
        for i in range(height):
            for j in range(width):
                cover_pixel = int(cover_image_np[0, c, i, j])
                secret_pixel = int(secret_image_np[0, c, i, j])

                stego_pixel = (cover_pixel & 0xFE) | (secret_pixel >> 7)
                stego_img[0, c, i, j] = stego_pixel

    stego_img_tensor = np_to_tensor(stego_img, device)
    return stego_img_tensor


def pvd_extract(stego_image):
    device = stego_image.device
    
    stego_image_np = tensor_to_np(stego_image.cpu())
    
    _, _, height, width = stego_image_np.shape
    extracted_img = np.zeros((1, 3, height, width), dtype=np.uint8)

    for c in range(3):
        for i in range(height):
            for j in range(width):
                stego_pixel = int(stego_image_np[0, c, i, j])

                secret_pixel = ((stego_pixel & 0x01) << 7)
                extracted_img[0, c, i, j] = secret_pixel

    extracted_img_tensor = np_to_tensor(extracted_img, device)
    return extracted_img_tensor

def tensor_to_np(tensor):
    return tensor.detach().cpu().numpy()

def np_to_tensor(array, device):
    return torch.tensor(array, device=device, dtype=torch.float32)

In [5]:
def train(generator_H2S, generator_S2H,
          discriminator_H, discriminator_S,
          cover_loader, stego_loader, hidden_loader,
          val_cover_loader, val_stego_loader, val_hidden_loader, device, 
          optimizer_G,
          optimizer_H, optimizer_S, 
          num_epochs, flag):

    for epoch in range(num_epochs):
        for cover, stego, hidden in zip(cover_loader, stego_loader, hidden_loader):
            cover_img, _ = cover
            stego_img, _ = stego
            hidden_img, _ = hidden
            
            cover_img = cover_img.to(device)
            stego_img = stego_img.to(device)
            hidden_img = hidden_img.to(device)

            # --------------------
            # Train the generators
            # --------------------
            
            # Identity loss
            same_hidden = generator_S2H(hidden_img)
            loss_identity_H = criterion_cycle(same_hidden, hidden_img) * 5.0
            
            same_stego = generator_H2S(stego_img)
            loss_identity_S = criterion_cycle(same_stego, stego_img) * 5.0

            # GAN loss
            fake_hidden = generator_S2H(stego_img)
            pred_fake_hidden = discriminator_H(fake_hidden)
            loss_gan_S2H = criterion_gan(pred_fake_hidden, torch.ones_like(pred_fake_hidden))

            fake_stego = generator_H2S(hidden_img)
            pred_fake_stego = discriminator_S(fake_stego)
            loss_gan_H2S = criterion_gan(pred_fake_stego, torch.ones_like(pred_fake_stego))

            # Cycle loss
            
            # It is lossy, so we only need to learn lossy output    
            """
            recovered_stego = generator_H2S(fake_hidden)
            pvd_recovered_stego = pvd_embed(cover_img, fake_hidden)
            loss_cycle_S = criterion_cycle(recovered_stego, pvd_recovered_stego) * 10.0
            """

            recovered_hidden = generator_S2H(fake_stego)
            pvd_recovered_hidden = pvd_extract(fake_stego)
            loss_cycle_H = criterion_cycle(recovered_hidden, pvd_recovered_hidden) * 10.0

            # Total generator loss
            total_g_loss = loss_identity_H + loss_identity_S + loss_gan_S2H + loss_gan_H2S + loss_cycle_H
            total_g_loss.backward()
            
            optimizer_G.step()
            
            # ------------------------
            # Train the discriminators
            # ------------------------
            
            optimizer_H.zero_grad()
            optimizer_S.zero_grad()
            
            # Discriminator S
            
            # Real loss
            pred_real_stego = discriminator_S(stego_img)
            loss_d_real_stego = criterion_gan(pred_real_stego, torch.ones_like(pred_real_stego))

            # Fake loss
            fake_stego.detach_()
            pred_fake_stego = discriminator_S(fake_stego)
            loss_d_fake_stego = criterion_gan(pred_fake_stego, torch.zeros_like(pred_fake_stego))

            # Total Discriminator S loss
            total_d_s_loss = (loss_d_real_stego + loss_d_fake_stego) * 0.5
            total_d_s_loss.backward()

            optimizer_S.step()
            
            # Discriminator H
            
            # Real loss
            pred_real_hidden = discriminator_H(hidden_img)
            loss_d_real_hidden = criterion_gan(pred_real_hidden, torch.ones_like(pred_real_hidden))

            # Fake loss
            fake_hidden.detach_()
            pred_fake_hidden = discriminator_H(fake_hidden)
            loss_d_fake_hidden = criterion_gan(pred_fake_hidden, torch.zeros_like(pred_fake_hidden))

            # Total Discriminator H loss
            total_d_h_loss = (loss_d_real_hidden + loss_d_fake_hidden) * 0.5
            total_d_h_loss.backward()
            
            optimizer_H.step()


        print(f"Epoch [{epoch+1}/{num_epochs}] | Generator Loss: {total_g_loss.item():.4f} | Discriminator S Loss: {total_d_s_loss.item():.4f} | Discriminator H Loss: {total_d_h_loss.item():.4f}")

        # Save the model checkpoints periodically or implement a custom model saving logic
        if (epoch + 1) % 50 == 0:
            torch.save(generator_H2S.state_dict(), f'./model_params/PVD/H2S/generator_H2S_epoch_{flag+epoch+1}.pth')
            torch.save(generator_S2H.state_dict(), f'./model_params/PVD/S2H/generator_S2H_epoch_{flag+epoch+1}.pth')
            torch.save(discriminator_H.state_dict(), f'./model_params/PVD/DH/discriminator_H_epoch_{flag+epoch+1}.pth')
            torch.save(discriminator_S.state_dict(), f'./model_params/PVD/DS/discriminator_S_epoch_{flag+epoch+1}.pth')
            torch.save(optimizer_G.state_dict(), f'./model_params/PVD/optimizer/G/optimizer_G_{flag+epoch+1}.pth')
            torch.save(optimizer_H.state_dict(), f'./model_params/PVD/optimizer/H/optimizer_H_{flag+epoch+1}.pth')
            torch.save(optimizer_S.state_dict(), f'./model_params/PVD/optimizer/S/optimizer_S_{flag+epoch+1}.pth')
        
        if (epoch + 1) % 10 == 0:
            generator_H2S.eval()
            generator_S2H.eval()
            discriminator_H.eval()
            discriminator_S.eval()
            
            with torch.no_grad():
                for val_cover, val_stego, val_hidden in zip(val_cover_loader, val_stego_loader, val_hidden_loader):
                    val_cover_img, _ = val_cover
                    val_stego_img, _ = val_stego
                    val_hidden_img, _ = val_hidden

                    val_cover_img = val_cover_img.to(device)
                    val_stego_img = val_stego_img.to(device)
                    val_hidden_img = val_hidden_img.to(device)

                    fake_hidden = generator_S2H(val_stego_img)
                    fake_stego = generator_H2S(val_hidden_img)
                    recovered_hidden = generator_S2H(fake_stego)
                    recovered_stego = generator_H2S(fake_hidden)

                    val_g_loss = criterion_cycle(recovered_hidden, val_hidden_img) + criterion_cycle(recovered_stego, val_stego_img)
                    val_d_h_loss = criterion_gan(discriminator_H(val_hidden_img), torch.ones_like(discriminator_H(val_hidden_img))) + criterion_gan(discriminator_H(fake_hidden.detach()), torch.zeros_like(discriminator_H(fake_hidden.detach())))
                    val_d_s_loss = criterion_gan(discriminator_S(val_stego_img), torch.ones_like(discriminator_S(val_stego_img))) + criterion_gan(discriminator_S(fake_stego.detach()), torch.zeros_like(discriminator_S(fake_stego.detach())))
                    
            generator_H2S.train()
            generator_S2H.train()
            discriminator_H.train()
            discriminator_S.train()
            print(f"Validation | Generator Loss: {val_g_loss.item():.4f} | Discriminator S Loss: {val_d_s_loss.item():.4f} | Discriminator H Loss: {val_d_h_loss.item():.4f}")


In [6]:
# Set device
device = torch.device('cuda')

# Load the datasets
transform = transforms.Compose([
    transforms.Resize((96, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Train set
dataset_cover =ImageFolder('./train/PVD_cover', transform=transform)
dataset_stego = ImageFolder('./train/PVD_stego', transform=transform)
dataset_hidden = ImageFolder('./train/PVD_hidden', transform=transform)

cover_loader = DataLoader(dataset_cover, batch_size=1)
stego_loader = DataLoader(dataset_stego, batch_size=1)
hidden_loader = DataLoader(dataset_hidden, batch_size=1)

# Valid set
dataset_cover_val = ImageFolder('./valid/PVD_cover', transform=transform)
dataset_stego_val = ImageFolder('./valid/PVD_stego', transform=transform)
dataset_hidden_val = ImageFolder('./valid/PVD_hidden', transform=transform)

val_cover_loader = DataLoader(dataset_cover_val, batch_size=1)
val_stego_loader = DataLoader(dataset_stego_val, batch_size=1)
val_hidden_loader = DataLoader(dataset_hidden_val, batch_size=1)

# Test set
dataset_cover_test = ImageFolder('./test/PVD_cover', transform=transform)
dataset_stego_test = ImageFolder('./test/PVD_stego', transform=transform)
dataset_hidden_test = ImageFolder('./test/PVD_hidden', transform=transform)

test_cover_loader = DataLoader(dataset_cover_test, batch_size=1)
test_stego_loader = DataLoader(dataset_stego_test, batch_size=1)
test_hidden_loader = DataLoader(dataset_hidden_test, batch_size=1)

In [7]:
# Instantiate the generators and discriminators
generator_H2S = Generator()
generator_S2H = Generator()

discriminator_H = Discriminator()
discriminator_S = Discriminator()


# Set device
device = torch.device("cuda")

generator_H2S.to(device)
generator_S2H.to(device)

discriminator_H.to(device)
discriminator_S.to(device)


# Define the optimizers
optimizer_G= optim.Adam(list(generator_H2S.parameters()) + list(generator_S2H.parameters()), lr=0.0005, betas=(0.5, 0.999))

optimizer_H = optim.Adam(discriminator_H.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_S = optim.Adam(discriminator_S.parameters(), lr=0.0001, betas=(0.5, 0.999))


# Define the loss functions
criterion_gan = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()

In [None]:
train(generator_H2S, generator_S2H, discriminator_H, discriminator_S,
      cover_loader, stego_loader, hidden_loader, val_cover_loader, val_stego_loader, val_hidden_loader, device, optimizer_G, optimizer_H, optimizer_S, num_epochs=1000, flag=0)


In [12]:
def reconstruct_hidden_image(generator_S2H, stego_img):
    device = 'cuda'
       
    reconstructed_hidden_image = generator_S2H(stego_img.to(device)).detach()
    
    return reconstructed_hidden_image.cpu()

def load_saved_dicts(generator_H2S, generator_S2H, discriminator_S, discriminator_H, optimizer_G, optimizer_S, optimizer_H, flag):
    generator_H2S.load_state_dict(torch.load(f'./model_params/PVD/H2S/generator_H2S_epoch_{flag}.pth'))
    generator_S2H.load_state_dict(torch.load(f'./model_params/PVD/S2H/generator_S2H_epoch_{flag}.pth'))
    discriminator_S.load_state_dict(torch.load(f'./model_params/PVD/DS/discriminator_S_epoch_{flag}.pth'))
    discriminator_H.load_state_dict(torch.load(f'./model_params/PVD/DH/discriminator_H_epoch_{flag}.pth'))
    
    optimizer_G.load_state_dict(torch.load(f'./model_params/PVD/optimizer/G/optimizer_G_{flag}.pth'))
    optimizer_S.load_state_dict(torch.load(f'./model_params/PVD/optimizer/S/optimizer_S_{flag}.pth'))
    optimizer_H.load_state_dict(torch.load(f'./model_params/PVD/optimizer/H/optimizer_H_{flag}.pth'))
    

In [9]:
def reconstruct_hidden_images(generator_H2S, generator_S2H, test_cover_loader, test_stego_loader, output_dir, flag):
    to_pil_image = transforms.ToPILImage()
    
    for i, (cover_data, stego_data) in enumerate(zip(test_cover_loader, test_stego_loader)):
        cover_image_tensor, _ = cover_data
        stego_image_tensor, _ = stego_data
        
        # Move tensors to the appropriate device
        device = torch.device("cuda")
        cover_image_tensor = cover_image_tensor.to(device)
        stego_image_tensor = stego_image_tensor.to(device)
        
        # Reconstruct the hidden image
        reconstructed_hidden_image = reconstruct_hidden_image(generator_S2H, stego_image_tensor)
        
        # Convert the reconstructed hidden image back to a PIL Image
        reconstructed_hidden_image_pil = to_pil_image(reconstructed_hidden_image.squeeze(0))
        
        # Save the reconstructed hidden image
        output_path = os.path.join(output_dir, f"{flag}/{i}.png")
        reconstructed_hidden_image_pil.save(output_path)

In [10]:
def change_lr(optimizer_G, optimizer_H, optimizer_S, flag):
    
    optimizer_G_state_dict = torch.load(f'./model_params/PVD/optimizer/G/optimizer_G_{flag}.pth')
    optimizer_H_state_dict = torch.load(f'./model_params/PVD/optimizer/H/optimizer_H_{flag}.pth')
    optimizer_S_state_dict = torch.load(f'./model_params/PVD/optimizer/S/optimizer_S_{flag}.pth')
    
    optimizer_G_state_dict['param_groups'][0]['lr'] = 0.001
    optimizer_H_state_dict['param_groups'][0]['lr'] = 0.00001
    optimizer_S_state_dict['param_groups'][0]['lr'] = 0.00001

    optimizer_G.load_state_dict(optimizer_G_state_dict)
    optimizer_H.load_state_dict(optimizer_H_state_dict)
    optimizer_S.load_state_dict(optimizer_S_state_dict)

In [14]:
# reconstruct
flag = 500
output_dir = './reconstructed_hidden/PVD/'

load_saved_dicts(generator_H2S, generator_S2H, discriminator_S, discriminator_H, optimizer_G, optimizer_S, optimizer_H, flag=flag)

# change lr at epoch 1000
# change_lr(optimizer_G, optimizer_C, optimizer_S, flag=flag)

flag = 500
# reconstruct at epoch flag
reconstruct_hidden_images(generator_H2S, generator_S2H, test_cover_loader, test_stego_loader, output_dir, flag)

