# Original code

import tensorflow as tf
strategy = tf.distribute.MirroredStrategy()
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.optimizers import Adam
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Add
from keras.layers import AveragePooling2D
from keras.layers import MaxPooling2D
from keras.layers import Dropout
from keras.layers import Embedding
from keras.layers import Concatenate
from keras.layers import GaussianNoise
from keras.layers import BatchNormalization
from keras.layers import LayerNormalization
from keras.layers import Conv3D
from keras.layers import ConvLSTM3D
from keras.layers import ConvLSTM2D
from keras.layers import TimeDistributed
from keras.initializers import RandomNormal
import keras.backend as K
from sklearn.utils import shuffle

def define_low_res_generator(low_res=(24, 12, 12,3)):
    #KERNEL INItialization
    init = RandomNormal(mean=0.0, stddev=0.02)
    #Image_Input
    in_img = Input(shape=low_res)
    #First Conv3d layer
    fir_gen =Conv3D(64, (3,3,3), strides=(1,1,1), padding='same',data_format='channels_last')(in_img)
    fir_gen = BatchNormalization(synchronized=False)(fir_gen)
    fir_gen = LeakyReLU(alpha=0.2)(fir_gen)
    #Second Conv3d layer
    gen = Conv3D(64, (3,3,3), strides=(1,1,1), padding='same',data_format='channels_last')(fir_gen)
    gen = BatchNormalization(synchronized=False)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    #add noise
    gen=GaussianNoise(0.01)(gen)
    #ConvLstm layer for aggregating weather parameters from /hr to /day
    Conv_layer = ConvLSTM2D(128,(3,3), activation ="tanh", padding='same',data_format='channels_last', name='LSTMLayer')(gen)
    # final conv2d layer to generate low resolution output
    out_layer= Conv2D(1, (3,3), strides=(1,1), padding='same',data_format='channels_last')(Conv_layer)
    #define the model with it's input and output
    model = Model(inputs=in_img, outputs=out_layer, name="low_res_generator")
    #generate a model summary 
    #model.summary()
    return model


def define_discriminator(in_shape=(48,48,1), n_class=5):
    # label input
    in_label = Input(shape=(1,))
    # embedding the label input
    li = Embedding(n_class, 50)(in_label)
    # scale up to image dimensions with linear activation
    n_nodes = in_shape[0]*in_shape[1]
    li = Dense(n_nodes)(li)
    # reshape to additional channel
    li = Reshape((in_shape[0], in_shape[1], 1))(li)
    #KERNEL INItialization
    init = RandomNormal(mean=0.0, stddev=0.02)
    # image input
    in_img = Input(shape=in_shape)
    in_image=Concatenate()([in_img,li])
    #add a convolutional layers
    conv1 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(in_image)
    conv1 = LayerNormalization()(conv1)
    conv1 = LeakyReLU(alpha=0.2)(conv1)
    # add 1st residual layer to the discriminator
    res11 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(conv1)
    res11 = LayerNormalization()(res11)
    res11 = LeakyReLU(alpha=0.2)(res11)
    # add 2nd residual layer to the discriminator
    res12 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res11)
    res12 = LayerNormalization()(res12)
    res12 = Add()([res12, conv1])
    res12 = LeakyReLU(alpha=0.2)(res12)
    # add 1st residual layer to the discriminator
    res21 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res12)
    res21 = LayerNormalization()(res21)
    res21 = LeakyReLU(alpha=0.2)(res21)
    # add 2nd residual layer to the discriminator
    res22 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res21)
    res22 = LayerNormalization()(res22)
    res22 = Add()([res22, res12])
    res22 = LeakyReLU(alpha=0.2)(res22)
    # add noise
    res22=GaussianNoise(0.01)(res22)
    # downsample layer 1
    conv2 = Conv2D(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init)(res22)
    conv2 = LayerNormalization()(conv2)
    conv2 = LeakyReLU(alpha=0.2)(conv2)
    # downsample layer 2
    conv3 = Conv2D(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init)(conv2)
    conv3 = LayerNormalization()(conv3)
    conv3 = LeakyReLU(alpha=0.2)(conv3)
    # downsample layer 3
    conv4 = Conv2D(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init)(conv3)
    conv4 = LayerNormalization()(conv4)
    conv4 = LeakyReLU(alpha=0.2)(conv4)  
    # flatten feature maps
    fl = Flatten()(conv4)
    # dropout
    fl = Dropout(0.4)(fl)
    # output
    out_layer = Dense(1)(fl)
    #define the model with it's input and output
    model = Model(inputs=[in_img,in_label], outputs= out_layer, name="discriminator")
    #generate a model summary 
    #model.summary()
    return model

