## Get a simple data set object together

In [4]:
import tensorflow as tf
from keras.layers import Conv2D, LeakyReLU, Input, Dense, Flatten
from keras.models import Model

In [None]:
#Make the dataset
from cycleganstyletransfer.config import DATA_DIR
data_dir = DATA_DIR / "raw"


my_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=42,
    image_size=(256, 256),
    batch_size=1,
)

print(my_ds.class_names)

Found 8230 files belonging to 2 classes.
Using 6584 files for training.
['Images', 'Monet']


In [3]:
class MyDataObject:
    def __init__(self, monet_data, photo_data):
        self.monet_data = iter(monet_data.shuffle(1000).repeat())
        self.photo_data = iter(photo_data.shuffle(1000).repeat())

    def __len__(self):
        return max(len(self.monet_data), len(self.photo_data))

    def get_new(self):
        monet_data = next(self.monet_data)
        photo_data = next(self.photo_data)
        return monet_data, photo_data


In [4]:
my_data = MyDataObject(monet_ds, images_ds)


## Put the model together

In [5]:
def discriminator_block(x, filters, kernel_size=4, strides=2, padding='same'):
    """Single block of the discriminator"""
    x = Conv2D(
        filters=filters,
        kernel_size=kernel_size,
        strides=strides,
        padding=padding
    )(x)
    x = LeakyReLU(0.2)(x)
    return x

In [11]:
def build_discriminator(input_shape=(256, 256, 3)):
    inputs = Input(shape=input_shape)
    
    # First layer doesn't use instance normalization
    x = discriminator_block(inputs, 64, strides=1)
    print(x.shape)
    
    # Downsampling layers with instance normalization
    x = discriminator_block(x, 128)
    print(x.shape)
    
    x = discriminator_block(x, 256) 
    print(x.shape)
    
    x = discriminator_block(x, 512)
    print(x.shape)
    
    # Final layer
    x = Conv2D(
        filters=1,
        kernel_size=4,
        strides=1,
        padding='same',
    )(x)

    print(x.shape)

    x = tf.keras.layers.Flatten()(x)
    x = Dense(1, activation='sigmoid')(x)
    
    return Model(inputs, x, name='discriminator')

In [12]:
my_descrim = build_discriminator()

(None, 256, 256, 64)
(None, 128, 128, 128)
(None, 64, 64, 256)
(None, 32, 32, 512)
(None, 32, 32, 1)


In [15]:
m_photo, i_photo = my_data.get_new()

test_output = my_descrim(m_photo)
print(test_output.shape)
print(test_output)

(1, 1)
tf.Tensor([[0.97778535]], shape=(1, 1), dtype=float32)


## Get a simple training loop together

In [22]:
from keras.losses import BinaryCrossentropy

my_loss = BinaryCrossentropy(from_logits=False)

def discrim_loss(monet_image_output, photo_image_output):

    total_loss = 0.5 * (my_loss(monet_image_output, tf.ones_like(monet_image_output)) + my_loss(photo_image_output, tf.zeros_like(photo_image_output)))
    return total_loss   



In [23]:
from keras.optimizers import Adam

my_optimizer = Adam()


In [26]:
EPOCHS = 1
EPOCH_LENGTH = 10

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    for iteration in range(EPOCH_LENGTH):
        m_photo, i_photo = my_data.get_new()

        with tf.GradientTape() as tape:
            monet_output = my_descrim(m_photo)
            photo_output = my_descrim(i_photo)
            print(monet_output)
            print(photo_output)

            loss = discrim_loss(monet_output, photo_output)

        grads = tape.gradient(loss, my_descrim.trainable_variables)
        my_optimizer.apply_gradients(zip(grads, my_descrim.trainable_variables))
    
        # Calculate accuracy
        monet_accuracy = tf.reduce_mean(tf.cast(monet_output > 0.5, tf.float32))
        photo_accuracy = tf.reduce_mean(tf.cast(photo_output < 0.5, tf.float32))
        total_accuracy = 0.5 * (monet_accuracy + photo_accuracy)

        print(f"Iteration {iteration+1}/{EPOCH_LENGTH}")
        print(f"Loss: {loss:.4f}")
        print(f"Accuracy on Monet images: {monet_accuracy:.2%}")
        print(f"Accuracy on Photo images: {photo_accuracy:.2%}")
        print(f"Total accuracy: {total_accuracy:.2%}\n")

#print(my_descrim.summary())

Epoch 1/1
tf.Tensor([[1.]], shape=(1, 1), dtype=float32)
tf.Tensor([[1.]], shape=(1, 1), dtype=float32)
Iteration 1/10
Loss: 7.7125
Accuracy on Monet images: 100.00%
Accuracy on Photo images: 0.00%
Total accuracy: 50.00%

tf.Tensor([[1.]], shape=(1, 1), dtype=float32)
tf.Tensor([[1.]], shape=(1, 1), dtype=float32)
Iteration 2/10
Loss: 7.7125
Accuracy on Monet images: 100.00%
Accuracy on Photo images: 0.00%
Total accuracy: 50.00%

tf.Tensor([[1.]], shape=(1, 1), dtype=float32)
tf.Tensor([[1.]], shape=(1, 1), dtype=float32)
Iteration 3/10
Loss: 7.7125
Accuracy on Monet images: 100.00%
Accuracy on Photo images: 0.00%
Total accuracy: 50.00%

tf.Tensor([[1.]], shape=(1, 1), dtype=float32)
tf.Tensor([[1.]], shape=(1, 1), dtype=float32)
Iteration 4/10
Loss: 7.7125
Accuracy on Monet images: 100.00%
Accuracy on Photo images: 0.00%
Total accuracy: 50.00%

tf.Tensor([[1.]], shape=(1, 1), dtype=float32)
tf.Tensor([[1.]], shape=(1, 1), dtype=float32)
Iteration 5/10
Loss: 7.7125
Accuracy on Monet im