In [17]:
''' LIBRARIES '''
import torch
from torch import nn
from torchvision.models import densenet, resnet
import torch.nn.functional as F

In [12]:
# SteganoGAN Critic
'''
BasicCritic
Implemented as a subclass of nn.Module
It takes an image and predicts whether it is a cover image or a steganographic image (N, 1).

Provided by SteganoGAN:
  Input: (N, 3, H, W)
  Output: (N, 1)

Input Details
  Takes a batch of images as input
  Parameters - (batch size, num channels, image height, image width)

Output Details
  Outputs a tensor of shape (N, 1)


class BasicCritic(nn.Module):
  """ Conv 2D
  Helper function which takes in a 2D convolutional layer of fixed size 3x3
  It then standardizes the size of the kernel throughout the model

  Parameters:
    in_channels - number of input channels. Since this is using images, it is going to be 3 for R, G, and B values
    out_channels - number of output channels desired

  Details:
    Operation - nn.Conv2d extracts spatial features of the input image
    kernel_size - ensures a fixed size of
  """
  def _conv2d(self, in_channels, out_channels):
    return nn.Conv2d(
      in_channels=in_channels,
      out_channels=out_channels,
      kernel_size=3
    )

  """ Build Models
  Constructs the network by stacking it layer-by-layer using conv2d in terms of:
    Convolutions
    Activations
    Batch normalizations
  """
  def _build_models(self):
    return nn.Sequential( # all returned as one sequential layer
      # 1
      self._conv2d(3, self.hidden_size),
      nn.LeakyReLU(inplace=True),
      nn.BatchNorm2d(self.hidden_size),

      # 2
      self._conv2d(self.hidden_size, self.hidden_size),
      nn.LeakyReLU(inplace=True),
      nn.BatchNorm2d(self.hidden_size),

      # 3
      self._conv2d(self.hidden_size, self.hidden_size),
      nn.LeakyReLU(inplace=True),
      nn.BatchNorm2d(self.hidden_size),

      # 4
      self._conv2d(self.hidden_size, 1)
    )

  def __init__(self, hidden_size):
    super().__init__()
    self.version = '1'
    self.hidden_size = hidden_size
    self._models = self._build_models()

  def upgrade_legacy(self):
    """Transform legacy pretrained models to make them usable with new code versions."""
    # Transform to version 1
    if not hasattr(self, 'version'):
        self._models = self.layers
        self.version = '1'

  def forward(self, x):
    x = self._models(x)
    x = torch.mean(x.view(x.size(0), -1), dim=1)

    return x

'''

'\nBasicCritic\nImplemented as a subclass of nn.Module\nIt takes an image and predicts whether it is a cover image or a steganographic image (N, 1).\n\nProvided by SteganoGAN:\n  Input: (N, 3, H, W)\n  Output: (N, 1)\n\nInput Details\n  Takes a batch of images as input\n  Parameters - (batch size, num channels, image height, image width)\n\nOutput Details\n  Outputs a tensor of shape (N, 1)\n\n\nclass BasicCritic(nn.Module):\n  """ Conv 2D\n  Helper function which takes in a 2D convolutional layer of fixed size 3x3\n  It then standardizes the size of the kernel throughout the model\n\n  Parameters:\n    in_channels - number of input channels. Since this is using images, it is going to be 3 for R, G, and B values\n    out_channels - number of output channels desired\n\n  Details:\n    Operation - nn.Conv2d extracts spatial features of the input image\n    kernel_size - ensures a fixed size of\n  """\n  def _conv2d(self, in_channels, out_channels):\n    return nn.Conv2d(\n      in_channe

In [16]:
class Mish(nn.Module):
  def forward(self, x):
    return x * torch.tanh(F.softplus(x))

In [14]:
''' Mish-based Critics '''

