In [1]:
import os
import sys
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ['XLA_FLAGS'] = '--xla_gpu_strict_conv_algorithm_picker=false'

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import ops
from keras import layers


from PIL import Image
import matplotlib.pyplot as plt

from keras import utils, models
from keras.layers import Rescaling
from keras.applications.vgg19 import VGG19
from keras.applications.vgg19 import preprocess_input
from keras.callbacks import ModelCheckpoint, CSVLogger

print('keras: ', keras.__version__)
print('tensorflow: ',tf.__version__)
print('python: ',sys.version)
print(tf.version.GIT_VERSION, tf.version.VERSION)

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

2025-03-18 20:26:52.117436: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-18 20:26:52.134529: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-18 20:26:52.139871: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-18 20:26:52.153306: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


keras:  3.6.0
tensorflow:  2.17.0
python:  3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
unknown 2.17.0
Num GPUs Available:  3


In [2]:
## CONFIGURATII

epochs = 100
latent_dim = 512
height = 128
width = 128
no_images = 192469
batch_size = 192
test_image_dir = './Dataset/test'
test_image_name = '000206.jpg'


strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2')


2025-03-18 20:26:54.753770: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43611 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:81:00.0, compute capability: 8.6
2025-03-18 20:26:54.755538: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 43611 MB memory:  -> device: 1, name: NVIDIA A40, pci bus id: 0000:c1:00.0, compute capability: 8.6
2025-03-18 20:26:54.757260: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 43611 MB memory:  -> device: 2, name: NVIDIA A40, pci bus id: 0000:e1:00.0, compute capability: 8.6


Number of devices: 3


In [3]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a face."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(280602)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon

