In [1]:
# example of pix2pix gan for satellite to map image-to-image translation
import numpy as np
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import *
from keras.layers import *
import os
import matplotlib.pyplot as plt
from PIL import Image

# Define Losses

In [2]:
input_size = 512 # 128x128x3 images (training_set)
batch_size = 1

In [3]:
# defining other metrics:
def psnr(y_true,y_pred):
    return tf.image.psnr(y_true,y_pred,1.0)
def ssim(y_true,y_pred):
    return tf.image.ssim(y_true,y_pred,1.0)

In [4]:
# load and prepare training images
def load_real_samples(root_dir):
    cloud_imgdir = root_dir+"/cloudy_image"
    clear_imgdir = root_dir+"/ground_truth"
    X1 = []
    X2 = []
    for i,j in zip(os.listdir(cloud_imgdir),os.listdir(clear_imgdir)):
        os.chdir(cloud_imgdir)
        X1.append(np.array(Image.open(i),dtype = 'float32')[:,:,:3]/255)
        os.chdir(clear_imgdir)
        X2.append(np.array(Image.open(j),dtype = 'float32')[:,:,:3]/255)
    return [X1, X2]

In [None]:
root_dir = "C:/Users/ArrunPersonal/Codes/ISRO_GAN/RICE Dataset/RICE1_ThinCloud/RICE1Train"
dataset = np.array(load_real_samples(root_dir))
print("Loaded", dataset[0].shape, dataset[1].shape)

## PatchGAN Discriminator

In [None]:
def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_src_image = Input(shape=image_shape)
	# target image input
	in_target_image = Input(shape=image_shape)
	# concatenate images channel-wise
	merged = Concatenate()([in_src_image, in_target_image])
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	patch_out = Activation('sigmoid')(d)
	# define model
	model = Model([in_src_image, in_target_image], patch_out)
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
	return model
 
# define image shape
image_shape = (input_size,input_size,3)
# create the model
model = define_discriminator(image_shape)
# summarize the model
model.summary()

## RS_Cloud GAN Proposed Encoder


In [9]:
def encoder_block(input_features,num_filters, filter_size=[3,5,7]):
    conv1_1 = Conv2D(num_filters,filter_size[0], padding = 'same',activation = 'relu')(input_features)
    conv1_2 = Conv2D(num_filters,filter_size[1], padding = 'same',activation = 'relu')(input_features)
    conv1_3 = Conv2D(num_filters,filter_size[2], padding = 'same',activation = 'relu')(input_features)

    concat_12 = Concatenate()([conv1_1,conv1_2])
    concat_13 = Concatenate()([conv1_2,conv1_3])
    concat_23 = Concatenate()([conv1_1,conv1_3])

    conv2_1 = Conv2D(num_filters,3, padding = 'same',activation = 'relu')(concat_12)
    conv2_2 = Conv2D(num_filters,3, padding = 'same',activation = 'relu')(concat_23)
    conv2_3 = Conv2D(num_filters,3, padding = 'same',activation = 'relu')(concat_13)

    concat_123 = Concatenate()([conv2_1,conv2_2,conv2_3])
    conv_fin = Conv2D(num_filters,3, padding = 'same',activation = 'relu')(concat_123)
    maxpool_fin = MaxPool2D(2)(conv_fin)
    return maxpool_fin

In [10]:
def decoder_block(input_layer,skip_connection, num_filters,filter_size=[3,5,7]):
    convt1_1 = Conv2DTranspose(num_filters, filter_size[0], padding = 'same',activation = 'relu')(input_layer)
    convt1_2 = Conv2DTranspose(num_filters, filter_size[1], padding = 'same',activation = 'relu')(input_layer)
    convt1_3 = Conv2DTranspose(num_filters, filter_size[2], padding = 'same',activation = 'relu')(input_layer)

    concat_12 = Concatenate()([convt1_1,convt1_2])
    concat_13 = Concatenate()([convt1_2,convt1_3])
    concat_23 = Concatenate()([convt1_1,convt1_3])

    conv2_1 = Conv2D(num_filters,3, padding = 'same',activation = 'relu')(concat_12)
    conv2_2 = Conv2D(num_filters,3, padding = 'same',activation = 'relu')(concat_23)
    conv2_3 = Conv2D(num_filters,3, padding = 'same',activation = 'relu')(concat_13)

    concat_123 = Concatenate()([conv2_1,conv2_2,conv2_3,skip_connection])
    conv_fin = Conv2D(num_filters,3, padding = 'same',activation = 'relu')(concat_123)
    upsampling_fin = UpSampling2D(2)(conv_fin)
    
    return upsampling_fin

