# SteganoGAN in Keras
This notebook contains code attempting to reimplement SteganoGAN in Keras, for the purpose of better understanding (and scrutinizing) it.

*Based on https://github.com/DAI-Lab/SteganoGAN/tree/master/steganogan*

### Modules

In [1]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import sys
sys.path.append("..")

import tensorflow as tf
import numpy as np
from keras.optimizers import Adam
from keras.losses import BinaryCrossentropy
from keras.utils import plot_model

from models import (
  steganogan_encoder_basic_model,
  steganogan_encoder_residual_model,
  steganogan_encoder_dense_model,

  steganogan_decoder_basic_model,
  steganogan_decoder_dense_model,
  
  steganogan_critic_model
)
from keras_steganogan_basic import KerasSteganoGAN

2024-07-15 22:29:20.815685: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Constants

In [2]:
# Image dimensions
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
IMAGE_CHANNELS = 3

IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)
MESSAGE_DEPTH = 2
BATCH_SIZE = 4
MODEL_PATH = '../pretrained_models/steganoGAN_residual_enc_basic_dec.keras'

----

### Build model for future train

In [3]:
encoder = steganogan_encoder_residual_model(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS, MESSAGE_DEPTH)
decoder = steganogan_decoder_basic_model(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS, MESSAGE_DEPTH)

steganoGAN = KerasSteganoGAN(
  encoder=encoder,
  decoder=decoder,
  image_shape=IMAGE_SHAPE,
  data_depth=MESSAGE_DEPTH,
  model_path=MODEL_PATH
)

steganoGAN.compile(
  encoder_optimizer = Adam(learning_rate=1e-4, beta_1=0.5),
  decoder_optimizer = Adam(learning_rate=1e-4, beta_1=0.5),
  loss_fn = BinaryCrossentropy(from_logits=False)
)

#steganoGAN.models_summary()
#steganoGAN.summary()

### Download div2k dataset and complete it with random message dataset of {0, 1}

In [4]:
train_dir = '/Users/dmitryhoma/Projects/phd_dissertation/state_2/SteganoGAN/research/data/div2k/train'
val_dir = '/Users/dmitryhoma/Projects/phd_dissertation/state_2/SteganoGAN/research/data/div2k/val'

train_image_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir, 
    label_mode=None, 
    color_mode='rgb',
    batch_size=BATCH_SIZE,
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    shuffle=True
)

val_image_ds = tf.keras.preprocessing.image_dataset_from_directory(
    val_dir, 
    label_mode=None, 
    color_mode='rgb',
    batch_size=BATCH_SIZE,
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    shuffle=True
)

def normalize_img(img):
    return (img / 127.5) - 1

train_image_ds = train_image_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
val_image_ds = val_image_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)

def create_message_tensor_for_training(batch_size, width, height, data_depth):
    message = tf.random.uniform([batch_size, width, height, data_depth], 0, 2, dtype=tf.int32)
    message = tf.cast(message, tf.float32)
    return message

def create_message_dataset(batch_size, num_batches, width, height, data_depth):
    message_tensors = [create_message_tensor_for_training(batch_size, width, height, data_depth) for _ in range(num_batches)]
    return tf.data.Dataset.from_tensor_slices(tf.concat(message_tensors, axis=0)).batch(batch_size)

train_message_ds = create_message_dataset(BATCH_SIZE, len(train_image_ds), IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)
val_message_ds = create_message_dataset(BATCH_SIZE, len(val_image_ds), IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)

train_ds = tf.data.Dataset.zip((train_image_ds, train_message_ds)).prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.zip((val_image_ds, val_message_ds)).prefetch(buffer_size=tf.data.AUTOTUNE)

Found 800 files.
Found 100 files.


In [5]:
steganoGAN.build([(1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), (1, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)])
steganoGAN.fit(train_ds, epochs=10, validation_data=val_ds)
steganoGAN.save(MODEL_PATH)

Epoch 1/10
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m120s[0m 568ms/step - bpp: 2.0000 - decoding_loss: 0.7477 - encoder_decoder_total_loss: 0.9698 - psnr: 7.3749 - similarity_loss: 0.2220 - ssim: 0.3177 - val_bpp: 2.0000 - val_decoding_loss: 0.6375 - val_encoder_decoder_total_loss: 0.6806 - val_psnr: 13.7184 - val_similarity_loss: 0.0432 - val_ssim: 0.5247
Epoch 2/10
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m121s[0m 599ms/step - bpp: 2.0000 - decoding_loss: 0.4944 - encoder_decoder_total_loss: 0.5559 - psnr: 12.2361 - similarity_loss: 0.0614 - ssim: 0.4537 - val_bpp: 2.0000 - val_decoding_loss: 0.3106 - val_encoder_decoder_total_loss: 0.3788 - val_psnr: 11.7189 - val_similarity_loss: 0.0682 - val_ssim: 0.4015
Epoch 3/10
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m119s[0m 590ms/step - bpp: 2.0000 - decoding_loss: 0.1910 - encoder_decoder_total_loss: 0.2492 - psnr: 12.4366 - similarity_loss: 0.0583 - ssim: 0.4285 - val_bpp: 2.0000 - val