# Importing And Installing Necessary Libaries

In [None]:
import tensorflow as tf 
from tensorflow  import keras

#Download Model: pre-programmed model from github
import segmentation_models as sm 
from segmentation_models import Unet
from segmentation_models import get_preprocessing
from keras import metrics
from keras.models import load_model

#Import other libaries necessary for preparing and analyzing the data
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image, ImageEnhance
import rasterio
import numpy as np
import patchify
from patchify import unpatchify 
from sklearn.model_selection import train_test_split
import math
import cv2
import os

# save model results
from io import BytesIO
from keras.utils  import img_to_array
from keras.utils  import array_to_img
from keras.utils import save_img

# Configure TensorFlow environment
tf.compat.v1.enable_eager_execution()
TF_ENABLE_ONEDNN_OPTS=0
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'


# Import and Prepare Training Data 



In [None]:
# define the patch size and steps of the training image
ps=64 # patchsize
s=int(ps/4) #steps



## Import Training Image

In [None]:
img_path="Image_Training.tif"# Path to Image

# read geotiff data to numpy array

with rasterio.open(img_path, 'r') as ds:

    arr = ds.read() 
    

    
# change the shape of input tiff for feeding models

arr1 = np.swapaxes(arr, 1, 0)

arr2_image = np.swapaxes(arr1, 1, 2)

print(arr2_image.shape) 


# create smaller image patches for training

patch2 = patchify.patchify(arr2_image ,(ps,ps,3), step=s)
print (patch2.shape)

# Out of the image patches we create a tensor: number of sample images; size; number of channels 

patch_X = np.reshape(patch2, (patch2.shape[0]*patch2.shape[1],ps,ps,3)) 

print(patch_X.shape)

## Import Labels

In [None]:
target_path="Label_Training.tif"# Path to Label

# read geotiff data to numpy array

with rasterio.open(target_path, 'r') as ds:

    arr = ds.read() 
  
    arr=arr.astype('uint8')
    
    
    arr= tf.one_hot(
    arr,#your image with label
    7, #the number of classes   
    ) 
   
    # change the shape of input tiff for feeding models

arr1 = np.swapaxes(arr, 1, 0)
arr2_label = np.swapaxes(arr1, 1, 2)
print(arr2_label.shape)

# create smaller image patches for training

patch2 = patchify.patchify(arr2_label ,(ps,ps,1,7), step=s) 
print(patch2.shape)

# Out of the image patches we create a tensor: number of sample images; size; number of channels 

patch_Y = np.reshape(patch2,(patch2.shape[0]*patch2.shape[1],ps,ps,1*7))
print(patch_Y.shape)

## Define Validation and Test Data

In [None]:
#Testsplit

X_tr, X_va, Y_tr, Y_va = train_test_split(patch_X, patch_Y, test_size=0.1) 
print(X_tr.shape, Y_tr.shape, X_va.shape, Y_va.shape)

# Prepare U-Net Model 

In [None]:
#Define Model parameter 
model_name='landcover_new.h5' # name of the model
backbone = 'resnet34'#backbone: the libary provides different backbones- we will use resnet34
pretrained_weights =None
nclass = 7 # number of class found in the training image
patchsize = ps 
nbands = 3 # number of bands of the training image
activation_func = 'softmax' # activation function
# define model metrics to be shown to follow the training porcess and the model performance
metrics_list = [metrics.CategoricalAccuracy(), metrics.CategoricalCrossentropy(), metrics.MeanIoU(8),metrics.Precision(),metrics.Recall()]

epochs = 200  # define training epochs
OPT = keras.optimizers.Adam(learning_rate=0.01, decay=0.01/epochs)
batchsize = ps# define the batchsize 
callbacks =[keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)]  # implement a callback to stop the model from training when a certain hyopthese is true

# define the model
model = Unet(backbone, 
             encoder_weights=pretrained_weights, 
             classes=nclass,
             input_shape=(patchsize, patchsize, nbands), 
             activation=activation_func)

model.compile(loss='categorical_crossentropy', metrics=metrics_list, optimizer=OPT)

print("Starting training")

# start the model training and save the trainings run

history = model.fit( x=X_tr,
                      y=Y_tr,
                      validation_data=(X_va, Y_va), 
                      batch_size=batchsize,
                      epochs=epochs,
                      verbose=2,
                      callbacks=callbacks,
                      shuffle=True
                    )

In [None]:
# save the trained model
model.save(model_name)

# Evaluation

## Import Evaluation Image

In [None]:
test_img="Image_Prediction.tif" # Path to evaluation image
# geotiff to array
with rasterio.open(test_img, 'r') as ds:

    arr_ti = ds.read() 
 
    
print(arr_ti.shape)

