# Import used pakedge

In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, UpSampling2D, Activation, Lambda, Add, Flatten, Dense, LeakyReLU, Layer, Conv2DTranspose,Concatenate
)
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from tensorflow.keras import initializers
from tqdm import tqdm

# Set used paths

In [2]:
image_folder = "celeba_hq_256/train"
masks_folder = "celeba_hq_256/train_masks"
masked_images_folder = "celeba_hq_256/train_masked_images"

# Create directories if they don't exist
os.makedirs(masks_folder, exist_ok=True)
os.makedirs(masked_images_folder, exist_ok=True)

# Define the model functions

## Coarse Network

In [3]:
def build_coarse_network(input_shape=(256, 256, 4)):
    inputs = Input(shape=input_shape,name="input_layer")
    num_of_channels=32
    # Encoder
    x = Conv2D(num_of_channels, (5, 5), strides=1, padding='same',name=f"Encoder_conv2d_layer_1_C{num_of_channels}K5S1")(inputs)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*2, (3, 3), strides=2, padding='same',name=f"Encoder_conv2d_layer_2_C{num_of_channels*2}K3S2")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*2, (3, 3), strides=1, padding='same',name=f"Encoder_conv2d_layer_3_C{num_of_channels*2}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), strides=2, padding='same',name=f"Encoder_conv2d_layer_4_C{num_of_channels*4}K3S2")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), strides=1, padding='same',name=f"Encoder_conv2d_layer_5_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), strides=1, padding='same',name=f"Encoder_conv2d_layer_6_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)

    # Dilated Convolutions
    x = Conv2D(num_of_channels*4, (3, 3), dilation_rate=2, padding='same',name=f"Dilation_conv2d_layer_1_rate_2_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), dilation_rate=4, padding='same',name=f"Dilation_conv2d_layer_2_rate_4_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), dilation_rate=8, padding='same',name=f"Dilation_conv2d_layer_3_rate_8_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), dilation_rate=16, padding='same',name=f"Dilation_conv2d_layer_4_rate_16_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)

    # Decoder
    x = Conv2D(num_of_channels*4, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_1_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_2_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = UpSampling2D(size=(2, 2),name="Decoder_UpSampling2D_1")(x)
    x = Conv2D(num_of_channels*2, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_3_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*2, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_4_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = UpSampling2D(size=(2, 2),name="Decoder_UpSampling2D_2")(x)
    x = Conv2D(num_of_channels, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_5_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels//2, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_6_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(3, (3, 3), padding='same',name=f"Decoder_conv2d_layer_7_C{num_of_channels*4}K3S1")(x)
    outputs = Activation('tanh')(x)
    return Model(inputs, outputs,name="coarse_network")

## Refinement Network

In [4]:

def build_refinement_network(input_shape=(256, 256, 3)):
    inputs = Input(shape=input_shape,name="input_layer",batch_size=16)
    num_of_channels=32
    #Encoder
    x = Conv2D(num_of_channels, (5, 5), strides=1, padding='same',name=f"Dilation_Encoder_conv2d_layer_1_C{num_of_channels}K5S1")(inputs)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels, (3, 3), strides=2, padding='same',name=f"Dilation_Encoder_conv2d_layer_2_C{num_of_channels}K3S2")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*2, (3, 3), strides=1, padding='same',name=f"Dilation_Encoder_conv2d_layer_3_C{num_of_channels*2}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*2, (3, 3), strides=2, padding='same',name=f"Dilation_Encoder_conv2d_layer_4_C{num_of_channels*2}K3S2")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), strides=1, padding='same',name=f"Dilation_Encoder_conv2d_layer_5_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), strides=1, padding='same',name=f"Dilation_Encoder_conv2d_layer_6_C{num_of_channels*4}K3S1")(x)
    x_before_split = Activation('elu')(x)

    # dilation Pathway
    x = Conv2D(num_of_channels*4, (3, 3), dilation_rate=2, padding='same',name=f"Dilation_conv2d_layer_1_rate_2_C{num_of_channels*4}K3S1")(x_before_split)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), dilation_rate=4, padding='same',name=f"Dilation_conv2d_layer_2_rate_4_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), dilation_rate=8, padding='same',name=f"Dilation_conv2d_layer_3_rate_8_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), dilation_rate=16, padding='same',name=f"Dilation_conv2d_layer_4_rate_16_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x_dilation_path=x

    # Decoder
    x = Conv2D(num_of_channels*4, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_1_C{num_of_channels*4}K3S1")(x_dilation_path)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*4, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_2_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = UpSampling2D(size=(2, 2),name="Decoder_UpSampling2D_1")(x)
    x = Conv2D(num_of_channels*2, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_3_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels*2, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_4_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = UpSampling2D(size=(2, 2),name="Decoder_UpSampling2D_2")(x)
    x = Conv2D(num_of_channels, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_5_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(num_of_channels//2, (3, 3), strides=1, padding='same',name=f"Decoder_conv2d_layer_6_C{num_of_channels*4}K3S1")(x)
    x = Activation('elu')(x)
    x = Conv2D(3, (3, 3), padding='same',name=f"Decoder_conv2d_layer_7_C{num_of_channels*4}K3S1")(x)
    outputs = Activation('tanh')(x)

    return Model(inputs, outputs,name="refinement_network")

## Discriminator

In [5]:
class SpectralNormalization(Layer):
    def __init__(self, layer, power_iterations=1, **kwargs):
        """
        Spectral Normalization wrapper for a given layer.
        
        Args:
            layer: The layer to apply spectral normalization to (e.g., Conv2D, Dense).
            power_iterations: Number of power iterations to approximate the spectral norm.
        """
        super(SpectralNormalization, self).__init__(**kwargs)
        self.layer = layer
        self.power_iterations = power_iterations

    def build(self, input_shape):
        self.layer.build(input_shape)
        self.w = self.layer.get_weights()  # Get the kernel weights of the wrapped layer
        self.u = self.add_weight(
            shape=(1, self.w[0].shape[-1]),
            initializer=initializers.RandomNormal(0, 1),
            trainable=False,
            name="spectral_u"
        )
        super(SpectralNormalization, self).build(input_shape)

    def call(self, inputs, training=None):
        # Reshape the kernel weights to [num_weights, output_dim]
        w_reshaped = tf.reshape(self.w[0], [-1, self.w[0].shape[-1]])

        # Perform power iterations to approximate the largest singular value
        if len(self.u.shape)==1:
            u = tf.squeeze(self.u, axis=0)
        else:u=self.u
        for _ in range(self.power_iterations):
            try:
                v = tf.nn.l2_normalize(tf.linalg.matvec(w_reshaped, u))
                u = tf.nn.l2_normalize(tf.linalg.matvec(tf.transpose(w_reshaped), v))
            except:
                v = tf.nn.l2_normalize(tf.linalg.matvec(w_reshaped, u))
                u = tf.nn.l2_normalize(tf.linalg.matvec(tf.transpose(w_reshaped), v))

        # Update the stored u vector
        self.u=u

        # Compute the spectral norm (largest singular value)
        frist_matvec=tf.linalg.matvec(tf.transpose(w_reshaped), v)
        u_reshaped = tf.transpose(tf.reshape(u, (-1, 1)))
        sigma = tf.linalg.matvec(u_reshaped, frist_matvec)

        # Normalize the kernel weights
        self.w[0] = self.w[0] / sigma
        
        self.layer.set_weights(self.w)
        return self.layer(inputs)

    def compute_output_shape(self, input_shape):
        return self.layer.compute_output_shape(input_shape)

def build_sn_patch_gan_discriminator(input_shape):
    """
    Spectral Normalized PatchGAN Discriminator.
    """
    inputs = Input(shape=input_shape)
    cnum = 64
    x = SpectralNormalization(Conv2D(cnum, (5, 5), strides=2, padding='same'))(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = SpectralNormalization(Conv2D(cnum * 2, (5, 5), strides=2, padding='same'))(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = SpectralNormalization(Conv2D(cnum * 4, (5, 5), strides=2, padding='same'))(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = SpectralNormalization(Conv2D(cnum * 4, (5, 5), strides=2, padding='same'))(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Flatten()(x)
    outputs = Dense(1)(x)
    return Model(inputs, outputs, name="sn_patch_gan_discriminator")

## Loss function

In [6]:
def gan_hinge_loss(pos, neg):
    """
    Hinge loss for GAN training.
    """
    d_loss = tf.reduce_mean(tf.nn.relu(1.0 - pos)) + tf.reduce_mean(tf.nn.relu(1.0 + neg))
    g_loss = -tf.reduce_mean(neg)
    return g_loss, d_loss

## Visualize train result for each epoch

In [7]:
def visualize_results(losses, epoch, log_dir):
    """
    Save intermediate and final results for visualization.
    """
    os.makedirs(os.path.join(log_dir, "results"), exist_ok=True)
    for idx, (coarse, refined, complete) in enumerate(zip(losses["coarse_output"], losses["refined_output"], losses["batch_complete"])):
        coarse_img = ((coarse.numpy() + 1.0) * 127.5).astype(np.uint8)
        refined_img = ((refined.numpy() + 1.0) * 127.5).astype(np.uint8)
        complete_img = ((complete.numpy() + 1.0) * 127.5).astype(np.uint8)

        # Save images
        tf.keras.utils.save_img(os.path.join(log_dir, "results", f"epoch_{epoch + 1}_coarse_{idx + 1}.png"), coarse_img)
        tf.keras.utils.save_img(os.path.join(log_dir, "results", f"epoch_{epoch + 1}_refined_{idx + 1}.png"), refined_img)
        tf.keras.utils.save_img(os.path.join(log_dir, "results", f"epoch_{epoch + 1}_complete_{idx + 1}.png"), complete_img)


# Create Dataset

In [8]:
def load_image_and_mask(image_path, mask_path, masked_image_path):
    """
    Loads a single image, its mask, and the masked image.
    This function is wrapped by tf.numpy_function for compatibility.
    """
    def _load_numpy_mask(mask_path):
        mask = np.load(mask_path.decode("utf-8")).astype(np.float32)
        return mask

    # Load image
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)

    # Load mask (wrapped with tf.numpy_function)
    mask = tf.numpy_function(_load_numpy_mask, [mask_path], tf.float32)
    mask = tf.expand_dims(mask, axis=-1)  # Expand dimensions to [H, W, 1]

    # Load masked image
    masked_image = tf.io.read_file(masked_image_path)
    masked_image = tf.image.decode_jpeg(masked_image, channels=3)
    masked_image = tf.image.convert_image_dtype(masked_image, tf.float32)

    return image, mask, masked_image


def create_dataset(image_folder, masks_folder, masked_images_folder, batch_size=16):
    """
    Creates a TensorFlow dataset for training.
    """
    image_files = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith(".jpg")])
    mask_files = sorted([os.path.join(masks_folder, f) for f in os.listdir(masks_folder) if f.endswith(".npy")])
    masked_image_files = sorted([os.path.join(masked_images_folder, f) for f in os.listdir(masked_images_folder) if f.endswith(".jpg")])

    # Ensure all files are aligned
    assert len(image_files) == len(mask_files) == len(masked_image_files), "Mismatch in dataset sizes."

    # Create a dataset
    dataset = tf.data.Dataset.from_tensor_slices((image_files, mask_files, masked_image_files))

    # Load and preprocess data
    def preprocess(image_path, mask_path, masked_image_path):
        image, mask, masked_image = load_image_and_mask(image_path, mask_path, masked_image_path)
        return image, mask, masked_image

    dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.shuffle(buffer_size=1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset

# Train the model

## Training function

In [9]:
# Training function
def train(generator, discriminator, gen_optimizer, disc_optimizer, dataset, epochs, FLAGS, log_dir="logs", checkpoint_dir="checkpoints"):
    summary_writer = tf.summary.create_file_writer(log_dir)
    checkpoint = tf.train.Checkpoint(generator=generator, discriminator=discriminator,
                                      gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        epoch_gen_loss = 0
        epoch_disc_loss = 0
        epoch_ae_loss = 0

        for step, (real_images, masks, masked_images) in tqdm(enumerate(dataset), total=len(dataset)):
            # Run training step
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Normalize real images to [-1, 1]
                batch_pos = real_images / 127.5 - 1.0

                # Create incomplete images
                batch_incomplete = batch_pos * (1.0 - masks)

                # Input to coarse network
                coarse_inputs = tf.concat([batch_incomplete, tf.ones_like(batch_pos)[:, :, :, 0:1], masks], axis=-1)

                # Coarse network output
                coarse_output = generator[0](coarse_inputs, training=True)
                
                # Combine coarse output with incomplete input for refinement
                refined_inputs = coarse_output * masks + batch_incomplete * (1.0 - masks)

                # Refinement network output
                refined_output = generator[1](refined_inputs, training=True)

                # Complete inpainted image
                batch_complete = refined_output * masks + batch_incomplete * (1.0 - masks)

                # Discriminator forward pass
                batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
                pos_neg = discriminator(batch_pos_neg, training=True)
                pos, neg = tf.split(pos_neg, 2)
                # Losses
                ae_loss = FLAGS.l1_loss_alpha * (
                    tf.reduce_mean(tf.abs(batch_pos - coarse_output)) + tf.reduce_mean(tf.abs(batch_pos - refined_output))
                )
                g_loss, d_loss = gan_hinge_loss(pos, neg)
                g_loss = FLAGS.gan_loss_alpha * g_loss + ae_loss

            # Compute gradients
            gen_gradients = gen_tape.gradient(g_loss, generator[0].trainable_variables + generator[1].trainable_variables)
            disc_gradients = disc_tape.gradient(d_loss, discriminator.trainable_variables)
            # Apply gradients
            gen_optimizer.apply_gradients(zip(gen_gradients, generator[0].trainable_variables + generator[1].trainable_variables))
            disc_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))
            # Return losses and intermediate outputs for logging and debugging
    

            losses = {"g_loss": g_loss,"d_loss": d_loss,"ae_loss": ae_loss,"batch_complete": batch_complete,"coarse_output": coarse_output,"refined_output": refined_output,}

            # Aggregate losses
            epoch_gen_loss += losses["g_loss"].numpy()
            epoch_disc_loss += losses["d_loss"].numpy()
            epoch_ae_loss += losses["ae_loss"].numpy()

            # Log losses to TensorBoard
            with summary_writer.as_default():
                tf.summary.scalar("Generator Loss", losses["g_loss"], step=epoch * len(dataset) + step)
                tf.summary.scalar("Discriminator Loss", losses["d_loss"], step=epoch * len(dataset) + step)
                tf.summary.scalar("Reconstruction Loss (AE)", losses["ae_loss"], step=epoch * len(dataset) + step)

        # Print epoch losses
        print(f"Epoch {epoch + 1}: Generator Loss = {epoch_gen_loss / len(dataset):.4f}, "
              f"Discriminator Loss = {epoch_disc_loss / len(dataset):.4f}, "
              f"AE Loss = {epoch_ae_loss / len(dataset):.4f}")

        # Save model checkpoint
        checkpoint_manager.save()

        # Optional: Visualize some results
        visualize_results(losses, epoch, log_dir)



## Prameters

In [10]:
# Define flags (hyperparameters for training)
class FLAGS:
    l1_loss_alpha = 10.0  # Weight for L1 reconstruction loss
    gan_loss_alpha = 1.0  # Weight for GAN loss
    batch_size = 16       # Batch size
    epochs = 2           # Number of training epochs
    learning_rate = 1e-4  # Learning rate
    beta_1 = 0.5          # Adam optimizer beta1
    beta_2 = 0.9          # Adam optimizer beta2
    image_size = (256, 256)  # Image dimensions (height, width)
    log_dir = "logs"       # Directory for TensorBoard logs
    checkpoint_dir = "checkpoints"  # Directory for model checkpoints

# Define generator: coarse and refinement networks
generator = [
    build_coarse_network(input_shape=(FLAGS.image_size[0], FLAGS.image_size[1], 5)),  # Coarse network
    build_refinement_network(input_shape=(FLAGS.image_size[0], FLAGS.image_size[1], 3)),  # Refinement network
]

# Define discriminator
discriminator = build_sn_patch_gan_discriminator(input_shape=(FLAGS.image_size[0], FLAGS.image_size[1], 3))

# Define optimizers for generator and discriminator
gen_optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate, beta_1=FLAGS.beta_1, beta_2=FLAGS.beta_2)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate, beta_1=FLAGS.beta_1, beta_2=FLAGS.beta_2)

# Load dataset
train_dataset = create_dataset(image_folder, masks_folder, masked_images_folder, batch_size=FLAGS.batch_size)

# Prepare directories
os.makedirs(FLAGS.log_dir, exist_ok=True)
os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)




## Model Summary

In [11]:
generator[0].summary()
generator[1].summary()
discriminator.summary()

## Train the model

In [12]:
# Train the model
train(
    generator=generator,
    discriminator=discriminator,
    gen_optimizer=gen_optimizer,
    disc_optimizer=disc_optimizer,
    dataset=train_dataset,
    epochs=FLAGS.epochs,
    FLAGS=FLAGS,
    log_dir=FLAGS.log_dir,
    checkpoint_dir=FLAGS.checkpoint_dir,
)

# Test the model

## Test function

In [13]:
import random
import matplotlib.pyplot as plt

def test_generator(generator, dataset, log_dir,samples_per_batch=2,random_batches_num=None):
    # Initialize figure for plotting
    dataset_size = len(dataset)
    if random_batches_num is None:
        random_batches_num=dataset_size
    num_samples = random_batches_num * samples_per_batch
    random_batches_indices = random.sample(range(dataset_size), random_batches_num)
    fig, axs = plt.subplots(num_samples, 4, figsize=(12, num_samples * 3))
    plot_index = 0  # Initialize plotting index
    for step, (real_images, masks, masked_images) in tqdm(enumerate(dataset), total=len(dataset)):
        # Unpack the batch into (real_images, masks, masked_images)
        if (step in random_batches_indices):
            # Randomly select 2 indices from the first dimension (16)
            random_indices = random.sample(range(real_images.shape[0]), samples_per_batch)

            # Select the corresponding samples and stack them together
            used_real_images = tf.gather(real_images, random_indices, axis=0)
            used_masks = tf.gather(masks, random_indices, axis=0)
                # Create incomplete image
            batch_incomplete = used_real_images * (1.0 - used_masks)
            # Input to coarse network
            coarse_inputs = tf.concat([batch_incomplete, 
                                            tf.ones_like(used_real_images)[..., 0:1], 
                                            used_masks], axis=-1)

            # Coarse network output
            coarse_output = generator[0](coarse_inputs, training=False)
            
            # Combine coarse output with incomplete input for refinement
            refined_inputs = coarse_output * used_masks + batch_incomplete * (1.0 - used_masks)
            # Refinement network output
            refined_output = generator[1](refined_inputs, training=False)
            # Complete inpainted image
            batch_complete = refined_output * used_masks + batch_incomplete * (1.0 - used_masks)
            # Plotting
            for i,(real_image,masked_image,refined_input,batch_complete_element) in enumerate(zip(used_real_images, batch_incomplete,refined_inputs,batch_complete)):
                axs[plot_index*samples_per_batch+i, 0].imshow((real_image + 1.0) / 2.0)  # Denormalize to [0, 1]
                axs[plot_index*samples_per_batch+i, 0].set_title("Real Image")
                axs[plot_index*samples_per_batch+i, 0].axis("off")
                
                axs[plot_index*samples_per_batch+i, 1].imshow((masked_image + 1.0) / 2.0)  # Denormalize
                axs[plot_index*samples_per_batch+i, 1].set_title("Masked Image")
                axs[plot_index*samples_per_batch+i, 1].axis("off")
                
                axs[plot_index*samples_per_batch+i, 2].imshow((refined_input + 1.0) / 2.0)  # Denormalize
                axs[plot_index*samples_per_batch+i, 2].set_title("Refined Inputs")
                axs[plot_index*samples_per_batch+i, 2].axis("off")
                
                axs[plot_index*samples_per_batch+i, 3].imshow((batch_complete_element + 1.0) / 2.0)  # Denormalize
                axs[plot_index*samples_per_batch+i, 3].set_title("Batch Complete")
                axs[plot_index*samples_per_batch+i, 3].axis("off")
            
            
            plot_index += 1
    
    plt.tight_layout()
    plt.savefig(f"{log_dir}/test_generator_results.png")
    plt.show()


## Restore checkpoint function

In [14]:
import tensorflow as tf

def restore_and_test_generator(generator, discriminator, gen_optimizer, disc_optimizer, dataset, checkpoint_dir, log_dir, num_samples=20):
    # Create a checkpoint object
    checkpoint = tf.train.Checkpoint(generator=generator, discriminator=discriminator,
                                      gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)

    # Restore the latest checkpoint
    if checkpoint_manager.latest_checkpoint:
        checkpoint.restore(checkpoint_manager.latest_checkpoint)
        print(f"Checkpoint restored from {checkpoint_manager.latest_checkpoint}")
    else:
        print("No checkpoint found. Ensure the checkpoint directory is correct.")
        return

    # Test the generator with restored weights
    test_generator(generator, dataset, log_dir,samples_per_batch=2,random_batches_num=10)


## Test the model

In [None]:
restore_and_test_generator(generator, discriminator, gen_optimizer, disc_optimizer, train_dataset, FLAGS.checkpoint_dir, FLAGS.log_dir, num_samples=20)
