In [1]:
import os
import math
import numpy as np
import pandas as pd
import PIL
from PIL import ImageOps
from IPython.display import Image, display
from osgeo import gdal, gdal_array, ogr, osr
import geopandas as gpd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.preprocessing.image import load_img
from keras.callbacks import History 

In [2]:
input_folder = r"C:\Objects"
input_dir_tif = os.path.join(input_folder, "Cropped")
input_dir = os.path.join(input_folder, "Cropped_PNG")
target_dir = os.path.join(input_folder, "Mask_Cropped_PNG")

model_file = "C:\Model\OilTanks_UNet_256_8_v13.h5"
img_size = (256, 256)
num_classes = 2
batch_size = 1

val_input_img_paths = sorted([os.path.join(input_dir, fname) for fname in os.listdir(input_dir)
                          if fname.endswith(".png")])
val_target_img_paths = sorted([os.path.join(target_dir, fname) for fname in os.listdir(target_dir) 
                           if fname.endswith(".png")])

class OilTanks(keras.utils.Sequence):
    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths

    def __len__(self):
        return len(self.target_img_paths) // self.batch_size

    def __getitem__(self, idx):
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        for j, path in enumerate(batch_input_img_paths):
            img = load_img(path, target_size=self.img_size)
            x[j] = img
        y = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="uint8")
        for j, path in enumerate(batch_target_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            y[j] = np.expand_dims(img, 2)
        return x, y

model = keras.models.load_model(model_file)
val_gen = OilTanks(batch_size, img_size, val_input_img_paths, val_target_img_paths)
val_preds = model.predict(val_gen)

folder_preds = os.path.join(input_folder, "Preds")
folder_extracted = os.path.join(input_folder, "Preds_Extracted")
folder_vector = os.path.join(input_folder, "Preds_Vector")
folder_vector_sort = os.path.join(input_folder, "Preds_Vector_Sort")
for folder in [folder_preds, folder_extracted, folder_vector, folder_vector_sort]:
    os.makedirs(folder)

class PredictResults:
    def __init__(self, input_tif, val_target_img_paths, val_preds, folder_preds, 
                 folder_extracted, folder_vector, folder_vector_sort, i):
        
        self.input_tif = input_tif
        self.val_target_img_paths = val_target_img_paths
        self.val_preds = val_preds
        self.folder_preds = folder_preds
        self.folder_extracted = folder_extracted
        self.folder_vector = folder_vector
        self.folder_vector_sort = folder_vector_sort
        self.i = i

        index = os.path.basename(self.input_tif).split('.')[0]  
        srs = self.GetEPSG(input_tif)
        file_preds = os.path.join(self.folder_preds, f'{index}_preds.tif')
        file_extracted = os.path.join(self.folder_extracted, f'{index}.tif')
        file_vector = os.path.join(self.folder_vector, f'{index}.shp')
        file_vector_sort = os.path.join(self.folder_vector_sort, f'{index}.shp')

        self.CreateRasterFromArray(self.input_tif, index, self.GetValsFromArr(self.val_preds, self.i), srs, file_preds)
        self.ExtractPreds(file_preds, file_extracted)
        self.VectorizePreds(file_extracted, file_vector, srs)
        self.SortVectorized(file_vector, file_vector_sort)
    
    def ValTIFPaths(self, val_target_img_paths):
        val_tif_paths = []
        for img in val_target_img_paths:
            folder = os.path.dirname(img).replace('_PNG', '')
            file_index = os.path.basename(img).split('.')[0]
            path = os.path.join(folder, f'{file_index}.tif')
            val_tif_paths.append(path)
        return val_tif_paths
    
    def GetImageBoundaryBox(self, image_ds):
        image_geotransform = image_ds.GetGeoTransform()
        if image_geotransform != (0.0, 1.0, 0.0, 0.0, 0.0, 1.0):
            # get resolution
            xsize = image_ds.RasterXSize
            ysize = image_ds.RasterYSize
            # get coordinates of boundary box
            xmin = image_geotransform[0]
            ymin = image_geotransform[3] + xsize * image_geotransform[4] + ysize * image_geotransform[5]
            xmax = image_geotransform[0] + xsize * image_geotransform[1] + ysize * image_geotransform[2]
            ymax = image_geotransform[3]

            return [xmin, ymin, xmax, ymax]

    def GetEPSG(self, image):
        ds = gdal.Open(image)
        arr = ds.ReadAsArray()
        srs = osr.SpatialReference(wkt=ds.GetProjection())
        ds = None
        return srs

    def CreateRasterFromArray(self, input_tif, index, arr, srs, filename):    
        xmin,ymin,xmax,ymax = self.GetImageBoundaryBox(gdal.Open(input_tif, 0))
        nrows, ncols = np.shape(arr)

        xres = (xmax-xmin)/float(ncols)
        yres = (ymax-ymin)/float(nrows)
        geotransform=(xmin,xres,0,ymax,0, -yres)   

        output_raster = gdal.GetDriverByName('GTiff').Create(filename, ncols, nrows, 1 ,gdal.GDT_Float32)
        output_raster.SetGeoTransform(geotransform)
        output_raster.SetProjection(srs.ExportToWkt())
        output_raster.GetRasterBand(1).WriteArray(arr)

    def GetValsFromArr(self, val_preds, i):
        val_strs = []
        for j in range(len(val_preds[i])):
            val_str = val_preds[i][j]
            val_cells = []
            for k in range(len(val_str)):
                val_cell = val_preds[i][j][k][0]
                val_cells.append(val_cell)
            val_strs.append(val_cells)
        return np.array(val_strs)

    def ExtractPreds(self, file, out_file):
        ds = gdal.Open(file)
        arr = ds.ReadAsArray()
        data = arr>0.85
        gdal_array.SaveArray(data.astype("float32"), out_file, "GTIFF", ds)
        ds = None

    def VectorizePreds(self, file, out_file, srs):
        src_ds = gdal.Open(file)
        srcband = src_ds.GetRasterBand(1)
        dst_layername = 'mask'
        driver = ogr.GetDriverByName("ESRI Shapefile")
        dst_ds = driver.CreateDataSource(out_file)
        dst_layer = dst_ds.CreateLayer(dst_layername, srs=srs)
        fld = ogr.FieldDefn("mask", ogr.OFTInteger)
        dst_layer.CreateField(fld)
        dst_field = dst_layer.GetLayerDefn().GetFieldIndex("mask")
        gdal.Polygonize(srcband, None, dst_layer, dst_field, [], callback=None)   
        src_ds = None
        dst_ds = None

    def SortVectorized(self, file, out_file):
        gdf = gpd.read_file(file)
        gdf_sel = gdf.query('mask==1')
        if len(gdf_sel) != 0:
            gdf_sel.to_file(out_file)

input_tifs = sorted([os.path.join(input_dir_tif, fname) for fname in os.listdir(input_dir_tif) if fname.endswith(".tif")])
for i, path in enumerate(val_preds):
    input_tif = input_tifs[i]
    PredictResults(input_tif, val_target_img_paths, val_preds, folder_preds, 
                   folder_extracted, folder_vector, folder_vector_sort, i)

