In [None]:
import junodch_utils_read_img as utils

import math
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import tensorflow as tf
from tqdm.keras import TqdmCallback

from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()

import keras
from keras import layers
import rasterio

# Data preparation
### Fetch data from file

In [None]:
folderName = "img/Sokoto/"
pathSatellite = folderName + "Sentinel-2.tif"
#pathNight = folderName + "Night VIIRS_1.tif"
pathNight = folderName + "lowres_night_1.tif"
pathValidation = folderName + "Population GHSL_1.tif"

aoi = utils.getImgBorder(pathSatellite)

# Fetch coords
dataCoords, dataRadiance = utils.getTilesCoordsPerimeter(pathNight, area=aoi)

trainMask = dataRadiance>25
testMask = dataRadiance<=25
train = dataCoords[trainMask]
test = dataCoords[testMask]

print('TrainingTile:',len(train))
print('TestTile:',len(test))

with rasterio.open(pathSatellite) as s:
  data, meta = utils.getEachImgFromCoord(s, train, True)


In [None]:
print(data[0].shape)
#print(meta[0])

dataTrain_formated = utils.formatDataForAutoencoder(data, res=64, toFloat=True)     # !!
print(dataTrain_formated.shape)

# Autoencoder

In [None]:
# Input encoder
input_shape = keras.Input(shape=dataTrain_formated.shape[1:])

#optimizer = 'adam'
optimizer = keras.optimizers.Adam(
  learning_rate=0.001,
  beta_1=0.9,
  beta_2=0.999,
)
#lossFunction = keras.losses.MeanAbsoluteError() # L1
lossFunction = keras.losses.MeanSquaredError() # l2
#lossFunction = keras.losses.MeanSquaredLogarithmicError()
#lossFunction = keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.SUM)
activationFunction = 'relu'
#activationFunction = lambda x: tf.keras.activations.relu(x, max_value=255)
earlyStop = tf.keras.callbacks.EarlyStopping(monitor='loss', min_delta=0, patience=3)

def loss_func(encoder_mu, encoder_log_variance):
  def vae_reconstruction_loss(y_true, y_predict):
    reconstruction_loss_factor = 1000
    reconstruction_loss = keras.backend.mean(keras.backend.square(y_true-y_predict), axis=[1, 2, 3])
    return reconstruction_loss_factor * reconstruction_loss

  def vae_kl_loss(encoder_mu, encoder_log_variance):
    kl_loss = -0.5 * keras.backend.sum(1.0 + encoder_log_variance - keras.backend.square(encoder_mu) - keras.backend.exp(encoder_log_variance), axis=1)
    return kl_loss

  def vae_kl_loss_metric(y_true, y_predict):
    kl_loss = -0.5 * keras.backend.sum(1.0 + encoder_log_variance - keras.backend.square(encoder_mu) - keras.backend.exp(encoder_log_variance), axis=1)
    return kl_loss

  def vae_loss(y_true, y_predict):
    reconstruction_loss = vae_reconstruction_loss(y_true, y_predict)
    kl_loss = vae_kl_loss(encoder_mu, encoder_log_variance)

    loss = reconstruction_loss + kl_loss
    return loss

  return vae_loss

class Sampling(layers.Layer):
  def call(self, inputs):
    mu, log_variance = inputs
    epsilon = tf.keras.backend.random_normal(shape=tf.keras.backend.shape(mu), mean=0.0, stddev=1.0)
    return mu + tf.keras.backend.exp(log_variance/2) * epsilon

latent_space_dim = 8*8*16

cnn = layers.Conv2D(16,(3,3), 2, padding='same', activation=activationFunction)(input_shape)
#cnn = layers.Conv2D(16,(3,3), 2, padding='same', activation=activationFunction)(cnn)
encoded = layers.Conv2D(32,(3,3), 2, padding='same', activation=activationFunction, name='encoder')(cnn)

shape_before_flatten = keras.backend.int_shape(encoded)[1:]
print(shape_before_flatten)

cnn = layers.Flatten()(encoded)
#cnn = layers.Dense(8*8*4, activation=activationFunction)(cnn)
encoder_mu = layers.Dense(units=latent_space_dim, name="encoder_mu")(cnn)
encoder_logvar = layers.Dense(units=latent_space_dim, name="encoder_log_variance")(cnn)

encoder_output = Sampling()([encoder_mu, encoder_logvar])
print(encoder_output.shape)


#decoder_input = layers.Input(shape=(latent_space_dim), name="decoder_input")
decoder_dense_layer = layers.Dense(np.prod(shape_before_flatten), name="decoder_dense")(encoder_output)
decoder_reshape = layers.Reshape(target_shape=shape_before_flatten)(decoder_dense_layer)

cnn = layers.Conv2D(32,(3,3), padding='same', activation=activationFunction)(decoder_reshape)
cnn = layers.UpSampling2D((2,2))(cnn)
cnn = layers.Conv2D(16,(3,3), padding='same', activation=activationFunction)(cnn)
cnn = layers.UpSampling2D((2,2))(cnn)
decoder = layers.Conv2D(3, (3,3), padding='same', activation='sigmoid', name='decoder')(cnn)

