# Super-Resolution GAN (SRGAN) for Astronomical Images

This notebook provides a corrected and enhanced implementation of a Super-Resolution Generative Adversarial Network (SRGAN) tailored for astronomical data. It includes bug fixes, a proper GAN training loop, and metrics for evaluating image quality.

## 1. Setup and Imports
First, we'll install necessary libraries and import them.

In [None]:
%pip install -r requirements.txt

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, applications
from tensorflow.keras.optimizers import Adam
import cv2
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim


## 2. Data Loading and Preprocessing
Here we'll define functions to load and preprocess our data. For astronomical data, it's often grayscale, so we'll handle that. We'll also create low-resolution versions of the high-resolution images.

In [None]:
def load_images(data_path, hr_size, lr_size, num_images=100):
    high_res_images = []
    low_res_images = []
    
    image_files = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith(('jpg', 'png', 'jpeg'))]
    image_files = image_files[:num_images]
    
    for image_path in tqdm(image_files, desc='Loading Images'):
        hr_img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if hr_img is None:
            continue
            
        hr_img = cv2.resize(hr_img, (hr_size, hr_size))
        lr_img = cv2.resize(hr_img, (lr_size, lr_size), interpolation=cv2.INTER_AREA)
        
        high_res_images.append(hr_img)
        low_res_images.append(lr_img)
        
    high_res_images = np.array(high_res_images).astype(np.float32) / 255.0
    low_res_images = np.array(low_res_images).astype(np.float32) / 255.0
    
    return low_res_images, high_res_images

# Example usage:
DATA_PATH = './astronomical_data/'
HR_SIZE = 128
LR_SIZE = 32
lr_train, hr_train = load_images(DATA_PATH, HR_SIZE, LR_SIZE, num_images=200)

# Placeholder data for demonstration without an actual dataset
def generate_dummy_data(hr_size, lr_size, num_images=50):
    hr_train = np.random.rand(num_images, hr_size, hr_size)
    lr_train = np.array([cv2.resize(img, (lr_size, lr_size), interpolation=cv2.INTER_AREA) for img in hr_train])
    return lr_train, hr_train

HR_SIZE = 128
LR_SIZE = 32
lr_train, hr_train = generate_dummy_data(HR_SIZE, LR_SIZE)
hr_train = np.expand_dims(hr_train, axis=-1)
lr_train = np.expand_dims(lr_train, axis=-1)

print(f"Low-resolution training data shape: {lr_train.shape}")
print(f"High-resolution training data shape: {hr_train.shape}")

## 3. Model Architecture
### Generator
The generator uses residual blocks to learn the mapping from low to high resolution images.

In [None]:
def residual_block(x, filters):
    res = x
    x = layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.PReLU()(x)
    x = layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.add([res, x])
    return x

def generator_model(lr_shape=(32, 32, 1), num_residual_blocks=16, upscale_factor=4):
    filters = 64
    
    lr_input = layers.Input(shape=lr_shape)
    
    x = layers.Conv2D(filters, (9, 9), padding='same')(lr_input)
    x = layers.PReLU()(x)
    
    res_input = x
    
    for _ in range(num_residual_blocks):
        x = residual_block(x, filters)
        
    x = layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.add([res_input, x])
    
    # Upsampling blocks
    for _ in range(int(np.log2(upscale_factor))):
        x = layers.Conv2D(filters * 4, (3, 3), padding='same')(x)
        x = layers.Lambda(lambda z: tf.nn.depth_to_space(z, 2))(x)
        x = layers.PReLU()(x)
        
    hr_output = layers.Conv2D(1, (9, 9), padding='same', activation='sigmoid')(x)
    
    return Model(inputs=lr_input, outputs=hr_output, name='Generator')

### Discriminator
The discriminator is a standard classifier that outputs a probability for whether an image is real or fake.

