In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
from PIL import Image
import time
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [2]:
import tensorflow as tf
from tensorflow.keras import layers

def mish(x):
    return x * tf.tanh(tf.nn.softplus(x))

In [3]:
class InstanceNormalization(layers.Layer):

    def __init__(self, epsilon=1e-5):
        super(InstanceNormalization, self).__init__()
        self.epsilon = epsilon

    def build(self, input_shape):
        depth = input_shape[-1]
        self.scale = self.add_weight(
            shape=[depth],
            initializer='ones',
            trainable=True,
            name='scale'
        )
        self.offset = self.add_weight(
            shape=[depth],
            initializer='zeros',
            trainable=True,
            name='offset'
        )

    def call(self, x):
        # Calculate mean
        mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        
        # Calculate variance: E[(x - E[x])²]
        variance = tf.reduce_mean(
            tf.square(x - mean), 
            axis=[1, 2], 
            keepdims=True
        )
        
        # Normalize
        inv = tf.sqrt(variance + self.epsilon)
        normalized = (x - mean) / inv
        
        return self.scale[None, None, None, :] * normalized + self.offset[None, None, None, :]

    def compute_output_shape(self, input_shape):
        return input_shape


In [4]:
class SpectralNormalization(layers.Layer):
    def __init__(self, layer, **kwargs):
        super(SpectralNormalization, self).__init__(**kwargs)
        self.layer = layer

    def build(self, input_shape):
        self.layer.build(input_shape)
        self.u = self.add_weight(
            shape=(1, self.layer.kernel.shape[-1]),
            initializer="random_normal",
            trainable=False,
            name="spectral_u"
        )
        super(SpectralNormalization, self).build(input_shape)

    def call(self, inputs, training=None):
        w = self.layer.kernel
        w_shape = w.shape
        w = tf.reshape(w, [-1, w_shape[-1]])
        v = tf.linalg.matvec(tf.transpose(w), self.u, transpose_a=True)
        v = tf.linalg.normalize(v)[0]
        u = tf.linalg.matvec(w, v)
        u = tf.linalg.normalize(u)[0]
        self.u.assign(u)
        sigma = tf.reduce_sum(tf.matmul(tf.matmul(tf.transpose(u), w), v))
        self.layer.kernel.assign(self.layer.kernel / sigma)
        return self.layer(inputs, training=training)

    def compute_output_shape(self, input_shape):
        return self.layer.compute_output_shape(input_shape)


