# SteganoGAN in Keras
This notebook contains code attempting to reimplement SteganoGAN in Keras, for the purpose of better understanding (and scrutinizing) it.

*Based on https://github.com/DAI-Lab/SteganoGAN/tree/master/steganogan*

### Modules

In [3]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import numpy as np
from keras.optimizers import Adam
from keras.losses import BinaryCrossentropy
from keras.utils import plot_model

from models import steganogan_encoder_dense_model, steganogan_decoder_dense_model, steganogan_critic_model
from keras_steganogan import KerasSteganoGAN

### Constants

In [4]:
# Image dimensions
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
IMAGE_CHANNELS = 3

IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)
MESSAGE_DEPTH = 2
BATCH_SIZE = 4
MODEL_PATH = 'pretrained_models/steganoGAN.keras'

### Call main encode and decode functions (with creating steganographic image and decoding it)

In [3]:
steganoGAN = KerasSteganoGAN(image_shape=IMAGE_SHAPE, data_depth=MESSAGE_DEPTH, model_path=MODEL_PATH)
steganoGAN.encode('input.png', 'images/output.png', 'Hello, World!')
steganoGAN.decode('images/output.png')

'ìþ_ÿÿÍ\r\x7f\\ÖýÝ1ñÿýÿÿÿ4ïê«\x02\x1fëëÿÿC\nø&4ÄL\x14\x03´\x05$¹\nf±\x89@\nj\x1fòºÿ\x7fõ8\x91UWAÑ\x89\x95Ò\x8aÐ®\x8e«0\x1b\x00B12\x90\x91)µ\x13{Í_\x0cv°!¿¢ëªº\x91\x8d¢\x91æ=´Wº\x07Ä\x93ÇØq\x830³zr¶ì\x0f4ûõu%wl6\x8aÝÆ6Ó\x8d\x99«lwkìì\x0e\x13²[³;\x9aóàÎF\x8cE{ù\x9c&1%\x8e\x9f\x8f\x81Ù¸ó\x8e\x88\x9a\x8c²\x90\xadÈ$>ú0\x1báªîÇ\x88é\x00)<\x02òöÇÊÆÈÑyÃ1Ç\x9e\x16Ç2¦\x8dc\x9bS8\xadñãÌÌ¿m4s\x17s¨"\x93Ø\x8eõ\x99\x14\x0ewÃ\x0bh<\x81y\x963\x9f>x¼#ÁÆä\x13?Êñ6v´ÖÃ\x98í`Ù\x8bh®\x8b;\x01!Ï\x10i\x8d/\x0e\x1aôqòdLÙ3¼ñ<=\x7f\x82\x1dÚôºÜë\x93mÈ.\x07ü\x93Ùò\x13µ\x8f\x9bX1²öæA\x11\x07ñ<è\x02\x02æ\x99FØÐ!½y\x85w\x7f\x1cüT$Õ\x99\xad\x11{ñàNÌë©Ç-\x9eõWwÿÕû|\x06s.¤øã\x8e\x04 h\xadwõ³@#qöä\x0ei\x9eëYÞ\x82ª\xa0\x00\x93¾ð°n\x95\x97-\x86uç\x9e;N´H2:ô\x07ÕÞ\x8cåCñ£6ÿ÷õu\x1e\x12q\'IÅKÇµl1\x16üÓ«ª\x8b@^I\x04Nýï\x9d\x87c\x06\x00\x10P\x80éÛH\\¤ÆCØr\x15P\r\x0b5ÐõSú®ýîÆÖe\x97ï¡\x14UU\x04PÒéKÈô\x9fÌ²UÚû³HÁ\x7fÿ÷!ß÷.¦âª©1¿´àª«£ïS\'\x1c~rÌ¶\x80"R¢sx\x02\x04\x12C\x11\x18"\x81\x13ô\x9dO\x07´\x03ôZ\x12¹J´ 5ÓÍ È¶\x83r61Cm×;°\x90Õ

### SteganoGAN predict random data with metrics

In [8]:
cover_image = tf.random.uniform([1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS], -1, 1, dtype=tf.float32)
message = tf.cast(tf.random.uniform([1, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH], 0, 2, dtype=tf.int32), tf.float32)

stego_img, recovered_msg = steganoGAN.predict([cover_image, message])

print("stego_img min: {0}, max: {1}".format(tf.reduce_min(stego_img), tf.reduce_max(stego_img)))
print("recovered_msg min: {0}, max: {1}".format(tf.reduce_min(recovered_msg), tf.reduce_max(recovered_msg)))

print("BinaryCrossentropy: {0}".format(BinaryCrossentropy(from_logits=True)(message, recovered_msg)))
print("PSNR: {0}".format(tf.reduce_mean(tf.image.psnr(cover_image, stego_img, 1))))
print("SSIM: {0}".format(tf.reduce_mean(tf.image.ssim(cover_image, stego_img, 1))))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step
stego_img min: -1.0, max: 1.0
recovered_msg min: 0.0, max: 1.0
BinaryCrossentropy: 0.7435410022735596
PSNR: 14.657017707824707
SSIM: 0.7204346656799316


----

### Build model for future train

In [3]:
encoder = steganogan_encoder_dense_model(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS, MESSAGE_DEPTH)
decoder = steganogan_decoder_dense_model(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS, MESSAGE_DEPTH)
critic  = steganogan_critic_model(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)

steganoGAN = KerasSteganoGAN(
  encoder=encoder,
  decoder=decoder,
  critic=critic,
  image_shape=IMAGE_SHAPE,
  data_depth=MESSAGE_DEPTH,
  model_path=MODEL_PATH
)

steganoGAN.compile(
  encoder_optimizer = Adam(learning_rate=1e-4, beta_1=0.5),
  decoder_optimizer = Adam(learning_rate=1e-4, beta_1=0.5),
  critic_optimizer = Adam(learning_rate=1e-4, beta_1=0.5),
  loss_fn = BinaryCrossentropy(from_logits=True)
)

#steganoGAN.models_summary()
#steganoGAN.summary()
#plot_model(steganoGAN.encoder, to_file='model_images/encoder.png', show_shapes=True, show_layer_names=True)
#plot_model(steganoGAN.decoder, to_file='model_images/decoder.png', show_shapes=True, show_layer_names=True)
#plot_model(steganoGAN.critic, to_file='model_images/critic.png', show_shapes=True, show_layer_names=True)

### Download div2k dataset and complete it with random message dataset of {0, 1}

In [4]:
train_dir = '/Users/dmitryhoma/Projects/phd_dissertation/state_2/SteganoGAN/research/data/div2k/train'
val_dir = '/Users/dmitryhoma/Projects/phd_dissertation/state_2/SteganoGAN/research/data/div2k/val'

train_image_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir, 
    label_mode=None, 
    color_mode='rgb',
    batch_size=BATCH_SIZE,
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    shuffle=True
)

val_image_ds = tf.keras.preprocessing.image_dataset_from_directory(
    val_dir, 
    label_mode=None, 
    color_mode='rgb',
    batch_size=BATCH_SIZE,
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    shuffle=True
)

def normalize_img(img):
    return (img / 127.5) - 1

train_image_ds = train_image_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
val_image_ds = val_image_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)

def create_message_tensor_for_training(batch_size, width, height, data_depth):
    message = tf.random.uniform([batch_size, width, height, data_depth], 0, 2, dtype=tf.int32)
    message = tf.cast(message, tf.float32)
    return message

def create_message_dataset(batch_size, num_batches, width, height, data_depth):
    message_tensors = [create_message_tensor_for_training(batch_size, width, height, data_depth) for _ in range(num_batches)]
    return tf.data.Dataset.from_tensor_slices(tf.concat(message_tensors, axis=0)).batch(batch_size)

train_message_ds = create_message_dataset(BATCH_SIZE, len(train_image_ds), IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)
val_message_ds = create_message_dataset(BATCH_SIZE, len(val_image_ds), IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)

train_ds = tf.data.Dataset.zip((train_image_ds, train_message_ds)).prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.zip((val_image_ds, val_message_ds)).prefetch(buffer_size=tf.data.AUTOTUNE)

Found 800 files.
Found 100 files.


In [5]:
steganoGAN.build([(1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), (1, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)])
steganoGAN.fit(train_ds, epochs=10, validation_data=val_ds)
steganoGAN.save(MODEL_PATH)

