## Descrim Model

In [61]:
import tensorflow as tf
from keras.models import Sequential, Model
from keras.layers import Conv2D, LeakyReLU, Input 
from keras.layers import Layer
from keras.optimizers import Adam
from keras.initializers import RandomNormal
import keras

In [62]:
weight_initializer = RandomNormal(stddev=0.02)

In [63]:

class InstanceNormalization(Layer):
    def __init__(self, epsilon=1e-5):
        super(InstanceNormalization, self).__init__()
        self.epsilon = epsilon

    def build(self, input_shape):
        self.scale = self.add_weight(
            name='scale',
            shape=input_shape[-1:],
            initializer='ones',
            trainable=True)
        self.offset = self.add_weight(
            name='offset',
            shape=input_shape[-1:],
            initializer='zeros',
            trainable=True)

    def call(self, x):
        mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
        inv = tf.math.rsqrt(variance + self.epsilon)
        normalized = (x - mean) * inv
        return self.scale * normalized + self.offset

In [64]:
def discriminator_block(x, filters, kernel_size=4, strides=2, padding='same'):
    """Single block of the discriminator"""
    x = Conv2D(
        filters=filters,
        kernel_size=kernel_size,
        strides=strides,
        padding=padding,
        kernel_initializer=weight_initializer
    )(x)
    x = LeakyReLU(0.2)(x)
    return x

In [65]:
def build_discriminator(input_shape=(256, 256, 3)):
    inputs = Input(shape=input_shape)
    
    # First layer doesn't use instance normalization
    x = discriminator_block(inputs, 64, strides=1)
    
    # Downsampling layers with instance normalization
    x = discriminator_block(x, 128)
    x = InstanceNormalization()(x)
    
    x = discriminator_block(x, 256) 
    x = InstanceNormalization()(x)
    
    x = discriminator_block(x, 512)
    x = InstanceNormalization()(x)
    
    # Final layer
    x = Conv2D(
        filters=1,
        kernel_size=4,
        strides=1,
        padding='same',
        kernel_initializer=weight_initializer
    )(x)
    
    return Model(inputs, x, name='discriminator')

In [66]:
#Make the dataset
from cycleganstyletransfer.config import DATA_DIR
data_dir = DATA_DIR / "raw"


BATCH_SIZE = 1
IMG_HEIGHT = 256
IMG_WIDTH = 256


monet_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir / "Monet",
    validation_split=0.2,
    subset="training",
    seed=42,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    label_mode=None
)

images_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir / "Images" ,
    validation_split=0.2,
    subset="training",
    seed=42,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    label_mode=None
)


Found 1193 files belonging to 1 classes.
Using 955 files for training.
Found 7037 files belonging to 1 classes.
Using 5630 files for training.


In [67]:
epoch_length = 10#max(len(monet_ds), len(images_ds))


In [71]:
def disciminator_loss(real_monet, image):
    real_loss = tf.reduce_mean(tf.math.squared_difference(real_monet, tf.ones_like(real_monet)))
    fake_loss = tf.reduce_mean(tf.math.squared_difference(image, tf.zeros_like(image)))
    total_loss = 0.5 * (real_loss + fake_loss)
    return total_loss


def domain_discrimination_loss(domain1_output, domain2_output):
    """Try to label domain1 as 1 and domain2 as 0"""
    real_loss = tf.reduce_mean(tf.math.squared_difference(domain1_output, tf.ones_like(domain1_output)))
    fake_loss = tf.reduce_mean(tf.math.squared_difference(domain2_output, tf.zeros_like(domain2_output)))
    return 0.5 * (real_loss + fake_loss)


In [72]:
def preprocess_image(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

In [75]:
NUM_EPOCHS = 3

my_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

training_monet_ds = iter(monet_ds.shuffle(1000).repeat())
training_images_ds = iter(images_ds.shuffle(100).repeat())

my_descrim = build_discriminator()


# Calculate accuracy on test batches
test_monet = next(training_monet_ds) 
test_images = next(training_images_ds)

monet_preds = my_descrim(test_monet)
images_preds = my_descrim(test_images)

# Calculate accuracy (% correct classifications)
monet_accuracy = tf.reduce_mean(tf.cast(monet_preds > 0.5, tf.float32))
images_accuracy = tf.reduce_mean(tf.cast(images_preds < 0.5, tf.float32))
total_accuracy = (monet_accuracy + images_accuracy) / 2

print(f"Monet Classification Accuracy: {monet_accuracy:.2%}")
print(f"Photo Classification Accuracy: {images_accuracy:.2%}") 
print(f"Total Accuracy: {total_accuracy:.2%}\n")


for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for iteration in range(epoch_length):
        #if iteration % 10 == 0:
        print(f"\rProgress: {(iteration+1)/epoch_length*100:.1f}%", end="")
            
        monet_images = preprocess_image(next(training_monet_ds))
        photo_images = preprocess_image(next(training_images_ds))
        

        with tf.GradientTape() as tape:
            monet_output = my_descrim(monet_images)
            photo_output = my_descrim(photo_images)

            loss = domain_discrimination_loss(monet_output, photo_output)

        grads = tape.gradient(loss, my_descrim.trainable_variables)


        print(f"\nWeight before:", my_descrim.layers[1].weights[0][0,0,0,0].numpy())

        # apply gradients
        my_optimizer.apply_gradients(zip(grads, my_descrim.trainable_variables))

        print(f"\nWeight after: ", my_descrim.layers[1].weights[0][0,0,0,0].numpy())           

        # Print loss and accuracy metrics after each epoch
        if iteration == epoch_length - 1:
            print(f"\nEpoch {epoch+1} Loss: {loss:.4f}")

            # Calculate accuracy on test batches
            test_monet = next(training_monet_ds) 
            test_images = next(training_images_ds)

            monet_preds = my_descrim(test_monet)
            photo_preds = my_descrim(test_images)

            monet_score = tf.reduce_mean(monet_preds)
            photo_score = tf.reduce_mean(photo_preds)

            print(f"Avg Monet Score: {monet_score:.3f} (should be ~1)")
            print(f"Avg Photo Score: {photo_score:.3f} (should be ~0)")


Monet Classification Accuracy: 37.40%
Photo Classification Accuracy: 64.36%
Total Accuracy: 50.88%

Epoch 1/3
Progress: 10.0%
Weight before: 0.018524071

Weight after:  0.018724069
Progress: 20.0%
Weight before: 0.018724069

Weight after:  0.01892801
Progress: 30.0%
Weight before: 0.01892801

Weight after:  0.019147845
Progress: 40.0%
Weight before: 0.019147845

Weight after:  0.019317245
Progress: 50.0%
Weight before: 0.019317245

Weight after:  0.01941621
Progress: 60.0%
Weight before: 0.01941621

Weight after:  0.019544443
Progress: 70.0%
Weight before: 0.019544443

Weight after:  0.019574102
Progress: 80.0%
Weight before: 0.019574102

Weight after:  0.019536678
Progress: 90.0%
Weight before: 0.019536678

Weight after:  0.019570908
Progress: 100.0%
Weight before: 0.019570908

Weight after:  0.019571759

Epoch 1 Loss: 1.8651
Avg Monet Score: 0.002 (should be ~1)
Avg Photo Score: 0.011 (should be ~0)
Epoch 2/3
Progress: 10.0%
Weight before: 0.019571759

Weight after:  0.01957579
Progr