### GAN

This section imports necessary libraries and modules for the implementation. Notable libraries include NumPy for numerical operations, Matplotlib for plotting, TensorFlow and Keras for deep learning, and others for image processing.

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, BatchNormalization, Activation, LeakyReLU, Concatenate
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MeanSquaredError
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras import backend as K
import os
from PIL import UnidentifiedImageError
import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input
from tensorflow.keras.models import Model as KerasModel
from tensorflow.keras.preprocessing import image
from scipy.linalg import sqrtm
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.transform import resize


Setting a random seed ensures that the random initialization of weights and other random processes in the code is reproducible.

In [None]:
# Set the seed for reproducibility
np.random.seed(1000)

Define Image Parameters:

These parameters define the size and shape of the images, the size of the noise vector (z_dim), the number of training epochs, batch size, and the interval at which to save generated images during training.

In [5]:
# Define image parameters
img_rows, img_cols, channels = 128, 128, 3
img_shape = (img_rows, img_cols, channels)
z_dim = 100  # Size of the noise vector
epochs = 1000
batch_size = 128
save_interval = 100  # Set the interval to save generated images

Function to Save Generated Images:

This function takes the current epoch, generator model, paths to undamaged and damaged test images, and saves a set of generated images for visualization.

In [None]:

# Function to save generated images
def save_generated_images(epoch, generator, output_path, X_test_undamaged, X_test_damaged, z_dim=100, r=5, c=3):
    # Resize images to match the repair generator input shape
    X_test_undamaged_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in X_test_undamaged])
    X_test_damaged_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in X_test_damaged])

    # Generate repaired images for damaged parts
    repaired_imgs = generator.predict([X_test_undamaged_resized, X_test_damaged_resized])

    # Identify damaged regions based on pixel differences
    damaged_regions = np.sum(np.abs(X_test_undamaged_resized - X_test_damaged_resized), axis=-1) > 50

    # Copy repaired parts to damaged images
    for i in range(len(X_test_undamaged_resized)):
        X_test_damaged_resized[i][damaged_regions[i]] = repaired_imgs[i][damaged_regions[i]]

    # Rescale images to 0-1
    X_test_undamaged_rescaled = 0.5 * (X_test_undamaged_resized + 1)
    X_test_damaged_rescaled = 0.5 * (X_test_damaged_resized + 1)

    fig, axs = plt.subplots(r, c, figsize=(15, 15))
    cnt = 0
    for i in range(r):
        axs[i, 0].imshow(X_test_undamaged_rescaled[cnt])
        axs[i, 0].axis('off')
        axs[i, 0].set_title("Real (Undamaged)")

        axs[i, 1].imshow(X_test_damaged_rescaled[cnt])
        axs[i, 1].axis('off')
        axs[i, 1].set_title("Damaged")

        axs[i, 2].imshow(repaired_imgs[cnt])
        axs[i, 2].axis('off')
        axs[i, 2].set_title("Repaired")

        cnt += 1

    fig.savefig(output_path + f"/gan_repaired_image_epoch_{epoch}.png")
    plt.close()

This function defines and returns a repair generator model, which takes undamaged and damaged images as input and aims to generate a repaired image.

In [None]:
def build_repair_generator(z_dim=100):
    undamaged_input = Input(shape=img_shape, name='undamaged_input')
    damaged_input = Input(shape=img_shape, name='damaged_input')

    # Encoder part
    encoder = Sequential([
        Flatten(input_shape=img_shape),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
    ])

    undamaged_encoded = encoder(undamaged_input)
    damaged_encoded = encoder(damaged_input)

    # Concatenate encoded representations
    merged = Concatenate()([undamaged_encoded, damaged_encoded])

    # Decoder part
    decoder = Sequential([
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(1024),
        LeakyReLU(alpha=0.2),
        Dense(np.prod(img_shape), activation='tanh'),
        Reshape(img_shape),
    ])

    repaired_output = decoder(merged)

    model = Model(inputs=[undamaged_input, damaged_input], outputs=repaired_output)
    model.summary()

    return model

This function defines and returns a generator model, which generates images from random noise.

In [None]:
# Generator model
def build_generator(z_dim=100):
    model = Sequential()
    model.add(Dense(256, input_dim=z_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))
    model.summary()

    noise = Input(shape=(z_dim,))
    img = model(noise)

    return Model(noise, img)



This function defines and returns a discriminator model, which classifies images as real or fake.



In [None]:

# Discriminator model
def build_discriminator(img_shape):
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(1024))  # Increased capacity
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))   # Increased capacity
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)


This function normalizes image pixel values to the range [-1, 1].

In [None]:

# Function to normalize images to the range [-1, 1]
def normalize_images(images):
    return 2 * images - 1


This function loads, resizes, and normalizes images from the dataset, identifying damaged regions based on pixel differences.



In [None]:

def load_and_resize_dataset(dataset_path, target_size=(128, 128), pixel_difference_threshold=50):
    damaged_images = []
    for filename in os.listdir(os.path.join(dataset_path, "damaged")):
        try:
            img = load_img(os.path.join(dataset_path, "damaged", filename), target_size=target_size)
            img = img_to_array(img) / 128.0
            damaged_images.append(img)
        except UnidentifiedImageError:
            print(f"Skipping non-image file: {filename}")

    undamaged_images = []
    for filename in os.listdir(os.path.join(dataset_path, "undamaged")):
        try:
            img = load_img(os.path.join(dataset_path, "undamaged", filename), target_size=target_size)
            img = img_to_array(img) / 128.0
            undamaged_images.append(img)
        except UnidentifiedImageError:
            print(f"Skipping non-image file: {filename}")

    min_samples = min(len(damaged_images), len(undamaged_images))
    damaged_images = damaged_images[:min_samples]
    undamaged_images = undamaged_images[:min_samples]

    damaged_images = np.array(damaged_images)
    undamaged_images = np.array(undamaged_images)

    # Identify damaged regions based on pixel differences
    damaged_regions = np.sum(np.abs(undamaged_images - damaged_images), axis=-1) > pixel_difference_threshold

    return damaged_images, undamaged_images, damaged_regions


These paths specify the locations of training and testing datasets, as well as the output directory for saving generated images.



In [None]:

# Set the paths and parameters
dataset_path = "D:/SagharGhaffari/Archive/dataset"
output_path = "D:/SagharGhaffari/Archive/output"
test_dataset_path = "D:/SagharGhaffari/Archive/testset"


This line loads and preprocesses the training dataset.

In [None]:

# Load and preprocess the dataset with reduced resolution
damaged_data, undamaged_data, damaged_regions = load_and_resize_dataset(dataset_path, target_size=(128, 128))