# define the standalone generator model
def define_generator(low_res=(12,12,128)):
    #KERNEL INItialization
    init = RandomNormal(mean=0.0, stddev=0.02)
    in_img = Input(shape=low_res)
    # add 1st residual layer to the generator
    res11 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(in_img)
    res11 = LayerNormalization()(res11)
    res11 = LeakyReLU(alpha=0.2)(res11)
    # add 2nd residual layer to the generator
    res12 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res11)
    res12 = LayerNormalization()(res12)
    res12 = Add()([res12, in_img])
    res12 = LeakyReLU(alpha=0.2)(res12)
    # add 1st residual layer to the generator
    res21 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res12)
    res21 = LayerNormalization()(res21)
    res21 = LeakyReLU(alpha=0.2)(res21)
    # add 2nd residual layer to the generator
    res22 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res21)
    res22 = LayerNormalization()(res22)
    res22 = Add()([res22, res12])
    res22 = LeakyReLU(alpha=0.2)(res22)
    # add noise
    res22=GaussianNoise(0.01)(res22)
    # upsampling to 24x24
    convt1 = Conv2DTranspose(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init)(res22)
    convt1 = BatchNormalization(synchronized=False)(convt1)
    convt1 = LeakyReLU(alpha=0.2)(convt1)
    #2nd upsampling to 48x48
    convt2 = Conv2DTranspose(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init)(convt1)
    convt2 = BatchNormalization(synchronized=False)(convt2)
    convt2 = LeakyReLU(alpha=0.2)(convt2)
    # output
    out_layer = Conv2D(1, (5,5), activation='tanh', padding='same',kernel_initializer=init)(convt2)
    #define the model with it's input and output
    model = Model(inputs=in_img, outputs=out_layer, name="generator")
    #generate a model summary 
    #model.summary()
    return model

def define_combined_generator(low_res_gen,generator): 
    # input to the corrector generator
    in_img = low_res_gen.input 
    # get image output from the corrector 
    low_res_output = low_res_gen.get_layer('LSTMLayer').output 
    # connect image output to the generator
    out_layer = generator(low_res_output) 
    # define the combined model 
    model = Model(inputs=in_img , outputs= out_layer, name= 'Combined_generator')  
    return model

class WGAN(keras.Model):
    def __init__(self, discriminator, generator, Dsteps=5, gp_weight=10.0):
        super(WGAN, self).__init__()

        self.discriminator = discriminator
        self.generator = generator
        self.d_steps = Dsteps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn
    
    
    def gradient_penalty (self,batch_size, real_images, fake_images, labels):
        alpha = tf.random.uniform([batch_size, 1, 1, 1], minval=0.,maxval=1.)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            pred = self.discriminator([interpolated,labels], training=True)

        grads = gp_tape.gradient(pred, [interpolated])[0]
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp                             

    def train_step(self, data):
        #if isinstance(data, list):
        lowres_images = data[0]
        real_images=data[1][0]
        labels=data[1][1]
        batch_size = tf.shape(real_images)[0]

        for i in range(self.d_steps):
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(lowres_images, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator([fake_images,labels], training=True)
                # Get the logits for the real images
                real_logits = self.discriminator([real_images,labels], training=True)
                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images,labels)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight
                # Get the gradients w.r.t the discriminator loss
                d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
                # Update the weights of the discriminator using the discriminator optimizer
                self.d_optimizer.apply_gradients(zip(d_gradient, self.discriminator.trainable_variables))
        # Train the generator
        #lowres_images= self.generate_low_res_samples(real_images)
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(lowres_images, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator([generated_images,labels], training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)
        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))
        return {"d_loss": d_loss, "g_loss": g_loss, 'batch_size':batch_size}

    

