In [1]:
import pywt
import torch
import torch.nn as nn
import cv2
import numpy as np
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import os

In [2]:
class Generator(nn.Module):
    def __init__(self, num_residual_blocks=6):
        super(Generator, self).__init__()
        
        # Define encoder layers
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Define residual layers
        self.residuals = nn.Sequential(
            *[ResidualBlock(256) for _ in range(num_residual_blocks)]
        )
        
        # Define decoder layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        # Apply encoder layers
        x = self.encoder(x)
        
        # Apply residual layers
        x = self.residuals(x)
        
        # Apply decoder layers
        x = self.decoder(x)
        
        return x
    
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(channels, channels, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        
    def forward(self, x):
        identity = x
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        
        x += identity
        x = self.relu(x)
        
        return x

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # Define discriminator layers
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, stride=1, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.model(x)

In [24]:
def lswt_steganography_encode(cover_freq, secret_freq, alpha=0.1):
    coeffs_cover = cover_freq.reshape((1, 1, *cover_freq.shape))
    coeffs_secret = secret_freq.reshape((1, 1, *secret_freq.shape))
    
    secret_hf = coeffs_secret[0][0][1]
    secret_mask = np.mean(secret_hf, axis=0)
    
    stego_hf = coeffs_cover[0][0][1] + alpha * secret_mask
    stego_freq = np.zeros_like(coeffs_cover)
    stego_freq[0][0][1] = stego_hf
    
    return stego_freq.reshape(cover_freq.shape)

def lswt_steganography_decode(stego_freq, cover_freq, alpha=0.1):
    coeffs_stego = stego_freq.reshape((1, 1, *stego_freq.shape))
    coeffs_cover = cover_freq.reshape((1, 1, *cover_freq.shape))
    
    secret_hf = [(coeffs_stego[0][0][1][i] - coeffs_cover[0][0][1][i]) / alpha for i in range(3)]
    secret_mask = np.mean(secret_hf, axis=0)
    secret_freq = np.zeros_like(coeffs_cover)
    secret_freq[0][0][1] = secret_mask
    
    return secret_freq.reshape(stego_freq.shape)

In [27]:
def train_cycle_gan(g_AB, g_BA, d_A, d_B, opt_g, opt_d, criterion_gan, criterion_cycle, criterion_identity, 
                    dataloader_cover, dataloader_hidden, dataloader_stego, 
                    device, num_epochs):
    alpha = 0.1
    lambda_cycle = 10

    for epoch in range(num_epochs):
        for i, (cover_img, hidden_img, stego_img) in enumerate(zip(dataloader_cover, dataloader_hidden, dataloader_stego)):
            # Set generator and discriminator gradients to zero
            opt_g.zero_grad()
            opt_d.zero_grad()
            
            # Move input data to device
            cover, _ = cover_img
            hidden, _ = hidden_img
            stego, _ = stego_img
            
            cover = cover.to(device)
            hidden = hidden.to(device)
            stego = stego.to(device)
            
            # Generate fake images
            fake_B = g_AB(stego)
            fake_A = g_BA(hidden)
            
            # Compute cycle consistency loss
            fake_stego = g_BA(hidden)
            fake_hidden_freq = lswt_steganography_decode(fake_stego.detach().cpu().numpy()[0, 0], cover.cpu().numpy()[0, 0], alpha=alpha)
            fake_hidden_freq = torch.from_numpy(fake_hidden_freq).unsqueeze(0).unsqueeze(0).float().to(device)
            
            fake_hidden = g_AB(stego)
            fake_stego_freq = lswt_steganography_encode(fake_hidden.detach().cpu().numpy()[0, 0], cover.cpu().numpy()[0, 0], alpha=alpha)
            fake_stego_freq = torch.from_numpy(fake_stego_freq).unsqueeze(0).unsqueeze(0).float().to(device)
            
            cycle_loss_B = criterion_cycle(fake_hidden_freq, fake_stego)
            cycle_loss_A = criterion_cycle(fake_stego_freq, fake_hidden)
            
            # Train discriminators
            real_A = d_A(hidden)
            real_B = d_B(stego)
            fake_A_ = d_A(fake_A.detach())
            fake_B_ = d_B(fake_B.detach())
            loss_d_A = criterion_gan(real_A, torch.ones_like(real_A)) + criterion_gan(fake_A_, torch.zeros_like(fake_A_))
            loss_d_B = criterion_gan(real_B, torch.ones_like(real_B)) + criterion_gan(fake_B_, torch.zeros_like(fake_B_))
            loss_d = (loss_d_A + loss_d_B) / 2
            loss_d.backward()
            opt_d.step()
            
            # Train generators
            fake_A_ = d_A(fake_A)
            fake_B_ = d_B(fake_B)
            loss_gan_A = criterion_gan(fake_A_, torch.ones_like(fake_A_))
            loss_gan_B = criterion_gan(fake_B_, torch.ones_like(fake_B_))
            loss_g = (loss_gan_A + loss_gan_B) / 2 + lambda_cycle * (cycle_loss_B + cycle_loss_A) * 0.5
            loss_g.backward()
            opt_g.step()
            
        # Print progress
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss G: {loss_g:.4f} Loss D_A: {loss_d_A:.4f} Loss D_B: {loss_d_B:.4f}")

        if (epoch+1) % 10 == 0:

            torch.save({
                'epoch': epoch+1,
                'g_AB_state_dict': g_AB.state_dict(),
                'g_BA_state_dict': g_BA.state_dict(),
                'd_A_state_dict': d_A.state_dict(),
                'd_B_state_dict': d_B.state_dict(),
                'opt_g_state_dict': opt_g.state_dict(),
                'opt_d_state_dict': opt_d.state_dict(),
                'loss': loss_g,
            }, f'./model_params/SWT/checkpoint_{num_epochs}.pth')
            

In [29]:
# Set device
device = torch.device('cuda')

# Load the datasets

transform_frequency = transforms.Compose([
    transforms.Resize((96, 128)),
    transforms.ToTensor(),
    lambda x: pywt.swt2(x.numpy(), 'bior1.1', level=1)[0][0],
])

# Train set
dataset_cover = ImageFolder('./train/cover', transform=transform_frequency)
dataset_stego = ImageFolder('./train/LSWT_stego', transform=transform_frequency)
dataset_hidden = ImageFolder('./train/LSWT_hidden', transform=transform_frequency)

cover_loader = DataLoader(dataset_cover, batch_size=1)
stego_loader = DataLoader(dataset_stego, batch_size=1)
hidden_loader = DataLoader(dataset_hidden, batch_size=1)

# Test set
dataset_cover_test = ImageFolder('./test/cover', transform=transform_frequency)
dataset_stego_test = ImageFolder('./test/LSWT_stego', transform=transform_frequency)
dataset_hidden_test = ImageFolder('./test/LSWT_hidden', transform=transform_frequency)

test_cover_loader = DataLoader(dataset_cover_test, batch_size=1)
test_stego_loader = DataLoader(dataset_stego_test, batch_size=1)
test_hidden_loader = DataLoader(dataset_hidden_test, batch_size=1)


In [7]:
# Define the generators and discriminators
g_AB = Generator().to(device)
g_BA = Generator().to(device)
d_A = Discriminator().to(device)
d_B = Discriminator().to(device)

# Define the optimizers
opt_g = optim.Adam(list(g_AB.parameters()) + list(g_BA.parameters()), lr=0.0005, betas=(0.5, 0.999))
opt_d = optim.Adam(list(d_A.parameters()) + list(d_B.parameters()), lr=0.00001, betas=(0.5, 0.999))

# Define the loss functions
criterion_gan = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [None]:
# Train the model
train_cycle_gan(g_AB, g_BA, d_A, d_B, opt_g, opt_d, criterion_gan, criterion_cycle, criterion_identity,
                cover_loader, hidden_loader, stego_loader,
                device, num_epochs=500)

In [8]:
def load_model(checkpoint_path, g_AB, g_BA, d_A, d_B, opt_g, opt_d):
    checkpoint = torch.load(checkpoint_path)

    g_AB.load_state_dict(checkpoint['g_AB_state_dict'])
    g_BA.load_state_dict(checkpoint['g_BA_state_dict'])
    d_A.load_state_dict(checkpoint['d_A_state_dict'])
    d_B.load_state_dict(checkpoint['d_B_state_dict'])
    opt_g.load_state_dict(checkpoint['opt_g_state_dict'])
    opt_d.load_state_dict(checkpoint['opt_d_state_dict'])

    return checkpoint['epoch'], checkpoint['loss']

In [9]:
epoch, loss = load_model(f'./model_params/SWT/checkpoint_500.pth', g_AB, g_BA, d_A, d_B, opt_g, opt_d)

In [20]:
from torchvision.utils import save_image
from PIL import Image

In [39]:
import numpy as np
from PIL import Image
from torchvision.transforms.functional import to_tensor

def tensor_to_image(tensor):
    # Perform inverse stationary wavelet transform (iswt2)
    img_iswt = pywt.iswt2([(tensor.numpy(), None, None, None)], 'bior1.1')
    
    # Convert the image back to the PIL format
    img_pil = Image.fromarray(np.uint8(img_iswt * 255), mode='L')
    
    return img_pil


In [49]:
transform_frequency = transforms.Compose([
    transforms.Resize((96, 128)),
    lambda x: pywt.swt2(to_tensor(x).numpy(), 'bior1.1', level=1)[0],
])

def reconstruct_images(input_dir, output_dir, g_BA, device):
    # Ensure the output directory exists
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Read and process images from the input directory
    for file in os.listdir(input_dir):
        if file.endswith(('.jpg', '.png', '.jpeg')):
            file_path = os.path.join(input_dir, file)
            
            # Load the image and apply the same transformation as in the dataloader
            img = Image.open(file_path)
            img_frequency = torch.tensor(transform_frequency(img), dtype=torch.float32).to(device)
            
            # Generate hidden images using G_BA
            with torch.no_grad():
                img_hidden = g_BA(Variable(img_frequency)).cpu()
            
            # Clip back to the image using inverse SWT
            img_reconstructed = pywt.iswt2([(img_hidden.squeeze().numpy(), None, None, None)], 'bior1.1')
            
            # Save the reconstructed image to the output directory
            output_path = os.path.join(output_dir, file)
            to_pil_image(img_reconstructed, mode='L').save(output_path)


# Usage example:
input_dir = './test/LSWT_stego/class1'
output_dir = './output_images/'
reconstruct_images(input_dir, output_dir, g_BA, device)


ValueError: expected sequence of length 1 at dim 1 (got 3)