# SteganoGAN in Keras
This notebook contains code attempting to reimplement SteganoGAN in Keras, for the purpose of better understanding (and scrutinizing) it.

*Based on https://github.com/DAI-Lab/SteganoGAN/tree/master/steganogan*

### Modules

In [1]:
import os
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MeanSquaredError
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint
from callbacks import SaveImages

from resnet_steganogan import SteganoGAN
from models import DenseEncoder, DenseDecoder, Discriminator

2025-02-11 20:08:54.804625: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Constants

In [None]:
MESSAGE_DEPTH = 4
BATCH_SIZE = 4
IMAGE_HEIGHT = 360
IMAGE_WIDTH = 360
IMAGE_CHANNELS = 3
IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)

MODEL_PATH = 'SteganoGAN.weights.h5'
LOGS_PATH = 'SteganoGAN.csv'
CALLBACK_IMAGES_PATH = 'images/callback'
CALLBACK_IMAGES_OUTPUT_PATH = 'epoch_images'

----

### Build model for future train

In [3]:
encoder = DenseEncoder(MESSAGE_DEPTH)
decoder = DenseDecoder(MESSAGE_DEPTH)
discriminator = Discriminator()

encoder.build(input_shape=[(None, None, None, IMAGE_CHANNELS), (None, None, None, MESSAGE_DEPTH)])
decoder.build(input_shape=(None, None, None, IMAGE_CHANNELS))
discriminator.build(input_shape=(None, None, None, IMAGE_CHANNELS))

stegano_gan = SteganoGAN(
  encoder=encoder,
  decoder=decoder,
  discriminator=discriminator,
  data_depth=MESSAGE_DEPTH,
  image_shape=IMAGE_SHAPE
)

stegano_gan.build(input_shape=[(None, None, None, IMAGE_CHANNELS), (None, None, None, MESSAGE_DEPTH)])

if MODEL_PATH is not None and os.path.exists(MODEL_PATH):
  stegano_gan.load_weights(MODEL_PATH)
  print(f'Model loaded from {MODEL_PATH}')

stegano_gan.compile(
  encoder_decoder_optimizer  = Adam(learning_rate=1e-4),
  discriminator_optimizer    = Adam(learning_rate=1e-4, beta_1=0.5, beta_2=0.9),
  similarity_loss_fn         = MeanSquaredError(),
  decoder_loss_fn            = BinaryCrossentropy(from_logits=True),
  discriminator_loss_fn      = BinaryCrossentropy(from_logits=False)
)

stegano_gan.summary()
stegano_gan.encoder.summary()
stegano_gan.decoder.summary()
stegano_gan.discriminator.summary()

### Download div2k dataset and complete it with random message dataset of {0, 1}

In [None]:
import tensorflow_datasets as tfds

# Load DIV2K dataset
ds_div2k = tfds.load('div2k', shuffle_files=True)

# Extract and preprocess high-resolution images
def preprocess_hr(image):
    image = tf.image.resize(image, (IMAGE_HEIGHT, IMAGE_WIDTH))  # Resize to target shape
    image = tf.cast(image, tf.float32)      # Convert to float
    image = (image / 127.5) - 1.0           # Normalize to [-1, 1]
    return image

train_image_ds = ds_div2k['train'].map(lambda x: preprocess_hr(x['hr']), num_parallel_calls=tf.data.AUTOTUNE)
val_image_ds = ds_div2k['validation'].map(lambda x: preprocess_hr(x['hr']), num_parallel_calls=tf.data.AUTOTUNE)

# Batch datasets
train_ds = train_image_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = val_image_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
# stegano_gan.fit(train_ds, epochs=10, validation_data=val_ds, initial_epoch=0, callbacks=[
#   CSVLogger(LOGS_PATH, append=True),
#   ModelCheckpoint(MODEL_PATH, monitor='encoder_decoder_total_loss', mode='min', save_weights_only=True),
#   SaveImages(MESSAGE_DEPTH, IMAGE_SHAPE, CALLBACK_IMAGES_PATH, CALLBACK_IMAGES_OUTPUT_PATH)
# ])

In [None]:
evaluated_metrics = stegano_gan.evaluate(val_ds)

metrics_names = [
  'encoder_decoder_total_loss',
  'critic_loss',
  'similarity_loss',
  'decoder_loss',
  'decoder_accuracy',
  'realism_loss',
  'psnr',
  'ssim'
]

for key, value in zip(metrics_names, evaluated_metrics):
  print(f'{key}: {value}')

In [None]:
#%rm -rf epoch_images/
#%rm -rf SteganoGAN.weights.h5
#%rm -rf SteganoGAN.csv

In [12]:
import tensorflow as tf

# Load image1
image_path = 'images/callback/image1.png'
image1 = tf.io.read_file(image_path)
image1 = tf.image.decode_jpeg(image1, channels=3)
image1 = tf.image.convert_image_dtype(image1, dtype=tf.float32)

# Apply Gaussian blur to the image
blurred_image = tf.image.resize(image1, (image1.shape[0] // 2, image1.shape[1] // 2))
blurred_image = tf.image.resize(blurred_image, image1.shape[:2])

# Compute PSNR
psnr_original_vs_blurred = tf.image.psnr(image1, blurred_image, max_val=1.0)

print(f"PSNR between original and blurred image: {psnr_original_vs_blurred.numpy()} dB")

PSNR between original and blurred image: 22.734779357910156 dB


In [14]:
import tensorflow as tf

# Load the original image
image_path = "images/callback/image1.png"
original_image = tf.io.read_file(image_path)
original_image = tf.image.decode_png(original_image, channels=3)
original_image = tf.image.convert_image_dtype(original_image, dtype=tf.float32)

# Change one pixel (e.g., at position (0, 0))
modified_image = original_image.numpy()  # Convert to numpy array to modify pixel
modified_image[0, 0] += 0.1  # Modify pixel value
modified_image = tf.convert_to_tensor(modified_image, dtype=tf.float32)  # Convert back to tensor

# Compute PSNR
psnr_value = tf.image.psnr(original_image, modified_image, max_val=1.0)

print(f"PSNR between original and modified image: {psnr_value.numpy()} dB")


PSNR between original and modified image: 74.87420654296875 dB
