In [None]:
from osgeo import gdal
import os
import sys
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import CustomObjectScope, get_custom_objects
import segmentation_models

In [None]:
# when you have customized initializer, set custom objects before load the model (https://github.com/keras-team/keras/issues/3867)
class CustomInitializer:
    def __call__(self, shape, dtype=None):
        return my_init(shape, dtype=dtype)

get_custom_objects().update({'my_init': CustomInitializer})
get_custom_objects().update({"binary_focal_loss": segmentation_models.losses.BinaryFocalLoss()})
get_custom_objects().update({"iou_score": segmentation_models.metrics.IOUScore()})

In [None]:
### this function is to read the data and the relevant properties
def ReadData_geoinf(path):
    """
    This function is used to read geoinformation

    param path: img path

    """
    ds = gdal.Open(path, 0)
    if ds is None:
        sys.exit('Could not open {0}.'.format(path))

    geoTransform = ds.GetGeoTransform()
    proj = ds.GetProjection()

    XSize = ds.RasterXSize
    YSize = ds.RasterYSize
    MinX = geoTransform[0]
    MaxY = geoTransform[3]
    MaxX = MinX + geoTransform[1] * XSize
    MinY = MaxY + geoTransform[5] * YSize

    resolution = geoTransform[1]

    data = ds.ReadAsArray()
    res = {'data': data,
           'geoTransform': geoTransform,
           'projection': proj,
           'minX': MinX,
           'maxX': MaxX,
           'minY': MinY,
           'maxY': MaxY,
           'Xsize': XSize,
           'Ysize': YSize,
           'resolution': resolution}
    return res


In [None]:
def cut_array(data, row, col, row_buffer, col_buffer):
    """
    Cut a 2D numpy array into smaller rectangular chips with some overlap.

    Parameters:
    - data: 2D numpy array to be divided into smaller chips.
    - row: Height of each chip.
    - col: Width of each chip.
    - row_buffer: Overlap between consecutive chips in the vertical direction.
    - col_buffer: Overlap between consecutive chips in the horizontal direction.

    Returns:
    - A numpy array containing the chips as subarrays.
    """
    # Dimensions of the input data array
    data_row = data.shape[0]
    data_col = data.shape[1]

    # Calculate the row indices for the bottom-right corner of each chip
    if ((data_row - row) % (row - row_buffer) == 0):
        row_list = list(range(row, data_row + 1, row - row_buffer))
    else:
        row_list = list(range(row, data_row + 1, row - row_buffer))
        row_list.append(data_row)  # Ensure the last chip includes the bottom edge of the array

    # Calculate the column indices for the bottom-right corner of each chip
    if ((data_col - col) % (col - col_buffer) == 0):
        col_list = list(range(col, data_col + 1, col - col_buffer))
    else:
        col_list = list(range(col, data_col + 1, col - col_buffer))
        col_list.append(data_col)  # Ensure the last chip includes the right edge of the array

    # Create an empty list to store the chips
    res = []

    # Loop through the column and row indices to extract chips
    for j in col_list:  # Iterate over column indices
        for i in row_list:  # Iterate over row indices
            # Extract the subarray for the current chip
            res.append(data[i - row:i, j - col:j])

    # Convert the list of chips into a numpy array
    return np.array(res)

In [None]:
def chip_index_finder(index, row, col, data_row, data_col, row_buffer, col_buffer):
    """
    Calculate the coordinate range of a chip in a 2D array based on its index.

    Parameters:
    - index: The index of the chip in a flattened array of all chips.
    - row: Height of each chip.
    - col: Width of each chip.
    - data_row: Total number of rows in the original data array.
    - data_col: Total number of columns in the original data array.
    - row_buffer: Overlap between consecutive chips in the vertical direction.
    - col_buffer: Overlap between consecutive chips in the horizontal direction.

    Returns:
    - A tuple (minX, maxX, minY, maxY) representing the coordinate range of the chip.
    """
    # Calculate the number of chips along rows and columns
    if (data_row - row) % (row - row_buffer) > 0:
        row_num = int((data_row - row) / (row - row_buffer)) + 2
    else:
        row_num = int((data_row - row) / (row - row_buffer)) + 1

    if (data_col - col) % (col - col_buffer) > 0:
        col_num = int((data_col - col) / (col - col_buffer)) + 2
    else:
        col_num = int((data_col - col) / (col - col_buffer)) + 1

    # Determine the chip's row and column index
    row_index = index % row_num
    col_index = index // row_num

    # Calculate the row coordinate range for the chip
    if row_index == row_num - 1:
        row_coor = data_row
    elif row_index == 0:
        row_coor = row
    else:
        row_coor = row + row_index * (row - row_buffer)

    # Calculate the column coordinate range for the chip
    if col_index == col_num - 1:
        col_coor = data_col
    elif col_index == 0:
        col_coor = col
    else:
        col_coor = col + col_index * (col - col_buffer)

    # Return the coordinate range of the chip
    return (row_coor - row, row_coor, col_coor - col, col_coor)


