In [None]:
import junodch_utils_read_img as utils

import matplotlib.pyplot as matPlt
import numpy as np
import tensorflow as tf
import keras
from keras import layers

from tqdm.keras import TqdmCallback
from sklearn.metrics import confusion_matrix

import rasterio
from rasterio import plot as rastPlt

# Data preparation
### Fetch data from file

In [None]:
folderName = "img/Sokoto/"
pathSatellite = folderName + "Sentinel-2.tif"
#pathSatellite = folderName + "Landsat-8.tif"
#pathNight = folderName + "Night VIIRS_1.tif"
pathNight = folderName + "lowres_night_1.tif"
pathValidation = folderName + "Population GHSL_1.tif"

aoi = utils.getImgBorder(pathSatellite)

# Fetch coords
dataCoords, dataRadiance = utils.getTilesCoordsPerimeter(pathNight, validThreshold=1, area=aoi)

print('Tiles:',dataCoords.shape[0])


In [None]:
# Fetch images
lightMask = dataRadiance>25
darkMask = (dataRadiance<=25) & (dataRadiance>1)
with rasterio.open(pathSatellite) as f:
  lightData, _ = utils.getEachImgFromCoord(f, dataCoords[lightMask], True)
  darkData, _ = utils.getEachImgFromCoord(f, dataCoords[darkMask], True)
trainData = lightData+darkData
trainData = utils.formatData(trainData, res=64, toFloat=True)
print('Light Tile:',len(lightData))
print('Dark Tile:',len(darkData))
print('Total train',trainData.shape)

## CNN

In [None]:
# Input encoder
input_shape = keras.Input(shape=trainData.shape[1:])

#optimizer = 'adam'
optimizer = keras.optimizers.Adam(
  learning_rate=0.001,
  beta_1=0.9,
  beta_2=0.999,
)
#lossFunction = keras.losses.MeanAbsoluteError() # L1
lossFunction = keras.losses.MeanSquaredError() # l2
#lossFunction = keras.losses.MeanSquaredLogarithmicError()
#lossFunction = keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.SUM)

activationFunction = 'relu'
#activationFunction = lambda x: tf.keras.activations.relu(x, max_value=255)

earlyStop = tf.keras.callbacks.EarlyStopping(monitor='loss', min_delta=0, patience=5)

cnn = layers.Conv2D(16,(3,3), 2, padding='same', activation=activationFunction)(input_shape)
cnn = layers.Conv2D(16,(3,3), 2, padding='same', activation=activationFunction)(cnn)
cnn = layers.Conv2D(16,(3,3), 2, padding='same', activation=activationFunction)(cnn)

cnn = layers.Flatten()(cnn)
#cnn = layers.Dense(8*8*4, activation=activationFunction)(cnn)
cnn = layers.Dense(1, activation='sigmoid')(cnn)

modelCNN = keras.Model(input_shape, cnn)
modelCNN.compile(optimizer=optimizer, loss=lossFunction)

result = modelCNN.fit(
  x=trainData,
  y=np.concatenate((dataRadiance[lightMask], dataRadiance[darkMask]), axis=0).astype("float32") / 255,
  epochs=20,
  batch_size=10,
  shuffle=True,
  verbose=0,
  callbacks=[
    TqdmCallback(verbose=1), # Concise display progression
    earlyStop,
  ],
)

In [None]:
matPlt.plot(result.history['loss'][0:], label='Training')
modelCNN.summary()

In [None]:
print('Process validation...')
getValid = lambda data : [ int(any([ any(c > 250 for c in row) for row in img[0] ])) for img in data ]
resultValid = utils.scanSatellite(pathValidation, dataCoords, getValid, batch=1000)

In [None]:
print('Process score...')
getScore = lambda data : modelCNN.predict(utils.formatData(data, res=64, toFloat=True), verbose=0)
result = utils.scanSatellite(pathSatellite, dataCoords, getScore, batch=1000)

In [None]:
print('Process confustion matrix...')
print(len(result))
resultTest = (np.asarray(result) > 0.9).astype(int)
confusionMatrix = confusion_matrix(resultValid, resultTest)
print(confusionMatrix)
print((confusionMatrix[0][0] + confusionMatrix[1][1]) / (confusionMatrix[0][0] + confusionMatrix[0][1] + confusionMatrix[1][0] + confusionMatrix[1][1]),"%")

In [None]:
resultImg, resultMeta = utils.mapResultOnImg(pathNight, dataCoords, resultTest, resultValid)

fig, axs = matPlt.subplots(1,3, dpi=240)
with rasterio.open(pathSatellite) as s: utils.displayTiles([s.read()], [s.transform],axs[0])

axs[2].set_xlim(axs[0].get_xlim())
axs[2].set_ylim(axs[0].get_ylim())

with rasterio.open(pathValidation) as p: rastPlt.show(p, ax=axs[2])

axs[1].set_xlim(axs[0].get_xlim())
axs[1].set_ylim(axs[0].get_ylim())

utils.displayTiles([resultImg], [resultMeta], axs[1])