# BASIC: all returned as one sequential layer
class BasicMishCritic(nn.Module):
  def _conv2d(self, in_channels, out_channels): # identical to SteganoGAN
      return nn.Conv2d(
          in_channels=in_channels,
          out_channels=out_channels,
          kernel_size=3
      )


  ''' Build models (modified):
  Instead of directly creating the
  '''
  def _build_models(self):
    self.c1 = nn.Sequential( # 1
      self._conv2d(3, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size)
    )

    self.c2 = nn.Sequential( # 2
      self._conv2d(self.hidden_size, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size)
    )

    self.c3 = nn.Sequential( # 3
      self._conv2d(self.hidden_size, self.hidden_size),
      Mish(),
      nn.BatchNorm2d(self.hidden_size)
    )

    self.c4 = nn.Sequential( # 4
      self._conv2d(self.hidden_size, 1)
    )

    return self.c1, self.c2, self.c3, self.c4 # return the values as a tuple

  def __init__(self, hidden_size):
    super().__init__()
    self.hidden_size = hidden_size
    self._models = self._build_models()

  def forward(self, img): # new feed forward loop manually moves from one layer to the next
    x = self._models[0](img)
    x1 = self._models[1](x)
    x2 = self._models[2](x1)
    x3 = self._models[3](x2)
    final_x = torch.mean(x3.view(x3.size(0), -1), dim=1)

    return final_x

# DENSE
class DenseMishCritic(nn.Module):
  def __init__(self, weights=densenet.DenseNet121_Weights.IMAGENET1K_V1):
    super(DenseMishCritic, self).__init__() # initialize using inheritance
    self._models = densenet.densenet121(weights=weights)
    self._models.train()

  def forward(self, x):
    features = self._models.features(x)
    out = Mish()(features)
    out = F.avg_pool2d(out, kernel_size=7).view(features.size(0),-1)
    out = torch.mean(out.view(out.size(0),-1),dim=1)
    return out

# RESIDUAL
class ResidualMishCritic(nn.Module):
  def __init__(self, num_classes=2):
    super(ResidualMishCritic, self).__init__() # initialize using inheritance
    self._models = resnet.ResNet(resnet.BasicBlock, [2,2,2,2], num_classes=num_classes)
    self.replace_relu_with_mish()
    self._models.train()

  def forward(self, x):
    x = self._models.conv1(x)
    x = self._models.bn1(x)
    x = Mish()(x) # replace ReLU with Mish
    x = self._models.maxpool(x)

    x = self._models.layer1(x)
    x = self._models.layer2(x)
    x = self._models.layer3(x)
    x = self._models.layer4(x)

    x = self._models.avgpool(x)
    x = torch.mean(x.view(x.size(0), -1), dim=1)
    return x

  # Helper method
  def replace_relu_with_mish(self):
    for name, module in self._models.named_children():
      if isinstance(module, nn.ReLU):
        setattr(self._models, name, Mish())
      elif isinstance(module, nn.Sequential):
        for child_name, child_module in module.named_children():
          if isinstance(child_module, nn.ReLU):
            setattr(module, child_name, Mish())

In [15]:
''' TEST '''
hidden_size = 64
dummy_input = torch.randn(1, 3, 224, 224)

# Instantiate critics
basic_critic = BasicMishCritic(hidden_size)
dense_critic = DenseMishCritic()
residual_critic = ResidualMishCritic(num_classes=2)

# Test BasicMishCritic
basic_output = basic_critic(dummy_input)
print("Basic Critic output shape:", basic_output.shape)
assert basic_output.shape == torch.Size([1]), "Basic Critic output shape is incorrect."

# Test DenseMishCritic
dense_output = dense_critic(dummy_input)
print("Dense Critic output shape:", dense_output.shape)
assert dense_output.shape == torch.Size([1]), "Dense Critic output shape is incorrect."

# Test ResidualMishCritic
residual_output = residual_critic(dummy_input)
print("Residual Critic output shape:", residual_output.shape)
assert residual_output.shape == torch.Size([1]), "Residual Critic output shape is incorrect."

# Verify that ReLU has been replaced with Mish in ResidualMishCritic
for name, module in residual_critic._models.named_children():
    if isinstance(module, Mish):
        print(f"Mish correctly placed in layer: {name}")
    elif isinstance(module, nn.Sequential):
        for child_name, child_module in module.named_children():
            if isinstance(child_module, Mish):
                print(f"Mish correctly placed in layer: {name}.{child_name}")

print("All tests passed.")

Basic Critic output shape: torch.Size([1])
Dense Critic output shape: torch.Size([1])
Residual Critic output shape: torch.Size([1])
Mish correctly placed in layer: relu
All tests passed.
