In [None]:
import junodch_utils_read_img as utils

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

import keras
from keras import layers, losses
from keras.models import Model
from shapely.geometry import box
#import gdal

import rasterio
from rasterio import plot as rastPlt
from rasterio.merge import merge as rasterMerge


In [None]:
folderName = "img/Sokoto/"
with rasterio.open(folderName + "Sentinel-2.tif") as s:
  sBox = box(*s.bounds).exterior.coords

#with rasterio.open(folderName + "Night VIIRS_1.tif") as s:
#  sBox = box(*s.bounds).exterior.coords

aoi = []
for v in sBox:
  aoi.append((v[0], v[1]))
aoi.pop()

aoi

In [None]:
folderName = "img/Sokoto/"
train, test = utils.getTrainingAndTestPerimeter(folderName + "Night VIIRS_1.tif", 200, area=aoi)

print(len(train))
print(len(test))

In [None]:
folderName = "img/Sokoto/"
with rasterio.open(folderName + "Sentinel-2.tif") as s:
  data, meta = utils.getEachImgFromCoord(s, train, True)


In [None]:
fig, axs = plt.subplots(1,1)
rastPlt.show(data[0],transform=meta[0],ax=axs)
xMin = axs.get_xlim()[0]
xMax = axs.get_xlim()[1]
yMin = axs.get_ylim()[0]
yMax = axs.get_ylim()[1]

for i in range(1,len(data)):
  rastPlt.show(data[i],transform=meta[i],ax=axs)
  newXMin = axs.get_xlim()[0]
  newXMax = axs.get_xlim()[1]
  newYMin = axs.get_ylim()[0]
  newYMax = axs.get_ylim()[1]
  
  xMin = newXMin if newXMin < xMin else xMin
  xMax = newXMax if newXMax > xMax else xMax
  yMin = newYMin if newYMin < yMin else yMin
  yMax = newYMax if newYMax > yMax else yMax

axs.set_xlim((xMin, xMax))
axs.set_ylim((yMin, yMax))


## format data for autoencoder

In [None]:
print(data[0].shape)
#print(meta[0])

dataTrain_formated = tf.transpose(np.asarray(data), [0, 2, 3, 1])
dataTrain_formated.shape
dataTrain_formated = tf.slice(
  dataTrain_formated, 
  [0, 0, 0, 0],
  [len(dataTrain_formated),32,32,3])
dataTrain_formated.shape

# Autoencoder

In [None]:
# Input encoder
input_img = keras.Input(shape=(32,32,3))
cnn = layers.Conv2D(8,(3,3), padding='same', activation='relu')(input_img)
cnn = layers.MaxPool2D((2,2), padding='same')(cnn)
cnn = layers.Conv2D(8,(3,3), padding='same', activation='relu')(cnn)
cnn = layers.MaxPool2D((2,2), padding='same')(cnn)
cnn = layers.Conv2D(8,(3,3), padding='same', activation='relu')(cnn)
encoded = layers.MaxPool2D((2,2), padding='same')(cnn)
print('Encoder shape:',encoded.get_shape())
encoder = keras.Model(input_img, encoded)

cnn = layers.Conv2D(8,(3,3), padding='same', activation='relu')(encoded)
cnn = layers.UpSampling2D((2,2))(cnn)
cnn = layers.Conv2D(8,(3,3), padding='same', activation='relu')(cnn)
cnn = layers.UpSampling2D((2,2))(cnn)
cnn = layers.Conv2D(8,(3,3), padding='same', activation='relu')(cnn)
cnn = layers.UpSampling2D((2,2))(cnn)
decoder = layers.Conv2D(3, (3,3), padding='same', activation='sigmoid')(cnn)

autoencoder = keras.Model(input_img, decoder)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

autoencoder.fit(dataTrain_formated[0:100], dataTrain_formated[0:100],
                epochs=5,
                #batch_size=256,
                shuffle=True,
                validation_data=(dataTrain_formated[1000:2000], dataTrain_formated[1000:2000]))

In [None]:
n = 10

encoded_imgs = encoder.predict(dataTrain_formated[1000:1010])
decoded_imgs = autoencoder.predict(dataTrain_formated[1000:1010])
print(dataTrain_formated[1000:1010].shape)
print(encoded_imgs.shape)
print(decoded_imgs.shape)

plt.figure(figsize=(20, 20))
for i in range(0, n):
    # Display original
    ax = plt.subplot(1, n, 1+i)
    plt.imshow(dataTrain_formated[1000+i])
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

plt.figure(figsize=(20, 4))
for i in range(0, n):
    # Display encoded
    ax = plt.subplot(1, n, 1+i)
    plt.imshow(encoded_imgs[i].reshape(4,4*8).T)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

plt.figure(figsize=(20, 20))
for i in range(0, n):
    # Display reconstruction
    ax = plt.subplot(1, n, 1+i)
    plt.imshow(decoded_imgs[i])
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()