This loads and preprocesses the test dataset.



In [None]:

# Load and preprocess the test dataset
X_test_damaged_testset, X_test_undamaged_testset, damaged_regions_testset = load_and_resize_dataset(test_dataset_path, target_size=(128, 128))

# Combine the test sets for evaluation
X_test_testset = np.concatenate([X_test_damaged_testset, X_test_undamaged_testset])


These are labels used to indicate whether an image is real (valid) or fake.

In [None]:

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))


This initializes and compiles the generator model.

In [None]:

# Build and compile the generator
generator = build_generator()
generator_optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
generator.compile(loss='mse', optimizer=generator_optimizer)


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 256)               25856     
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 256)               0         
                                                                 
 batch_normalization (Batch  (None, 256)               1024      
 Normalization)                                                  
                                                                 
 dense_1 (Dense)             (None, 512)               131584    
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 512)               0         
                                                                 
 batch_normalization_1 (Bat  (None, 512)               2048      
 chNormalization)                                       

This initializes and compiles the discriminator model.



In [None]:

# Build and compile the discriminator
discriminator = build_discriminator(img_shape)
discriminator_optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
discriminator.compile(loss='binary_crossentropy', optimizer=discriminator_optimizer, metrics=['accuracy'])


Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 49152)             0         
                                                                 
 dense_4 (Dense)             (None, 1024)              50332672  
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 1024)              0         
                                                                 
 dense_5 (Dense)             (None, 512)               524800    
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 512)               0         
                                                                 
 dense_6 (Dense)             (None, 1)                 513       
                                                                 
Total params: 50857985 (194.01 MB)
Trainable params: 5

This creates and compiles the combined model, where only the generator is trained.

In [None]:

# The generator takes noise as input and generates images
z = Input(shape=(z_dim,))
img = generator(z)

# For the combined model, only train the generator
discriminator.trainable = False

# The discriminator takes generated images as input and determines validity
validity = discriminator(img)

# Build and compile the combined model
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer=generator_optimizer)


This initializes and compiles the repair generator model.

In [None]:

# Build and compile the repair generator
repair_generator = build_repair_generator()
repair_generator_optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
repair_generator.compile(loss='mse', optimizer=repair_generator_optimizer)


Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 undamaged_input (InputLaye  [(None, 128, 128, 3)]        0         []                            
 r)                                                                                               
                                                                                                  
 damaged_input (InputLayer)  [(None, 128, 128, 3)]        0         []                            
                                                                                                  
 sequential_2 (Sequential)   (None, 256)                  2529766   ['undamaged_input[0][0]',     
                                                          4          'damaged_input[0][0]']       
                                                                                            

This resizes the test images to the expected input shape of the repair generator.

In [None]:

# Resize the test images to the expected input shape of the repair generator
X_test_undamaged_testset_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in X_test_undamaged_testset])
X_test_damaged_testset_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in X_test_damaged_testset])

In [None]:
import numpy as np

def calculate_metrics(original, repaired, win_size=3):
    psnr_values = []
    ssim_values = []

    # Convert TensorFlow tensors to NumPy arrays
    original_np = original.numpy() if isinstance(original, tf.Tensor) else original
    repaired_np = repaired.numpy() if isinstance(repaired, tf.Tensor) else repaired

    for i in range(len(original_np)):
        # Explicitly convert to float32 to avoid the TypeError
        psnr = peak_signal_noise_ratio(original_np[i].astype(np.float32), repaired_np[i].astype(np.float32), data_range=1.0)
        ssim = structural_similarity(original_np[i].astype(np.float32), repaired_np[i].astype(np.float32), win_size=win_size, multichannel=True, data_range=1.0)
        psnr_values.append(psnr)
        ssim_values.append(ssim)

    return np.mean(psnr_values), np.mean(ssim_values)


### Training the GAN and Repair Generator:


This loop trains the GAN and repair generator for the specified number of epochs, periodically saving generated images and model weights.

In [None]:

