In [1]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from itertools import cycle
from PIL import Image
from torchvision.utils import save_image
import numpy as np
import torchvision.utils as vutils
import pywt
import cv2

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_channels),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, num_res_blocks=6):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        )
        
         # Create a list of residual blocks with the specified number of blocks
        res_blocks = [ResidualBlock(256) for _ in range(num_res_blocks)]
        # Convert the list of residual blocks to a sequential model
        self.residual_blocks = nn.Sequential(*res_blocks)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # Encoder-decoder forward pass
        x = self.encoder(x)
        x = self.residual_blocks(x)
        x = self.decoder(x)
        return x

In [3]:
# PatchGAN Discriminator class
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1),
        )
    # Implement the forward pass by passing the input through the layers and reshaping the output
    def forward(self, x):
        return self.layers(x).view(x.size(0), -1)

In [4]:
def dwt_embed(cover_image, secret_image):
    device = cover_image.device

    # Convert the cover and secret images from PyTorch tensors to NumPy arrays
    cover_image_np = tensor_to_np(cover_image.cpu())
    secret_image_np = tensor_to_np(secret_image.cpu())

    # Compute the DWT coefficients for the cover and secret images using the Haar wavelet
    cover_coeffs = pywt.dwt2(cover_image_np, 'haar')
    secret_coeffs = pywt.dwt2(secret_image_np, 'haar')

    # Keep only the most significant bit of the secret coefficients
    msb_mask = 0b10000000 
    secret_coeffs_msb = tuple([(np.uint8(coeff) & msb_mask) for coeff in secret_coeffs])
    
    
    # Encoding here:
    # Combine the coefficients of the cover and secret images
    stego_coeffs = (cover_coeffs[0] + secret_coeffs_msb[0], cover_coeffs[1])

    # Compute the inverse DWT to obtain the stego image
    stego_image_np = pywt.idwt2(stego_coeffs, 'haar').clip(0, 255).astype(np.uint8)

    # Return tensors
    return np_to_tensor(stego_image_np, device)



def dwt_extract(stego_image, cover_image):
    device = stego_image.device
    
    # Convert the stego and cover images from PyTorch tensors to NumPy arrays
    stego_image_np = tensor_to_np(stego_image.cpu())
    cover_image_np = tensor_to_np(cover_image.cpu())
    
    # Compute the DWT coefficients for the stego and cover images using the Haar wavelet
    stego_coeffs = pywt.dwt2(stego_image_np, 'haar')
    cover_coeffs = pywt.dwt2(cover_image_np, 'haar')
    
    msb_mask = 0b10000000
    # Decoding here:
    # Calculate the hidden image coefficients by subtracting the cover coefficients from the stego coefficients
    secret_coeffs_msb = tuple([(np.uint8(coeff) & msb_mask) for coeff in (stego_coeffs[0] - cover_coeffs[0])])
    
    
    # Compute the inverse DWT to obtain the hidden image
    secret_coeffs = (secret_coeffs_msb, cover_coeffs[1])
    secret_image_np = pywt.idwt2(secret_coeffs, 'haar').clip(0, 255).astype(np.uint8)
    
    # Return tensors
    return np_to_tensor(secret_image_np, device)

def tensor_to_np(tensor):
    return tensor.detach().cpu().numpy()

def np_to_tensor(array, device):
    return torch.tensor(array, device=device, dtype=torch.float32)

