# SteganoGAN in Keras
This notebook contains code attempting to reimplement SteganoGAN in Keras, for the purpose of better understanding (and scrutinizing) it.

*Based on https://github.com/DAI-Lab/SteganoGAN/tree/master/steganogan*

### Modules

In [1]:
import os
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MeanSquaredError
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping
from callbacks import SaveImages

from resnet_steganogan_gp import ResnetSteganoGAN
from models import ResidualEncoder, BasicDecoder, Critic
from dataset_utils import create_message_dataset

2025-01-27 23:35:38.697045: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Constants

In [2]:
MESSAGE_DEPTH = 6
BATCH_SIZE = 4
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
IMAGE_CHANNELS = 3
IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)

MODEL_PATH = 'ResnetSteganoGAN.weights.h5'
LOGS_PATH = 'ResnetSteganoGAN.csv'
CALLBACK_IMAGES_PATH = 'images/callback'
CALLBACK_IMAGES_OUTPUT_PATH = 'epoch_images'

----

### Build model for future train

In [3]:
encoder = ResidualEncoder(MESSAGE_DEPTH)
decoder = BasicDecoder(MESSAGE_DEPTH)
critic = Critic()

encoder.build(input_shape=[(None, None, None, IMAGE_CHANNELS), (None, None, None, MESSAGE_DEPTH)])
decoder.build(input_shape=(None, None, None, IMAGE_CHANNELS))
critic.build(input_shape=(None, None, None, IMAGE_CHANNELS))

resnetSteganoGAN = ResnetSteganoGAN(
  encoder=encoder,
  decoder=decoder,
  critic=critic,
  data_depth=MESSAGE_DEPTH
)

resnetSteganoGAN.build(input_shape=[(None, None, None, IMAGE_CHANNELS), (None, None, None, MESSAGE_DEPTH)])

if MODEL_PATH is not None and os.path.exists(MODEL_PATH):
  resnetSteganoGAN.load_weights(MODEL_PATH)
  print(f'Model loaded from {MODEL_PATH}')

resnetSteganoGAN.compile(
  encoder_optimizer  = Adam(learning_rate=1e-4),
  decoder_optimizer  = Adam(learning_rate=1e-4),
  critic_optimizer   = Adam(learning_rate=1e-4, beta_1=0.5, beta_2=0.9),
  similarity_loss_fn = MeanSquaredError(),
  decoder_loss_fn    = BinaryCrossentropy(from_logits=False) # false means that data to compare is in [0, 1]
)

# resnetSteganoGAN.summary()
# resnetSteganoGAN.encoder.summary()
# resnetSteganoGAN.decoder.summary()
# resnetSteganoGAN.critic.summary()

### Download div2k dataset and complete it with random message dataset of {0, 1}

In [5]:
import tensorflow_datasets as tfds

# Load DIV2K dataset
ds_div2k = tfds.load('div2k', shuffle_files=True)

# Extract and preprocess high-resolution images
def preprocess_hr(image):
    image = tf.image.resize(image, (IMAGE_HEIGHT, IMAGE_WIDTH))  # Resize to target shape
    image = tf.cast(image, tf.float32)      # Convert to float
    image = (image / 127.5) - 1.0  # Normalize to [-1, 1]
    return image

train_image_ds = ds_div2k['train'].map(lambda x: preprocess_hr(x['hr']), num_parallel_calls=tf.data.AUTOTUNE)
val_image_ds = ds_div2k['validation'].map(lambda x: preprocess_hr(x['hr']), num_parallel_calls=tf.data.AUTOTUNE)

# Batch datasets
train_image_ds = train_image_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_image_ds = val_image_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# Create message datasets
train_message_ds = create_message_dataset(len(train_image_ds) * BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH).batch(BATCH_SIZE)
val_message_ds = create_message_dataset(len(val_image_ds) * BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH).batch(BATCH_SIZE)

# Combine image and message datasets
train_ds = tf.data.Dataset.zip((train_image_ds, train_message_ds))
val_ds = tf.data.Dataset.zip((val_image_ds, val_message_ds))

# # Example usage
# for images, messages in train_ds.take(1):
#     print("Image batch shape:", images.shape)
#     print("Message batch shape:", messages.shape)

In [6]:
resnetSteganoGAN.fit(train_ds, epochs=1, validation_data=val_ds, initial_epoch=0, callbacks=[
  CSVLogger(LOGS_PATH, append=True),
  ModelCheckpoint(MODEL_PATH, monitor='encoder_decoder_total_loss', mode='min', save_weights_only=True),
  SaveImages(MESSAGE_DEPTH, IMAGE_SHAPE, CALLBACK_IMAGES_PATH, CALLBACK_IMAGES_OUTPUT_PATH)
])

[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m464s[0m 2s/step - critic_loss: 330.9911 - decoder_accuracy: 0.5001 - decoder_loss: 0.7385 - encoder_decoder_total_loss: 16821.7793 - psnr: -41.7834 - realism_loss: -0.0104 - similarity_loss: 16821.0488 - ssim: inf - val_critic_loss: -1.5182 - val_decoder_accuracy: 0.5002 - val_decoder_loss: 0.6942 - val_encoder_decoder_total_loss: 16539.3887 - val_psnr: -41.4581 - val_realism_loss: -0.0046 - val_similarity_loss: 16538.6992 - val_ssim: inf


<keras.src.callbacks.history.History at 0x14d3838f0>

In [7]:
evaluated_metrics = resnetSteganoGAN.evaluate(val_ds)

metrics_names = [
  'encoder_decoder_total_loss',
  'critic_loss',
  'similarity_loss',
  'decoder_loss',
  'decoder_accuracy',
  'realism_loss',
  'psnr',
  'ssim'
]

for key, value in zip(metrics_names, evaluated_metrics):
  print(f'{key}: {value}')

[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 265ms/step - critic_loss: -1.4506 - decoder_accuracy: 0.5000 - decoder_loss: 0.6943 - encoder_decoder_total_loss: 15586.0791 - psnr: -41.1314 - realism_loss: -0.0046 - similarity_loss: 15585.3887 - ssim: inf
encoder_decoder_total_loss: 16539.38671875
critic_loss: -1.5171455144882202
similarity_loss: 16538.6953125
decoder_loss: 0.6942342519760132
decoder_accuracy: 0.5002461671829224
realism_loss: -0.004601673223078251
psnr: -41.45809555053711
ssim: inf