# define a callback to save keras model after every epoch      
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, gen_loss, critic_loss):
        self.gen_loss=gen_loss
        self.critic_loss=critic_loss
    
    def on_epoch_end(self, epoch, logs={}):
        g_model.save('Pretraining_2_3Channels%.1f.keras'%epoch) 
        self.gen_loss.append(logs.get('g_loss'))
        self.critic_loss.append(logs.get('d_loss'))
        
# list to track the losses for training the WGAN-GP   
gen_loss=[]
critic_loss=[]
# tensorflow method to call all available gpu's for training.
with strategy.scope():
    # creating the model architecture
    cbk = GANMonitor(gen_loss, critic_loss)
    pretraining=define_low_res_generator()
    generator=define_generator()
    g_model=define_combined_generator(pretraining,generator)
    d_model= define_discriminator()
    
    #define the optimizer for generator and discriminator
    pretraining_optimizer= keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.2, beta_2=0.9)
    g_model_optimizer= keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.2, beta_2=0.9)
    generator_optimizer = keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.2, beta_2=0.9)
    discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.2, beta_2=0.9)
    
    #define the pretraining loss function
    def corrector_loss(y_true, y_pred):
        #calculating the soft discritization FSS score with coutoff 0.5 on images scaled [0,1]
        gamma=0.1
        c=10
        cutoff=0.5
        eps = K.epsilon()
        y_true_bi = tf.math.sigmoid( c * ( y_true - cutoff ))
        y_pred_bi = tf.math.sigmoid( c * ( y_pred - cutoff ))
        MSE_n = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(y_true_bi, y_pred_bi) 
        #MSE_weighted(y_true_bi,y_pred_bi) 
        #tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(y_true_bi, y_pred_bi)
        O_sqimg = tf.keras.layers.Multiply()([y_true_bi, y_true_bi])   
        O_sqvec = tf.keras.layers.Flatten()(O_sqimg)
        M_sqimg = tf.keras.layers.Multiply()([y_pred_bi, y_pred_bi])
        M_sqvec = tf.keras.layers.Flatten()(M_sqimg)
        MSE_ref = tf.math.reduce_mean(O_sqvec + M_sqvec)
        return (tf.math.reduce_mean(tf.keras.losses.huber(y_true, y_pred, delta=0.1)+ gamma*(float(MSE_n) / float(MSE_ref+eps))))
    
    def MSE_weighted(y_true,y_pred):
        return K.mean(tf.multiply(tf.square(y_true),tf.square(tf.subtract(y_pred, y_true))))

    def gen_fss(y_true, y_pred):
        #calculating the soft discritization FSS score with coutoff 0.5 on images scaled [0,1]
        gamma=10
        c=10
        cutoff=0.5
        eps = K.epsilon()
        y_true_bi = tf.math.sigmoid( c * ( y_true - cutoff ))
        y_pred_bi = tf.math.sigmoid( c * ( y_pred - cutoff ))
        MSE_n =MSE_weighted(y_true_bi,y_pred_bi) 
        #MSE_weighted(y_true_bi,y_pred_bi) 
        #tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(y_true_bi, y_pred_bi)
        MSE_n=tf.cast(MSE_n, tf.float32)
        O_sqimg = tf.keras.layers.Multiply()([y_true_bi, y_true_bi])   
        O_sqvec = tf.keras.layers.Flatten()(O_sqimg)
        M_sqimg = tf.keras.layers.Multiply()([y_pred_bi, y_pred_bi])
        M_sqvec = tf.keras.layers.Flatten()(M_sqimg)
        MSE_ref = tf.math.reduce_mean(O_sqvec + M_sqvec)
        MSE_ref=tf.cast(MSE_ref, tf.float32)
        return (tf.math.reduce_mean(tf.keras.losses.huber(y_true, y_pred, delta=0.1)+ gamma*(float(MSE_n) / float(MSE_ref+eps))))
                
    #critic loss without the gradient penalty ter
    def discriminator_loss(real_img, fake_img):
        real_loss = tf.reduce_mean(real_img)
        fake_loss = tf.reduce_mean(fake_img)
        return fake_loss - real_loss
    
    #generator loss 
    def generator_loss(fake_img):
        fake_loss= -tf.reduce_mean(fake_img)
        return fake_loss
    
    #compile the pretraining model using low_res_gen
    pretraining.compile(optimizer=pretraining_optimizer, loss= corrector_loss)        
    g_model.compile(optimizer=g_model_optimizer, loss=gen_fss)
    #define the model with generator and critic
    wgan = WGAN(discriminator=d_model, generator=g_model, Dsteps=5)
    # Compile the WGAN model.
    wgan.compile(d_optimizer=discriminator_optimizer, g_optimizer=generator_optimizer, g_loss_fn=generator_loss, d_loss_fn=discriminator_loss,)       

