### Test on arbitary dataset

We have saved the weights for the generator, discriminator, and repair_generator models. We can now test our model on any arbitrary test set to see how well it repairs damaged images. The input images should be in a 256x256 format. Additionally, the weights for the code regarding 8000 epochs are available on the Google Drive link.

In [68]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, BatchNormalization, Activation, LeakyReLU, Concatenate
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MeanSquaredError
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras import backend as K
import os
from PIL import UnidentifiedImageError
import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input
from tensorflow.keras.models import Model as KerasModel
from tensorflow.keras.preprocessing import image
from scipy.linalg import sqrtm
from skimage.transform import resize

# Set the seed for reproducibility
np.random.seed(1000)
# Define image parameters
img_rows, img_cols, channels = 128, 128, 3
img_shape = (img_rows, img_cols, channels)
z_dim = 100  # Size of the noise vector
epochs = 8000
batch_size = 64
save_interval = 100  # Set the interval to save generated images



def build_repair_generator(z_dim=100):
    undamaged_input = Input(shape=img_shape, name='undamaged_input')
    damaged_input = Input(shape=img_shape, name='damaged_input')

    # Encoder part
    encoder = Sequential([
        Flatten(input_shape=img_shape),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
    ])

    undamaged_encoded = encoder(undamaged_input)
    damaged_encoded = encoder(damaged_input)

    # Concatenate encoded representations
    merged = Concatenate()([undamaged_encoded, damaged_encoded])

    # Decoder part
    decoder = Sequential([
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(1024),
        LeakyReLU(alpha=0.2),
        Dense(np.prod(img_shape), activation='tanh'),
        Reshape(img_shape),
    ])

    repaired_output = decoder(merged)

    model = Model(inputs=[undamaged_input, damaged_input], outputs=repaired_output)
    model.summary()

    return model
# Generator model
def build_generator(z_dim=100):
    model = Sequential()
    model.add(Dense(256, input_dim=z_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))
    model.summary()

    noise = Input(shape=(z_dim,))
    img = model(noise)

    return Model(noise, img)



# Discriminator model
def build_discriminator(img_shape):
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(1024))  # Increased capacity
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))   # Increased capacity
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)


# Function to normalize images to the range [-1, 1]
def normalize_images(images):
    return 2 * images - 1

def load_and_resize_dataset(dataset_path, target_size=(128, 128), pixel_difference_threshold=50):
    damaged_images = []
    for filename in os.listdir(os.path.join(dataset_path, "damaged")):
        try:
            img = load_img(os.path.join(dataset_path, "damaged", filename), target_size=target_size)
            img = img_to_array(img) / 128.0
            damaged_images.append(img)
        except UnidentifiedImageError:
            print(f"Skipping non-image file: {filename}")
 
    undamaged_images = []
    for filename in os.listdir(os.path.join(dataset_path, "undamaged")):
        try:
            img = load_img(os.path.join(dataset_path, "undamaged", filename), target_size=target_size)
            img = img_to_array(img) / 128.0
            undamaged_images.append(img)
        except UnidentifiedImageError:
            print(f"Skipping non-image file: {filename}")
 
    min_samples = min(len(damaged_images), len(undamaged_images))
    damaged_images = damaged_images[:min_samples]
    undamaged_images = undamaged_images[:min_samples]
 
    damaged_images = np.array(damaged_images)
    undamaged_images = np.array(undamaged_images)
 
    # Identify damaged regions based on pixel differences
    damaged_regions = np.sum(np.abs(undamaged_images - damaged_images), axis=-1) > pixel_difference_threshold
 
    return damaged_images, undamaged_images, damaged_regions

