In [None]:
''' LIBRARIES '''
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
class Mish(nn.Module):
  def forward(self, x):
    return x * torch.tanh(F.softplus(x))

In [None]:
# Encoder from SteganoGAN, this time re-implemented with Mish() instead of LeakyReLU

class BasicMishEncoder(nn.Module):
  """
  The BasicEncoder module takes an cover image and a data tensor and combines
  them into a steganographic image.

  Input: (N, 3, H, W), (N, D, H, W)
  Output: (N, 3, H, W)
  """

  add_image = False

  def _conv2d(self, in_channels, out_channels):
    return nn.Conv2d(
      in_channels=in_channels,
      out_channels=out_channels,
      kernel_size=3,
      padding=1
    )

  def _build_models(self):
    self.features = nn.Sequential(
      self._conv2d(3, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),
    )
    self.layers = nn.Sequential(
      self._conv2d(self.hidden_size + self.data_depth, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),
      self._conv2d(self.hidden_size, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),
      self._conv2d(self.hidden_size, 3),
      nn.Tanh(),
    )
    return self.features, self.layers

  def __init__(self, data_depth, hidden_size):
    super().__init__()
    self.data_depth = data_depth
    self.hidden_size = hidden_size
    self._models = self._build_models()


  def forward(self, image, data):
    x = self._models[0](image)
    x_list = [x]

    for layer in self._models[1:]:
      x = layer(torch.cat(x_list + [data], dim=1))
      x_list.append(x)

    if self.add_image:
      x = image + x

    return x



class ResidualMishEncoder(BasicEncoder):
  """
  The ResidualEncoder module takes an cover image and a data tensor and combines
  them into a steganographic image.

  Input: (N, 3, H, W), (N, D, H, W)
  Output: (N, 3, H, W)
  """

  add_image = True

  def _build_models(self):
    self.features = nn.Sequential(
      self._conv2d(3, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),
    )
    self.layers = nn.Sequential(
      self._conv2d(self.hidden_size + self.data_depth, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),
      self._conv2d(self.hidden_size, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),
      self._conv2d(self.hidden_size, 3),
    )
    return self.features, self.layers


class DenseMishEncoder(BasicEncoder):
  """
  The DenseEncoder module takes an cover image and a data tensor and combines
  them into a steganographic image.

  Input: (N, 3, H, W), (N, D, H, W)
  Output: (N, 3, H, W)
  """

  add_image = True

  def _build_models(self):
    self.conv1 = nn.Sequential(
      self._conv2d(3, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),
    )
    self.conv2 = nn.Sequential(
      self._conv2d(self.hidden_size + self.data_depth, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),
    )
    self.conv3 = nn.Sequential(
      self._conv2d(self.hidden_size * 2 + self.data_depth, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),
    )
    self.conv4 = nn.Sequential(
      self._conv2d(self.hidden_size * 3 + self.data_depth, 3)
    )

    return self.conv1, self.conv2, self.conv3, self.conv4

In [None]:
# Encoder from SteganoGAN, this time re-implemented with Mish() instead of LeakyReLU

class BasicMishDecoder(nn.Module):
  """
  The BasicDecoder module takes an steganographic image and attempts to decode
  the embedded data tensor.

  Input: (N, 3, H, W)
  Output: (N, D, H, W)
  """

  def _conv2d(self, in_channels, out_channels):
    return nn.Conv2d(
      in_channels=in_channels,
      out_channels=out_channels,
      kernel_size=3,
      padding=1
    )

  def _build_models(self):
    self.layers = nn.Sequential(
      self._conv2d(3, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),

      self._conv2d(self.hidden_size, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),

      self._conv2d(self.hidden_size, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size),

      self._conv2d(self.hidden_size, self.data_depth)
    )

    return [self.layers]

  def __init__(self, data_depth, hidden_size):
    super().__init__()
    self.data_depth = data_depth
    self.hidden_size = hidden_size

    self._models = self._build_models()

  def forward(self, x):
    x = self._models[0](x)

    if len(self._models) > 1:
      x_list = [x]
      for layer in self._models[1:]:
        x = layer(torch.cat(x_list, dim=1))
        x_list.append(x)

    return x