In [4]:
def create_encoder(input_shape=(128, 128, 3), latent_dim=256):
    inputs = keras.Input(shape=input_shape)
    
    
    x = layers.Conv2D(64, (4,4), activation='leaky_relu', strides=2, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)
    
    x = layers.Conv2D(128, (4,4), activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)
    
    x = layers.Conv2D(256, (4,4), activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)
    
    x = layers.Conv2D(512, (4,4), activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)
    
    x = layers.Flatten()(x)
    
    # Add dense layers before final latent space
    x = layers.Dense(512, activation='leaky_relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)
    
    z_mean = layers.Dense(latent_dim, name='z_mean')(x)
    z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
    z = Sampling()([z_mean, z_log_var])
    
    return keras.Model(inputs, [z_mean, z_log_var, z], name='improved_encoder')

In [5]:
def create_decoder(latent_dim=256):
    latent_inputs = keras.Input(shape=(latent_dim,))
    
    x = layers.Dense(8*8*512, activation='leaky_relu')(latent_inputs)
    x = layers.Reshape((8, 8, 512))(x)
    
    x = layers.Conv2DTranspose(256, (4,4), activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    x = layers.Conv2DTranspose(128, (4,4), activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    x = layers.Conv2DTranspose(64, (4,4), activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    x = layers.Conv2DTranspose(32, (4,4), activation='leaky_relu', strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    decoder_outputs = layers.Conv2DTranspose(3, (4,4), activation='sigmoid', padding='same')(x)
    
    return keras.Model(latent_inputs, decoder_outputs, name='improved_decoder')

In [6]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        
        # Initialize VGG19 for feature extraction (use pre-trained weights)
        vgg = VGG19(include_top=False, weights='imagenet', input_shape=(None, None, 3))
        
        # Choose intermediate layers for feature comparison
        self.feature_layers = [
            'block1_conv2',  # Low-level features
            'block2_conv2',  # Mid-level features
            'block3_conv2',  # Higher-level features
            'block4_conv2'   # Very high-level features
        ]
        
        # Create a model that outputs features from these layers
        self.feature_extractor = keras.Model(
            inputs=vgg.input, 
            outputs=[vgg.get_layer(name).output for name in self.feature_layers]
        )
        
        # Freeze the VGG19 weights
        self.feature_extractor.trainable = False
        
        # Tracking metrics
        self.total_loss_tracker = keras.metrics.Mean(name='total_loss')
        self.reconstruction_loss_tracker = keras.metrics.Mean(name='reconstruction_loss')
        self.perceptual_loss_tracker = keras.metrics.Mean(name='perceptual_loss')
        self.kl_loss_tracker = keras.metrics.Mean(name='kl_loss')
    
    def call(self, inputs):
        # Implement the forward pass
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstruction = self.decoder(z)
        return reconstruction
    
    
    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.perceptual_loss_tracker,
            self.kl_loss_tracker 
        ]
    
    def compute_perceptual_loss(self, original, reconstructed):
        # Preprocess images for VGG19 (ensure 3 channels and correct scaling)
        original_processed = preprocess_input(original * 255.0)
        reconstructed_processed = preprocess_input(reconstructed * 255.0)
        
        # Extract features for original and reconstructed images
        original_features = self.feature_extractor(original_processed)
        reconstructed_features = self.feature_extractor(reconstructed_processed)
        
        # Compute perceptual loss as mean squared error between features
        perceptual_loss = 0
        for orig_feat, recon_feat in zip(original_features, reconstructed_features):
            perceptual_loss += ops.mean(ops.square(orig_feat - recon_feat))
        
        # Normalize by the number of feature layers
        perceptual_loss /= len(self.feature_layers)
        
        return perceptual_loss
    
    def compute_reconstruction_loss(self, original, reconstructed):
        
        reconstruction_loss = ops.mean(
                ops.sum(
                    keras.losses.binary_crossentropy(original, reconstructed),
                    axis=(1,2)
                )
            )
        
        return reconstruction_loss
   
    
    def train_step(self, data):
        with tf.GradientTape() as tape:
            # Encoder forward pass
            z_mean, z_log_var, z = self.encoder(data)
            
            # Decoder reconstruction
            reconstruction = self.decoder(z)
            
            # Compute KL divergence loss
            kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
            kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
            
            # Compute perceptual loss using VGG19 features
            perceptual_loss = self.compute_perceptual_loss(data, reconstruction)
            
            reconstruction_loss = self.compute_reconstruction_loss(data,reconstruction)
            
            # Total loss combines perceptual loss and KL divergence
            total_loss = perceptual_loss + kl_loss + reconstruction_loss
        
        # Compute gradients and apply them
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.perceptual_loss_tracker.update_state(perceptual_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "perceptual_loss": self.perceptual_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "recreation_loss": self.reconstruction_loss_tracker.result()
        }

In [7]:
train_data = utils.image_dataset_from_directory(
    './Dataset/train',
    labels=None,
    label_mode='categorical',
    batch_size=batch_size,
    image_size=(128, 128),
    shuffle=True,
    seed=123,
    interpolation='bilinear',
)

# Normalize the images by dividing by 255.0
normalization_layer = Rescaling(1./255)
train_data = train_data.map(lambda x: normalization_layer(x))

#train_data = train_data.batch(batch_size=batch_size,drop_remainder=True)

Found 192469 files.


In [None]:


class CustomModelCheckpoint(tf.keras.callbacks.Callback):
    def __init__(self, save_freq, filepath):
        super(CustomModelCheckpoint, self).__init__()
        self.save_freq = save_freq
        self.filepath = filepath

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.save_freq == 0:
            self.model.save_weights(self.filepath.format(epoch=epoch + 1))

with strategy.scope():
    encoder = create_encoder(latent_dim=latent_dim)
    decoder = create_decoder(latent_dim=latent_dim)

    # Compile with better optimizer
    vae = VAE(encoder, decoder)
    vae.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4))

    vae.encoder.summary()
    vae.decoder.summary()
    vae.summary()

    load = False
    run = True
    if run == True:
        if load == True:
            vae.load_weights('results_encode_decode/1/vae_weights.weights.h5')
        else:
            # Define the callbacks
            checkpoint_callback = CustomModelCheckpoint(save_freq=10, filepath='results_encode_decode/last/vae_weights_epoch_{epoch}.h5')
            csv_logger = CSVLogger('training_log.csv', append=True)

            # Fit the model with callbacks
            vae.fit(train_data, epochs=epochs, callbacks=[checkpoint_callback, csv_logger])


In [None]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt



img = Image.open(os.path.join(test_image_dir, test_image_name))
img = img.resize((height, width))
img = np.array(img)
img = img.astype('float32') / 255.


z_mean, z_log_var, z = encoder.predict(np.expand_dims(img, axis=0))

reconstructed_img = decoder.predict(z)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))


axes[0].imshow(img)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(reconstructed_img[0])
axes[1].set_title('Reconstructed Image')
axes[1].axis('off')



# Add information about the run to the plot
info_text = f"Latent Space: {latent_dim} Epochs:{epochs}\nImage Size: {height}x{width}\nNumber of Images for Training: {no_images}\n"
fig.text(0.5, 0.01, info_text, ha='center',fontsize=15, bbox=dict(facecolor='red', alpha=0.5))

SAVE = True
if SAVE:
    # Save the plot as a JPG file in the specified folder
    output_dir = 'results_encode_decode'
    os.makedirs(output_dir, exist_ok=True)
    result_counter_path = os.path.join(output_dir, 'result_counter.txt')

    if os.path.exists(result_counter_path):
        with open(result_counter_path, 'r') as file:
            result_no = int(file.read().strip()) + 1
    else:
        result_no = 0

    with open(result_counter_path, 'w') as file:
        file.write(str(result_no))

    result_dir = os.path.join(output_dir, str(result_no))
    os.makedirs(result_dir, exist_ok=True)

    # Save weights in result dir
    vae.save_weights(os.path.join(result_dir, 'vae_weights.weights.h5'))
    vae.encoder.save_weights(os.path.join(result_dir, 'encoder_weights.weights.h5'))
    vae.decoder.save_weights(os.path.join(result_dir, 'decoder_weights.weights.h5'))

    output_path = os.path.join(result_dir, 'result.jpg')
    plt.savefig(output_path, format='jpg')

plt.show()