# change the shape of input tiff for feeding models

arr1_ti = np.swapaxes(arr_ti, 1, 0)

print(arr1_ti.shape)

arr2_ti = np.swapaxes(arr1_ti, 1, 2)

print(arr2_ti.shape) 

# Change the shape of the input evaluation data
patchsize = ps
nbands = 3
# to make the shape of input image the same as label data, delete some rows and cols
new_row, new_col = 207, 173 # rows and columns
arr2_ti = arr2_ti[:new_row, :new_col, :]
print(arr2_ti.shape)

patch1_ti = patchify.patchify(arr2_ti, (patchsize,patchsize,nbands), step=patchsize) #1952 #3264
num_patch_row = int(arr2_ti.shape[0]/patchsize)
num_patch_col = int(arr2_ti.shape[1]/patchsize)
num_total = num_patch_row * num_patch_col

# Out of the image patches we create a tensor: number of sample images; size; number of channels 

test_img_patch = np.reshape(patch1_ti, (num_total, patchsize, patchsize, nbands)) 

print(test_img_patch.shape)

## Import Evaluation Labels

In [None]:
test_label="Label_Prediction_new_classes.tif" # Path to evaluation labels
with rasterio.open(test_label, 'r') as ds:

    arr_tl = ds.read() 
    
    arr_tl=arr_tl.astype('uint8')
   

    arr_tl= tf.one_hot(
    arr_tl,#your image with label
    7, #the number of classes   
    )


arr_tl = arr_tl.numpy()
arr_tl = arr_tl[0]
print(arr_tl.shape)



patchsize = ps
nbands = 7
new_row, new_col = 207, 173
arr_tl = arr_tl[:new_row, :new_col, :]
print(arr_tl.shape)

patch1_tl = patchify.patchify(arr_tl, (patchsize,patchsize,nbands), step=patchsize) 
num_patch_row = int(arr_tl.shape[0]/patchsize)
num_patch_col = int(arr_tl.shape[1]/patchsize)
num_total = num_patch_row * num_patch_col
test_label_patch = np.reshape(patch1_tl, (num_total, patchsize, patchsize, nbands))
print(test_label_patch.shape)

## Import Trained Model 

In [None]:
weights_name = 'landcover_new.h5'
model_pred=load_model(weights_name)

In [None]:
#performance of model on normal prediciton data

evl_test=model_pred.evaluate(test_img_patch, test_label_patch) 

prediction_label = model_pred.predict(test_img_patch)

## Reconstruct Predicted Patches Into A Complete Testing Area

In [None]:
# reconstruct predicted patches into a complete testing area
def reshape_prediction_by_unpatchify(prediction, patchsize, nclass, lab_array):
    
    num_row = int(lab_array.shape[0]/patchsize)
    num_col = int(lab_array.shape[1]/patchsize)

    prediction_reshape = prediction.reshape((num_row, num_col, 1, patchsize, patchsize, nclass))

    target_shape = (num_row*patchsize, num_col*patchsize, nclass)

    prediction_reshape_unpatch = unpatchify(prediction_reshape, target_shape)

    return prediction_reshape_unpatch

In [None]:
prediction = prediction_label
nclass = 7
lab_array = arr_tl
prediction_label_complete = reshape_prediction_by_unpatchify(prediction, patchsize, nclass, lab_array)
prediction_label_complete.shape

In [None]:
#reassign the values to predicted labels
prediction_label_final=np.argmax(prediction_label_complete, axis=2)
print(prediction_label_final.shape)
t=tf.reduce_max(prediction_label_final)
print(t)

## Visualize Prediction Results

In [None]:
y_pred=model_pred.predict(test_img_patch)

with rasterio.open(test_label, 'r') as ds:

    ground_truth_ = ds.read() 
    
    ground_truth_ = ground_truth_.astype('uint8')

ground_truth = ground_truth_[0, :, :]

y_pred_argmax=np.argmax(y_pred, axis=3) [0,:,:] 
print(y_pred_argmax.shape   )
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.title('Testing Image')
plt.imshow(ground_truth, cmap='gray')
plt.subplot(232)
plt.title('Testing Label')
#plt.imshow(prediction_label_final[0,:,:,0], cmap='jet')

plt.subplot(232)
plt.title('Prediction on test image')
plt.imshow(prediction_label_final, cmap='jet')
plt.show()

img=np.argmax(test_label_patch, axis=3)[0,:,:] 

## Save Prediction Results

In [None]:
img_array=img_to_array(img) 
save_img('Testing_label.tiff',img_array)
img_pred=prediction_label_final
save= 'Prediction.tif'
cv2.imwrite(save,img_pred)
print('Prediction and Testing labels saved:'+' '+save)