autoencoder = keras.Model(input_shape, decoder)
lossFunction = loss_func(encoder_mu, encoder_logvar)
autoencoder.compile(optimizer=optimizer, loss=lossFunction)
#autoencoder.compile(optimizer=optimizer, loss=lossFunction)

print('Encoder shape:',autoencoder.get_layer('encoder').output_shape)
autoencoder.get_output_shape_at
result = autoencoder.fit(dataTrain_formated, dataTrain_formated,
                          epochs=50,
                          batch_size=10,
                          steps_per_epoch=10,
                          shuffle=True,
                          verbose=0,
                          callbacks=[
                            TqdmCallback(verbose=1), # Concise display progression
                            earlyStop,
                          ],
                        )


In [None]:
plt.plot(result.history['loss'][:], label='Training')
autoencoder.summary()

In [None]:
def displayAutoEncoderResults(autoencoder, dataInput, showDetail=0, precision=0):
  MAX_ON_ROW = 20
  total = dataInput.shape[0]
  nRow = (dataInput.shape[0] // MAX_ON_ROW) + 1
  nCol = MAX_ON_ROW if total > MAX_ON_ROW else total
  encoder = keras.Model(inputs=autoencoder.inputs, outputs=autoencoder.get_layer('encoder').output)

  # Display original
  plt.figure(figsize=(30,nRow*2))
  for i in range(0, total):
    ax = plt.subplot(nRow, nCol, 1+i)
    plt.imshow(dataInput[i].eval(session=tf.compat.v1.Session()))       # !
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
  print("Original data:",dataInput.shape)
  plt.show()

  if showDetail == 1:
    # Display encoded. The first MAX_ON_ROW only
    encoded_imgs = encoder.predict(dataInput[:nCol], steps=1)       # !
    displayImgCollection(encoded_imgs)
  elif showDetail == 2:
    layers = autoencoder.layers[0:len(autoencoder.layers)-1]
    for l in layers:
      if 'Conv2D' in l.__class__.__name__:
        intermediateLayers = keras.Model(inputs=autoencoder.inputs, outputs=l.output)
        encoded_imgs = intermediateLayers.predict(dataInput[:nCol], steps=1)       # !
        displayImgCollection(encoded_imgs)
  
  # Display reconstruction
  decoded_imgs = autoencoder.predict(dataInput, steps=1)       # !
  plt.figure(figsize=(30,nRow*2))
  for i in range(0, decoded_imgs.shape[0]):
    ax = plt.subplot(nRow, nCol, 1+i)
    #score = autoencoder.loss(dataInput[i], decoded_imgs[i])
    score = lossFunction(dataInput[i], decoded_imgs[i])
    plt.title(np.round(score.eval(session=tf.compat.v1.Session()),precision))
    #plt.imshow(decoded_imgs[i].astype('uint8'))
    plt.imshow(decoded_imgs[i].astype('float32')) # TODO CHANGE !!!
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
  print("Output data:",decoded_imgs.shape)
  plt.show()

def displayImgCollection(imgs):
  grid = gridspec.GridSpec(1, imgs.shape[0])
  plt.figure(figsize=(30,imgs[0].T.shape[0]/8))
  for i in range(0, imgs.shape[0]):
    nCol = imgs[i].T.shape[0]
    nRow = 1
    while nCol > 8:
      nCol = math.ceil(nCol/2)
      nRow *= 2
    cell = gridspec.GridSpecFromSubplotSpec(int(nRow), int(nCol), subplot_spec=grid[i], wspace=0.1, hspace=0.1)
    for index, img in enumerate(imgs[i].T):
      ax = plt.subplot(cell[index])
      plt.imshow(img.eval(session=tf.compat.v1.Session()))      # !
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)
  print("Data:",imgs.shape)
  plt.show()

In [None]:
dataInput = dataTrain_formated[dataTrain_formated.shape[0]-40:]

displayAutoEncoderResults(autoencoder, dataInput, showDetail=0, precision=5)

In [None]:
with rasterio.open(pathSatellite) as s:
  validation, metaValid = utils.getEachImgFromCoord(s, test[0:1] + test[1104:1123]+ test[4000:4010]+ test[10000:10010], True)
  #validation, metaValid = utils.getEachImgFromCoord(s, test[0:1] + test[1104:1123]+ test[2944:2964]+ test[4000:4020]+ test[5000:5020]+ test[10000:10020], True)

displayAutoEncoderResults(autoencoder, utils.formatDataForAutoencoder(validation,res=64, toFloat=True), showDetail=0, precision=5)

In [None]:
autoencoder.save('model/autoencoder_64_GEN7_V1')

In [None]:
# Display area test
with rasterio.open(pathSatellite) as s:
  validationTest, metaValidTest = utils.getEachImgFromCoord(s, test[0:10] + test[368:378] + test[736:746] + test[1104:1114] + test[1472:1482], True)

utils.displayTiles(validationTest, metaValidTest)