#create a low resolution of the observational data to feed the pre-training   
def Pooling(High_Res_Data):
    Avgpool= MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')
    low1=Avgpool(High_Res_Data)
    low2=Avgpool(low1)
    return low2

#load the data
Input_data=np.load("/kaggle/input/threechannelinput/Inputfile3channels.npy")
Highres_data=np.load("/kaggle/input/threechannelinput/Highresfile3channels.npy")
labels_data=np.load("/kaggle/input/threechannelinput/labelfile3channels.npy")
Low_Res_Data, High_Res_Data, labels= shuffle(Input_data, Highres_data, labels_data)
print(Low_Res_Data.shape,High_Res_Data.shape, labels_data.shape)
               
#define the batch size and epochs
pre_epoch=10
pre_batch_size=128
training_epoch =32
training_batch_size=64
total_samples=High_Res_Data.shape[0]

#maintain consistent number of samples per epoch
stpe_wgan=total_samples//training_batch_size
stpe_pretrain=total_samples//pre_batch_size
#create a low resolution of the observational data set
Maxpooled_data=Pooling(High_Res_Data)

#Train the model.
pretraining.fit(Low_Res_Data, Maxpooled_data,batch_size=pre_batch_size,epochs=pre_epoch,steps_per_epoch=stpe_pretrain, verbose=2)
g_model.fit(Low_Res_Data, High_Res_Data, batch_size=pre_batch_size,epochs=pre_epoch, steps_per_epoch=stpe_pretrain, verbose=2)
wgan.fit(Low_Res_Data, [High_Res_Data,labels], batch_size=training_batch_size, epochs=training_epoch,callbacks=[cbk],steps_per_epoch=stpe_wgan, verbose=2)

# Edit in progress below

In [None]:
import tensorflow as tf
strategy = tf.distribute.MirroredStrategy()
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.optimizers import Adam
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Add
from keras.layers import AveragePooling2D
from keras.layers import MaxPooling2D
from keras.layers import Dropout
from keras.layers import Embedding
from keras.layers import Concatenate
from keras.layers import GaussianNoise
from keras.layers import BatchNormalization
from keras.layers import LayerNormalization
from keras.layers import Conv3D
from keras.layers import ConvLSTM3D
from keras.layers import ConvLSTM2D
from keras.layers import TimeDistributed
from keras.initializers import RandomNormal
import keras.backend as K
from sklearn.utils import shuffle

def define_low_res_generator(low_res=(12, 12, 1)):
    #KERNEL INItialization
    init = RandomNormal(mean=0.0, stddev=0.02)
    #Image_Input
    in_img = Input(shape=low_res)
    #add a convolutional layers
    conv1 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(in_img)
    conv1 = LayerNormalization()(conv1)
    conv1 = LeakyReLU(alpha=0.2)(conv1)
    #residual layer
    res11 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(conv1)
    res11 = LayerNormalization()(res11)
    res11 = LeakyReLU(alpha=0.2)(res11)
    # add 2nd residual layer to the discriminator
    res12 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res11)
    res12 = LayerNormalization()(res12)
    res12 = Add()([res12, conv1])
    res12 = LeakyReLU(alpha=0.2)(res12)
    #add noise
    res12=GaussianNoise(0.02)(res12)
    # add 1st residual layer to the discriminator
    res21 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res12)
    res21 = LayerNormalization()(res21)
    res21 = LeakyReLU(alpha=0.2)(res21)
    # add 2nd residual layer to the discriminator
    res22 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res21)
    res22 = LayerNormalization()(res22)
    res22 = Add()([res22, res12])
    gen_layer = LeakyReLU(alpha=0.2, name="gen_layer")(res22)
    # final conv2d layer to generate low resolution output
    out_layer= Conv2D(1, (3,3), strides=(1,1), activation='tanh',padding='same')(gen_layer)
    #define the model with it's input and output
    model = Model(inputs=in_img, outputs=out_layer, name="low_res_generator")
    #generate a model summary 
    #model.summary()
    return model


