# **Import libraries and stuff**

In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, BatchNormalization, Activation
from tensorflow.keras.models import Model

2023-08-15 15:13:12.789016: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-15 15:13:12.814600: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Utils functions

In [2]:
def build_generator(input_shape):
    
    inputs = Input(shape=input_shape)
    
    # Encoder
    #each block is conv2d->batchnorm->relu
    enc_conv1 = Conv2D(64, kernel_size=7, strides=1, padding='same')(inputs)
    enc_norm1 = BatchNormalization()(enc_conv1)
    enc_relu1 = Activation('relu')(enc_norm1)
    
    # ... Add more convolutional layers for the encoder ...
    
    # Decoder
    dec_conv1 = Conv2DTranspose(64, kernel_size=7, strides=1, padding='same')(enc_relu1)
    dec_norm1 = BatchNormalization()(dec_conv1)
    dec_relu1 = Activation('relu')(dec_norm1)
    
    # ... Add more transpose convolutional layers for the decoder ...
    
    # Output layer
    outputs = Conv2D(3, kernel_size=7, strides=1, padding='same', activation='tanh')(dec_relu1)
    
    return Model(inputs, outputs)

In [None]:
def build_discriminator(input_shape):
    
    inputs = Input(shape=input_shape)
    
    # ... Build the discriminator architecture ...
    
    outputs = Conv2D(1, kernel_size=4, strides=1, padding='same', activation='sigmoid')(inputs)
    
    return Model(inputs, outputs)

In [3]:
def build_cycle_gan(input_shape):
    # Build the generators
    generator_A_to_B = build_generator(input_shape)
    generator_B_to_A = build_generator(input_shape)
    
    # Build the discriminators
    discriminator_A = build_discriminator(input_shape)
    discriminator_B = build_discriminator(input_shape)
    
    # Loss function
    mae_loss = MeanAbsoluteError()
    
    # Compile the discriminators
    discriminator_A.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5), loss='binary_crossentropy')
    discriminator_B.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5), loss='binary_crossentropy')
    
    # Discriminator models are not trained during cycleGAN training
    discriminator_A.trainable = False
    discriminator_B.trainable = False
    
    # Inputs for CycleGAN
    real_A = Input(shape=input_shape)
    real_B = Input(shape=input_shape)
    
    # Generators
    fake_B = generator_A_to_B(real_A)
    fake_A = generator_B_to_A(real_B)
    
    # Reconstructed images
    reconstructed_A = generator_B_to_A(fake_B)
    reconstructed_B = generator_A_to_B(fake_A)
    
    # Identity mapping
    id_A = generator_B_to_A(real_A)
    id_B = generator_A_to_B(real_B)
    
    # Adversarial loss
    valid = tf.ones_like(discriminator_A(fake_A))
    fake = tf.zeros_like(discriminator_A(real_A))
    
    adversarial_loss = mae_loss(valid, discriminator_A(fake_A)) + mae_loss(fake, discriminator_A(real_A))
    
    # Cycle consistency loss
    cycle_loss = mae_loss(real_A, reconstructed_A) + mae_loss(real_B, reconstructed_B)
    
    # Identity loss
    id_loss = mae_loss(real_A, id_A) + mae_loss(real_B, id_B)
    
    # Total generator loss
    total_gen_loss = adversarial_loss + (10 * cycle_loss) + (5 * id_loss)
    
    cycle_gan_model = Model(inputs=[real_A, real_B], outputs=[fake_A, fake_B, reconstructed_A, reconstructed_B])
    cycle_gan_model.add_loss(total_gen_loss)
    
    cycle_gan_model.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
    
    return cycle_gan_model

In [None]:
# Define the input shape (e.g., 256x256 with 3 channels for RGB images)
input_shape = (256, 256, 3)

# Build the CycleGAN model
cycle_gan = build_cycle_gan(input_shape)

# Display the summary of the model
cycle_gan.summary()