class DenseMishDecoder(BasicDecoder):
    """
    The DenseDecoder module takes an steganographic image and attempts to decode
    the embedded data tensor.

    Input: (N, 3, H, W)
    Output: (N, D, H, W)
    """
    def _build_models(self):
      self.conv1 = nn.Sequential(
        self._conv2d(3, self.hidden_size),
        Mish(),
        nn.BatchNorm2d(self.hidden_size)
      )

      self.conv2 = nn.Sequential(
        self._conv2d(self.hidden_size, self.hidden_size),
        Mish(),
        nn.BatchNorm2d(self.hidden_size)
      )

      self.conv3 = nn.Sequential(
        self._conv2d(self.hidden_size * 2, self.hidden_size),
        Mish(),
        nn.BatchNorm2d(self.hidden_size)
      )

      self.conv4 = nn.Sequential(self._conv2d(self.hidden_size * 3, self.data_depth))

      return self.conv1, self.conv2, self.conv3, self.conv4

In [None]:
''' TEST '''

import torch
from PIL import Image
from torchvision import transforms
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Set the parameters for testing
hidden_size = 64
data_depth = 3

# Define the transform to preprocess the image (resize, normalize, and convert to tensor)
transform = transforms.Compose([
  transforms.Resize((224, 224)),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load a cover image from Google Drive
cover_image_path = '/content/drive/My Drive/DU/COMP 3432 Machine Learning/Imgs/Cover/100000.jpg'
cover_image = Image.open(cover_image_path).convert('RGB')
cover_image_tensor = transform(cover_image).unsqueeze(0)  # Add batch dimension

# Create dummy data tensor
dummy_data_tensor = torch.randn(1, data_depth, 224, 224)

print("Cover Image Tensor shape:", cover_image_tensor.shape)
print("Data Tensor shape:", dummy_data_tensor.shape)


def test_basic_encoder_decoder_with_images(cover_image_tensor, dummy_data_tensor):
  basic_encoder = BasicEncoder(data_depth=data_depth, hidden_size=hidden_size)
  basic_decoder = BasicDecoder(data_depth=data_depth, hidden_size=hidden_size)

  # Encode the data into the cover image
  stego_image = basic_encoder(cover_image_tensor, dummy_data_tensor)
  print("Stego Image shape:", stego_image.shape)

  # Decode the stego image to retrieve the data
  retrieved_data = basic_decoder(stego_image)
  print("Retrieved Data shape:", retrieved_data.shape)

  assert retrieved_data.shape == dummy_data_tensor.shape, "Basic Encoder-Decoder test failed!"
  print("Basic Encoder-Decoder test passed.")

def test_dense_encoder_decoder_with_images(cover_image_tensor, dummy_data_tensor):
  dense_encoder = DenseEncoder(data_depth=data_depth, hidden_size=hidden_size)
  dense_decoder = DenseDecoder(data_depth=data_depth, hidden_size=hidden_size)

  # Encode the data into the cover image
  dense_stego_image = dense_encoder(cover_image_tensor, dummy_data_tensor)
  print("Dense Stego Image shape:", dense_stego_image.shape)

  # Decode the stego image to retrieve the data
  dense_retrieved_data = dense_decoder(dense_stego_image)
  print("Dense Retrieved Data shape:", dense_retrieved_data.shape)

  assert dense_retrieved_data.shape == dummy_data_tensor.shape, "Dense Encoder-Decoder test failed!"
  print("Dense Encoder-Decoder test passed.")

# Run the tests with real images
test_basic_encoder_decoder_with_images(cover_image_tensor, dummy_data_tensor)
test_dense_encoder_decoder_with_images(cover_image_tensor, dummy_data_tensor)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Cover Image Tensor shape: torch.Size([1, 3, 224, 224])
Data Tensor shape: torch.Size([1, 3, 224, 224])
Stego Image shape: torch.Size([1, 3, 224, 224])
Retrieved Data shape: torch.Size([1, 3, 224, 224])
Basic Encoder-Decoder test passed.
Dense Stego Image shape: torch.Size([1, 3, 224, 224])
Dense Retrieved Data shape: torch.Size([1, 3, 224, 224])
Dense Encoder-Decoder test passed.