def define_discriminator(in_shape=(48,48,1)):
    # label input
    #in_label = Input(shape=(1,))
    # embedding the label input
    #li = Embedding(n_class, 50)(in_label)
    # scale up to image dimensions with linear activation
    #n_nodes = in_shape[0]*in_shape[1]
    #li = Dense(n_nodes)(li)
    # reshape to additional channel
    #li = Reshape((in_shape[0], in_shape[1], 1))(li)
    #KERNEL INItialization
    init = RandomNormal(mean=0.0, stddev=0.02)
    # image input
    in_image = Input(shape=in_shape)
    #in_image=Concatenate()([in_img,li])
    #add a convolutional layers
    conv1 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(in_image)
    conv1 = LayerNormalization()(conv1)
    conv1 = LeakyReLU(alpha=0.2)(conv1)
    # add 1st residual layer to the discriminator
    res11 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(conv1)
    res11 = LayerNormalization()(res11)
    res11 = LeakyReLU(alpha=0.2)(res11)
    # add 2nd residual layer to the discriminator
    res12 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res11)
    res12 = LayerNormalization()(res12)
    res12 = Add()([res12, conv1])
    res12 = LeakyReLU(alpha=0.2)(res12)
    # add 1st residual layer to the discriminator
    res21 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res12)
    res21 = LayerNormalization()(res21)
    res21 = LeakyReLU(alpha=0.2)(res21)
    # add 2nd residual layer to the discriminator
    res22 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res21)
    res22 = LayerNormalization()(res22)
    res22 = Add()([res22, res12])
    res22 = LeakyReLU(alpha=0.2)(res22)
    # add noise
    res22=GaussianNoise(0.01)(res22)
    # downsample layer 1
    conv2 = Conv2D(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init)(res22)
    conv2 = LayerNormalization()(conv2)
    conv2 = LeakyReLU(alpha=0.2)(conv2)
    # downsample layer 2
    conv3 = Conv2D(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init)(conv2)
    conv3 = LayerNormalization()(conv3)
    conv3 = LeakyReLU(alpha=0.2)(conv3)
    # downsample layer 3
    conv4 = Conv2D(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init)(conv3)
    conv4 = LayerNormalization()(conv4)
    conv4 = LeakyReLU(alpha=0.2)(conv4)  
    # flatten feature maps
    fl = Flatten()(conv4)
    # dropout
    fl = Dropout(0.3)(fl)
    # output
    out_layer = Dense(1)(fl)
    #define the model with it's input and output
    model = Model(inputs=in_image, outputs= out_layer, name="discriminator")
    #generate a model summary 
    #model.summary()
    return model

# define the standalone generator model
def define_generator(low_res=(12,12,128)):
    #KERNEL INItialization
    init = RandomNormal(mean=0.0, stddev=0.02)
    in_img = Input(shape=low_res)
    # add 1st residual layer to the generator
    res11 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(in_img)
    res11 = LayerNormalization()(res11)
    res11 = LeakyReLU(alpha=0.2)(res11)
    # add 2nd residual layer to the generator
    res12 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res11)
    res12 = LayerNormalization()(res12)
    res12 = Add()([res12, in_img])
    res12 = LeakyReLU(alpha=0.2)(res12)
    # add 1st residual layer to the generator
    res21 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res12)
    res21 = LayerNormalization()(res21)
    res21 = LeakyReLU(alpha=0.2)(res21)
    # add 2nd residual layer to the generator
    res22 = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(res21)
    res22 = LayerNormalization()(res22)
    res22 = Add()([res22, res12])
    res22 = LeakyReLU(alpha=0.2)(res22)
    # add noise
    res22=GaussianNoise(0.01)(res22)
    # upsampling to 24x24
    convt1 = Conv2DTranspose(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init)(res22)
    convt1 = LayerNormalization()(convt1)
    convt1 = LeakyReLU(alpha=0.2)(convt1)
    #2nd upsampling to 48x48
    convt2 = Conv2DTranspose(256, (3,3), strides=(2,2), padding='same',kernel_initializer=init)(convt1)
    convt2 = LayerNormalization()(convt2)
    convt2 = LeakyReLU(alpha=0.2)(convt2)
    # output
    out_layer = Conv2D(1, (5,5), activation='tanh', padding='same',kernel_initializer=init)(convt2)
    #define the model with it's input and output
    model = Model(inputs=in_img, outputs=out_layer, name="generator")
    #generate a model summary 
    #model.summary()
    return model