# Function to save generated images
def save_generated_images(epoch, generator, output_path, X_test_undamaged, X_test_damaged, z_dim=100, r=5, c=3):
    # Resize images to match the repair generator input shape
    X_test_undamaged_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in X_test_undamaged])
    X_test_damaged_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in X_test_damaged])
 
    # Generate repaired images for damaged parts
    repaired_imgs = generator.predict([X_test_undamaged_resized, X_test_damaged_resized])
 
    # Identify damaged regions based on pixel differences
    damaged_regions = np.sum(np.abs(X_test_undamaged_resized - X_test_damaged_resized), axis=-1) > 50
 
    # Copy repaired parts to damaged images
    for i in range(len(X_test_undamaged_resized)):
        X_test_damaged_resized[i][damaged_regions[i]] = repaired_imgs[i][damaged_regions[i]]
 
    # Rescale images to 0-1
    X_test_undamaged_rescaled = 0.5 * (X_test_undamaged_resized + 1)
    X_test_damaged_rescaled = 0.5 * (X_test_damaged_resized + 1)
 
    fig, axs = plt.subplots(r, c, figsize=(15, 15))
    cnt = 0
    for i in range(r):
        axs[i, 0].imshow(X_test_undamaged_rescaled[cnt])
        axs[i, 0].axis('off')
        axs[i, 0].set_title("Real (Undamaged)")
 
        axs[i, 1].imshow(X_test_damaged_rescaled[cnt])
        axs[i, 1].axis('off')
        axs[i, 1].set_title("Damaged")
 
        axs[i, 2].imshow(repaired_imgs[cnt])
        axs[i, 2].axis('off')
        axs[i, 2].set_title("Repaired")
 
        cnt += 1
 
    fig.savefig(output_path + f"/gan_repaired_image_epoch_{epoch}.png")
    plt.close()

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))


# Build and compile the generator
generator = build_generator()
generator_optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
generator.compile(loss='mse', optimizer=generator_optimizer)


# Build and compile the discriminator
discriminator = build_discriminator(img_shape)
discriminator_optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
discriminator.compile(loss='binary_crossentropy', optimizer=discriminator_optimizer, metrics=['accuracy'])


# The generator takes noise as input and generates images
z = Input(shape=(z_dim,))
img = generator(z)

# For the combined model, only train the generator
discriminator.trainable = False

# The discriminator takes generated images as input and determines validity
validity = discriminator(img)

# Build and compile the combined model
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer=generator_optimizer)


# Build and compile the repair generator
repair_generator = build_repair_generator()
repair_generator_optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
repair_generator.compile(loss='mse', optimizer=repair_generator_optimizer)


# Test on new damaged images
test_folder_path = r"D:\SagharGhaffari\new dataset2\new dataset\mine"
output_path = r"D:\SagharGhaffari\new dataset2\new dataset\arbitaryoutput"




# Load and preprocess the test dataset
X_test_damaged_testset, X_test_undamaged_testset, damaged_regions_testset = load_and_resize_dataset(test_folder_path, target_size=(128, 128))
 
# Combine the test sets for evaluation
X_test_testset = np.concatenate([X_test_damaged_testset, X_test_undamaged_testset])

# Resize the test images to the expected input shape of the repair generator
X_test_undamaged_testset_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in X_test_undamaged_testset])
X_test_damaged_testset_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in X_test_damaged_testset])

weights_directory = r"D:\SagharGhaffari\ANN_Project"





# Load the weights into the models with the full path
generator.load_weights(os.path.join(weights_directory, "generator_weights_epoch_8000.h5"))
discriminator.load_weights(os.path.join(weights_directory, "discriminator_weights_epoch_8000.h5"))
repair_generator.load_weights(os.path.join(weights_directory, "repair_generator_weights_epoch_8000.h5"))

# Sample noise and generate images using the generator
num_samples = 50  # Adjust as needed
noise = np.random.normal(0, 1, (num_samples, 100))
generated_images = generator.predict(noise)

# Evaluate the generated images using the discriminator
discriminator_loss = discriminator.evaluate(generated_images, np.zeros((num_samples, 1)))
print(f"Discriminator Loss on Generated Images: {discriminator_loss}")


# Test the repair generator on your test set
test_loss = repair_generator.evaluate([X_test_undamaged_testset_resized, X_test_damaged_testset_resized], X_test_undamaged_testset_resized)
print(f"Repair Generator Test Loss: {test_loss}")

# Visualize and save the generated images
save_generated_images(8000, repair_generator, output_path, X_test_undamaged_testset, X_test_damaged_testset)


Model: "sequential_50"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_147 (Dense)           (None, 256)               25856     
                                                                 
 leaky_re_lu_110 (LeakyReLU  (None, 256)               0         
 )                                                               
                                                                 
 batch_normalization_42 (Ba  (None, 256)               1024      
 tchNormalization)                                               
                                                                 
 dense_148 (Dense)           (None, 512)               131584    
                                                                 
 leaky_re_lu_111 (LeakyReLU  (None, 512)               0         
 )                                                               
                                                     

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i