In [16]:
def train(generator_H2S, generator_S2H,
          discriminator_H, discriminator_S,
          cover_loader, stego_loader, hidden_loader,
          val_cover_loader, val_stego_loader, val_hidden_loader, device, 
          optimizer_G,
          optimizer_H, optimizer_S, 
          num_epochs, flag):
    
    """
    Train function for the cycle_GAN model.

    Parameters:
    - generator_H2S (nn.Module): The generator model that transforms hidden images to stego images.
    - generator_S2H (nn.Module): The generator model that transforms stego images to hidden images.
    - discriminator_H (nn.Module): The discriminator model that discriminates between real and fake hidden images.
    - discriminator_S (nn.Module): The discriminator model that discriminates between real and fake stego images.
    - cover_loader (DataLoader): DataLoader for cover images.
    - stego_loader (DataLoader): DataLoader for stego images.
    - hidden_loader (DataLoader): DataLoader for hidden images.
    - val_cover_loader (DataLoader): DataLoader for validation cover images.
    - val_stego_loader (DataLoader): DataLoader for validation stego images.
    - val_hidden_loader (DataLoader): DataLoader for validation hidden images.
    - device (torch.device): Device to perform computations on (CPU or GPU).
    - optimizer_G (torch.optim.Optimizer): Optimizer for both generator models (generator_H2S and generator_S2H).
    - optimizer_H (torch.optim.Optimizer): Optimizer for the discriminator_H model.
    - optimizer_S (torch.optim.Optimizer): Optimizer for the discriminator_S model.
    - num_epochs (int): The number of epochs to train the model.
    - flag (int): Flag to check where the save should start (flag+epoch).
    """
    
    for epoch in range(num_epochs):
        for cover, stego, hidden in zip(cover_loader, stego_loader, hidden_loader):
            
            # Unpack images and labels from the data loaders
            cover_img, _ = cover
            stego_img, _ = stego
            hidden_img, _ = hidden
            
            # Move images to device
            cover_img = cover_img.to(device)
            stego_img = stego_img.to(device)
            hidden_img = hidden_img.to(device)

            # --------------------
            # Train the generators
            # --------------------
            
            """
            # Identity loss
            same_hidden = generator_S2H(hidden_img)
            loss_identity_H = criterion_cycle(same_hidden, hidden_img) * 5.0
            
            same_stego = generator_H2S(stego_img)
            loss_identity_S = criterion_cycle(same_stego, stego_img) * 5.0
            """

            # GAN loss
            fake_hidden = generator_S2H(stego_img)
            pred_fake_hidden = discriminator_H(fake_hidden)
            loss_gan_S2H = criterion_gan(pred_fake_hidden, torch.ones_like(pred_fake_hidden))

            fake_stego = generator_H2S(hidden_img)
            pred_fake_stego = discriminator_S(fake_stego)
            loss_gan_H2S = criterion_gan(pred_fake_stego, torch.ones_like(pred_fake_stego))

            # Cycle loss
            
            # Use dwt_embed to recover stego image
            recovered_stego = generator_H2S(fake_hidden)
            dwt_recovered_stego = dwt_embed(cover_img, fake_hidden)
            loss_cycle_S = criterion_cycle(recovered_stego, dwt_recovered_stego) * 10.0
            
            # Use dwt_extract to recover hidden image
            recovered_hidden = generator_S2H(fake_stego)
            dwt_recovered_hidden = dwt_extract(fake_stego, cover_img)
            loss_cycle_H = criterion_cycle(recovered_hidden, dwt_recovered_hidden) * 10.0

            # Total generator loss
            total_g_loss = loss_gan_S2H + loss_gan_H2S + loss_cycle_S + loss_cycle_H
            total_g_loss.backward()
            
            optimizer_G.step()
            
            # ------------------------
            # Train the discriminators
            # ------------------------
            
            optimizer_H.zero_grad()
            optimizer_S.zero_grad()
            
            # Discriminator S
            
            # Real loss
            pred_real_stego = discriminator_S(stego_img)
            loss_d_real_stego = criterion_gan(pred_real_stego, torch.ones_like(pred_real_stego))

            # Fake loss
            fake_stego.detach_()
            pred_fake_stego = discriminator_S(fake_stego)
            loss_d_fake_stego = criterion_gan(pred_fake_stego, torch.zeros_like(pred_fake_stego))

            # Total Discriminator S loss
            total_d_s_loss = (loss_d_real_stego + loss_d_fake_stego) * 0.5
            total_d_s_loss.backward()

            optimizer_S.step()
            
            # Discriminator H
            
            # Real loss
            pred_real_hidden = discriminator_H(hidden_img)
            loss_d_real_hidden = criterion_gan(pred_real_hidden, torch.ones_like(pred_real_hidden))

            # Fake loss
            fake_hidden.detach_()
            pred_fake_hidden = discriminator_H(fake_hidden)
            loss_d_fake_hidden = criterion_gan(pred_fake_hidden, torch.zeros_like(pred_fake_hidden))

            # Total Discriminator H loss
            total_d_h_loss = (loss_d_real_hidden + loss_d_fake_hidden) * 0.5
            total_d_h_loss.backward()
            
            optimizer_H.step()


        print(f"Epoch [{epoch+1}/{num_epochs}] | Generator Loss: {total_g_loss.item():.4f} | Discriminator S Loss: {total_d_s_loss.item():.4f} | Discriminator H Loss: {total_d_h_loss.item():.4f}")

        # Save the model checkpoints every 50 epochs
        if (epoch + 1) % 50 == 0:
            torch.save(generator_H2S.state_dict(), f'./model_params/DWT/H2S/generator_H2S_epoch_{flag+epoch+1}.pth')
            torch.save(generator_S2H.state_dict(), f'./model_params/DWT/S2H/generator_S2H_epoch_{flag+epoch+1}.pth')
            torch.save(discriminator_H.state_dict(), f'./model_params/DWT/DH/discriminator_H_epoch_{flag+epoch+1}.pth')
            torch.save(discriminator_S.state_dict(), f'./model_params/DWT/DS/discriminator_S_epoch_{flag+epoch+1}.pth')
            torch.save(optimizer_G.state_dict(), f'./model_params/DWT/optimizer/G/optimizer_G_{flag+epoch+1}.pth')
            torch.save(optimizer_H.state_dict(), f'./model_params/DWT/optimizer/H/optimizer_H_{flag+epoch+1}.pth')
            torch.save(optimizer_S.state_dict(), f'./model_params/DWT/optimizer/S/optimizer_S_{flag+epoch+1}.pth')
        # Valid the model every 10 epochs
        if (epoch + 1) % 10 == 0:
            generator_H2S.eval()
            generator_S2H.eval()
            discriminator_H.eval()
            discriminator_S.eval()
            
            with torch.no_grad():
                for val_cover, val_stego, val_hidden in zip(val_cover_loader, val_stego_loader, val_hidden_loader):
                    val_cover_img, _ = val_cover
                    val_stego_img, _ = val_stego
                    val_hidden_img, _ = val_hidden

                    val_cover_img = val_cover_img.to(device)
                    val_stego_img = val_stego_img.to(device)
                    val_hidden_img = val_hidden_img.to(device)

                    fake_hidden = generator_S2H(val_stego_img)
                    fake_stego = generator_H2S(val_hidden_img)
                    recovered_hidden = generator_S2H(fake_stego)
                    recovered_stego = generator_H2S(fake_hidden)

                    val_g_loss = criterion_cycle(recovered_hidden, val_hidden_img) + criterion_cycle(recovered_stego, val_stego_img)
                    val_d_h_loss = criterion_gan(discriminator_H(val_hidden_img), torch.ones_like(discriminator_H(val_hidden_img))) + criterion_gan(discriminator_H(fake_hidden.detach()), torch.zeros_like(discriminator_H(fake_hidden.detach())))
                    val_d_s_loss = criterion_gan(discriminator_S(val_stego_img), torch.ones_like(discriminator_S(val_stego_img))) + criterion_gan(discriminator_S(fake_stego.detach()), torch.zeros_like(discriminator_S(fake_stego.detach())))
                    
            generator_H2S.train()
            generator_S2H.train()
            discriminator_H.train()
            discriminator_S.train()
            print(f"Validation | Generator Loss: {val_g_loss.item():.4f} | Discriminator S Loss: {val_d_s_loss.item():.4f} | Discriminator H Loss: {val_d_h_loss.item():.4f}")