def define_combined_generator(low_res_gen,generator): 
    # input to the corrector generator
    in_img = low_res_gen.input 
    # get image output from the corrector 
    low_res_output = low_res_gen.get_layer("gen_layer").output 
    # connect image output to the generator
    out_layer = generator(low_res_output) 
    # define the combined model 
    model = Model(inputs=in_img , outputs= out_layer, name= 'Combined_generator')  
    return model

class WGAN(keras.Model):
    def __init__(self, discriminator, generator, Dsteps=5, gp_weight=10.0):
        super(WGAN, self).__init__()

        self.discriminator = discriminator
        self.generator = generator
        self.d_steps = Dsteps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn
    
    
    def gradient_penalty (self,batch_size, real_images, fake_images):
        alpha = tf.random.uniform([batch_size, 1, 1, 1], minval=0.,maxval=1.)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            pred = self.discriminator(interpolated, training=True)

        grads = gp_tape.gradient(pred, [interpolated])[0]
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp                             

    def train_step(self, data):
        #if isinstance(data, list):
        lowres_images = data[0]
        real_images=data[1]
        batch_size = tf.shape(real_images)[0]

        for i in range(self.d_steps):
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(lowres_images, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator(fake_images, training=True)
                # Get the logits for the real images
                real_logits = self.discriminator(real_images, training=True)
                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight
                # Get the gradients w.r.t the discriminator loss
                d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
                # Update the weights of the discriminator using the discriminator optimizer
                self.d_optimizer.apply_gradients(zip(d_gradient, self.discriminator.trainable_variables))
        # Train the generator
        #lowres_images= self.generate_low_res_samples(real_images)
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(lowres_images, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_images, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)
        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))
        return {"d_loss": d_loss, "g_loss": g_loss, 'batch_size':batch_size}

    

# define a callback to save keras model after every epoch      
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, gen_loss, critic_loss):
        self.gen_loss=gen_loss
        self.critic_loss=critic_loss
    
    def on_epoch_end(self, epoch, logs={}):
        g_model.save('TempSRgencomb%.1f.keras'%epoch)
        d_model.save('TempSRdisc%.1f.keras'%epoch)
        pretraining.save('TempSRpretrain%.1f.keras'%epoch)
       # model.save('./MyModel_tf',save_format='tf')
        self.gen_loss.append(logs.get('g_loss'))
        self.critic_loss.append(logs.get('d_loss'))
        
