In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os
from glob import glob
import time

In [2]:

print("TensorFlow Version:", tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))


TensorFlow Version: 2.10.0
Num GPUs Available:  1


In [None]:
INPUT_DIR = 'input'    
OUTPUT_DIR = 'output' 
NUM_SAMPLES = 100     
EPOCHS = 50
BATCH_SIZE = 2     # Keep this low due to 8GB VRAM
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 3
LR_SHAPE = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
HR_SHAPE = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)


In [None]:
# Generator configuration
RESIDUAL_BLOCKS = 6 # Number of RRDB blocks (adjust based on VRAM, standard ESRGAN uses ~16-23)
GF = 32           # Generator filters

# Loss weights
LAMBDA_L1 = 1e-2      
LAMBDA_PERCEPTUAL = 1.0 
LAMBDA_GAN = 5e-3       

LEARNING_RATE_G = 1e-4
LEARNING_RATE_D = 1e-4
BETA_1 = 0.9
BETA_2 = 0.999


In [None]:

CHECKPOINT_DIR = './training_checkpoints_satellite'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)


In [None]:


def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=IMG_CHANNELS)
    image = tf.image.convert_image_dtype(image, tf.float32) # Normalizes to [0, 1]
    return image

def load_dataset(input_dir, output_dir, num_samples):
    input_paths = sorted(glob(os.path.join(input_dir, "*.jpg")))[:num_samples]
    output_paths = sorted(glob(os.path.join(output_dir, "*.jpg")))[:num_samples]

    if len(input_paths) == 0 or len(output_paths) == 0:
        raise ValueError("Input or Output directory is empty or contains no .jpg files.")
    if len(input_paths) != len(output_paths):
         raise ValueError(f"Mismatch in number of files: {len(input_paths)} input vs {len(output_paths)} output.")

    print(f"Found {len(input_paths)} image pairs.")

    # Create tf.data datasets
    input_ds = tf.data.Dataset.from_tensor_slices(input_paths).map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    output_ds = tf.data.Dataset.from_tensor_slices(output_paths).map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

    # Zip datasets
    dataset = tf.data.Dataset.zip((input_ds, output_ds))

    # Batch, shuffle, and prefetch
    dataset = dataset.shuffle(buffer_size=num_samples)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset


In [7]:

# --- ESRGAN Generator (RRDBNet adapted for same input/output size) ---

def dense_block(input_tensor, filters):
    """ Basic Dense Block used in RRDB """
    x1 = layers.Conv2D(filters, kernel_size=3, padding='same')(input_tensor)
    x1 = layers.LeakyReLU(alpha=0.2)(x1)
    x1_concat = layers.concatenate([input_tensor, x1], axis=-1)

    x2 = layers.Conv2D(filters, kernel_size=3, padding='same')(x1_concat)
    x2 = layers.LeakyReLU(alpha=0.2)(x2)
    x2_concat = layers.concatenate([input_tensor, x1, x2], axis=-1)

    x3 = layers.Conv2D(filters, kernel_size=3, padding='same')(x2_concat)
    x3 = layers.LeakyReLU(alpha=0.2)(x3)
    x3_concat = layers.concatenate([input_tensor, x1, x2, x3], axis=-1)

    x4 = layers.Conv2D(filters, kernel_size=3, padding='same')(x3_concat)
    x4 = layers.LeakyReLU(alpha=0.2)(x4)
    x4_concat = layers.concatenate([input_tensor, x1, x2, x3, x4], axis=-1)

    x5 = layers.Conv2D(filters, kernel_size=3, padding='same')(x4_concat)
    # No final activation in the dense block path, only within RRDB residual

    # Scale output before adding residual
    x5 = layers.Lambda(lambda x: x * 0.2)(x5)
    return x5


def rrdb(input_tensor, filters):
    """ Residual-in-Residual Dense Block """
    x = dense_block(input_tensor, filters)
    x = layers.add([x, input_tensor]) # Residual connection within RRDB
    return x


def build_generator(input_shape, num_residual_blocks=RESIDUAL_BLOCKS, gf=GF):
    """ Builds the Generator network """
    inputs = layers.Input(shape=input_shape)

    # Initial Convolution
    x_feat = layers.Conv2D(gf, kernel_size=3, padding='same')(inputs)
    x = x_feat # Store for later skip connection

    # RRDB Blocks
    for _ in range(num_residual_blocks):
        x = rrdb(x, gf)

    # Post-RRDB Convolution
    x = layers.Conv2D(gf, kernel_size=3, padding='same')(x)
    x = layers.add([x, x_feat]) # Skip connection over RRDBs

    # --- NO UPSAMPLING NEEDED as input/output are same size ---
    # If upsampling were needed, it would go here:
    # x = layers.UpSampling2D(size=2, interpolation='nearest')(x)
    # x = layers.Conv2D(gf, kernel_size=3, padding='same')(x)
    # x = layers.LeakyReLU(alpha=0.2)(x)
    # x = layers.UpSampling2D(size=2, interpolation='nearest')(x)
    # x = layers.Conv2D(gf, kernel_size=3, padding='same')(x)
    # x = layers.LeakyReLU(alpha=0.2)(x)

    # Final Output Layers
    x = layers.Conv2D(gf, kernel_size=3, padding='same')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    outputs = layers.Conv2D(IMG_CHANNELS, kernel_size=3, padding='same', activation='sigmoid')(x) # Sigmoid for [0, 1] output

    return keras.Model(inputs, outputs, name='generator')