In [None]:
def mosaic_chips(data_array, index_list, weight_array, data_row, data_col, row, col, row_buffer, col_buffer):
    res = np.zeros((data_row, data_col))
    for i in range(len(index_list)):
        chip_index = index_list[i]
        coors = chip_index_finder(chip_index, row, col, data_row, data_col, row_buffer, col_buffer)
        temp_array = weight_array[i,:,:]*data_array[i,:,:]
        res[coors[0]:coors[1],coors[2]:coors[3]] = res[coors[0]:coors[1],coors[2]:coors[3]]+temp_array
    return res

In [None]:
def weights_generator(weight_type, data_row, data_col, row, col, row_buffer, col_buffer):
    if ((data_row - row) % (row - row_buffer) > 0):
        row_num = int((data_row - row) / (row - row_buffer)) + 2
    else:
        row_num = int((data_row - row) / (row - row_buffer)) + 1
    if ((data_col - col) % (col - col_buffer) > 0):
        col_num = int((data_col - col) / (col - col_buffer)) + 2
    else:
        col_num = int((data_col - col) / (col - col_buffer)) + 1
    if (weight_type == 'no_buffer'):
        weight_array = [np.ones((row, col)) for _ in range(row_num * col_num)]
        right_margin = np.zeros((row, col))
        right_cut = (row if (data_row % row == 0) else data_row % row)
        right_margin[row - right_cut:, :] = 1
        down_margin = np.zeros((row, col))
        down_cut = (col if (data_col % col == 0) else data_col % col)
        down_margin[:, col - down_cut:] = 1
        corner_margin = right_margin * down_margin
        for i in range(row_num - 1, row_num * col_num, row_num):
            weight_array[i] = right_margin
        for i in range(row_num * col_num - row_num, row_num * col_num):
            weight_array[i] = down_margin
        weight_array[row_num * col_num - 1] = corner_margin
        weight_array = np.array(weight_array)
    if (weight_type == 'buffer_average'):
        weight_array = []
        template_zeros = np.zeros((data_row, data_col))
        for i in range(col_num * row_num):
            coors = chip_index_finder(i, row, col, data_row, data_col, row_buffer, col_buffer)
            temp_array = np.ones((row, col))
            template_zeros[coors[0]:coors[1], coors[2]:coors[3]] = template_zeros[coors[0]:coors[1],
                                                                   coors[2]:coors[3]] + temp_array
        template_zeros = 1.0 / template_zeros
        for i in range(col_num * row_num):
            coors = chip_index_finder(i, row, col, data_row, data_col, row_buffer, col_buffer)
            weight_array.append(template_zeros[coors[0]:coors[1], coors[2]:coors[3]])
        weight_array = np.array(weight_array)
    if (weight_type == 'buffer_gauss_average'):
        def gaus2d(x=0, y=0, mx=0, my=0, sx=1, sy=1):
            return 1. / (2. * np.pi * sx * sy) * np.exp(
                -((x - mx) ** 2. / (2. * sx ** 2.) + (y - my) ** 2. / (2. * sy ** 2.)))

        weight_array = []
        x = np.linspace(-5, 5, row)
        y = np.linspace(-5, 5, col)
        x, y = np.meshgrid(x, y)  # get 2D variables instead of 1D
        template_weights = gaus2d(x, y)
        weight_array = []
        template_zeros = np.zeros((data_row, data_col))
        for i in range(col_num * row_num):
            coors = chip_index_finder(i, row, col, data_row, data_col, row_buffer, col_buffer)
            template_zeros[coors[0]:coors[1], coors[2]:coors[3]] = template_zeros[coors[0]:coors[1],
                                                                   coors[2]:coors[3]] + template_weights

        for i in range(col_num * row_num):
            coors = chip_index_finder(i, row, col, data_row, data_col, row_buffer, col_buffer)
            weight_array.append(template_weights / template_zeros[coors[0]:coors[1], coors[2]:coors[3]])
        weight_array = np.array(weight_array)
    if (weight_type == 'buffer_linear_average'):
        assert row_buffer == col_buffer
        template_weights = np.ones((row, col))
        for i in range(row_buffer):
            pixel_int = 1. / row_buffer
            template_weights[i, i:(col - i)] = pixel_int * i
            template_weights[i:(row - i), i] = pixel_int * i
            template_weights[row - i - 1, i:(col - i)] = pixel_int * i
            template_weights[i:(row - i), col - i - 1] = pixel_int * i
        weight_array = []
        template_zeros = np.zeros((data_row, data_col))
        for i in range(col_num * row_num):
            coors = chip_index_finder(i, row, col, data_row, data_col, row_buffer, col_buffer)
            template_zeros[coors[0]:coors[1], coors[2]:coors[3]] = template_zeros[coors[0]:coors[1],
                                                                   coors[2]:coors[3]] + template_weights

        for i in range(col_num * row_num):
            coors = chip_index_finder(i, row, col, data_row, data_col, row_buffer, col_buffer)
            weight_array.append(template_weights / template_zeros[coors[0]:coors[1], coors[2]:coors[3]])
        weight_array = np.array(weight_array)

    if (weight_type == 'half_buffer'):
        half_row_buffer = int(row_buffer / 2)
        half_col_buffer = int(col_buffer / 2)
        template_weights = np.zeros((row, col))
        template_weights[half_row_buffer:(row - half_row_buffer), half_col_buffer:(col - half_col_buffer)] = 1
        weight_array = [template_weights for _ in range(row_num * col_num)]

        up_border = np.zeros((row, col))
        up_border[half_row_buffer:(row - half_row_buffer), 0:(col - half_col_buffer)] = 1
        down_border = np.zeros((row, col))
        down_border[half_row_buffer:(row - half_row_buffer), half_col_buffer:] = 1
        left_border = np.zeros((row, col))
        left_border[0:(row - half_row_buffer), half_col_buffer:(col - half_col_buffer)] = 1
        right_border = np.zeros((row, col))
        right_border[half_row_buffer:, half_col_buffer:(col - half_col_buffer)] = 1

        up_left_border = np.zeros((row, col))
        up_left_border[0:(row - half_row_buffer), 0:(col - half_col_buffer)] = 1

        up_right_border = np.zeros((row, col))
        up_right_border[half_row_buffer:, 0:(col - half_col_buffer)] = 1
        down_left_border = np.zeros((row, col))
        down_left_border[0:(row - half_row_buffer), half_col_buffer:] = 1
        down_right_border = np.zeros((row, col))
        down_right_border[half_row_buffer:, half_col_buffer:] = 1

        for i in range(0, row_num):
            weight_array[i] = up_border
        for i in range(0, row_num * col_num - 1, row_num):
            weight_array[i] = left_border

        right_cut = ((row - row_buffer) if ((data_row - row) % (row - row_buffer) == 0) else (data_row - row) % (
                    row - row_buffer))
        down_cut = ((col - col_buffer) if ((data_col - col) % (col - col_buffer) == 0) else (data_col - col) % (
                    col - col_buffer))
        #         half_row_buffer:(row - half_row_buffer), half_col_buffer:(col - half_col_buffer)
        if (right_cut == (row - row_buffer)):
            for i in range(row_num - 1, row_num * col_num - 1, row_num):
                weight_array[i] = right_border
            if (down_cut == (col - col_buffer)):
                for i in range(row_num * col_num - row_num, row_num * col_num - 1):
                    weight_array[i] = down_border
                weight_array[0] = up_left_border
                weight_array[row_num - 1] = up_right_border
                weight_array[row_num * (col_num - 1)] = down_left_border
                weight_array[row_num * col_num - 1] = down_right_border
            else:
                down_margin = np.zeros((row, col))
                down_margin[half_row_buffer:(row - half_row_buffer), col - down_cut:] = 1
                left_down_margin = np.zeros((row, col))
                left_down_margin[0:(row - half_row_buffer), col - down_cut:] = 1
                right_down_margin_2 = np.zeros((row, col))
                right_down_margin_2[half_row_buffer:(row), col - down_cut:] = 1
                for i in range(row_num * col_num - row_num, row_num * col_num - 1):
                    weight_array[i] = down_margin
                    weight_array[i - row_num] = down_border
                weight_array[0] = up_left_border
                weight_array[row_num - 1] = up_right_border
                weight_array[row_num * (col_num - 2)] = down_left_border
                weight_array[row_num * (col_num - 1)] = left_down_margin
                weight_array[row_num * col_num - row_num - 1] = down_right_border
                weight_array[row_num * col_num - 2] = right_down_margin_2
        else:
            right_margin = np.zeros((row, col))
            right_margin[row - right_cut:, half_col_buffer:(col - half_col_buffer)] = 1
            right_up_margin = np.zeros((row, col))
            right_up_margin[row - right_cut:, 0:(col - half_col_buffer)] = 1
            right_down_margin = np.zeros((row, col))
            right_down_margin[row - right_cut:, half_col_buffer:col] = 1
            for i in range(row_num - 1, row_num * col_num - 1, row_num):
                weight_array[i] = right_margin
                weight_array[i - 1] = right_border

            if (down_cut == (col - col_buffer)):
                for i in range(row_num * col_num - row_num, row_num * col_num - 1):
                    weight_array[i] = down_border
                weight_array[0] = up_left_border
                weight_array[row_num - 2] = up_right_border
                weight_array[row_num - 1] = right_up_margin
                weight_array[row_num * col_num - row_num - 1] = right_down_margin
                weight_array[row_num * (col_num - 1)] = down_left_border
                weight_array[row_num * col_num - 2] = down_right_border
            else:
                down_margin = np.zeros((row, col))
                down_margin[half_row_buffer:(row - half_row_buffer), col - down_cut:] = 1
                left_down_margin = np.zeros((row, col))
                left_down_margin[0:(row - half_row_buffer), col - down_cut:] = 1
                right_down_margin_2 = np.zeros((row, col))
                right_down_margin_2[half_row_buffer:(row), col - down_cut:] = 1
                for i in range(row_num * col_num - row_num, row_num * col_num - 1):
                    weight_array[i] = down_margin
                    weight_array[i - row_num] = down_border
                weight_array[0] = up_left_border
                weight_array[row_num - 2] = up_right_border
                weight_array[row_num - 1] = right_up_margin
                weight_array[row_num * col_num - row_num - 1] = right_down_margin
                weight_array[row_num * (col_num - 2)] = down_left_border
                weight_array[row_num * (col_num - 1)] = left_down_margin
                weight_array[row_num * col_num - row_num - 2] = down_right_border
                corner_margin = np.zeros((row, col))
                corner_margin[row - right_cut:, col - down_cut:] = 1
                weight_array[row_num * col_num - 2] = right_down_margin_2
                weight_array[row_num * col_num - 1] = corner_margin
                print(corner_margin)
        weight_array = np.array(weight_array)

    return weight_array

In [None]:
# output the data in the same format
def output_same(data, template_file_name, output_name, gdal_type):
    gtif = gdal.Open(template_file_name)
    # get the first band in the file
    band = gtif.GetRasterBand(1)
    # get the rows and cols of the input file
    rows = gtif.RasterYSize
    cols = gtif.RasterXSize
    output_format = output_name.split('.')[-1].upper()
    if (output_format == 'TIF'):
        output_format = 'GTIFF'
    elif (output_format == 'RST'):
        output_format = 'rst'
    driver = gdal.GetDriverByName(output_format)
    outDs = driver.Create(output_name, cols, rows, 1, gdal_type)
    outBand = outDs.GetRasterBand(1)
    outBand.WriteArray(data)
    # georeference the image and set the projection
    outDs.SetGeoTransform(gtif.GetGeoTransform())
    outDs.SetProjection(gtif.GetProjection())
    outDs.FlushCache()
    outBand.SetNoDataValue(-99)
    # need to release the driver
    del outDs

### Prediction

In [None]:
predict_ID = '116060'
project_name = 'test'
res_folder = '/home/vishal/DLrepo/results/'
scene_dict = {'116060':'/workspace/_libs/116060_preprocessed'}

In [None]:
project_folder = os.path.join(res_folder,project_name)
sourcefolder = scene_dict[predict_ID]
model_path = os.path.join(project_folder, project_name+".h5")

In [None]:
model_final = load_model(model_path)

In [None]:
folderlist = []
## get bands 2~7
for i in os.listdir(sourcefolder):
    if(i.endswith('rst') and ("_B" in i)):
        folderlist.append(os.path.join(sourcefolder,i))
folderlist.sort()
folderlist
predict_data_source = []
for i in folderlist:
    ds = ReadData_geoinf(i)
    predict_data_source.append(ds['data'])
predict_data_source = np.stack(predict_data_source,axis=-1)

In [None]:
predict_chips = cut_array(predict_data_source, 256, 256, 128, 128)

In [None]:
predict_label = model_final.predict(predict_chips)

In [None]:
mosaic_weights = weights_generator('half_buffer', ds['Ysize'], ds['Xsize'], 256, 256, 128, 128)

In [None]:
mosaic_res =  mosaic_chips(predict_label[:,:,:,0], range(len(predict_label)),mosaic_weights, ds['Ysize'], ds['Xsize'], 256, 256, 128, 128)

In [None]:
prediction_path = os.path.join(project_folder, project_name+'_prob_'+predict_ID+'.rst')
output_same(mosaic_res, os.path.join(sourcefolder,folderlist[0]), prediction_path, gdal.GDT_Float32)