# list to track the losses for training the WGAN-GP   
gen_loss=[]
critic_loss=[]
# tensorflow method to call all available gpu's for training.
with strategy.scope():
    # creating the model architecture
    cbk = GANMonitor(gen_loss, critic_loss)
    pretraining=define_low_res_generator()
    generator=define_generator()
    g_model=define_combined_generator(pretraining,generator)
    d_model= define_discriminator()
    
    #define the optimizer for generator and discriminator
    pretraining_optimizer= keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.2, beta_2=0.9)
    generator_optimizer = keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.2, beta_2=0.9)
    discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.2, beta_2=0.9)
    
    def MSE_weighted(y_true,y_pred):
        return K.mean(tf.multiply(tf.square(y_true),tf.square(tf.subtract(y_pred, y_true))))

    #define the pretraining loss function
    def corrector_loss(y_true, y_pred):
        #calculating the soft discritization FSS score with coutoff 0.5 on images scaled [0,1]
        gamma=0.1
        c=10
        cutoff=0.5
        eps = K.epsilon()
        y_true_bi = tf.math.sigmoid( c * ( y_true - cutoff ))
        y_pred_bi = tf.math.sigmoid( c * ( y_pred - cutoff ))
        MSE_n = MSE_weighted(y_true_bi,y_pred_bi) 
        #tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(y_true_bi, y_pred_bi)
        O_sqimg = tf.keras.layers.Multiply()([y_true_bi, y_true_bi])   
        O_sqvec = tf.keras.layers.Flatten()(O_sqimg)
        M_sqimg = tf.keras.layers.Multiply()([y_pred_bi, y_pred_bi])
        M_sqvec = tf.keras.layers.Flatten()(M_sqimg)
        MSE_ref = tf.math.reduce_mean(O_sqvec + M_sqvec)
        return (tf.math.reduce_mean(tf.keras.losses.huber(y_true, y_pred, delta=0.1)+ gamma*(float(MSE_n) / float(MSE_ref+eps))))
    
                
    #critic loss without the gradient penalty ter
    def discriminator_loss(real_img, fake_img):
        real_loss = tf.reduce_mean(real_img)
        fake_loss = tf.reduce_mean(fake_img)
        return fake_loss - real_loss
    
    #generator loss 
    def generator_loss(fake_img):
        fake_loss= -tf.reduce_mean(fake_img)
        return fake_loss
    
    #compile the pretraining model using low_res_gen
    pretraining.compile(optimizer=pretraining_optimizer, loss= corrector_loss)        
    #define the model with generator and critic
    wgan = WGAN(discriminator=d_model, generator=g_model, Dsteps=5)
    # Compile the WGAN model.
    wgan.compile(d_optimizer=discriminator_optimizer, g_optimizer=generator_optimizer, g_loss_fn=generator_loss, d_loss_fn=discriminator_loss,)       

#create a low resolution of the observational data to feed the pre-training   
def Pooling(High_Res_Data):
    Avgpool= MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')
    low1=Avgpool(High_Res_Data)
    low2=Avgpool(low1)
    return low2

#load the data
Input_data=np.load("/kaggle/input/tempimdera/Inputfiletemp.npy")
Highres_data=np.load("/kaggle/input/tempimdera/Highresfiletemp.npy")
Low_Res_Data, High_Res_Data= shuffle(Input_data, Highres_data)
print(Low_Res_Data.shape,High_Res_Data.shape)
               
#define the batch size and epochs
pre_epoch=10
pre_batch_size=128
training_epoch =20
training_batch_size=128
total_samples=High_Res_Data.shape[0]

#maintain consistent number of samples per epoch
stpe_wgan=total_samples//training_batch_size
stpe_pretrain=total_samples//pre_batch_size
#create a low resolution of the observational data set
Maxpooled_data=Pooling(High_Res_Data)

#Train the model.
pretraining.fit(Low_Res_Data, Maxpooled_data,batch_size=pre_batch_size,epochs=pre_epoch,steps_per_epoch=stpe_pretrain, verbose=2)
wgan.fit(Low_Res_Data, High_Res_Data, batch_size=training_batch_size, epochs=training_epoch,callbacks=[cbk],steps_per_epoch=stpe_wgan, verbose=2)

import tensorflow as tf
import pandas as pd
import numpy as np
from numpy import expand_dims
import numpy.ma as ma
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.cm as mtpltcm
from mpl_toolkits.mplot3d import Axes3D
import random
import folium
from folium.plugins import HeatMap, HeatMapWithTime
from folium import plugins
from netCDF4 import Dataset
import cartopy.crs as ccrs
from tensorflow import keras
from keras.layers import AveragePooling2D
from keras.layers import MaxPooling2D
from keras.layers import Add
from keras.models import load_model
import numpy.ma as ma
from warnings import filterwarnings
import glob
import os

INPUT=np.load("/kaggle/input/testfile/Inputfiletemptest.npy")
OUTPUT=np.load("/kaggle/input/testfile/Highresfiletemptest.npy")
model = load_model("/kaggle/input/model1/TempSR11.0.keras")
IMAGE=4567

gen=INPUT[IMAGE]
era=OUTPUT[IMAGE]
era=np.ma.masked_greater(era, 50)
era=np.reshape(era, (48,48))

gen=np.reshape(gen, (1, 12, 12, 1))
OUT = model.predict(gen)
OUT=np.reshape(OUT,(48,48))
OUT=np.ma.masked_greater(OUT, 50)
IN=np.reshape(gen, (12, 12))
IN=np.ma.masked_greater(IN, 50)