In [None]:
def discriminator_model(hr_shape=(128, 128, 1)):
    def d_block(x, filters, strides):
        x = layers.Conv2D(filters, (3, 3), strides=strides, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        return x
    
    hr_input = layers.Input(shape=hr_shape)
    x = layers.Conv2D(64, (3, 3), padding='same')(hr_input)
    x = layers.LeakyReLU(alpha=0.2)(x)
    
    x = d_block(x, 64, 2)
    x = d_block(x, 128, 1)
    x = d_block(x, 128, 2)
    x = d_block(x, 256, 1)
    x = d_block(x, 256, 2)
    
    x = layers.Flatten()(x)
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    
    d_output = layers.Dense(1, activation='sigmoid')(x)
    
    return Model(inputs=hr_input, outputs=d_output, name='Discriminator')

## 4. Loss Functions
We use a combination of content loss (MSE) and adversarial loss for the generator. Binary Cross-Entropy is used for the discriminator.

In [None]:
def get_losses():
    mse_loss = tf.keras.losses.MeanSquaredError()
    bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
    return mse_loss, bce_loss

mse_loss, bce_loss = get_losses()

def discriminator_loss(real_output, fake_output):
    real_loss = bce_loss(tf.ones_like(real_output), real_output)
    fake_loss = bce_loss(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output, generated_images, hr_images, content_loss_weight=1e-3):
    adversarial_loss = bce_loss(tf.ones_like(fake_output), fake_output)
    content_loss = mse_loss(hr_images, generated_images)
    total_loss = content_loss_weight * content_loss + adversarial_loss
    return total_loss


## 5. Optimizers and Training
We'll use separate optimizers for the generator and discriminator and define a custom training loop.

In [None]:
generator_optimizer = Adam(learning_rate=1e-4)
discriminator_optimizer = Adam(learning_rate=1e-4)

generator = generator_model(lr_shape=(LR_SIZE, LR_SIZE, 1), upscale_factor=4)
discriminator = discriminator_model(hr_shape=(HR_SIZE, HR_SIZE, 1))

@tf.function
def train_step(lr_images, hr_images):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(lr_images, training=True)
        
        real_output = discriminator(hr_images, training=True)
        fake_output = discriminator(generated_images, training=True)
        
        gen_loss = generator_loss(fake_output, generated_images, hr_images)
        disc_loss = discriminator_loss(real_output, fake_output)
        
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

def train(dataset, epochs):
    for epoch in range(epochs):
        gen_total_loss = 0.0
        disc_total_loss = 0.0
        num_batches = 0
        for lr_batch, hr_batch in tqdm(dataset, desc=f'Epoch {epoch+1}/{epochs}'):
            g_loss, d_loss = train_step(lr_batch, hr_batch)
            gen_total_loss += g_loss
            disc_total_loss += d_loss
            num_batches += 1
            
        avg_gen_loss = gen_total_loss / num_batches
        avg_disc_loss = disc_total_loss / num_batches
        print(f'Epoch {epoch+1} - Generator Loss: {avg_gen_loss:.4f}, Discriminator Loss: {avg_disc_loss:.4f}')


## 6. Evaluation Metrics
We'll add PSNR and SSIM to measure the quality of the super-resolved images. These are essential for quantitative analysis. We will also visualize the metrics using seaborn and a few image comparisons.

In [None]:
def evaluate_model(generator, lr_images, hr_images):
    generated_images = generator.predict(lr_images)
    
    psnr_scores = []
    ssim_scores = []
    
    for i in range(lr_images.shape[0]):
        p = psnr(hr_images[i], generated_images[i], data_range=1.0)
        s = ssim(hr_images[i], generated_images[i], data_range=1.0, channel_axis=-1)
        
        psnr_scores.append(p)
        ssim_scores.append(s)
    
    avg_psnr = np.mean(psnr_scores)
    avg_ssim = np.mean(ssim_scores)
    
    print(f"\nAverage PSNR: {avg_psnr:.4f}")
    print(f"Average SSIM: {avg_ssim:.4f}")
    
    # Visualize metrics using seaborn
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    sns.histplot(psnr_scores, kde=True, ax=axes[0])
    axes[0].set_title('PSNR Distribution')
    axes[0].set_xlabel('PSNR Score')
    axes[0].set_ylabel('Frequency')
    
    sns.histplot(ssim_scores, kde=True, ax=axes[1])
    axes[1].set_title('SSIM Distribution')
    axes[1].set_xlabel('SSIM Score')
    axes[1].set_ylabel('Frequency')
    
    plt.tight_layout()
    plt.show()
    
    return psnr_scores, ssim_scores

def plot_comparisons(generator, lr_images, hr_images, num_images=3):
    generated_images = generator.predict(lr_images[:num_images])
    
    for i in range(num_images):
        lr_img = lr_images[i].squeeze()
        hr_img = hr_images[i].squeeze()
        gen_img = generated_images[i].squeeze()
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(lr_img, cmap='gray')
        axes[0].set_title('Low Resolution')
        axes[0].axis('off')
        
        axes[1].imshow(gen_img, cmap='gray')
        axes[1].set_title('Super-Resolved')
        axes[1].axis('off')
        
        axes[2].imshow(hr_img, cmap='gray')
        axes[2].set_title('High Resolution (Ground Truth)')
        axes[2].axis('off')
        
        plt.show()


## 7. Main Execution
Here we'll tie everything together and run the training process.

In [None]:
BATCH_SIZE = 4
EPOCHS = 10

# Create a TensorFlow dataset
dataset = tf.data.Dataset.from_tensor_slices((lr_train, hr_train))
dataset = dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE)

print("Starting training...")
train(dataset, EPOCHS)

print("\nStarting evaluation...")
lr_test, hr_test = generate_dummy_data(HR_SIZE, LR_SIZE, num_images=10)
lr_test = np.expand_dims(lr_test, axis=-1)
hr_test = np.expand_dims(hr_test, axis=-1)

psnr_scores, ssim_scores = evaluate_model(generator, lr_test, hr_test)

print("\nPlotting image comparisons...")
plot_comparisons(generator, lr_test, hr_test, num_images=3)

print("\nDone!")