# SteganoGAN in Keras
This notebook contains code attempting to reimplement SteganoGAN in Keras, for the purpose of better understanding (and scrutinizing) it.

*Based on https://github.com/DAI-Lab/SteganoGAN/tree/master/steganogan*

### Modules

In [1]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import sys
sys.path.append("../..")
sys.path.append("../../..")

import tensorflow as tf
from keras.optimizers import Adam
from keras.losses import BinaryCrossentropy
from tensorflow.keras.callbacks import CSVLogger
from callbacks import Checkpoint, SaveImages

from models import (
  steganogan_encoder_basic_model,
  steganogan_decoder_basic_model
)

from dataset_utils import normalize_img, create_message_dataset
from keras_steganogan import KerasSteganoGAN

2024-09-22 17:35:15.903011: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Constants

In [2]:
# Image dimensions
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
IMAGE_CHANNELS = 3

IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)
MESSAGE_DEPTH = 6
BATCH_SIZE = 4
MODEL_PATH = f'../../pretrained_models/{MESSAGE_DEPTH}/steganoGAN_basic.keras'
LOGS_PATH = f'../../logs/basic/{MESSAGE_DEPTH}/steganoGAN_basic.csv'
CALLBACK_IMAGES_PATH = '../../../images/callback'
CALLBACK_IMAGES_OUTPUT_PATH = f'../../epoch_images/basic/{MESSAGE_DEPTH}'

----

### Build model for future train

In [3]:
encoder = steganogan_encoder_basic_model(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS, MESSAGE_DEPTH)
decoder = steganogan_decoder_basic_model(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS, MESSAGE_DEPTH)

steganoGAN = KerasSteganoGAN(
  encoder=encoder,
  decoder=decoder,
  image_shape=IMAGE_SHAPE,
  data_depth=MESSAGE_DEPTH,
  model_path=MODEL_PATH
)

steganoGAN.compile(
  encoder_optimizer = Adam(learning_rate=1e-4, beta_1=0.5),
  decoder_optimizer = Adam(learning_rate=1e-4, beta_1=0.5),
  loss_fn = BinaryCrossentropy(from_logits=False)
)

#steganoGAN.models_summary()
#steganoGAN.summary()

### Download div2k dataset and complete it with random message dataset of {0, 1}

In [4]:
train_dir = '/Users/dmitryhoma/Projects/phd_dissertation/state_2/SteganoGAN/research/data/div2k/train'
val_dir = '/Users/dmitryhoma/Projects/phd_dissertation/state_2/SteganoGAN/research/data/div2k/val'

train_image_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir,
    label_mode=None,
    color_mode='rgb',
    batch_size=BATCH_SIZE,
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    shuffle=True
)

val_image_ds = tf.keras.preprocessing.image_dataset_from_directory(
    val_dir,
    label_mode=None,
    color_mode='rgb',
    batch_size=BATCH_SIZE,
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    shuffle=True
)

train_image_ds = train_image_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
val_image_ds = val_image_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)

train_message_ds = create_message_dataset(BATCH_SIZE, len(train_image_ds), IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)
val_message_ds = create_message_dataset(BATCH_SIZE, len(val_image_ds), IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)

train_ds = tf.data.Dataset.zip((train_image_ds, train_message_ds)).prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.zip((val_image_ds, val_message_ds)).prefetch(buffer_size=tf.data.AUTOTUNE)

Found 800 files.
Found 100 files.


In [5]:
steganoGAN.build([(1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), (1, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)])
steganoGAN.fit(train_ds, epochs=5, validation_data=val_ds, callbacks=[
  SaveImages(MESSAGE_DEPTH, IMAGE_SHAPE, CALLBACK_IMAGES_PATH, CALLBACK_IMAGES_OUTPUT_PATH),
  Checkpoint(MODEL_PATH),
  CSVLogger(LOGS_PATH)
])

Epoch 1/5
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m127s[0m 600ms/step - decoder_accuracy: 0.5269 - decoder_loss: 0.7544 - encoder_decoder_total_loss: 1.1221 - psnr: 5.3423 - rs_bpp: 0.3226 - similarity_loss: 0.3677 - ssim: 0.1909 - val_decoder_accuracy: 0.5830 - val_decoder_loss: 0.6799 - val_encoder_decoder_total_loss: 0.8050 - val_psnr: 9.1432 - val_rs_bpp: 0.9956 - val_similarity_loss: 0.1251 - val_ssim: 0.2534
Epoch 2/5
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 620ms/step - decoder_accuracy: 0.5887 - decoder_loss: 0.6722 - encoder_decoder_total_loss: 0.7560 - psnr: 11.0597 - rs_bpp: 1.0641 - similarity_loss: 0.0838 - ssim: 0.4382 - val_decoder_accuracy: 0.6520 - val_decoder_loss: 0.6402 - val_encoder_decoder_total_loss: 0.6960 - val_psnr: 12.6377 - val_rs_bpp: 1.8236 - val_similarity_loss: 0.0559 - val_ssim: 0.4984
Epoch 3/5
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m119s[0m 590ms/step - decoder_accuracy: 0.6611 - decode

<keras.src.callbacks.history.History at 0x156074e60>