imdlat=np.array([i for i in np.arange(7.5,39.5,1)])
imdlong=np.array([i for i in np.arange(67.5,99.5,1)])


(iln5, ilt5) = (imdlong[5:17], imdlat[0:12])
(iln2, ilt2) = (imdlong[0:12], imdlat[10:22])
(iln3, ilt3) = (imdlong[10:22], imdlat[10:22])
(iln4, ilt4) = (imdlong[20:32], imdlat[12:24])
(iln1, ilt1) = (imdlong[4:16], imdlat[20:32])

if (IMAGE//1096==0):
	(ilat, ilong)=(iln1, ilt1)
elif (IMAGE//1096==1):
	(ilat, ilong)=(iln2, ilt2)
elif (IMAGE//1096==2):
	(ilat, ilong)=(iln3, ilt3)
elif(IMAGE//1096==3):
	(ilat, ilong)=(iln4, ilt4)
else:
	(ilat, ilong)=(iln5, ilt5)


eralat=np.array([i for i in np.arange(7,39,0.25)])
eralong=np.array([i for i in np.arange(67.25,99.25,0.25)])


(ln5, lt5) = (eralong[20:68], eralat[0:48])
(ln2, lt2) = (eralong[0:48], eralat[40:88])
(ln3, lt3) = (eralong[40:88], eralat[40:88])
(ln4, lt4) = (eralong[80:128], eralat[48:96])
(ln1, lt1) = (eralong[16:64], eralat[80:128])

if (IMAGE//1096==0):
	(elat, elong)=(ln1, lt1)
elif (IMAGE//1096==1):
	(elat, elong)=(ln2, lt2)
elif (IMAGE//1096==2):
	(elat, elong)=(ln3, lt3)
elif(IMAGE//1096==3):
	(elat, elong)=(ln4, lt4)
else:
	(elat, elong)=(ln5, lt5)

maxval=np.amax(OUT)
minval=np.amin(OUT)
levelout=np.linspace(minval,maxval, 50)

maxval=np.amax(era)
minval=np.amin(era)
levelera=np.linspace(minval,maxval, 50)

plt.style.use("seaborn-v0_8-bright")
figure = plt.figure(figsize=(10,10))
axis_func = plt.axes(projection=ccrs.PlateCarree())
axis_func.coastlines(resolution="10m",linewidth=1)
axis_func.gridlines(linestyle='--',color='black',linewidth=2)
N=99
plt.contourf(elat, elong,  era, N, transform=ccrs.PlateCarree(), cmap='terrain', levels=levelout)
color_bar_func = plt.colorbar(ax=axis_func, orientation="vertical", pad=0.05, aspect=16, shrink=.8)
color_bar_func.ax.tick_params(labelsize=12)
plt.tight_layout()
#plt.savefig('dp4.png')
plt.show()

plt.style.use("seaborn-v0_8-bright")
figure = plt.figure(figsize=(10,10))
axis_func = plt.axes(projection=ccrs.PlateCarree())
axis_func.coastlines(resolution="10m",linewidth=1)
axis_func.gridlines(linestyle='--',color='black',linewidth=2)
N=99
plt.contourf(elat, elong , OUT, N, transform=ccrs.PlateCarree(), cmap='terrain', levels=levelout)
color_bar_func = plt.colorbar(ax=axis_func, orientation="vertical", pad=0.05, aspect=16, shrink=.8)
color_bar_func.ax.tick_params(labelsize=12)
plt.tight_layout()
##plt.savefig('dp4.png')
plt.show()


plt.style.use("seaborn-v0_8-bright")
figure = plt.figure(figsize=(10,10))
axis_func = plt.axes(projection=ccrs.PlateCarree())
axis_func.coastlines(resolution="10m",linewidth=1)
axis_func.gridlines(linestyle='--',color='black',linewidth=2)
N=99
plt.contourf(ilat, ilong, IN, N, transform=ccrs.PlateCarree(), cmap='terrain',levels=levelout)
color_bar_func = plt.colorbar(ax=axis_func, orientation="vertical", pad=0.05, aspect=16, shrink=.8)
color_bar_func.ax.tick_params(labelsize=12)
plt.tight_layout()
#plt.savefig('dp4.png')
plt.show()