In [8]:

# --- Discriminator (VGG-style) ---

def build_discriminator(input_shape):
    """ Builds the Discriminator network """
    inputs = layers.Input(shape=input_shape)

    x = layers.Conv2D(64, kernel_size=3, strides=1, padding='same')(inputs)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv2D(64, kernel_size=3, strides=2, padding='same')(x) # Downsample
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv2D(128, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv2D(128, kernel_size=3, strides=2, padding='same')(x) # Downsample
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv2D(256, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv2D(256, kernel_size=3, strides=2, padding='same')(x) # Downsample
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv2D(512, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv2D(512, kernel_size=3, strides=2, padding='same')(x) # Downsample
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    outputs = layers.Dense(1)(x) # Output logits (no activation)

    return keras.Model(inputs, outputs, name='discriminator')


In [9]:
# --- Perceptual Loss (VGG19) ---

# Using mean absolute error for VGG loss (can also use MSE)
vgg_loss_object = tf.keras.losses.MeanAbsoluteError()

def build_vgg():
    """ Build VGG19 model for perceptual loss """
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=HR_SHAPE)
    vgg.trainable = False
    # Use features from a high-level layer (e.g., block5_conv4)
    output_layer = vgg.get_layer('block5_conv4').output
    model = tf.keras.Model(vgg.input, output_layer)
    return model

vgg = build_vgg()

def perceptual_loss(hr_true, sr_fake):
    hr_true_vgg = vgg(tf.keras.applications.vgg19.preprocess_input(hr_true * 255.)) # VGG expects 0-255 input
    sr_fake_vgg = vgg(tf.keras.applications.vgg19.preprocess_input(sr_fake * 255.))
    return vgg_loss_object(hr_true_vgg, sr_fake_vgg)

# --- Other Losses ---
binary_cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) # For Adversarial Loss
mae_loss = tf.keras.losses.MeanAbsoluteError() # For L1 Pixel Loss



In [10]:

# --- Optimizers ---
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_G, beta_1=BETA_1, beta_2=BETA_2)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_D, beta_1=BETA_1, beta_2=BETA_2)

# --- Build Models ---
generator = build_generator(LR_SHAPE)
discriminator = build_discriminator(HR_SHAPE)




In [11]:
generator.summary()


Model: "generator"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 256, 256, 32  896         ['input_2[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 256, 256, 32  9248        ['conv2d[0][0]']                 
                                )                                                         

In [12]:
discriminator.summary()


Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 256, 256, 3)]     0         
                                                                 
 conv2d_34 (Conv2D)          (None, 256, 256, 64)      1792      
                                                                 
 leaky_re_lu_25 (LeakyReLU)  (None, 256, 256, 64)      0         
                                                                 
 conv2d_35 (Conv2D)          (None, 128, 128, 64)      36928     
                                                                 
 batch_normalization (BatchN  (None, 128, 128, 64)     256       
 ormalization)                                                   
                                                                 
 leaky_re_lu_26 (LeakyReLU)  (None, 128, 128, 64)      0         
                                                     

In [13]:

# --- Checkpointing ---
checkpoint_prefix = os.path.join(CHECKPOINT_DIR, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)


In [14]:

