# SteganoGAN in Keras
This notebook contains code attempting to reimplement SteganoGAN in Keras, for the purpose of better understanding (and scrutinizing) it.

*Based on https://github.com/DAI-Lab/SteganoGAN/tree/master/steganogan*

### Modules

In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MeanSquaredError
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping
from callbacks import SaveImages

from models import (
  steganogan_encoder_residual_model,
  steganogan_decoder_basic_model,
  steganogan_critic_model
)

from dataset_utils import normalize_img, create_message_dataset
from keras_steganogan import KerasSteganoGAN

2024-12-02 23:40:09.949467: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Constants

In [2]:
# Image dimensions

MESSAGE_DEPTH = 6
BATCH_SIZE = 4
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
IMAGE_CHANNELS = 3
IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)

MODEL_PATH = 'steganoGAN_residual.keras'
LOGS_PATH = 'steganoGAN_residual.csv'
CALLBACK_IMAGES_PATH = 'images/callback'
CALLBACK_IMAGES_OUTPUT_PATH = 'epoch_images'

----

### Build model for future train

In [None]:
encoder = steganogan_encoder_residual_model(MESSAGE_DEPTH)
decoder = steganogan_decoder_basic_model(MESSAGE_DEPTH)
critic = steganogan_critic_model()

steganoGAN = KerasSteganoGAN(
  encoder=encoder,
  decoder=decoder,
  critic=critic,
  data_depth=MESSAGE_DEPTH
)

steganoGAN.build([(1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), (1, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)])

if MODEL_PATH is not None and os.path.exists(MODEL_PATH):
  steganoGAN.load_weights(MODEL_PATH)
  
steganoGAN.compile(
  encoder_optimizer  = Adam(learning_rate=1e-4, beta_1=0.5),
  decoder_optimizer  = Adam(learning_rate=1e-4, beta_1=0.5),
  critic_optimizer   = Adam(learning_rate=1e-4, beta_1=0.5),
  similarity_loss_fn = MeanSquaredError(),
  decoder_loss_fn    = BinaryCrossentropy(from_logits=False)
)

# steganoGAN.models_summary()
# steganoGAN.summary()

### Download div2k dataset and complete it with random message dataset of {0, 1}

In [None]:
train_dir = '/Users/dmitryhoma/Projects/phd_dissertation/state_2/SteganoGAN/research/data/div2k/train'
val_dir = '/Users/dmitryhoma/Projects/phd_dissertation/state_2/SteganoGAN/research/data/div2k/val'

train_image_ds = tf.keras.preprocessing.image_dataset_from_directory(
  train_dir,
  label_mode=None,
  color_mode='rgb',
  batch_size=BATCH_SIZE,
  seed=123,
  image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
  shuffle=True
)

val_image_ds = tf.keras.preprocessing.image_dataset_from_directory(
  val_dir,
  label_mode=None,
  color_mode='rgb',
  batch_size=BATCH_SIZE,
  seed=123,
  image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
  shuffle=True
)

train_image_ds = train_image_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
val_image_ds = val_image_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)

train_message_ds = create_message_dataset(len(train_image_ds)*BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH).batch(BATCH_SIZE, num_parallel_calls=tf.data.AUTOTUNE)
val_message_ds = create_message_dataset(len(val_image_ds)*BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH).batch(BATCH_SIZE, num_parallel_calls=tf.data.AUTOTUNE)

train_ds = tf.data.Dataset.zip((train_image_ds, train_message_ds)).prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.zip((val_image_ds, val_message_ds)).prefetch(buffer_size=tf.data.AUTOTUNE)

Found 800 files.
Found 100 files.


In [5]:
steganoGAN.fit(train_ds, epochs=1, validation_data=val_ds, callbacks=[
  SaveImages(MESSAGE_DEPTH, IMAGE_SHAPE, CALLBACK_IMAGES_PATH, CALLBACK_IMAGES_OUTPUT_PATH),
  ModelCheckpoint(MODEL_PATH, save_best_only=True, monitor='encoder_decoder_total_loss', mode='min'),
  EarlyStopping(monitor='encoder_decoder_total_loss', mode='min', patience=10, min_delta=0.001),
  CSVLogger(LOGS_PATH, append=True)
])

[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m152s[0m 716ms/step - critic_loss: -0.0078 - decoder_accuracy: 0.7602 - decoder_loss: 0.4796 - encoder_decoder_total_loss: 0.5312 - psnr: 14.8808 - realism_loss: 0.0179 - rs_bpp: 3.1229 - similarity_loss: 0.0336 - ssim: 0.5345 - val_critic_loss: -0.0113 - val_decoder_accuracy: 0.7686 - val_decoder_loss: 0.4644 - val_encoder_decoder_total_loss: 0.5325 - val_psnr: 14.3393 - val_realism_loss: 0.0294 - val_rs_bpp: 3.2238 - val_similarity_loss: 0.0386 - val_ssim: 0.5218


<keras.src.callbacks.history.History at 0x1117f6990>

In [6]:
evaluated_metrics = steganoGAN.evaluate(val_ds)

metrics_names = [
  'encoder_decoder_total_loss',
  'critic_loss',
  'similarity_loss',
  'decoder_loss',
  'decoder_accuracy',
  'realism_loss',
  'psnr',
  'ssim',
  'rs_bpp'
]

for key, value in zip(metrics_names, evaluated_metrics):
  print(f'{key}: {value}')

[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 185ms/step - critic_loss: -0.0107 - decoder_accuracy: 0.7708 - decoder_loss: 0.4609 - encoder_decoder_total_loss: 0.5317 - psnr: 14.4274 - realism_loss: 0.0330 - rs_bpp: 3.2494 - similarity_loss: 0.0378 - ssim: 0.5306
encoder_decoder_total_loss: 0.5325661301612854
critic_loss: -0.01129884459078312
similarity_loss: 0.03869222104549408
decoder_loss: 0.46444883942604065
decoder_accuracy: 0.7687587141990662
realism_loss: 0.029425162822008133
psnr: 14.334359169006348
ssim: 0.5213987827301025
rs_bpp: 3.2251038551330566