Epoch 1/10
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m243s[0m 1s/step - bpp: 2.0000 - critic_loss: -0.0217 - decoding_loss: 0.4530 - encoder_decoder_total_loss: 0.7289 - psnr: 25.3925 - realism_loss: -0.0383 - similarity_loss: 0.3142 - ssim: 0.8950 - val_bpp: 2.0000 - val_critic_loss: -0.1052 - val_decoding_loss: 0.4456 - val_encoder_decoder_total_loss: 0.8875 - val_psnr: 25.7010 - val_realism_loss: 0.1661 - val_similarity_loss: 0.2759 - val_ssim: 0.9073
Epoch 2/10
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m234s[0m 1s/step - bpp: 2.0000 - critic_loss: -0.0175 - decoding_loss: 0.4184 - encoder_decoder_total_loss: 0.6198 - psnr: 26.4359 - realism_loss: -0.0391 - similarity_loss: 0.2405 - ssim: 0.9054 - val_bpp: 2.0000 - val_critic_loss: -0.1195 - val_decoding_loss: 0.4009 - val_encoder_decoder_total_loss: 0.5077 - val_psnr: 26.6778 - val_realism_loss: -0.1132 - val_similarity_loss: 0.2201 - val_ssim: 0.9126
Epoch 3/10
[1m200/200[0m [32m━━━━━━━━━━━━━━━

----

### Testing different functions 

In [9]:
from tensorflow.keras.losses import BinaryCrossentropy

# Create two random tensors
tensor1 = tf.random.uniform((4, 128, 128, 2), minval=0, maxval=1, dtype=tf.int32)
tensor2 = tf.random.uniform((4, 128, 128, 2), minval=0, maxval=1, dtype=tf.float32)

# Calculate binary crossentropy
loss = BinaryCrossentropy(from_logits=True)
loss = loss(tf.constant([2, 15, 100]), tf.constant([2, 15, 100]))
#loss = loss(tensor1, tensor2)

loss.numpy().astype(np.uint8)

214

In [None]:
first_value = next(iter(train_ds.take(1)))
first, second = first_value
#print(first)
#print(second)

tf.Tensor(
[[[[1. 0.]
   [1. 0.]
   [0. 0.]
   ...
   [0. 1.]
   [1. 0.]
   [0. 1.]]

  [[1. 1.]
   [1. 0.]
   [0. 1.]
   ...
   [1. 1.]
   [0. 1.]
   [0. 1.]]

  [[0. 0.]
   [0. 0.]
   [1. 0.]
   ...
   [1. 0.]
   [0. 1.]
   [1. 0.]]

  ...

  [[1. 0.]
   [1. 0.]
   [0. 0.]
   ...
   [1. 1.]
   [0. 1.]
   [0. 0.]]

  [[1. 1.]
   [1. 0.]
   [1. 0.]
   ...
   [0. 1.]
   [0. 1.]
   [0. 0.]]

  [[0. 1.]
   [1. 0.]
   [1. 1.]
   ...
   [1. 0.]
   [0. 0.]
   [0. 1.]]]


 [[[0. 0.]
   [1. 1.]
   [1. 1.]
   ...
   [1. 1.]
   [0. 1.]
   [1. 1.]]

  [[0. 1.]
   [0. 0.]
   [0. 1.]
   ...
   [1. 1.]
   [0. 0.]
   [1. 1.]]

  [[1. 0.]
   [1. 1.]
   [1. 0.]
   ...
   [0. 1.]
   [0. 0.]
   [0. 0.]]

  ...

  [[1. 0.]
   [0. 0.]
   [0. 1.]
   ...
   [0. 0.]
   [1. 0.]
   [1. 0.]]

  [[1. 1.]
   [0. 1.]
   [0. 1.]
   ...
   [0. 1.]
   [0. 1.]
   [0. 1.]]

  [[1. 1.]
   [1. 1.]
   [0. 1.]
   ...
   [1. 1.]
   [1. 1.]
   [1. 1.]]]


 [[[0. 1.]
   [1. 0.]
   [0. 1.]
   ...
   [1. 0.]
   [0. 1.]
   [0. 1.