# Training the GAN and repair generator
for epoch in range(epochs):
    # Sample noise and generate a batch of new images
    noise = np.random.normal(0, 1, (batch_size, z_dim))
    gen_imgs = generator(noise, training=False)  # Set training to False

    # Create labels for training the discriminator
    valid_labels = np.ones((batch_size, 1))
    fake_labels = np.zeros((batch_size, 1))

    # Randomly select a batch of undamaged and damaged images from the train set
    idx = np.random.randint(0, undamaged_data.shape[0], batch_size // 2)
    X_batch_undamaged = undamaged_data[idx]
    X_batch_damaged = damaged_data[idx]

    # Concatenate undamaged and damaged batches for training the discriminator
    X_batch = np.concatenate([X_batch_undamaged, X_batch_damaged])

    # Train the discriminator
    d_loss_real = discriminator.train_on_batch(X_batch_undamaged, valid_labels[:batch_size // 2])
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake_labels)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train the generator
    g_loss = combined.train_on_batch(noise, valid_labels)

    # Print progress
    print(f"{epoch}/{epochs} [D loss: {d_loss[0]} | D accuracy: {100 * d_loss[1]}] [G loss: {g_loss}]")

    # Sample noise and generate a batch of repaired images
    # Rescale images to 0-1
    X_test_undamaged_rescaled = 0.5 * (X_test_undamaged_testset_resized + 1)
    X_test_damaged_rescaled = 0.5 * (X_test_damaged_testset_resized + 1)

    repair_noise = np.random.normal(0, 1, (batch_size, z_dim))
    repaired_imgs = repair_generator([X_test_undamaged_testset_resized, X_test_damaged_testset_resized], training=False)
    repaired_imgs_rescaled = 0.5 * (repaired_imgs + 1)

    # Train the repair generator
    repair_loss = repair_generator.train_on_batch([X_test_undamaged_testset_resized, X_test_damaged_testset_resized], X_test_undamaged_testset_resized)

    # Print repair progress
    print(f"[Repair Epoch {epoch}/{epochs}] [Repair loss: {repair_loss}]")

    # Calculate and print metrics
    psnr_value, ssim_value = calculate_metrics(X_test_undamaged_rescaled, repaired_imgs_rescaled, win_size=3)
    print(f"[Quantitative Evaluation] [PSNR: {psnr_value}, SSIM: {ssim_value}]")

    # Save generated images and update TensorBoard logs at specified intervals
    if epoch % save_interval == 0:
        save_generated_images(epoch, repair_generator, output_path, X_test_undamaged_testset, X_test_damaged_testset)

        # Save generator weights
        generator.save_weights(f"generator_weights_epoch_{epoch}.h5")

        # Save discriminator weights
        discriminator.save_weights(f"discriminator_weights_epoch_{epoch}.h5")

        # Save repair generator weights
        repair_generator.save_weights(f"repair_generator_weights_epoch_{epoch}.h5")



0/1000 [D loss: 0.5870512574911118 | D accuracy: 60.9375] [G loss: 0.686246931552887]
[Repair Epoch 0/1000] [Repair loss: 1.1405812501907349]
[Quantitative Evaluation] [PSNR: 5.837541539825861, SSIM: 0.11469845377361548]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

1/1000 [D loss: 0.4406650960465262 | D accuracy: 50.78125] [G loss: 0.4516961872577667]
[Repair Epoch 1/1000] [Repair loss: 1.1449185609817505]
[Quantitative Evaluation] [PSNR: 5.819048596483551, SSIM: 0.03603614499069508]
2/1000 [D loss: 0.3421957492828369 | D accuracy: 76.171875] [G loss: 0.48926490545272827]
[Repair Epoch 2/1000] [Repair loss: 0.8485788702964783]
[Quantitative Evaluation] [PSNR: 7.089226220080101, SSIM: 0.025214073617349547]
3/1000 [D loss: 0.4497961699962616 | D accuracy: 63.671875] [G loss: 0.4191523492336273]
[Repair Epoch 3/1000] [Repair loss: 0.4645453095436096]
[Quantitative Evaluation] [PSNR: 9.606522609089899, SSIM: 0.025954874737445646]
4/1000 [D loss: 0.5203969478607178 | D accuracy: 62.5] [G loss: 0.49807974696159363]
[Repair Epoch 4/1000] [Repair loss: 0.31018194556236267]
[Quantitative Evaluation] [PSNR: 11.326662264701964, SSIM: 0.09471518837969777]
5/1000 [D loss: 0.4473330378532421 | D accuracy: 72.65625] [G loss: 0.597773015499115]
[Repair Epoch 5/1

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

101/1000 [D loss: 10.364853858947754 | D accuracy: 62.890625] [G loss: 10.574719429016113]
[Repair Epoch 101/1000] [Repair loss: 0.21454648673534393]
[Quantitative Evaluation] [PSNR: 13.109581181931443, SSIM: 0.2486094026433976]
102/1000 [D loss: 9.884761810302734 | D accuracy: 62.890625] [G loss: 10.99178695678711]
[Repair Epoch 102/1000] [Repair loss: 0.21657459437847137]
[Quantitative Evaluation] [PSNR: 13.086447406801927, SSIM: 0.23410192898885926]
103/1000 [D loss: 11.997733116149902 | D accuracy: 60.9375] [G loss: 10.691733360290527]
[Repair Epoch 103/1000] [Repair loss: 0.21566252410411835]
[Quantitative Evaluation] [PSNR: 13.063734549181227, SSIM: 0.25918259531367066]
104/1000 [D loss: 9.39077091217041 | D accuracy: 60.15625] [G loss: 6.5106658935546875]
[Repair Epoch 104/1000] [Repair loss: 0.21381054818630219]
[Quantitative Evaluation] [PSNR: 13.160236100443065, SSIM: 0.2929012217392165]
105/1000 [D loss: 9.236194610595703 | D accuracy: 62.890625] [G loss: 7.946686744689941]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

201/1000 [D loss: 3.5742883682250977 | D accuracy: 51.953125] [G loss: 0.1854114979505539]
[Repair Epoch 201/1000] [Repair loss: 0.19039615988731384]
[Quantitative Evaluation] [PSNR: 13.74624893841564, SSIM: 0.3720697961265195]
202/1000 [D loss: 3.4623847007751465 | D accuracy: 50.78125] [G loss: 0.0829712525010109]
[Repair Epoch 202/1000] [Repair loss: 0.1941586583852768]
[Quantitative Evaluation] [PSNR: 13.602207465703607, SSIM: 0.34802247947404763]
203/1000 [D loss: 3.030073642730713 | D accuracy: 51.171875] [G loss: 0.11351388692855835]
[Repair Epoch 203/1000] [Repair loss: 0.1899615228176117]
[Quantitative Evaluation] [PSNR: 13.749282735936198, SSIM: 0.3645888747991851]
204/1000 [D loss: 2.5752434730529785 | D accuracy: 50.78125] [G loss: 0.10250328481197357]
[Repair Epoch 204/1000] [Repair loss: 0.1899847388267517]
[Quantitative Evaluation] [PSNR: 13.73549567034097, SSIM: 0.3616713981015536]
205/1000 [D loss: 2.2570629119873047 | D accuracy: 51.171875] [G loss: 0.1659432053565979

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

301/1000 [D loss: 0.3065880015492439 | D accuracy: 93.75] [G loss: 0.9531166553497314]
[Repair Epoch 301/1000] [Repair loss: 0.18209540843963623]
[Quantitative Evaluation] [PSNR: 13.965659735340882, SSIM: 0.387712592880981]
302/1000 [D loss: 0.26742764938148866 | D accuracy: 97.65625] [G loss: 1.0644633769989014]
[Repair Epoch 302/1000] [Repair loss: 0.18213307857513428]
[Quantitative Evaluation] [PSNR: 13.956675467049132, SSIM: 0.38287086776621493]
303/1000 [D loss: 0.21318830968812108 | D accuracy: 100.0] [G loss: 1.2697796821594238]
[Repair Epoch 303/1000] [Repair loss: 0.18204987049102783]
[Quantitative Evaluation] [PSNR: 13.969220061261336, SSIM: 0.3877661898460327]
304/1000 [D loss: 0.21077902242541313 | D accuracy: 100.0] [G loss: 1.2443783283233643]
[Repair Epoch 304/1000] [Repair loss: 0.18257053196430206]
[Quantitative Evaluation] [PSNR: 13.935834416028296, SSIM: 0.37859834582065915]
305/1000 [D loss: 0.20355496928095818 | D accuracy: 100.0] [G loss: 1.254408597946167]
[Repai

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

401/1000 [D loss: 0.2204515310440911 | D accuracy: 100.0] [G loss: 1.2047371864318848]
[Repair Epoch 401/1000] [Repair loss: 0.17817245423793793]
[Quantitative Evaluation] [PSNR: 14.078313349599947, SSIM: 0.4010253001394074]
402/1000 [D loss: 0.17096151727673714 | D accuracy: 100.0] [G loss: 1.41206693649292]
[Repair Epoch 402/1000] [Repair loss: 0.17843051254749298]
[Quantitative Evaluation] [PSNR: 14.06037219756206, SSIM: 0.39136836993986107]
403/1000 [D loss: 0.20862369798123837 | D accuracy: 100.0] [G loss: 1.2259019613265991]
[Repair Epoch 403/1000] [Repair loss: 0.17838990688323975]
[Quantitative Evaluation] [PSNR: 14.073499710501792, SSIM: 0.4045982375030262]
404/1000 [D loss: 0.18617017311044037 | D accuracy: 100.0] [G loss: 1.3064446449279785]
[Repair Epoch 404/1000] [Repair loss: 0.17986634373664856]
[Quantitative Evaluation] [PSNR: 14.005858366565585, SSIM: 0.3853238983395669]
405/1000 [D loss: 0.1650169426575303 | D accuracy: 100.0] [G loss: 1.4329055547714233]
[Repair Epoc

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

501/1000 [D loss: 0.9036667347682571 | D accuracy: 50.0] [G loss: 0.29444217681884766]
[Repair Epoch 501/1000] [Repair loss: 0.17537441849708557]
[Quantitative Evaluation] [PSNR: 14.143491688928462, SSIM: 0.40014872713329375]
502/1000 [D loss: 0.7012205719948111 | D accuracy: 51.5625] [G loss: 0.45941802859306335]
[Repair Epoch 502/1000] [Repair loss: 0.175773486495018]
[Quantitative Evaluation] [PSNR: 14.156396237468304, SSIM: 0.4092629643035534]
503/1000 [D loss: 0.47567943287489056 | D accuracy: 58.984375] [G loss: 0.730699896812439]
[Repair Epoch 503/1000] [Repair loss: 0.17758190631866455]
[Quantitative Evaluation] [PSNR: 14.06879239210111, SSIM: 0.3871773093513133]
504/1000 [D loss: 0.3027204495981124 | D accuracy: 88.28125] [G loss: 1.0620934963226318]
[Repair Epoch 504/1000] [Repair loss: 0.1755540668964386]
[Quantitative Evaluation] [PSNR: 14.163637116308665, SSIM: 0.4134517878459021]
505/1000 [D loss: 0.20009892776801053 | D accuracy: 99.609375] [G loss: 1.3772106170654297]
[

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

601/1000 [D loss: 0.11945528647629544 | D accuracy: 100.0] [G loss: 1.8197284936904907]
[Repair Epoch 601/1000] [Repair loss: 0.4285498261451721]
[Quantitative Evaluation] [PSNR: 9.81859944446774, SSIM: 0.15273171564365073]
602/1000 [D loss: 0.0921343705849722 | D accuracy: 100.0] [G loss: 2.048018217086792]
[Repair Epoch 602/1000] [Repair loss: 0.3082367479801178]
[Quantitative Evaluation] [PSNR: 11.385302073105366, SSIM: 0.19933665878324675]
603/1000 [D loss: 1.1272167265415192 | D accuracy: 47.265625] [G loss: 0.09005682170391083]
[Repair Epoch 603/1000] [Repair loss: 0.3382275104522705]
[Quantitative Evaluation] [PSNR: 10.924063442462238, SSIM: 0.1747993983494603]
604/1000 [D loss: 1.6167380809784655 | D accuracy: 50.390625] [G loss: 0.08843661844730377]
[Repair Epoch 604/1000] [Repair loss: 0.34189894795417786]
[Quantitative Evaluation] [PSNR: 10.874145763228725, SSIM: 0.17039835635955342]
605/1000 [D loss: 1.4635226728089608 | D accuracy: 50.0] [G loss: 0.16034452617168427]
[Repa

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

701/1000 [D loss: 0.7700525008463863 | D accuracy: 51.5625] [G loss: 0.4466492235660553]
[Repair Epoch 701/1000] [Repair loss: 0.33754774928092957]
[Quantitative Evaluation] [PSNR: 10.943959231640605, SSIM: 0.17938839489260172]
702/1000 [D loss: 0.47952523343765563 | D accuracy: 61.71875] [G loss: 0.7607197761535645]
[Repair Epoch 702/1000] [Repair loss: 0.3373506963253021]
[Quantitative Evaluation] [PSNR: 10.946822102667076, SSIM: 0.17963939074203752]
703/1000 [D loss: 0.28474306120551773 | D accuracy: 88.671875] [G loss: 1.2343679666519165]
[Repair Epoch 703/1000] [Repair loss: 0.3375486135482788]
[Quantitative Evaluation] [PSNR: 10.943904377668469, SSIM: 0.17964939857367254]
704/1000 [D loss: 0.15056506113614887 | D accuracy: 100.0] [G loss: 1.6669588088989258]
[Repair Epoch 704/1000] [Repair loss: 0.338137149810791]
[Quantitative Evaluation] [PSNR: 10.935319061168805, SSIM: 0.17904449319463878]
705/1000 [D loss: 0.14673009887337685 | D accuracy: 99.21875] [G loss: 1.625872850418090

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

801/1000 [D loss: 0.1314757843501866 | D accuracy: 100.0] [G loss: 1.69723641872406]
[Repair Epoch 801/1000] [Repair loss: 0.35061293840408325]
[Quantitative Evaluation] [PSNR: 10.757781176152177, SSIM: 0.17222750670968828]
802/1000 [D loss: 0.15385777130723 | D accuracy: 100.0] [G loss: 1.5394995212554932]
[Repair Epoch 802/1000] [Repair loss: 0.3504500985145569]
[Quantitative Evaluation] [PSNR: 10.759993059531915, SSIM: 0.1726903810794688]
803/1000 [D loss: 0.14548687916249037 | D accuracy: 100.0] [G loss: 1.670264720916748]
[Repair Epoch 803/1000] [Repair loss: 0.35083842277526855]
[Quantitative Evaluation] [PSNR: 10.754675489280306, SSIM: 0.1723113135773801]
804/1000 [D loss: 0.16377553343772888 | D accuracy: 98.4375] [G loss: 1.5384129285812378]
[Repair Epoch 804/1000] [Repair loss: 0.3507492244243622]
[Quantitative Evaluation] [PSNR: 10.755927045655646, SSIM: 0.17238830556093204]
805/1000 [D loss: 0.1503724167123437 | D accuracy: 99.21875] [G loss: 1.658695101737976]
[Repair Epoc

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

901/1000 [D loss: 0.09323233552277088 | D accuracy: 100.0] [G loss: 2.0335867404937744]
[Repair Epoch 901/1000] [Repair loss: 0.35658732056617737]
[Quantitative Evaluation] [PSNR: 10.676029529784694, SSIM: 0.1701459559634385]
902/1000 [D loss: 0.1844823844730854 | D accuracy: 99.21875] [G loss: 1.41977858543396]
[Repair Epoch 902/1000] [Repair loss: 0.35685306787490845]
[Quantitative Evaluation] [PSNR: 10.67245552991602, SSIM: 0.1701136808515592]
903/1000 [D loss: 0.1548828991362825 | D accuracy: 100.0] [G loss: 1.5863505601882935]
[Repair Epoch 903/1000] [Repair loss: 0.3567183315753937]
[Quantitative Evaluation] [PSNR: 10.674263033304856, SSIM: 0.17026057203179773]
904/1000 [D loss: 0.10574155968788546 | D accuracy: 100.0] [G loss: 1.936658263206482]
[Repair Epoch 904/1000] [Repair loss: 0.3565840423107147]
[Quantitative Evaluation] [PSNR: 10.676063401996911, SSIM: 0.17030876819883317]
905/1000 [D loss: 0.13178734853863716 | D accuracy: 98.4375] [G loss: 1.790856122970581]
[Repair Ep

### Repairing images with Diffusion

In [10]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, BatchNormalization, Activation, LeakyReLU, Concatenate
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MeanSquaredError
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras import backend as K
import os
from PIL import UnidentifiedImageError
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input
from tensorflow.keras.models import Model as KerasModel
from tensorflow.keras.preprocessing import image
from scipy.linalg import sqrtm
from skimage.transform import resize
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers





calculate_metrics is a function that takes two sets of images (original and repaired) and an optional parameter win_size.

It initializes empty lists to store PSNR and SSIM values for each image pair.

It then iterates through the images, calculates PSNR and SSIM using the specified window size (win_size), and appends the values to the respective lists.

Finally, it returns the mean PSNR and mean SSIM values across all images.


In [None]:

def calculate_metrics(original, repaired, win_size=5):
    psnr_values = []
    ssim_values = []

    for i in range(len(original)):
        psnr = peak_signal_noise_ratio(original[i], repaired[i], data_range=1.0)
        ssim = structural_similarity(original[i], repaired[i], win_size=win_size, multichannel=True, data_range=1.0)
        psnr_values.append(psnr)
        ssim_values.append(ssim)

    return np.mean(psnr_values), np.mean(ssim_values)




This code implements a diffusion model for repairing damaged images. The diffusion model is a generative model that simulates the process of diffusing noise into a damaged image to generate a repaired version.

Import necessary libraries and set parameters, including image dimensions, noise vector size, and training parameters.

img_rows, img_cols, channels: Dimensions of the images (128x128 pixels with 3 color channels).

z_dim: Size of the noise vector (100).

epochs: Number of training epochs (5000).

batch_size: Number of samples in each batch (64).

save_interval: Interval for saving generated images (every 100 epochs).

Learning rate and beta1 for the Adam optimizer in the repair model: learning_rate=0.0002, beta_1=0.5.

In [7]:

# Set the seed for reproducibility
np.random.seed(1000)

# Define image parameters
img_rows, img_cols, channels = 128, 128, 3
img_shape = (img_rows, img_cols, channels)
z_dim = 100  # Size of the noise vector
epochs = 1000
batch_size = 64
save_interval = 100  # Set the interval to save generated images


save_generated_images:

Visualizes and saves a comparison of undamaged, damaged, and repaired images.

Parameters:

epoch: Current training epoch.

repair_model: The diffusion model.

output_path: Path to save the generated images.

X_test_undamaged, X_test_damaged: Test sets for visualization.

r, c: Number of rows and columns in the visualization grid.

In [None]:

# Function to save generated images
def save_generated_images(epoch, repair_model, output_path, X_test_undamaged, X_test_damaged, r=5, c=3):
    # Resize images to match the repair model input shape
    X_test_undamaged_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in X_test_undamaged])
    X_test_damaged_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in X_test_damaged])

    # Generate repaired images for damaged parts
    repaired_imgs = repair_model.predict([X_test_undamaged_resized, X_test_damaged_resized])

    # Rescale images to 0-1
    X_test_undamaged_rescaled = 0.5 * (X_test_undamaged_resized + 1)
    X_test_damaged_rescaled = 0.5 * (X_test_damaged_resized + 1)
    repaired_imgs_rescaled = 0.5 * (repaired_imgs + 1)

    fig, axs = plt.subplots(r, c, figsize=(15, 15))
    cnt = 0
    for i in range(r):
        axs[i, 0].imshow(X_test_undamaged_rescaled[cnt])
        axs[i, 0].axis('off')
        axs[i, 0].set_title("Real (Undamaged)")

        axs[i, 1].imshow(X_test_damaged_rescaled[cnt])
        axs[i, 1].axis('off')
        axs[i, 1].set_title("Damaged")

        axs[i, 2].imshow(repaired_imgs_rescaled[cnt])
        axs[i, 2].axis('off')
        axs[i, 2].set_title("Repaired")

        cnt += 1

    fig.savefig(output_path + f"/repaired_image_epoch_{epoch}.png")
    plt.close()


Builds the diffusion model using TensorFlow Probability.

Input: Undamaged and damaged images.

Output: Repaired output based on the diffusion process.

Uses a simple Gaussian diffusion process.

diffusion_function:

Defines the diffusion process using TensorFlow Probability.

Parameters:

inputs: Tuple of undamaged and damaged images.

Returns the output after simulating the diffusion process.

normalize_images:

Normalizes images to the range [-1, 1].

In [None]:

# Function to normalize images to the range [-1, 1]
def normalize_images(images):
    return 2 * images - 1


load_and_resize_dataset:

Loads and resizes the dataset of damaged and undamaged images.

Parameters:

dataset_path: Path to the dataset.

target_size: Target size for resizing.

Returns the normalized damaged and undamaged images.

In [None]:

def load_and_resize_dataset(dataset_path, target_size=(128, 128), pixel_difference_threshold=50):
    damaged_images = []
    for filename in os.listdir(os.path.join(dataset_path, "damaged")):
        try:
            img = load_img(os.path.join(dataset_path, "damaged", filename), target_size=target_size)
            img = img_to_array(img) / 128.0
            damaged_images.append(img)
        except UnidentifiedImageError:
            print(f"Skipping non-image file: {filename}")

    undamaged_images = []
    for filename in os.listdir(os.path.join(dataset_path, "undamaged")):
        try:
            img = load_img(os.path.join(dataset_path, "undamaged", filename), target_size=target_size)
            img = img_to_array(img) / 128.0
            undamaged_images.append(img)
        except UnidentifiedImageError:
            print(f"Skipping non-image file: {filename}")

    damaged_images = normalize_images(np.array(damaged_images))
    undamaged_images = normalize_images(np.array(undamaged_images))

    return damaged_images, undamaged_images


Initialization:

          Sets the random seed for reproducibility.

          Defines image parameters, hyperparameters, and paths.

Load and Preprocess Dataset:

          Loads and resizes the training and test datasets.

Build and Compile Repair Model:

          Builds the diffusion repair model.
          Compiles the model with Mean Squared Error (MSE) loss and the Adam optimizer.

In [None]:

# Set the paths and parameters

dataset_path = r"D:\SagharGhaffari\Archive\dataset"
output_path = r"D:\SagharGhaffari\Archive\diffusionoutput"
test_dataset_path = r"D:\SagharGhaffari\Archive\testset"


# Load and preprocess the dataset with reduced resolution
damaged_data, undamaged_data = load_and_resize_dataset(dataset_path, target_size=(128, 128))


# Load and preprocess the test dataset
X_test_damaged_testset, X_test_undamaged_testset = load_and_resize_dataset(test_dataset_path, target_size=(128, 128))

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))



# Resize the test images to the expected input shape of the repair model
undamaged_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in undamaged_data])
damaged_resized = np.array([resize(img, (128, 128, 3), mode='reflect', anti_aliasing=True) for img in damaged_data])


Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 damaged_input (InputLayer)  [(None, 128, 128, 3)]        0         []                            
                                                                                                  
 tf.compat.v1.shape (TFOpLa  (4,)                         0         ['damaged_input[0][0]']       
 mbda)                                                                                            
                                                                                                  
 tf.cast (TFOpLambda)        (4,)                         0         ['tf.compat.v1.shape[0][0]']  
                                                                                                  
 tf.convert_to_tensor (TFOp  (4,)                         0         ['tf.cast[0][0]']       

define the encoder part of inpainting model. The encoder_input is the input to model, and must two convolutional layers (encoder_conv1 and encoder_conv2) that be processed this input.




add a new input for the mask (mask_input). The shape of mask_input is assumed to be the same as the input shape, except for the last dimension, which is set to 1. This assumes a grayscale mask. If your mask has three channels, adjust the shape accordingly. The masked_encoder_input is created by concatenating the original encoder_input and the mask_input along the channel axis using layers.Concatenate.


define the decoder. It takes the masked_encoder_input and performs a transposed convolution (Conv2DTranspose) to upsample the features. Finally, a convolutional layer produces the final output (decoder_output) with 3 channels using the sigmoid activation function.


The inpainting_model is created using the keras.Model class, taking both encoder_input and mask_input as inputs and producing decoder_output. The model summary is printed for inspection.

In [13]:

# Define the diffusion function using TensorFlow Probability
def diffusion_function(inputs):
    undamaged, damaged = inputs

    x = tf.concat([undamaged, damaged], axis=-1)
    x = BatchNormalization()(x)

    # Output layer
    diffusion_output = Dense(img_shape[-1], activation='tanh')(x)

    return diffusion_output

def build_diffusion_model():
    undamaged_input = Input(shape=img_shape, name='undamaged_input')
    damaged_input = Input(shape=img_shape, name='damaged_input')

    diffusion_output = diffusion_function([undamaged_input, damaged_input])

    repaired_output = undamaged_input + diffusion_output

    model = Model(inputs=[undamaged_input, damaged_input], outputs=repaired_output)
    model.summary()

    return model

# Clear TensorFlow session before training
tf.keras.backend.clear_session()

# Build and compile the repair model
repair_model = build_diffusion_model()
repair_model_optimizer = Adam(learning_rate=0.0002, beta_1=0.5, clipvalue=1.0)

repair_model.compile(loss='mse', optimizer=repair_model_optimizer)


Model: "inpainting_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 128, 128, 3)]     0         
                                                                 
 conv2d_6 (Conv2D)           (None, 128, 128, 64)      1792      
                                                                 
 conv2d_7 (Conv2D)           (None, 64, 64, 128)       73856     
                                                                 
 conv2d_transpose_2 (Conv2D  (None, 128, 128, 64)      73792     
 Transpose)                                                      
                                                                 
 conv2d_8 (Conv2D)           (None, 128, 128, 3)       1731      
                                                                 
Total params: 151171 (590.51 KB)
Trainable params: 151171 (590.51 KB)
Non-trainable params: 0 (0.00 Byte)
__________

Training Loop:

Loops through epochs.

Samples noise and generates repaired images using the repair model.

Trains the repair model on the batch of undamaged and damaged images.

Calculates quantitative metrics (PSNR and SSIM) between undamaged and repaired images.

Prints repair progress, metrics, and saves generated images at specified intervals.

In [None]:

# Training loop for the repair model
for epoch in range(epochs):
    # Sample noise and generate a batch of repaired images using diffusion
    repaired_imgs = repair_model.predict([X_test_undamaged_testset_resized, X_test_damaged_testset_resized])

    # Train the repair model
    K.clear_session()  # Clear TensorFlow session to release memory
    repair_loss = repair_model.train_on_batch([X_test_undamaged_testset_resized, X_test_damaged_testset_resized], X_test_undamaged_testset_resized)

    # Rescale images to 0-1
    X_test_undamaged_rescaled = 0.5 * (X_test_undamaged_testset_resized + 1)
    X_test_damaged_rescaled = 0.5 * (X_test_damaged_testset_resized + 1)
    repaired_imgs_rescaled = 0.5 * (repaired_imgs + 1)


    # Print repair progress and metrics
    print(f"[Repair Epoch {epoch}/{epochs}] [Repair loss: {repair_loss}]")

    # Calculate and print metrics
    psnr_value, ssim_value = calculate_metrics(X_test_undamaged_rescaled, repaired_imgs_rescaled, win_size=3)
    print(f"[Quantitative Evaluation] [PSNR: {psnr_value}, SSIM: {ssim_value}]")

    # Save generated images at a specified interval
    if epoch % save_interval == 0:
        save_generated_images(epoch, repair_model, output_path, X_test_undamaged_testset, X_test_damaged_testset)

[Repair Epoch 0/1000] [Repair loss: 0.43653714656829834]
[Quantitative Evaluation] [PSNR: 9.386246597028837, SSIM: 0.5821332004619051]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

[Repair Epoch 1/1000] [Repair loss: 0.4360705316066742]
[Quantitative Evaluation] [PSNR: 9.396606056401476, SSIM: 0.5812405898659244]
[Repair Epoch 2/1000] [Repair loss: 0.4356038570404053]
[Quantitative Evaluation] [PSNR: 9.406835630671177, SSIM: 0.5803626849793121]
[Repair Epoch 3/1000] [Repair loss: 0.4351373314857483]
[Quantitative Evaluation] [PSNR: 9.416939108895974, SSIM: 0.5795007970384566]
[Repair Epoch 4/1000] [Repair loss: 0.43467092514038086]
[Quantitative Evaluation] [PSNR: 9.426921102584942, SSIM: 0.5786525917575577]
[Repair Epoch 5/1000] [Repair loss: 0.434204638004303]
[Quantitative Evaluation] [PSNR: 9.436785709767054, SSIM: 0.57781953882037]
[Repair Epoch 6/1000] [Repair loss: 0.4337383210659027]
[Quantitative Evaluation] [PSNR: 9.446536457868994, SSIM: 0.5770012398033937]
[Repair Epoch 7/1000] [Repair loss: 0.43327224254608154]
[Quantitative Evaluation] [PSNR: 9.456177302162601, SSIM: 0.5761975154620885]
[Repair Epoch 8/1000] [Repair loss: 0.4328063428401947]
[Quanti

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

[Repair Epoch 101/1000] [Repair loss: 0.39016515016555786]
[Quantitative Evaluation] [PSNR: 10.134563763068376, SSIM: 0.5430049737950159]
[Repair Epoch 102/1000] [Repair loss: 0.3897145092487335]
[Quantitative Evaluation] [PSNR: 10.14076203550663, SSIM: 0.5429159759735686]
[Repair Epoch 103/1000] [Repair loss: 0.3892640471458435]
[Quantitative Evaluation] [PSNR: 10.14695353706252, SSIM: 0.5428301727711616]
[Repair Epoch 104/1000] [Repair loss: 0.3888138234615326]
[Quantitative Evaluation] [PSNR: 10.153138326489506, SSIM: 0.5427473806650087]
[Repair Epoch 105/1000] [Repair loss: 0.38836368918418884]
[Quantitative Evaluation] [PSNR: 10.159317096913856, SSIM: 0.5426676569426323]
[Repair Epoch 106/1000] [Repair loss: 0.38791364431381226]
[Quantitative Evaluation] [PSNR: 10.165489846706908, SSIM: 0.5425907780433084]
[Repair Epoch 107/1000] [Repair loss: 0.38746383786201477]
[Quantitative Evaluation] [PSNR: 10.1716567941173, SSIM: 0.5425168480286751]
[Repair Epoch 108/1000] [Repair loss: 0.3

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

[Repair Epoch 201/1000] [Repair loss: 0.3455140292644501]
[Quantitative Evaluation] [PSNR: 10.748819844050459, SSIM: 0.5453935674961284]
[Repair Epoch 202/1000] [Repair loss: 0.34506818652153015]
[Quantitative Evaluation] [PSNR: 10.755123178473978, SSIM: 0.5455084284839233]
[Repair Epoch 203/1000] [Repair loss: 0.34462234377861023]
[Quantitative Evaluation] [PSNR: 10.761433295054522, SSIM: 0.545624771463017]
[Repair Epoch 204/1000] [Repair loss: 0.34417635202407837]
[Quantitative Evaluation] [PSNR: 10.767750053108873, SSIM: 0.5457425308403238]
[Repair Epoch 205/1000] [Repair loss: 0.3437304198741913]
[Quantitative Evaluation] [PSNR: 10.774073209103024, SSIM: 0.5458617875644705]
[Repair Epoch 206/1000] [Repair loss: 0.3432844281196594]
[Quantitative Evaluation] [PSNR: 10.780403740641317, SSIM: 0.5459826146267476]
[Repair Epoch 207/1000] [Repair loss: 0.3428383469581604]
[Quantitative Evaluation] [PSNR: 10.786740677027526, SSIM: 0.5461049000361337]
[Repair Epoch 208/1000] [Repair loss: 0

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

[Repair Epoch 301/1000] [Repair loss: 0.3006264269351959]
[Quantitative Evaluation] [PSNR: 11.4217016958464, SSIM: 0.5637472565105833]
[Repair Epoch 302/1000] [Repair loss: 0.3001749515533447]
[Quantitative Evaluation] [PSNR: 11.428945294213118, SSIM: 0.5639955276084313]
[Repair Epoch 303/1000] [Repair loss: 0.29972338676452637]
[Quantitative Evaluation] [PSNR: 11.43620012036068, SSIM: 0.564244989524722]
[Repair Epoch 304/1000] [Repair loss: 0.2992718815803528]
[Quantitative Evaluation] [PSNR: 11.443466463174504, SSIM: 0.564495562742913]
[Repair Epoch 305/1000] [Repair loss: 0.2988203763961792]
[Quantitative Evaluation] [PSNR: 11.450744228320811, SSIM: 0.5647473218554488]
[Repair Epoch 306/1000] [Repair loss: 0.29836878180503845]
[Quantitative Evaluation] [PSNR: 11.458033569634614, SSIM: 0.5650002592898317]
[Repair Epoch 307/1000] [Repair loss: 0.2979171872138977]
[Quantitative Evaluation] [PSNR: 11.46533422418768, SSIM: 0.5652543694823159]
[Repair Epoch 308/1000] [Repair loss: 0.29746

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

[Repair Epoch 401/1000] [Repair loss: 0.2556001842021942]
[Quantitative Evaluation] [PSNR: 12.206508745684905, SSIM: 0.5937110527310494]
[Repair Epoch 402/1000] [Repair loss: 0.2551540434360504]
[Quantitative Evaluation] [PSNR: 12.215004582928856, SSIM: 0.59405340464022]
[Repair Epoch 403/1000] [Repair loss: 0.25470811128616333]
[Quantitative Evaluation] [PSNR: 12.223514095981145, SSIM: 0.5943964037275681]
[Repair Epoch 404/1000] [Repair loss: 0.2542622685432434]
[Quantitative Evaluation] [PSNR: 12.23203680502869, SSIM: 0.5947399793360229]
[Repair Epoch 405/1000] [Repair loss: 0.2538166344165802]
[Quantitative Evaluation] [PSNR: 12.240572702361742, SSIM: 0.5950841770501464]
[Repair Epoch 406/1000] [Repair loss: 0.25337114930152893]
[Quantitative Evaluation] [PSNR: 12.249121418056983, SSIM: 0.5954289837355217]
[Repair Epoch 407/1000] [Repair loss: 0.2529257833957672]
[Quantitative Evaluation] [PSNR: 12.257683574177051, SSIM: 0.595774362432064]
[Repair Epoch 408/1000] [Repair loss: 0.252

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

[Repair Epoch 501/1000] [Repair loss: 0.21226775646209717]
[Quantitative Evaluation] [PSNR: 13.119597317724535, SSIM: 0.6300981122629111]
[Repair Epoch 502/1000] [Repair loss: 0.21185392141342163]
[Quantitative Evaluation] [PSNR: 13.129326455443609, SSIM: 0.630478558948878]
[Repair Epoch 503/1000] [Repair loss: 0.2114405781030655]
[Quantitative Evaluation] [PSNR: 13.139065702712628, SSIM: 0.6308593385852691]
[Repair Epoch 504/1000] [Repair loss: 0.21102775633335114]
[Quantitative Evaluation] [PSNR: 13.148815601628934, SSIM: 0.6312404812103309]
[Repair Epoch 505/1000] [Repair loss: 0.21061548590660095]
[Quantitative Evaluation] [PSNR: 13.158575313868955, SSIM: 0.6316218284672599]
[Repair Epoch 506/1000] [Repair loss: 0.21020373702049255]
[Quantitative Evaluation] [PSNR: 13.168345976349663, SSIM: 0.6320035586023941]
[Repair Epoch 507/1000] [Repair loss: 0.2097925841808319]
[Quantitative Evaluation] [PSNR: 13.178125774151761, SSIM: 0.6323856111175842]
[Repair Epoch 508/1000] [Repair loss:

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

[Repair Epoch 601/1000] [Repair loss: 0.17413027584552765]
[Quantitative Evaluation] [PSNR: 14.130982475127876, SSIM: 0.6700151651344728]
[Repair Epoch 602/1000] [Repair loss: 0.17378787696361542]
[Quantitative Evaluation] [PSNR: 14.141337986256088, SSIM: 0.6704356727159031]
[Repair Epoch 603/1000] [Repair loss: 0.17344631254673004]
[Quantitative Evaluation] [PSNR: 14.151694367463138, SSIM: 0.6708566224185855]
[Repair Epoch 604/1000] [Repair loss: 0.17310568690299988]
[Quantitative Evaluation] [PSNR: 14.162051904892632, SSIM: 0.6712780114482241]
[Repair Epoch 605/1000] [Repair loss: 0.1727658361196518]
[Quantitative Evaluation] [PSNR: 14.17241091636888, SSIM: 0.6716997688574771]
[Repair Epoch 606/1000] [Repair loss: 0.17242689430713654]
[Quantitative Evaluation] [PSNR: 14.18277038711514, SSIM: 0.6721219836757171]
[Repair Epoch 607/1000] [Repair loss: 0.1720888614654541]
[Quantitative Evaluation] [PSNR: 14.193131212620631, SSIM: 0.6725444925681925]
[Repair Epoch 608/1000] [Repair loss: 

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

[Repair Epoch 701/1000] [Repair loss: 0.1443941593170166]
[Quantitative Evaluation] [PSNR: 15.153003176849563, SSIM: 0.7135147911696939]
[Repair Epoch 702/1000] [Repair loss: 0.14414353668689728]
[Quantitative Evaluation] [PSNR: 15.16287193738091, SSIM: 0.7139560858535824]
[Repair Epoch 703/1000] [Repair loss: 0.14389385282993317]
[Quantitative Evaluation] [PSNR: 15.172729882629035, SSIM: 0.714397342352567]
[Repair Epoch 704/1000] [Repair loss: 0.14364509284496307]
[Quantitative Evaluation] [PSNR: 15.182575976587149, SSIM: 0.7148385419155378]
[Repair Epoch 705/1000] [Repair loss: 0.14339719712734222]
[Quantitative Evaluation] [PSNR: 15.192411299715976, SSIM: 0.7152796329945744]
[Repair Epoch 706/1000] [Repair loss: 0.14315025508403778]
[Quantitative Evaluation] [PSNR: 15.202234620433465, SSIM: 0.7157207693461772]
[Repair Epoch 707/1000] [Repair loss: 0.14290420711040497]
[Quantitative Evaluation] [PSNR: 15.212046111738083, SSIM: 0.7161616509254307]
[Repair Epoch 708/1000] [Repair loss:

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

[Repair Epoch 801/1000] [Repair loss: 0.12352073192596436]
[Quantitative Evaluation] [PSNR: 16.068969913936044, SSIM: 0.7563542536543219]
[Repair Epoch 802/1000] [Repair loss: 0.12335033714771271]
[Quantitative Evaluation] [PSNR: 16.07728640607197, SSIM: 0.7567603572321361]
[Repair Epoch 803/1000] [Repair loss: 0.12318062782287598]
[Quantitative Evaluation] [PSNR: 16.085584281148083, SSIM: 0.7571664560969472]
[Repair Epoch 804/1000] [Repair loss: 0.1230115219950676]
[Quantitative Evaluation] [PSNR: 16.093864124046725, SSIM: 0.7575713645895462]
[Repair Epoch 805/1000] [Repair loss: 0.12284307926893234]
[Quantitative Evaluation] [PSNR: 16.10212544130188, SSIM: 0.757975992258623]
[Repair Epoch 806/1000] [Repair loss: 0.12267529219388962]
[Quantitative Evaluation] [PSNR: 16.110368673481922, SSIM: 0.7583799933788222]
[Repair Epoch 807/1000] [Repair loss: 0.12250816822052002]
[Quantitative Evaluation] [PSNR: 16.118593468091653, SSIM: 0.7587833771733747]
[Repair Epoch 808/1000] [Repair loss: 

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

[Repair Epoch 901/1000] [Repair loss: 0.10922405123710632]
[Quantitative Evaluation] [PSNR: 16.811198661752304, SSIM: 0.7935974309880751]
[Repair Epoch 902/1000] [Repair loss: 0.10910405963659286]
[Quantitative Evaluation] [PSNR: 16.81774901928493, SSIM: 0.7939315422211967]
[Repair Epoch 903/1000] [Repair loss: 0.10898442566394806]
[Quantitative Evaluation] [PSNR: 16.824284122439483, SSIM: 0.794264953412753]
[Repair Epoch 904/1000] [Repair loss: 0.1088651493191719]
[Quantitative Evaluation] [PSNR: 16.830802464754598, SSIM: 0.7945977667115897]
[Repair Epoch 905/1000] [Repair loss: 0.1087462455034256]
[Quantitative Evaluation] [PSNR: 16.837306026621114, SSIM: 0.7949296478956042]
[Repair Epoch 906/1000] [Repair loss: 0.10862767696380615]
[Quantitative Evaluation] [PSNR: 16.843793849499722, SSIM: 0.7952602211132714]
[Repair Epoch 907/1000] [Repair loss: 0.10850948095321655]
[Quantitative Evaluation] [PSNR: 16.850265941142634, SSIM: 0.7955909739237009]
[Repair Epoch 908/1000] [Repair loss: 