In [17]:
# Set device
device = torch.device('cuda')

# Load the datasets
transform = transforms.Compose([
    transforms.Resize((96, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Train set
dataset_cover =ImageFolder('./train/cover', transform=transform)
dataset_stego = ImageFolder('./train/stego', transform=transform)
dataset_hidden = ImageFolder('./train/DWT_hidden', transform=transform)

cover_loader = DataLoader(dataset_cover, batch_size=1)
stego_loader = DataLoader(dataset_stego, batch_size=1)
hidden_loader = DataLoader(dataset_hidden, batch_size=1)

# Val set
dataset_cover_val = ImageFolder('./valid/cover', transform=transform)
dataset_stego_val = ImageFolder('./valid/stego', transform=transform)
dataset_hidden_val = ImageFolder('./valid/DWT_hidden', transform=transform)

val_cover_loader = DataLoader(dataset_cover_val, batch_size=1)
val_stego_loader = DataLoader(dataset_stego_val, batch_size=1)
val_hidden_loader = DataLoader(dataset_hidden_val, batch_size=1)

# Test set
dataset_cover_test = ImageFolder('./test/cover', transform=transform)
dataset_stego_test = ImageFolder('./test/stego', transform=transform)
dataset_hidden_test = ImageFolder('./test/DWT_hidden', transform=transform)

test_cover_loader = DataLoader(dataset_cover_test, batch_size=1)
test_stego_loader = DataLoader(dataset_stego_test, batch_size=1)
test_hidden_loader = DataLoader(dataset_hidden_test, batch_size=1)

In [18]:
# Instantiate the generators and discriminators
generator_H2S = Generator()
generator_S2H = Generator()

discriminator_H = Discriminator()
discriminator_S = Discriminator()


# Set device
device = torch.device("cuda")

generator_H2S.to(device)
generator_S2H.to(device)

discriminator_H.to(device)
discriminator_S.to(device)


# Define the optimizers
optimizer_G = optim.Adam(list(generator_H2S.parameters()) + list(generator_S2H.parameters()), lr=0.001, betas=(0.5, 0.999))

optimizer_H = optim.Adam(discriminator_H.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_S = optim.Adam(discriminator_S.parameters(), lr=0.0001, betas=(0.5, 0.999))


# Define the loss functions
criterion_gan = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()

In [19]:
train(generator_H2S, generator_S2H, discriminator_H, discriminator_S,
      cover_loader, stego_loader, hidden_loader, val_cover_loader, val_stego_loader, val_hidden_loader, device, optimizer_G, optimizer_H, optimizer_S, num_epochs=500, flag=0)


Epoch [1/500] | Generator Loss: 97.5961 | Discriminator S Loss: 0.0175 | Discriminator H Loss: 0.0119
Epoch [2/500] | Generator Loss: 69.4685 | Discriminator S Loss: 0.0070 | Discriminator H Loss: 0.0042
Epoch [3/500] | Generator Loss: 48.1921 | Discriminator S Loss: 0.0057 | Discriminator H Loss: 0.0026
Epoch [4/500] | Generator Loss: 62.7466 | Discriminator S Loss: 0.0042 | Discriminator H Loss: 0.0015
Epoch [5/500] | Generator Loss: 65.3141 | Discriminator S Loss: 0.0042 | Discriminator H Loss: 0.0022
Epoch [6/500] | Generator Loss: 74.8980 | Discriminator S Loss: 0.0022 | Discriminator H Loss: 0.0013
Epoch [7/500] | Generator Loss: 75.1669 | Discriminator S Loss: 0.0016 | Discriminator H Loss: 0.0019
Epoch [8/500] | Generator Loss: 74.7388 | Discriminator S Loss: 0.0022 | Discriminator H Loss: 0.0015
Epoch [9/500] | Generator Loss: 68.3367 | Discriminator S Loss: 0.0015 | Discriminator H Loss: 0.0012
Epoch [10/500] | Generator Loss: 46.0432 | Discriminator S Loss: 0.0020 | Discrimi

In [7]:
def reconstruct_hidden_image(generator_H2S, generator_S2H, cover_image, hidden_image):
    device = cover_image.device
    # Real stego image
    stego_image = dwt_embed(cover_image, hidden_image)

    # Reconstructed hidden image
    reconstructed_hidden_image = generator_S2H(stego_image.to(device)).detach()
    
    return reconstructed_hidden_image.cpu()

def load_saved_dicts(generator_H2S, generator_S2H, discriminator_S, discriminator_H, optimizer_G, optimizer_S, optimizer_H, flag):
    # Save the model state_dict
    generator_H2S.load_state_dict(torch.load(f'./model_params/DWT/H2S/generator_H2S_epoch_{flag}.pth'))
    generator_S2H.load_state_dict(torch.load(f'./model_params/DWT/S2H/generator_S2H_epoch_{flag}.pth'))
    discriminator_S.load_state_dict(torch.load(f'./model_params/DWT/DS/discriminator_S_epoch_{flag}.pth'))
    discriminator_H.load_state_dict(torch.load(f'./model_params/DWT/DH/discriminator_H_epoch_{flag}.pth'))
    # Also optimizers
    optimizer_G.load_state_dict(torch.load(f'./model_params/DWT/optimizer/G/optimizer_G_{flag}.pth'))
    optimizer_S.load_state_dict(torch.load(f'./model_params/DWT/optimizer/S/optimizer_S_{flag}.pth'))
    optimizer_H.load_state_dict(torch.load(f'./model_params/DWT/optimizer/H/optimizer_H_{flag}.pth'))
    

In [22]:
def reconstruct_hidden_images(generator_H2S, generator_S2H, test_cover_loader, test_stego_loader, output_dir, flag):
    to_pil_image = transforms.ToPILImage()
    
    for i, (cover_data, stego_data) in enumerate(zip(test_cover_loader, test_stego_loader)):
        cover_image_tensor, _ = cover_data
        stego_image_tensor, _ = stego_data
        
        # Move tensors to the appropriate device
        device = torch.device("cuda")
        cover_image_tensor = cover_image_tensor.to(device)
        stego_image_tensor = stego_image_tensor.to(device)
        
        # Reconstruct the hidden image
        reconstructed_hidden_image = reconstruct_hidden_image(generator_H2S, generator_S2H, cover_image_tensor, stego_image_tensor)
        
        # Convert the reconstructed hidden image back to a PIL Image
        reconstructed_hidden_image_pil = to_pil_image(reconstructed_hidden_image.squeeze(0))
        
        # Save the reconstructed hidden image
        output_path = os.path.join(output_dir, f"{i}.png")
        reconstructed_hidden_image_pil.save(output_path)

In [9]:
# This area is for initialing and load the saved model

# Instantiate the generators and discriminators
generator_H2S = Generator()
generator_S2H = Generator()


discriminator_S = Discriminator()
discriminator_H = Discriminator()


# Set device
device = torch.device("cuda")

generator_H2S.to(device)
generator_S2H.to(device)

discriminator_H.to(device)
discriminator_S.to(device)


# Define the optimizers
optimizer_G = optim.Adam(list(generator_H2S.parameters()) + list(generator_S2H.parameters()), lr=0.0005, betas=(0.5, 0.999))


optimizer_S = optim.Adam(discriminator_S.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_H = optim.Adam(discriminator_H.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Define the loss functions
criterion_gan = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()

In [15]:
def change_lr(optimizer_G, optimizer_H, optimizer_S, flag):
    # change learning rate to specific value
    optimizer_G_state_dict = torch.load(f'./model_params/DWT/optimizer/G/optimizer_G_{flag}.pth')
    optimizer_H_state_dict = torch.load(f'./model_params/DWT/optimizer/H/optimizer_H_{flag}.pth')
    optimizer_S_state_dict = torch.load(f'./model_params/DWT/optimizer/S/optimizer_S_{flag}.pth')
    
    optimizer_G_state_dict['param_groups'][0]['lr'] = 0.001
    optimizer_H_state_dict['param_groups'][0]['lr'] = 0.00001
    optimizer_S_state_dict['param_groups'][0]['lr'] = 0.00001

    optimizer_G.load_state_dict(optimizer_G_state_dict)
    optimizer_H.load_state_dict(optimizer_H_state_dict)
    optimizer_S.load_state_dict(optimizer_S_state_dict)

In [25]:
# reconstruct
output_dir = './reconstructed_hidden/TEST/'

load_saved_dicts(generator_H2S, generator_S2H, discriminator_S, discriminator_H, optimizer_G, optimizer_S, optimizer_H, flag=flag)

# change lr at epoch 1000
# change_lr(optimizer_G, optimizer_C, optimizer_S, flag=flag)

flag = 300
# reconstruct at epoch flag
reconstruct_hidden_images(generator_H2S, generator_S2H, test_cover_loader, test_stego_loader, output_dir, flag)


In [26]:
import cv2
import os

# Specify the directory containing the images
directory = './reconstructed_hidden/TEST'

# Loop through each file in the directory
for filename in os.listdir(directory):

    # Check if the file is an image
    if filename.endswith('.jpg') or filename.endswith('.png'):

        # Read the image
        img = cv2.imread(os.path.join(directory, filename))

        # Convert the image to grayscale
        gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Write the grayscale image back to the same file
        cv2.imwrite(os.path.join(directory, filename), gray_img)
