# SteganoGAN in Keras
This notebook contains code attempting to reimplement SteganoGAN in Keras, for the purpose of better understanding (and scrutinizing) it.

*Based on https://github.com/DAI-Lab/SteganoGAN/tree/master/steganogan*

### Modules

In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import sys

import tensorflow as tf
from keras.losses import BinaryCrossentropy

from models import (
  steganogan_encoder_residual_model,
  steganogan_decoder_basic_model,
  steganogan_critic_model
)

from keras_steganogan import KerasSteganoGAN

### Constants

In [None]:
# Image dimensions
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
IMAGE_CHANNELS = 3

IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)
MESSAGE_DEPTH = 6
BATCH_SIZE = 4
MODEL_PATH = 'steganoGAN_residual.keras'

IMAGES_TESTING_PATH = 'images/testing'
IMAGES_OUTPUT_PATH = 'images/testing_output'

----

### Call main encode and decode functions (with creating steganographic image and decoding it)

In [None]:
encoder = steganogan_encoder_residual_model(MESSAGE_DEPTH)
decoder = steganogan_decoder_basic_model(MESSAGE_DEPTH)
critic = steganogan_critic_model()

steganoGAN = KerasSteganoGAN(
  encoder=encoder,
  decoder=decoder,
  critic=critic,
  data_depth=MESSAGE_DEPTH
)

steganoGAN.build([(1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), (1, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)])

if MODEL_PATH is not None and os.path.exists(MODEL_PATH):
  steganoGAN.load_weights(MODEL_PATH)

steganoGAN.encode(f'{IMAGES_TESTING_PATH}/input1.png', f'{IMAGES_OUTPUT_PATH}/output1.png', 'Hello, World! 1111')
print(steganoGAN.decode(f'{IMAGES_OUTPUT_PATH}/output1.png'))

steganoGAN.encode(f'{IMAGES_TESTING_PATH}/input2.png', f'{IMAGES_OUTPUT_PATH}/output2.png', 'Hello, World! 2222')
print(steganoGAN.decode(f'{IMAGES_OUTPUT_PATH}/output2.png'))

steganoGAN.encode(f'{IMAGES_TESTING_PATH}/input3.png', f'{IMAGES_OUTPUT_PATH}/output3.png', 'Hello, World! 3333')
print(steganoGAN.decode(f'{IMAGES_OUTPUT_PATH}/output3.png'))

steganoGAN.encode(f'{IMAGES_TESTING_PATH}/input4.png', f'{IMAGES_OUTPUT_PATH}/output4.png', 'Hello, World! 4444')
print(steganoGAN.decode(f'{IMAGES_OUTPUT_PATH}/output4.png'))

### SteganoGAN predict random data with metrics

In [None]:
cover_image = tf.random.uniform([1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS], -1, 1, dtype=tf.float32)
message = tf.cast(tf.random.uniform([1, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH], 0, 2, dtype=tf.int32), tf.float32)

stego_img, recovered_msg = steganoGAN.predict([cover_image, message])

print("stego_img min: {0}, max: {1}".format(tf.reduce_min(stego_img), tf.reduce_max(stego_img)))
print("recovered_msg min: {0}, max: {1}".format(tf.reduce_min(recovered_msg), tf.reduce_max(recovered_msg)))

print("BinaryCrossentropy: {0}".format(BinaryCrossentropy(from_logits=False)(message, recovered_msg)))
print("PSNR: {0}".format(tf.reduce_mean(tf.image.psnr(cover_image, stego_img, 1))))
print("SSIM: {0}".format(tf.reduce_mean(tf.image.ssim(cover_image, stego_img, 1))))