In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.down_conv1 = self.double_conv(3, 6)
        self.down_conv2 = self.double_conv(64, 128)
        self.down_conv3 = self.double_conv(128, 256)
        self.down_conv4 = self.double_conv(256, 512)
        self.down_conv5 = self.double_conv(512, 1024)

        self.up_conv1 = self.deconv(1024, 512)
        self.up_conv2 = self.double_conv(1024, 256)
        self.up_conv3 = self.double_conv(512, 128)
        self.up_conv4 = self.double_conv(256, 64)
        self.up_conv5 = nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1)

    def double_conv(self, in_channels, out_channels):
        conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )
        return conv

    def deconv(self, in_channels, out_channels):
        deconv = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        return deconv

    def forward(self, x):
        # Encoding path
        x1 = self.down_conv1(x)
        x2 = self.down_conv2(x1)
        x3 = self.down_conv3(x2)
        x4 = self.down_conv4(x3)
        x5 = self.down_conv5(x4)

        # Decoding path
        x = self.up_conv1(x5)
        x = torch.cat([x, x4], dim=1)
        x = self.up_conv2(x)
        x = torch.cat([x, x3], dim=1)
        x = self.up_conv3(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up_conv4(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv5(x)
        x = torch.sigmoid(x)

        return x


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.down_conv1 = self.double_conv(3, 64)
        self.down_conv2 = self.double_conv(64, 128)
        self.down_conv3 = self.double_conv(128, 256)
        self.down_conv4 = self.double_conv(256, 512)
        self.down_conv5 = self.double_conv(512, 1024)

        self.up_conv1 = self.deconv(1024, 512)
        self.up_conv2 = self.double_conv(1024, 256)
        self.up_conv3 = self.double_conv(512, 128)
        self.up_conv4 = self.double_conv(256, 64)
        self.up_conv5 = nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1)

    def double_conv(self, in_channels, out_channels):
        conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )
        return conv

    def deconv(self, in_channels, out_channels):
        deconv = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        return deconv

    def forward(self, x):
        # Wavelet transform of the secret image
        coeffs = pywt.dwt2(x, 'haar')
        cA, (cH, cV, cD) = coeffs

        # Encoding path
        x1 = self.down_conv1(x)
        x2 = self.down_conv2(x1)
        x3 = self.down_conv3(x2)
        x4 = self.down_conv4(x3)
        x5 = self.down_conv5(x4)

        # Decoding path
        x = self.up_conv1(x5)
        x = torch.cat([x, x4], dim=1)
        x = self.up_conv2(x)
        x = torch.cat([x, x3], dim=1)
        x = self.up_conv3(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up_conv4(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv5(x)

        # Inverse wavelet transform of the secret image
        x = torch.cat([cA, x], dim=1)
        x = torch.cat([x, cH, cV, cD], dim=1)
        x = pywt.idwt2(x, 'haar')

        x = torch.sigmoid(x)

        return x


In [None]:
import cv2
import numpy as np

# Load carrier image and secret image
carrier = cv2.imread('carrier_image.png')
secret = cv2.imread('secret_image.png', cv2.IMREAD_GRAYSCALE)

# Convert to PyTorch tensors and normalize
carrier = torch.from_numpy(carrier).permute(2, 0, 1).float() / 255
secret = torch.from_numpy(secret).unsqueeze(0).unsqueeze(0).float() / 255

# Instantiate UNet model and pass secret image through it
unet = UNet()
output = unet(secret)

# Convert output back to numpy array and save as image
output = output.detach().numpy().squeeze()
output = (output * 255).clip(0, 255).astype(np.uint8)
cv2.imwrite('output_image.png', output)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Define the UNet model
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Define encoder layers
        self.conv1 = nn.Conv2d(1, 64, 4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, 4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.conv5 = nn.Conv2d(512, 512, 4, stride=2, padding=1)
        self.bn5 = nn.BatchNorm2d(512)
        self.conv6 = nn.Conv2d(512, 512, 4, stride=2, padding=1)
        self.bn6 = nn.BatchNorm2d(512)
        self.conv7 = nn.Conv2d(512, 512, 4, stride=2, padding=1)
        self.bn7 = nn.BatchNorm2d(512)

        # Define decoder layers
        self.deconv1 = nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1)
        self.bn8 = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1)
        self.bn9 = nn.BatchNorm2d(512)
        self.deconv3 = nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1)
        self.bn10 = nn.BatchNorm2d(512)
        self.deconv4 = nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1)
        self.bn11 = nn.BatchNorm2d(512)
        self.deconv5 = nn.ConvTranspose2d(1024, 256, 4, stride=2, padding=1)
        self.bn12 = nn.BatchNorm2d(256)
        self.deconv6 = nn.ConvTranspose2d(512, 128, 4, stride=2, padding=1)
        self.bn13 = nn.BatchNorm2d(128)
        self.deconv7 = nn.ConvTranspose2d(256, 64, 4, stride=2, padding=1)
        self.bn14 = nn.BatchNorm2d(64)
        self.deconv8 = nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1)
        
        # Define activation functions
        self.relu = nn.ReLU()
        self.lrelu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        # Encode
        enc1 = self.conv1(x)
        enc2 = self.bn2(self.conv2(self.relu(x)))
        enc3 = self.bn3(self.conv3(self.relu(enc2)))
        enc4 = self.bn4(self.conv4(self.relu(enc3)))
        enc5 = self.bn5(self.conv5(self.relu(enc4)))
        enc6 = self.bn6(self.conv6(self.relu(enc5)))
        enc7 = self.bn7(self.conv7(self.relu(enc6)))
            # Decode
        dec1 = self.deconv1(self.relu(enc7))
        dec1 = torch.cat([dec1, enc6], dim=1)
        dec1 = self.bn8(dec1)
        dec2 = self.deconv2(self.relu(dec1))
        dec2 = torch.cat([dec2, enc5], dim=1)
        dec2 = self.bn9(dec2)
        dec3 = self.deconv3(self.relu(dec2))
        dec3 = torch.cat([dec3, enc4], dim=1)
        dec3 = self.bn10(dec3)
        dec4 = self.deconv4(self.relu(dec3))
        dec4 = torch.cat([dec4, enc3], dim=1)
        dec4 = self.bn11(dec4)
        dec5 = self.deconv5(self.relu(dec4))
        dec5 = torch.cat([dec5, enc2], dim=1)
        dec5 = self.bn12(dec5)
        dec6 = self.deconv6(self.relu(dec5))
        dec6 = torch.cat([dec6, enc1], dim=1)
        dec6 = self.bn13(dec6)
        dec7 = self.deconv7(self.relu(dec6))
        dec7 = self.bn14(dec7)
        dec8 = self.deconv8(self.relu(dec7))
        out = self.sigmoid(dec8)
        return out



