In [None]:
# models_gan.py

import tensorflow as tf
from tensorflow.keras import layers, Model, Input


# The number of output channels for the generator is 3 (for RGB)
OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_batchnorm=True):
    """Creates a downsampling block used in the Generator's encoder."""
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(
        layers.Conv2D(filters, size, strides=2, padding='same',
                      kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        result.add(layers.BatchNormalization())
    result.add(layers.LeakyReLU())
    return result

def upsample(filters, size, apply_dropout=False):
    """Creates an upsampling block used in the Generator's decoder."""
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(
        layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                               kernel_initializer=initializer, use_bias=False))
    result.add(layers.BatchNormalization())
    if apply_dropout:
        result.add(layers.Dropout(0.5))
    result.add(layers.ReLU())
    return result

def Generator(input_shape=(TILE_SIZE, TILE_SIZE, 3)):
    """
    Builds the Generator model based on a modified U-Net architecture.
    It takes a 3-channel label map and outputs a 3-channel RGB image.
    """
    # The input is a label map, but it's represented as a 3-channel image
    inputs = Input(shape=input_shape)

    # Encoder (Downsampling path)
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # (bs, 256, 256, 64)
        downsample(128, 4),  # (bs, 128, 128, 128)
        downsample(256, 4),  # (bs, 64, 64, 256)
        downsample(512, 4),  # (bs, 32, 32, 512)
        downsample(512, 4),  # (bs, 16, 16, 512)
        downsample(512, 4),  # (bs, 8, 8, 512)
        downsample(512, 4),  # (bs, 4, 4, 512)
        downsample(512, 4),  # (bs, 2, 2, 512)
    ]

    # Decoder (Upsampling path)
    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
        upsample(512, 4, apply_dropout=True),  # (bs, 16, 16, 1024)
        upsample(512, 4),  # (bs, 32, 32, 1024)
        upsample(256, 4),  # (bs, 64, 64, 512)
        upsample(128, 4),  # (bs, 128, 128, 256)
        upsample(64, 4),   # (bs, 256, 256, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh')  # tanh activation outputs in [-1, 1] range

    x = inputs
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return Model(inputs=inputs, outputs=x)


def Discriminator(input_shape=(TILE_SIZE, TILE_SIZE, 3), target_shape=(TILE_SIZE, TILE_SIZE, 3)):
    """
    Builds the PatchGAN Discriminator model.
    It takes both the input label map and the real/fake image as input.
    """
    initializer = tf.random_normal_initializer(0., 0.02)

    # The discriminator receives two inputs
    inp = Input(shape=input_shape, name='input_image')
    tar = Input(shape=target_shape, name='target_image')

    x = layers.concatenate([inp, tar])  # (bs, 512, 512, channels*2)

    down1 = downsample(64, 4, False)(x)  # (bs, 256, 256, 64)
    down2 = downsample(128, 4)(down1)  # (bs, 128, 128, 128)
    down3 = downsample(256, 4)(down2)  # (bs, 64, 64, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3)  # (bs, 66, 66, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1)  # (bs, 63, 63, 512)

    batchnorm1 = layers.BatchNormalization()(conv)
    leaky_relu = layers.LeakyReLU()(batchnorm1)
    zero_pad2 = layers.ZeroPadding2D()(leaky_relu)  # (bs, 65, 65, 512)

    # This is the final output patch
    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2)  # (bs, 62, 62, 1)

    return Model(inputs=[inp, tar], outputs=last)

# --- Example Usage (for testing the script directly) ---
if __name__ == '__main__':
    # This block will only run if you execute this script directly
    if 'TILE_SIZE' not in globals():
        TILE_SIZE = 512

    print("--- Building Generator ---")
    generator = Generator()
    generator.summary()

    print("\n" + "="*40 + "\n")

    print("--- Building Discriminator ---")
    discriminator = Discriminator()
    discriminator.summary()