# --- Training Step ---
@tf.function
def train_step(input_img, target_img):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate enhanced image
        sr_img = generator(input_img, training=True)

        # Discriminator output for real and fake images
        real_output = discriminator(target_img, training=True)
        fake_output = discriminator(sr_img, training=True)

        # --- Calculate Losses ---

        # Generator Adversarial Loss (wants discriminator to think fake is real)
        gen_gan_loss = binary_cross_entropy(tf.ones_like(fake_output), fake_output)

        # Perceptual Loss (VGG)
        perc_loss = perceptual_loss(target_img, sr_img)

        # Pixel Loss (L1)
        l1_loss = mae_loss(target_img, sr_img)

        # Total Generator Loss
        total_gen_loss = (LAMBDA_GAN * gen_gan_loss +
                          LAMBDA_PERCEPTUAL * perc_loss +
                          LAMBDA_L1 * l1_loss)

        # Discriminator Loss
        real_loss = binary_cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = binary_cross_entropy(tf.zeros_like(fake_output), fake_output)
        total_disc_loss = real_loss + fake_loss

    # Calculate Gradients
    generator_gradients = gen_tape.gradient(total_gen_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(total_disc_loss, discriminator.trainable_variables)

    # Apply Gradients
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

    return total_gen_loss, total_disc_loss, l1_loss, perc_loss, gen_gan_loss


In [15]:

# --- Image Generation for Visualization ---
def generate_and_save_images(model, test_input_batch, epoch):
    # Ensure output directory exists
    os.makedirs("./generated_images", exist_ok=True)

    predictions = model(test_input_batch, training=False) # Use first few images from the batch

    plt.figure(figsize=(10, 5 * predictions.shape[0])) # Adjust figure size based on batch size used for sampling

    num_display = min(predictions.shape[0], 4) # Display up to 4 images

    for i in range(num_display):
        plt.subplot(num_display, 2, 2*i + 1)
        plt.title("Input Image")
        plt.imshow(test_input_batch[i]) # Assumes input is [0, 1]
        plt.axis("off")

        plt.subplot(num_display, 2, 2*i + 2)
        plt.title("Generated Image")
        plt.imshow(predictions[i]) # Assumes prediction is [0, 1]
        plt.axis("off")

    plt.tight_layout()
    plt.savefig(f'./generated_images/image_at_epoch_{epoch+1:04d}.png')
    print(f"Generated sample image saved for epoch {epoch+1}")
    plt.close() # Close the plot to free memory

In [16]:

# --- Training Loop ---
def train(dataset, epochs):
    print(f"\n--- Starting Training for {epochs} Epochs ---")
    print(f"Dataset size: {NUM_SAMPLES} images")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Batches per epoch: {len(dataset)}") # tf.data lengths might be approximate before first iteration

    # Try restoring checkpoint
    ckpt_status = checkpoint.restore(tf.train.latest_checkpoint(CHECKPOINT_DIR))
    if tf.train.latest_checkpoint(CHECKPOINT_DIR):
         print(f"Checkpoint restored from {tf.train.latest_checkpoint(CHECKPOINT_DIR)}")
         # ckpt_status.assert_consumed() # Optional: check if all variables were restored
    else:
         print("Initializing from scratch.")


    for epoch in range(epochs):
        start_time = time.time()
        epoch_gen_loss = 0
        epoch_disc_loss = 0
        batch_count = 0

        for batch_idx, (input_batch, target_batch) in enumerate(dataset):
            gen_loss, disc_loss, l1, perc, gan = train_step(input_batch, target_batch)

            epoch_gen_loss += gen_loss
            epoch_disc_loss += disc_loss
            batch_count += 1

            if batch_idx % 10 == 0: # Print progress every 10 batches
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx+1}/{len(dataset)}, "
                      f"Gen Loss: {gen_loss:.4f} (L1: {l1:.4f}, Perc: {perc:.4f}, GAN: {gan:.4f}), "
                      f"Disc Loss: {disc_loss:.4f}")

        # End of Epoch
        avg_gen_loss = epoch_gen_loss / batch_count
        avg_disc_loss = epoch_disc_loss / batch_count
        epoch_time = time.time() - start_time

        print(f"\nEpoch {epoch+1} Summary:")
        print(f"Time: {epoch_time:.2f}s")
        print(f"Average Generator Loss: {avg_gen_loss:.4f}")
        print(f"Average Discriminator Loss: {avg_disc_loss:.4f}")

        # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
            print(f"Checkpoint saved for epoch {epoch+1}")

        # Optional: Generate and save a sample image
        if (epoch + 1) % 5 == 0:
             generate_and_save_images(generator, next(iter(dataset))[0], epoch) # Use first batch input


    print("\n--- Training Finished ---")


In [17]:
train_dataset = load_dataset(INPUT_DIR, OUTPUT_DIR, NUM_SAMPLES)

Found 100 image pairs.


In [18]:
train(train_dataset, EPOCHS)


--- Starting Training for 50 Epochs ---
Dataset size: 100 images
Batch size: 2
Batches per epoch: 50
Initializing from scratch.
Epoch 1/50, Batch 1/50, Gen Loss: 0.7291 (L1: 0.2474, Perc: 0.7218, GAN: 0.9531), Disc Loss: 1.4868
Epoch 1/50, Batch 11/50, Gen Loss: 1.1974 (L1: 0.2812, Perc: 0.7200, GAN: 94.9178), Disc Loss: 0.0000
Epoch 1/50, Batch 21/50, Gen Loss: 1.6395 (L1: 0.3781, Perc: 0.7835, GAN: 170.4381), Disc Loss: 0.0000
Epoch 1/50, Batch 31/50, Gen Loss: 1.1698 (L1: 0.3719, Perc: 0.5640, GAN: 120.4117), Disc Loss: 0.0000
Epoch 1/50, Batch 41/50, Gen Loss: 0.8428 (L1: 0.2447, Perc: 0.8403, GAN: 0.0000), Disc Loss: 27.8888

Epoch 1 Summary:
Time: 22.78s
Average Generator Loss: 1.0961
Average Discriminator Loss: 5.1689
Epoch 2/50, Batch 1/50, Gen Loss: 0.8708 (L1: 0.2661, Perc: 0.7771, GAN: 18.2128), Disc Loss: 13.5819
Epoch 2/50, Batch 11/50, Gen Loss: 0.7134 (L1: 0.3688, Perc: 0.5723, GAN: 27.4855), Disc Loss: 5.8020
Epoch 2/50, Batch 21/50, Gen Loss: 1.0176 (L1: 0.3137, Perc: