In [1]:
# Importing necessary libraries

# xarray for working with labeled multi-dimensional arrays and datasets
import xarray as xr

# numpy for numerical computations
import numpy as np
from numpy.random import randint

# train_test_split for splitting the data into training and testing sets
from sklearn.model_selection import train_test_split

# tqdm for creating progress bars
from tqdm import tqdm

import tensorflow as tf
# Importing Keras libraries and modules for building the neural network
from keras import layers, Model
from keras.layers import Input, Conv2D, PReLU, BatchNormalization, Flatten
from keras.layers import UpSampling2D, LeakyReLU, Dense, add, Cropping2D
from keras.models import load_model  # For loading pre-trained models

In [None]:
# Open the dataset using xarray
# This dataset is in NetCDF format and contains the data for training the SRGAN model
data = xr.open_dataset('adaptor.mars.internal-1699024719.206968-18137-18-9d14156f-5adf-4abf-880e-6eb4422da3d2.nc')

In [None]:
# Extracting the 'msl' (mean sea level pressure) data from the dataset
# taking every 10th value along the first dimension
press = data['msl'][::10,1::,::]

# Assign the data to hr_data (high-resolution data)
hr_data = press

# Normalize the high-resolution data
# The normalization is done by subtracting the minimum value and dividing by the range (max - min)
hr_data = (hr_data - hr_data.min(dim=('latitude', 'longitude'))) / (hr_data.max(dim=('latitude', 'longitude')) - hr_data.min(dim=('latitude', 'longitude')))

# Convert the xarray DataArray to a NumPy array
high_res = np.array(hr_data)

# Repeat the high-resolution data along a new axis to match the required input shape for the model
# Adding a new axis to match the expected input shape for the model
high_res_data = np.repeat(high_res[:,:,:,np.newaxis], 1, axis=3)

# Check the shape of the high-resolution data
high_res_data.shape

# Downsample the data to create low-resolution data by taking every 4th value along the second and third dimensions
lr_data = press[:,::4,::4]

# Normalize the low-resolution data
# The normalization is done by subtracting the minimum value and dividing by the range (max - min)
lr_data = (lr_data - lr_data.min(dim=('latitude', 'longitude'))) / (lr_data.max(dim=('latitude', 'longitude')) - lr_data.min(dim=('latitude', 'longitude')))

# Convert the xarray DataArray to a NumPy array
low_res = np.array(lr_data)

# Repeat the low-resolution data along a new axis to match the required input shape for the model
# Adding a new axis to match the expected input shape for the model
low_res_data = np.repeat(low_res[:,:,:,np.newaxis], 1, axis=3)

# Check the shape of the low-resolution data
low_res_data.shape


In [None]:
# Splitting the data into training and test sets
# Using 80% of the data for training and 20% for testing
lr_train_full, lr_test, hr_train_full, hr_test = train_test_split(low_res_data, high_res_data, test_size=0.2, random_state=42)

# Further splitting the training data into training and validation sets
# Using 80% of the training data for actual training and 20% for validation
lr_train, lr_val, hr_train, hr_val = train_test_split(lr_train_full, hr_train_full, test_size=0.2, random_state=42)

# Defining the shapes of the high-resolution and low-resolution training data
hr_shape = (hr_train.shape[1], hr_train.shape[2], hr_train.shape[3])
lr_shape = (lr_train.shape[1], lr_train.shape[2], lr_train.shape[3])

# Creating input layers for the low-resolution and high-resolution data
lr_ip = Input(shape=lr_shape)
hr_ip = Input(shape=hr_shape)

# Print the shapes of the input layers to verify
print(hr_ip.shape, lr_ip.shape)


In [None]:
# Define a residual block for the generator model
def res_block(ip):
    # First convolutional layer with batch normalization and PReLU activation
    res_model = Conv2D(64, (3, 3), padding='same')(ip)
    res_model = BatchNormalization(momentum=0.5)(res_model)
    res_model = PReLU(shared_axes=[1, 2])(res_model)

    # Second convolutional layer with batch normalization
    res_model = Conv2D(64, (3, 3), padding='same')(res_model)
    res_model = BatchNormalization(momentum=0.5)(res_model)

    # Add the input to the output to create a residual connection
    return add([ip, res_model])

# Define an upscale block for the generator model
def upscale_block(ip):
    # Convolutional layer followed by upsampling and PReLU activation
    up_model = Conv2D(256, (3, 3), padding='same')(ip)
    up_model = UpSampling2D(size=2)(up_model)
    up_model = PReLU(shared_axes=[1, 2])(up_model)

    return up_model

# Define the generator model
def create_gen(gen_ip, num_res_block):
    # Initial convolutional layer with PReLU activation
    layers = Conv2D(64, (9, 9), padding='same')(gen_ip)
    layers = PReLU(shared_axes=[1, 2])(layers)

    # Store the output of the first layer for later
    temp = layers

    # Add a specified number of residual blocks
    for i in range(num_res_block):
        layers = res_block(layers)

    # Convolutional layer with batch normalization
    layers = Conv2D(64, (3, 3), padding='same')(layers)
    layers = BatchNormalization(momentum=0.5)(layers)

    # Add the initial layer output to the current output to create a residual connection
    layers = add([layers, temp])

    # Add two upscale blocks
    layers = upscale_block(layers)
    layers = upscale_block(layers)

    # Final convolutional layer to produce the output image
    op = Conv2D(1, (9, 9), padding='same')(layers)

    return Model(inputs=gen_ip, outputs=op)

# Define a discriminator block
def discriminator_block(ip, filters, strides=1, bn=True):
    # Convolutional layer followed by LeakyReLU activation
    disc_model = Conv2D(filters, (3, 3), strides=strides, padding='same')(ip)
    disc_model = LeakyReLU(alpha=0.2)(disc_model)

    # Optionally add batch normalization
    if bn:
        disc_model = BatchNormalization(momentum=0.8)(disc_model)

    return disc_model

# Define the discriminator model
def create_disc(disc_ip):
    df = 64  # Number of filters

    # Add a series of discriminator blocks with increasing filters and strides
    d1 = discriminator_block(disc_ip, df, bn=False)
    d2 = discriminator_block(d1, df, strides=2)
    d3 = discriminator_block(d2, df * 2)
    d4 = discriminator_block(d3, df * 2, strides=2)
    d5 = discriminator_block(d4, df * 4)
    d6 = discriminator_block(d5, df * 4, strides=2)
    d7 = discriminator_block(d6, df * 8)
    d8 = discriminator_block(d7, df * 8, strides=2)

    # Flatten the output and add dense layers
    d8_5 = Flatten()(d8)
    d9 = Dense(df * 16)(d8_5)
    d10 = LeakyReLU(alpha=0.2)(d9)

    # Final dense layer with sigmoid activation to produce a validity score
    validity = Dense(1, activation='sigmoid')(d10)

    return Model(disc_ip, validity)

# Define a VGG model for feature extraction
def build_vgg(hr_shape):
    from keras.applications import VGG19
    # Load the VGG19 model pre-trained on ImageNet, excluding the top layers
    vgg = VGG19(weights='imagenet', include_top=False, input_shape=hr_shape)
    vgg_layers = vgg.layers[:10]  # Use the first 10 layers

    # Define the input tensor and a convolutional layer to convert the input to 3 channels
    input_tensor = Input(shape=(144, 320, 1))
    x = Conv2D(3, (3, 3), padding='same')(input_tensor)

    # Pass the input through the VGG layers
    out = x
    for layer in vgg_layers[1:]:
        out = layer(out)

    return Model(inputs=input_tensor, outputs=out, name='vgg')

# Define a combined model for training the generator with the discriminator and VGG feature extractor
def create_comb(gen_model, disc_model, vgg, lr_ip, hr_ip):
    # Generate the high-resolution image from the low-resolution input
    gen_img = gen_model(lr_ip)

    # Extract features from the generated image using the VGG model
    gen_features = vgg(gen_img)

    # Make the discriminator untrainable in the combined model
    disc_model.trainable = False

    print(gen_img.shape)

    # Get the validity score from the discriminator
    validity = disc_model(gen_img)

    return Model(inputs=[lr_ip, hr_ip], outputs=[validity, gen_features])


In [None]:
import random

# Create the generator model with 16 residual blocks
generator = create_gen(lr_ip, num_res_block=16)
generator.summary()

# Define the learning rate schedule for the generator using exponential decay
gen_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.0001,
    decay_steps=100000,
    decay_rate=0.1,
    staircase=True
)

# Define the learning rate schedule for the discriminator using exponential decay
disc_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.0004,
    decay_steps=100000,
    decay_rate=0.1,
    staircase=True
)

# Create and compile the discriminator model
discriminator = create_disc(hr_ip)
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=tf.keras.optimizers.Adam(learning_rate=disc_schedule),
    metrics=['accuracy']
)
discriminator.summary()

# Build the VGG model for feature extraction
vgg = build_vgg((144, 240, 3))
print(vgg.summary())
vgg.trainable = False  # Make the VGG model untrainable

# Create the combined GAN model
gan_model = create_comb(generator, discriminator, vgg, lr_ip, hr_ip)
gan_model.compile(
    loss=['binary_crossentropy', 'mse'],
    loss_weights=[1e-3, 1],
    optimizer=tf.keras.optimizers.Adam(learning_rate=gen_schedule)
)
gan_model.summary()

# Define the batch size for training
batch_size = 16

# Create lists to store training and validation batches
train_lr_batches = []
train_hr_batches = []
val_lr_batches = []
val_hr_batches = []

# Create training and validation batches
for it in range(int(hr_train.shape[0] / batch_size)):
    start_idx = it * batch_size
    end_idx = start_idx + batch_size
    train_hr_batches.append(hr_train[start_idx:end_idx])
    train_lr_batches.append(lr_train[start_idx:end_idx])
    val_hr_batches.append(hr_val[start_idx:end_idx])
    val_lr_batches.append(lr_val[start_idx:end_idx])


In [None]:
epochs = 600

# Assuming these variables are initialized elsewhere in your code:
# train_lr_batches, train_hr_batches, val_lr_batches, val_hr_batches
# generator, discriminator, gan_model, vgg

loss_d = []
loss_g = []
val_loss_g = []
val_loss_d = []  # Combined validation loss for the discriminator

for e in range(epochs):
    fake_label = np.zeros((batch_size, 1))
    real_label = np.ones((batch_size, 1))

    g_losses = []
    d_losses = []

    # Training loop
    for b in tqdm(range(len(train_hr_batches))):
        lr_imgs = train_lr_batches[b]
        hr_imgs = train_hr_batches[b]

        fake_imgs = generator.predict_on_batch(lr_imgs)

        discriminator.trainable = True
        d_loss_gen = discriminator.train_on_batch(fake_imgs, fake_label)
        d_loss_real = discriminator.train_on_batch(hr_imgs, real_label)

        discriminator.trainable = False

        d_loss = 0.5 * np.add(d_loss_gen, d_loss_real)

        image_features = vgg.predict(hr_imgs)

        g_loss, _, _ = gan_model.train_on_batch([lr_imgs, hr_imgs], [real_label, image_features])

        d_losses.append(d_loss)
        g_losses.append(g_loss)

    # Validation loop
    val_g_losses = []
    val_d_losses = []
    for b in range(len(val_hr_batches)):
        val_lr_imgs = val_lr_batches[b]
        val_hr_imgs = val_hr_batches[b]

        # Skip the batch if it's empty
        if val_lr_imgs.size == 0 or val_hr_imgs.size == 0:
            continue

        val_fake_imgs = generator.predict_on_batch(val_lr_imgs)

        # Adjust the label sizes to match the validation batch size
        current_batch_size = val_lr_imgs.shape[0]
        val_real_label = np.ones((current_batch_size, 1))
        val_fake_label = np.zeros((current_batch_size, 1))

        # Evaluate discriminator on real and fake images
        val_d_loss_real = discriminator.evaluate(val_hr_imgs, val_real_label, verbose=0)
        val_d_loss_fake = discriminator.evaluate(val_fake_imgs, val_fake_label, verbose=0)
        val_d_loss = 0.5 * np.add(val_d_loss_real[0], val_d_loss_fake[0])  # Assuming 0th index is loss

        val_image_features = vgg.predict(val_hr_imgs)
        val_g_loss, _, _ = gan_model.evaluate([val_lr_imgs, val_hr_imgs], [val_real_label, val_image_features], verbose=0)

        val_g_losses.append(val_g_loss)
        val_d_losses.append(val_d_loss)

    # Calculate average losses
    avg_g_loss = np.mean(g_losses)
    avg_d_loss = np.mean(d_losses)
    avg_val_g_loss = np.mean(val_g_losses)
    avg_val_d_loss = np.mean(val_d_losses)  # Combined average validation D loss

    # Store the average losses
    loss_g.append(avg_g_loss)
    loss_d.append(avg_d_loss)
    val_loss_g.append(avg_val_g_loss)
    val_loss_d.append(avg_val_d_loss)  # Adjusted to combined validation D loss

    print(f'Epoch: {e + 1}, Training G Loss: {avg_g_loss}, Training D Loss: {avg_d_loss}, Validation G Loss: {avg_val_g_loss}, Validation D Loss: {avg_val_d_loss}')

    # Saving the generator model
    if (e + 1) % 1 == 0:
        generator.save(f'SRGAN_large_region_gen_e_{e + 1}.h5')

In [None]:
import matplotlib.pyplot as plt

# Plotting
plt.figure(figsize=(12, 6))

# Plot discriminator loss
plt.subplot(1, 2, 1)
plt.plot(range(epochs), loss_d, label='Training D Loss')
plt.plot(range(epochs), val_loss_d, label='Validation D Loss')
plt.title('Discriminator Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot generator loss
plt.subplot(1, 2, 2)
plt.plot(range(epochs), loss_g, label='Training G Loss')
plt.plot(range(epochs), val_loss_g, label='Validation G Loss')
plt.title('Generator Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
# Save the figure
plt.savefig('SRGAN_4_times_large_region_loss_plot.png', bbox_inches='tight)
plt.show()


In [None]:
import tensorflow as tf

# SSIM (Structural Similarity Index) loss function
def ssim_loss(gen_image, tar_image):
    ssim_val = tf.image.ssim(
        gen_image,
        tar_image,
        max_val=1,
        filter_size=11,
        filter_sigma=1.5,
        k1=0.01,
        k2=0.03,
        # return_index_map=False  # Not needed as it defaults to False
    )

    return tf.reduce_mean(ssim_val)

# PSNR (Peak Signal-to-Noise Ratio) loss function
def psnr_loss(gen_image, tar_image):
    psnr_val = tf.image.psnr(
        gen_image,
        tar_image,
        max_val=1
    )
    return tf.reduce_mean(psnr_val)

# MAE (Mean Absolute Error) loss function
def mae_loss(gen_image, tar_image):
    # Initialize an empty list to store losses
    loss = []

    # Compute MAE for each image in the batch
    for i in range(len(gen_image[:,0,0,0])):
        loss.append(np.mean(np.abs(gen_image[i,:,:,:] - tar_image[i,:,:,:])))

    # Return the mean of all MAE losses
    return np.array(loss).mean()


In [None]:
import pandas as pd

# Initialize lists to store evaluation metrics
epochs = []
ssim_ = []
psnr_ = []
mae_ = []

# Evaluate on test dataset
for i in range(600):
    generator = load_model(f'SRGAN_large_region_gen_e_{i+1}.h5', compile=False)
    gen_image = generator.predict(lr_test)

    ssim_.append(ssim_loss(gen_image, hr_test))
    psnr_.append(psnr_loss(gen_image, hr_test))
    mae_.append(mae_loss(gen_image, hr_test))
    epochs.append(i+1)

# Create a DataFrame for test dataset evaluation
test_loss_df = pd.DataFrame({
    'Epochs': epochs,
    'PSNR': np.array(psnr_),
    'SSIM': np.array(ssim_),
    'MAE': mae_
})

# Save test dataset evaluation to CSV
test_loss_df.to_csv('SRGAN_4_times_test_loss_data.csv', index=False)

# Initialize lists to store evaluation metrics for train dataset
epochs = []
ssim_ = []
psnr_ = []
mae_ = []

# Evaluate on train dataset
for i in range(600):
    generator = load_model(f'SRGAN_large_region_gen_e_{i+1}.h5', compile=False)
    gen_image = generator.predict(lr_train)

    ssim_.append(ssim_loss(gen_image, hr_train))
    psnr_.append(psnr_loss(gen_image, hr_train))
    mae_.append(mae_loss(gen_image, hr_train))
    epochs.append(i+1)

# Create a DataFrame for train dataset evaluation
train_loss_df = pd.DataFrame({
    'Epochs': epochs,
    'PSNR': np.array(psnr_),
    'SSIM': np.array(ssim_),
    'MAE': mae_
})

# Save train dataset evaluation to CSV
train_loss_df.to_csv('SRGAN_4_times_train_loss_data.csv', index=False)


In [None]:
import matplotlib.pyplot as plt
from numpy.random import randint

# Load the trained generator model
generator = load_model('SRGAN_large_region_gen_e_600.h5', compile=False)

# Select random samples from the training dataset
[X1, X2] = [lr_train, hr_train]

# Generate and display super-resolved images for 10 random samples
for _ in range(10):
    ix = randint(0, len(X1), 1)
    src_image, tar_image = X1[ix], X2[ix]

    # Generate super-resolved image
    gen_image = generator.predict(src_image)

    # Plot LR image, super-resolved image, and original HR image
    plt.figure(figsize=(16, 8))

    plt.subplot(231)
    plt.title('LR Image')
    plt.imshow(src_image[0, :, :, 0], cmap='jet')  # Assuming single-channel input, adjust cmap if needed
    plt.axis('off')

    plt.subplot(232)
    plt.title('Superresolved Image')
    plt.imshow(gen_image[0, :, :, 0], cmap='jet')  # Assuming single-channel output, adjust cmap if needed
    plt.axis('off')

    plt.subplot(233)
    plt.title('Original HR Image')
    plt.imshow(tar_image[0, :, :, 0], cmap='jet')  # Assuming single-channel target, adjust cmap if needed
    plt.axis('off')

    plt.tight_layout()
    plt.show()
    plt.close()


In [None]:
import matplotlib.pyplot as plt
from numpy.random import randint

# Load the trained generator model
generator = load_model('SRGAN_large_region_gen_e_600.h5', compile=False)

# Select random samples from the test dataset
[X1, X2] = [lr_test, hr_test]

# Generate and display super-resolved images for 10 random samples from the test dataset
for _ in range(10):
    ix = randint(0, len(X1), 1)
    src_image, tar_image = X1[ix], X2[ix]

    # Generate super-resolved image
    gen_image = generator.predict(src_image)

    # Plot LR image, super-resolved image, and original HR image
    plt.figure(figsize=(16, 8))

    plt.subplot(231)
    plt.title('LR Image')
    plt.imshow(src_image[0, :, :, 0], cmap='jet')  # Assuming single-channel input, adjust cmap if needed
    plt.axis('off')

    plt.subplot(232)
    plt.title('Superresolved Image')
    plt.imshow(gen_image[0, :, :, 0], cmap='jet')  # Assuming single-channel output, adjust cmap if needed
    plt.axis('off')

    plt.subplot(233)
    plt.title('Original HR Image')
    plt.imshow(tar_image[0, :, :, 0], cmap='jet')  # Assuming single-channel target, adjust cmap if needed
    plt.axis('off')

    plt.tight_layout()
    plt.show()
    plt.close()