In [None]:
class CarrierSecretDataset(Dataset):
  def init(self, carrier_dir, secret_dir):
    self.carrier_files = sorted(os.listdir(carrier_dir))
    self.secret_files = sorted(os.listdir(secret_dir))
    self.carrier_dir = carrier_dir
    self.secret_dir = secret_dir
    
  def __len__(self):
    return len(self.carrier_files)
    
  def __getitem__(self, idx):
    carrier_file = os.path.join(self.carrier_dir, self.carrier_files[idx])
    secret_file = os.path.join(self.secret_dir, self.secret_files[idx])
    carrier_img = Image.open(carrier_file).convert('RGB')
    secret_img = Image.open(secret_file).convert('L')
    secret_img = wt2(secret_img)
    return carrier_img, secret_img


In [None]:
train_dataset = CarrierSecretDataset(train_carrier_dir, train_secret_dir)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = CarrierSecretDataset(test_carrier_dir, test_secret_dir)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
criterion = nn.MSELoss()
optimizer = optim.Adam(unet.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
  unet.train()
for i, (carriers, secrets) in enumerate(train_loader):
  carriers = carriers.to(device)
  secrets = secrets.to(device)
  optimizer.zero_grad()
  outputs = unet(carriers)
  loss = criterion(outputs, carriers)
  loss.backward()
  optimizer.step()
  
  if i % print_freq == 0:
      print('Epoch [{}/{}], Iter [{}/{}], Loss: {:.4f}'
            .format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))

# Test
  unet.eval()
  with torch.no_grad():
      test_loss = 0.0
      for carriers, secrets in test_loader:
          carriers = carriers.to(device)
          secrets = secrets.to(device)

          outputs = unet(carriers)
          test_loss += criterion(outputs, carriers).item() * carriers.size(0)
      
      test_loss /= len(test_loader.dataset)
      print('Test Loss: {:.4f}'.format(test_loss))
      
  # Save the model
  if (epoch+1) % save_freq == 0:
      model_path = os.path.join(model_dir, 'unet_epoch{}.pth'.format(epoch+1))
      torch.save(unet.state_dict(), model_path)
      print('Finished training the U-Net model.')

