# SteganoGAN in Keras
This notebook contains code attempting to reimplement SteganoGAN in Keras, for the purpose of better understanding (and scrutinizing) it.

*Based on https://github.com/DAI-Lab/SteganoGAN/tree/master/steganogan*

### Modules

In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import sys
sys.path.append("../..")
sys.path.append("../../..")

import tensorflow as tf
from keras.optimizers import Adam
from keras.losses import BinaryCrossentropy
from tensorflow.keras.callbacks import CSVLogger
from callbacks import Checkpoint, SaveImages

from models import (
  steganogan_encoder_basic_model,
  steganogan_decoder_basic_model,
  steganogan_critic_model
)

from dataset_utils import normalize_img, create_message_dataset
from keras_steganogan import KerasSteganoGAN

### Constants

In [2]:
# Image dimensions
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
IMAGE_CHANNELS = 3

IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)
MESSAGE_DEPTH = 3
BATCH_SIZE = 4
MODEL_PATH = f'../../pretrained_models/{MESSAGE_DEPTH}/steganoGAN_basic.keras'
LOGS_PATH = f'../../logs/basic/{MESSAGE_DEPTH}/steganoGAN_basic.csv'
CALLBACK_IMAGES_PATH = '../../../images/callback'
CALLBACK_IMAGES_OUTPUT_PATH = f'../../epoch_images/basic/{MESSAGE_DEPTH}'

----

### Build model for future train

In [3]:
encoder = steganogan_encoder_basic_model(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS, MESSAGE_DEPTH)
decoder = steganogan_decoder_basic_model(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS, MESSAGE_DEPTH)
critic = steganogan_critic_model(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)

steganoGAN = KerasSteganoGAN(
  encoder=encoder,
  decoder=decoder,
  critic=critic,
  image_shape=IMAGE_SHAPE,
  data_depth=MESSAGE_DEPTH,
  model_path=MODEL_PATH
)

steganoGAN.compile(
  encoder_optimizer = Adam(learning_rate=1e-4, beta_1=0.5),
  decoder_optimizer = Adam(learning_rate=1e-4, beta_1=0.5),
  critic_optimizer = Adam(learning_rate=1e-4, beta_1=0.5),
  loss_fn = BinaryCrossentropy(from_logits=False)
)

#steganoGAN.models_summary()
#steganoGAN.summary()

### Download div2k dataset and complete it with random message dataset of {0, 1}

In [None]:
train_dir = '/Users/dmitryhoma/Projects/phd_dissertation/state_2/SteganoGAN/research/data/div2k/train'
val_dir = '/Users/dmitryhoma/Projects/phd_dissertation/state_2/SteganoGAN/research/data/div2k/val'

train_image_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir, 
    label_mode=None, 
    color_mode='rgb',
    batch_size=BATCH_SIZE,
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    shuffle=True
)

val_image_ds = tf.keras.preprocessing.image_dataset_from_directory(
    val_dir, 
    label_mode=None, 
    color_mode='rgb',
    batch_size=BATCH_SIZE,
    seed=123,
    image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
    shuffle=True
)

train_image_ds = train_image_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
val_image_ds = val_image_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)

train_message_ds = create_message_dataset(BATCH_SIZE, len(train_image_ds), IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)
val_message_ds = create_message_dataset(BATCH_SIZE, len(val_image_ds), IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)

train_ds = tf.data.Dataset.zip((train_image_ds, train_message_ds)).prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.zip((val_image_ds, val_message_ds)).prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
steganoGAN.build([(1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), (1, IMAGE_HEIGHT, IMAGE_WIDTH, MESSAGE_DEPTH)])
steganoGAN.fit(train_ds, epochs=40, validation_data=val_ds, callbacks=[
  SaveImages(MESSAGE_DEPTH, IMAGE_SHAPE, CALLBACK_IMAGES_PATH, CALLBACK_IMAGES_OUTPUT_PATH),
  Checkpoint(MODEL_PATH),
  CSVLogger(LOGS_PATH)
])