In [5]:
class DehazeGAN:
    def __init__(self, image_size=(256, 256)):
        self.image_size = image_size
        self.generator = self._build_generator()
        self.discriminator = self._build_discriminator()

        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=1e-4,
            decay_steps=100000,
            decay_rate=0.96,
            staircase=True
        )

        # Initialize optimizers
        self.generator_optimizer = keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0.5)
        self.discriminator_optimizer = keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9, nesterov=True)

        # Initialize loss functions
        self.gan_loss = keras.losses.BinaryCrossentropy(from_logits=True)
        self.pixel_loss = keras.losses.MeanAbsoluteError()

        # Initialize metrics
        self.gen_loss_tracker = tf.keras.metrics.Mean(name='generator_loss')
        self.disc_loss_tracker = tf.keras.metrics.Mean(name='discriminator_loss')

    def _build_generator(self):
        """Build generator network with instance normalization."""
        def residual_block(x, filters, kernel_size=3):
            shortcut = x
            x = layers.Conv2D(filters, kernel_size, padding='same')(x)
            x = InstanceNormalization()(x)
            x = layers.Activation(mish)(x)
            x = layers.Conv2D(filters, kernel_size, padding='same')(x)
            x = InstanceNormalization()(x)
            x = layers.Add()([shortcut, x])
            return x

        # Input layer
        inputs = layers.Input(shape=(*self.image_size, 3))

        # Initial convolution with instance normalization
        x = layers.Conv2D(64, 7, padding='same')(inputs)
        x = InstanceNormalization()(x)
        x = layers.Activation(mish)(x)

        # Encoder with instance normalization
        x = layers.Conv2D(128, 3, strides=2, padding='same')(x)
        x = InstanceNormalization()(x)
        x = layers.Activation(mish)(x)

        x = layers.Conv2D(256, 3, strides=2, padding='same')(x)
        x = InstanceNormalization()(x)
        x = layers.Activation(mish)(x)

        # Residual blocks
        for _ in range(9):
            x = residual_block(x, 256)

        # Decoder with instance normalization
        x = layers.Conv2DTranspose(128, 3, strides=2, padding='same')(x)
        x = InstanceNormalization()(x)
        x = layers.Activation(mish)(x)

        x = layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
        x = InstanceNormalization()(x)
        x = layers.Activation(mish)(x)

        # Output layer
        outputs = layers.Conv2D(3, 7, padding='same', activation='tanh')(x)

        return keras.Model(inputs, outputs, name='generator')

    def _build_discriminator(self):
        """Build discriminator network."""
        def discriminator_block(x, filters, strides=2):
            x = layers.Conv2D(filters, 4, strides=strides, padding='same', 
                            kernel_initializer='orthogonal')(x)
            x = layers.LeakyReLU(0.2)(x)
            return x

        # Input layer
        inputs = layers.Input(shape=(*self.image_size, 3))

        # Discriminator blocks
        x = discriminator_block(inputs, 64)
        x = discriminator_block(x, 128)
        x = discriminator_block(x, 256)
        x = discriminator_block(x, 512, strides=1)

        # Output layer
        outputs = layers.Conv2D(1, 4, strides=1, padding='same')(x)

        return keras.Model(inputs, outputs, name='discriminator')

    @tf.function
    def train_step(self, hazy_images, clean_images):
        """Single training step."""
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            # Generate fake images
            fake_images = self.generator(hazy_images, training=True)

            # Discriminator predictions
            real_output = self.discriminator(clean_images, training=True)
            fake_output = self.discriminator(fake_images, training=True)

            # Calculate losses
            gen_loss = self._generator_loss(fake_output, fake_images, clean_images)
            disc_loss = self._discriminator_loss(real_output, fake_output)

        # Calculate gradients
        gen_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        disc_gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)

        # Apply gradients
        self.generator_optimizer.apply_gradients(
            zip(gen_gradients, self.generator.trainable_variables)
        )
        self.discriminator_optimizer.apply_gradients(
            zip(disc_gradients, self.discriminator.trainable_variables)
        )

        # Update metrics
        self.gen_loss_tracker.update_state(gen_loss)
        self.disc_loss_tracker.update_state(disc_loss)

        return {
            'gen_loss': gen_loss,
            'disc_loss': disc_loss
        }

    def _generator_loss(self, fake_output, fake_images, clean_images):
        """Calculate generator loss."""
        gan_loss = self.gan_loss(tf.ones_like(fake_output), fake_output)
        l1_loss = self.pixel_loss(clean_images, fake_images) * 100  # Pixel loss weight
        return gan_loss + l1_loss

    def _discriminator_loss(self, real_output, fake_output):
        """Calculate discriminator loss."""
        real_loss = self.gan_loss(tf.ones_like(real_output) * 0.9, real_output)
        fake_loss = self.gan_loss(tf.zeros_like(fake_output) + 0.1, fake_output)
        return (real_loss + fake_loss) * 0.5