In [11]:
def bottleneck(input_layer,drop = 0.2):
    feature_layer = Conv2D(512,3,padding = 'same',activation = 'linear')(input_layer)
    attention_layer = Conv2D(512,3,padding = 'same',activation = 'sigmoid')(feature_layer)
    new_input_features = MultiHeadAttention(num_heads=3, key_dim=3, attention_axes=(2, 3))(input_layer,attention_layer)
    
    batch_norma = BatchNormalization()(new_input_features)
    if(drop):
        drop = Dropout(drop)(batch_norma)
        return drop
    return batch_norma

## 3 Layer

In [12]:
input_layer = Input((512,512,3))
e1 = encoder_block(input_layer,32)
e2 = encoder_block(e1,64)
e3 = encoder_block(e2,128)
b = bottleneck(e3)
d1 = decoder_block(b,e3,64)
d2 = decoder_block(d1,e2,32)
d3 = decoder_block(d2,e1,3)

CR_Net = Model(inputs = input_layer, outputs = d3)
CR_Net.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_6 (Conv2D)              (None, 512, 512, 32  4736        ['input_3[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_7 (Conv2D)              (None, 512, 512, 32  7808        ['input_3[0][0]']                
                                )                                                           

## 4 Layer

In [13]:
input_layer = Input((512,512,3))
e1 = encoder_block(input_layer,32)
e2 = encoder_block(e1,64)
e3 = encoder_block(e2,128)
e4 = encoder_block(e3,256)
b = bottleneck(e4)
d1 = decoder_block(b,e4,128)
d2 = decoder_block(d1,e3,64)
d3 = decoder_block(d2,e2,32)
d4 = decoder_block(d3,e1,3)

CR_Net = Model(inputs = input_layer, outputs = d4)
CR_Net.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_4 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_40 (Conv2D)             (None, 512, 512, 32  4736        ['input_4[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_41 (Conv2D)             (None, 512, 512, 32  7808        ['input_4[0][0]']                
                                )                                                           

## 5 Layer

In [12]:
input_layer = Input((512,512,3))
e1 = encoder_block(input_layer,32)
e2 = encoder_block(e1,64)
e3 = encoder_block(e2,128)
e4 = encoder_block(e3,256)
e5 = encoder_block(e4,512)
b = bottleneck(e5)
d1 = decoder_block(b,e5,256)
d2 = decoder_block(d1,e4,256)
d3 = decoder_block(d2,e3,128)
d4 = decoder_block(d3,e2,64)
d5 = decoder_block(d4,e1,3)

CR_Net = Model(inputs = input_layer, outputs = d5)
CR_Net.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_6 (Conv2D)              (None, 512, 512, 32  896         ['input_3[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_7 (Conv2D)              (None, 512, 512, 32  2432        ['input_3[0][0]']                
                                )                                                           

# GAN Defintion

In [5]:
# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
	# make weights in the discriminator not trainable
	for layer in d_model.layers:
		if not isinstance(layer, BatchNormalization):
			layer.trainable = False
	# define the source image
	in_src = Input(shape=image_shape)
	# connect the source image to the generator input
	gen_out = g_model(in_src)
	# connect the source input and generator output to the discriminator input
	dis_out = d_model([in_src, gen_out])
	# src image as input, generated image and classification output
	model = Model(in_src, [dis_out, gen_out])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy','mae'], optimizer=opt, loss_weights=[1,100], metrics=[psnr,ssim,tf.keras.losses.mean_squared_error])
	return model

In [None]:
image_shape = (512,512,3)
# define the models
d_model = define_discriminator(image_shape)
g_model = CR_Net
#g_model.load_weights('C:/Users/ArrunPersonal/Codes/ISRO_GAN/RUNet.h5')
# define the composite model
gan_model = define_gan(g_model, d_model, image_shape)
# summarize the model
gan_model.summary()
# plot the model
#K.utils.plot_model(gan_model, to_file='gan_model_plot.png', show_shapes=True, show_layer_names=True)

# Aiding functions

In [15]:
def generate_real_samples(dataset, n_samples, patch_shape):
	# unpack dataset
	trainA, trainB = dataset
	# choose random instances
	ix = np.random.randint(0, trainA.shape[0], n_samples)
	# retrieve selected images
	X1, X2 = trainA[ix], trainB[ix]
	# generate 'real' class labels (1)
	y = np.ones((n_samples, patch_shape, patch_shape, 1))
	return [X1, X2], y

In [16]:
def generate_fake_samples(g_model, samples, targets, patch_shape):
	# generate fake instance
	X = g_model.predict(samples)
	#plt.imshow(X[0])
	#print("PSNR: ", psnr_np(samples, X))
	# create 'fake' class labels (0)
	y = np.zeros((len(X), patch_shape, patch_shape, 1))
	return X, y

# Train GAN

In [18]:
# train pix2pix models
def train(d_model, g_model, gan_model, train_dataset, test_dataset, max_psnr = 20, max_ssim = 0.80, n_epochs=60, n_batch=batch_size, n_patch=32):
    # calculate the number of batches per training epoch
    bat_per_epo = batch_size
    # calculate the number of training iterations
    n_steps = bat_per_epo * n_epochs
    psnr_tot = 0
    msee_tot = 0
    ssim_tot = 0
    psnr_tot_2 = 0
    msee_tot_2 = 0
    ssim_tot_2 = 0
    trloss = 0
    teloss=0
    n_steps = int(400//n_batch)
    # manually enumerate epochs
    for i in range(n_epochs):
        psnr_tot = 0
        msee_tot = 0
        ssim_tot = 0
        psnr_tot_2 = 0
        msee_tot_2 = 0
        ssim_tot_2 = 0
        q1_tot = 0
        q2_tot = 0
        q3_tot = 0
        q1_tot_2 = 0
        q2_tot_2 = 0
        q3_tot_2 = 0
        trloss = 0
        teloss=0
        print("Epoch: ", i+1)

        for j in range((n_steps)):
        # select a batch of real samples
            [X_realA, X_realB], y_real = generate_real_samples(train_dataset, n_batch, n_patch)

            # generate a batch of fake samples
            X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, X_realB, n_patch)
            #break
            # update discriminator for real samples
            d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
            # update discriminator for generated samples
            d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
            # update the generator
            #g_loss, psnr, ssim, msee, xd, ab, bc, cd, de = gan_model.train_on_batch(X_realA, [y_real, X_realB])
            g_loss, _, _,q1, q2, q3, psnr, ssim, msee  = gan_model.train_on_batch(X_realA, [y_real, X_realB])

            [X_realA_test, X_realB_test], y_real_test = generate_real_samples(test_dataset, n_batch, n_patch)
            g_loss_2, _, _,q1_2, q2_2, q3_2, psnr_2, ssim_2, msee_2  = gan_model.test_on_batch(X_realA_test, [y_real_test, X_realB_test])
            #tr = gan_model.train_on_batch(X_realA, [y_real, X_realB])
            #te = gan_model.test_on_batch(X_realA, [y_real, X_realB])
            psnr_tot+=psnr
            msee_tot+=msee
            ssim_tot+=ssim
            psnr_tot_2+=psnr_2
            msee_tot_2+=msee_2
            ssim_tot_2+=ssim_2
            trloss+=g_loss
            teloss+=g_loss_2
            q1_tot+=q1
            q2_tot+=q2
            q3_tot+=q3
            q1_tot_2+=q1_2
            q2_tot_2+=q2_2
            q3_tot_2+=q3_2
            print('\t >%d, d1[%.3f] d2[%.3f] g[%.3f]' % (j+1, d_loss1, d_loss2, g_loss))

        print("\n\t\tTrain\n\tLoss: ", trloss, "PSNR: ", psnr_tot/n_steps, "SSIM: ", ssim_tot/n_steps, "MSE: ", msee_tot/n_steps, f' q1: {q1_tot/n_steps}, q2:{q2_tot/n_steps}, q3:{q3_tot/n_steps}\n')
        print("\t\tTest\n\tLoss: ", teloss, "PSNR: ", psnr_tot_2/n_steps, "SSIM: ", ssim_tot_2/n_steps, "MSE: ", msee_tot_2/n_steps, f' q1: {q1_tot_2/n_steps}, q2:{q2_tot_2/n_steps}, q3:{q3_tot_2/n_steps}\n')
        #ssim_tot_2, psnr_tot_2 = custom_test(test_dataset)

        if (psnr_tot_2 > max_psnr):
            ## For RICE 1
            g_model.save_weights('C:/Users/ArrunPersonal/Codes/ISRO_GAN/PostComments_NewWeights/RICE1_weights/RS_CloudGAN_l5_k357_selfattn_bce_max_psnr_rice1.h5')
            ## For RICE 2
            #g_model.save_weights('C:/Users/ArrunPersonal/Codes/ISRO_GAN/PostComments_NewWeights/RICE2_weights/RS_CloudGAN_l3_k7911_max_psnr_rice2.h5')
            max_psnr = psnr_tot_2
            print("\n ------------\nSaved psnr model \n------------------")
            
        if (ssim_tot_2 > max_ssim):
            ## For RICE 1
            g_model.save_weights('C:/Users/ArrunPersonal/Codes/ISRO_GAN/PostComments_NewWeights/RICE1_weights/RS_CloudGAN_l5_k357_selfattn_bce_max_ssim_rice1.h5')
            ## For RICE 2
            #g_model.save_weights('C:/Users/ArrunPersonal/Codes/ISRO_GAN/PostComments_NewWeights/RICE2_weights/RS_CloudGAN_l3_k7911_max_ssim_rice2.h5')
            max_ssim = ssim_tot_2
            print("\n ------------\nSaved ssim model \n------------------")
    return max_psnr, max_ssim

In [None]:
max_psnr_g, max_ssim_g = train(d_model, g_model, gan_model, train_data, val_data)