In [6]:
class DehazeDataProcessor:
    def __init__(self, hazy_dir, gt_dir, image_size=(256, 256), batch_size=4):
        self.hazy_dir = hazy_dir
        self.gt_dir = gt_dir
        self.image_size = image_size
        self.batch_size = batch_size
    
    def load_and_preprocess_image(self, image_path):
        """Load and preprocess a single image with augmentation."""
        img = tf.io.read_file(image_path)
        img = tf.image.decode_png(img, channels=3)
        img = tf.image.resize(img, self.image_size)
        img = tf.cast(img, tf.float32) / 127.5 - 1  # Normalize to [-1, 1]
        return img

    
    def create_dataset(self):
        """Create TensorFlow dataset pipeline."""  
        # Get file paths
        hazy_paths = sorted(glob.glob(os.path.join(self.hazy_dir, '*.*')))
        gt_paths = sorted(glob.glob(os.path.join(self.gt_dir, '*.*')))
        
        logger.info(f"Found {len(hazy_paths)} hazy images and {len(gt_paths)} ground truth images")
        assert len(hazy_paths) == len(gt_paths), "Number of hazy and ground truth images must match"
        
        # Create dataset
        dataset = tf.data.Dataset.from_tensor_slices((hazy_paths, gt_paths))
        dataset = dataset.map(
            lambda x, y: (
                self.load_and_preprocess_image(x),
                self.load_and_preprocess_image(y)
            ),
            num_parallel_calls=tf.data.AUTOTUNE
        )
        
        # Configure dataset
        dataset = dataset.shuffle(len(hazy_paths))
        dataset = dataset.batch(self.batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
        return dataset, len(hazy_paths)


In [7]:
class DehazeTrainer:
    def __init__(self, model, dataset, num_epochs, checkpoint_dir='./checkpoints'):
        self.model = model
        self.dataset = dataset
        self.num_epochs = num_epochs
        self.checkpoint_dir = checkpoint_dir
        
        # Create directories
        os.makedirs(checkpoint_dir, exist_ok=True)
        os.makedirs('samples', exist_ok=True)
        
        # Setup checkpointing
        self.checkpoint = tf.train.Checkpoint(
            generator=model.generator,
            discriminator=model.discriminator,
            generator_optimizer=model.generator_optimizer,
            discriminator_optimizer=model.discriminator_optimizer
        )
        self.checkpoint_manager = tf.train.CheckpointManager(
            self.checkpoint, checkpoint_dir, max_to_keep=3
        )
    
    def train(self):
        """Training loop."""
        start_time = time.time()
        
        for epoch in range(self.num_epochs):
            print(f"\nStarting epoch {epoch + 1}/{self.num_epochs}", flush=True)
            time.sleep(0.1)
            
            # Reset metrics
            self.model.gen_loss_tracker.reset_state()
            self.model.disc_loss_tracker.reset_state()
            
            # Train on batches
            for batch_idx, (hazy_batch, clean_batch) in enumerate(self.dataset):
                losses = self.model.train_step(hazy_batch, clean_batch)
                
                if batch_idx % 10 == 0:
                    print(
                        f"Batch {batch_idx}: "
                        f"Generator Loss = {losses['gen_loss']:.4f}, "
                        f"Discriminator Loss = {losses['disc_loss']:.4f}",
                        flush=True
                    )
            
            # Save checkpoint and generate samples
            if (epoch + 1) % 10 == 0:
                self.checkpoint_manager.save()
                self._generate_and_save_samples(epoch + 1)
            
            # Log epoch results
            print(
                f"Epoch {epoch + 1}: "
                f"Generator Loss = {self.model.gen_loss_tracker.result():.4f}, "
                f"Discriminator Loss = {self.model.disc_loss_tracker.result():.4f}",
                flush=True
            )
        
        total_time = time.time() - start_time
        print(f"\nTraining completed in {total_time / 60:.2f} minutes", flush = True)


    def _generate_and_save_samples(self, epoch):
        """Generate and save sample images."""
        # Get a batch of test images
        test_batch = next(iter(self.dataset))
        hazy_images, clean_images = test_batch
        
        # Generate dehazed images
        generated_images = self.model.generator(hazy_images, training=False)
        
        # Create figure
        plt.figure(figsize=(15, 5))
        
        # Plot images
        for i in range(min(3, len(hazy_images))):
            # Hazy image
            plt.subplot(3, 3, i*3 + 1)
            plt.imshow(hazy_images[i] * 0.5 + 0.5)
            plt.axis('off')
            plt.title('Hazy')
            
            # Generated image
            plt.subplot(3, 3, i*3 + 2)
            plt.imshow(generated_images[i] * 0.5 + 0.5)
            plt.axis('off')
            plt.title('Generated')
            
            # Ground truth
            plt.subplot(3, 3, i*3 + 3)
            plt.imshow(clean_images[i] * 0.5 + 0.5)
            plt.axis('off')
            plt.title('Ground Truth')
        
        plt.savefig(f'samples/epoch_{epoch}.png')
        plt.close()


In [8]:
def main():
    # Configuration
    HAZY_DIR = '/kaggle/input/hazing-images-dataset-cvpr-2019/hazy'
    GT_DIR = '/kaggle/input/hazing-images-dataset-cvpr-2019/GT'
    IMAGE_SIZE = (256, 256)
    BATCH_SIZE = 4
    NUM_EPOCHS = 5000
    CHECKPOINT_DIR = './checkpoints'

    # Set up data processor
    data_processor = DehazeDataProcessor(
        HAZY_DIR,
        GT_DIR,
        image_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE
    )
    
    # Create dataset
    dataset, num_images = data_processor.create_dataset()
    logger.info(f"Dataset created with {num_images} images")
    
    # Initialize model
    model = DehazeGAN(image_size=IMAGE_SIZE)
    logger.info("Model initialized")
    
    # Create trainer and start training
    trainer = DehazeTrainer(model, dataset, NUM_EPOCHS, checkpoint_dir=CHECKPOINT_DIR)
    logger.info("Starting training...")
    trainer.train()
    
    # Load and test the saved model
    # test_image_path = '/kaggle/input/hazing-images-dataset-cvpr-2019/hazy/01_hazy.png'
    # logger.info("Testing saved model...")
    # test_saved_model(model, test_image_path)



In [9]:
import tensorflow as tf
tf.get_logger().setLevel('DEBUG')

if __name__ == "__main__":
    main()


Starting epoch 1/5000
Batch 0: Generator Loss = 61.8142, Discriminator Loss = 0.6830
Batch 10: Generator Loss = 32.5270, Discriminator Loss = 0.6993
Epoch 1: Generator Loss = 42.8812, Discriminator Loss = 0.6957

Starting epoch 2/5000
Batch 0: Generator Loss = 29.9445, Discriminator Loss = 0.6911
Batch 10: Generator Loss = 32.0863, Discriminator Loss = 0.7064
Epoch 2: Generator Loss = 34.1445, Discriminator Loss = 0.6973

Starting epoch 3/5000
Batch 0: Generator Loss = 34.2007, Discriminator Loss = 0.6935
Batch 10: Generator Loss = 32.9209, Discriminator Loss = 0.6952
Epoch 3: Generator Loss = 32.7646, Discriminator Loss = 0.6967

Starting epoch 4/5000
Batch 0: Generator Loss = 28.1053, Discriminator Loss = 0.6973
Batch 10: Generator Loss = 29.8769, Discriminator Loss = 0.6872
Epoch 4: Generator Loss = 31.7570, Discriminator Loss = 0.6952

Starting epoch 5/5000
Batch 0: Generator Loss = 29.6275, Discriminator Loss = 0.6981
Batch 10: Generator Loss = 28.7607, Discriminator Loss = 0.691

In [10]:
import shutil

folder_path = '/kaggle/working/'
zip_file_path = '/kaggle/working/sample_images.zip'

# Create a zip file
shutil.make_archive(zip_file_path.replace('.zip', ''), 'zip', folder_path)

# Output the path of the zip file for download
print(f"Zip file created at: {zip_file_path}")


Zip file created at: /kaggle/working/sample_images.zip


In [13]:
from tensorflow.keras.models import save_model

def save_generator_model(checkpoint_dir, save_path):
    """Load the latest checkpoint and save the generator model."""
    # Initialize a DehazeGAN model (with the same configuration as during training)
    model = DehazeGAN(image_size=(256, 256))

    # Load the latest checkpoint
    checkpoint = tf.train.Checkpoint(
        generator=model.generator,
        discriminator=model.discriminator,
        generator_optimizer=model.generator_optimizer,
        discriminator_optimizer=model.discriminator_optimizer
    )
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
    
    if checkpoint_manager.latest_checkpoint:
        checkpoint.restore(checkpoint_manager.latest_checkpoint)
        print(f"Restored from {checkpoint_manager.latest_checkpoint}")
    else:
        print("No checkpoint found. Make sure training has been completed.")

    # Save the generator model
    save_model(model.generator, save_path)
    print(f"Generator model saved at: {save_path}")

# Call the function to save the generator model
CHECKPOINT_DIR = './checkpoints'  # Same as the one used during training
SAVE_PATH = '/kaggle/working/generator_model.keras'  # Path to save the generator
save_generator_model(CHECKPOINT_DIR, SAVE_PATH)


Restored from ./checkpoints/ckpt-500
Generator model saved at: /kaggle/working/generator_model.keras
