In [6]:
import numpy as np
from osgeo import gdal
import cv2
import math


def read_img(filename):
    """
    Read a remote sensing image and extract its metadata and pixel data.

    Args:
        filename (str): Path to the image file.

    Returns:
        tuple: (projection, geotransform, image data, height, width)
    """
    dataset = gdal.Open(filename)
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height)
    del dataset
    return im_proj, im_geotrans, im_data, im_height, im_width


def write_img(filename, im_proj, im_geotrans, im_data):
    """
    Save an image to a file with the specified projection and geotransform.

    Args:
        filename (str): Output file path.
        im_proj (str): Projection information.
        im_geotrans (tuple): Geotransform information.
        im_data (numpy.ndarray): Image data to save.
    """
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset


def truncated_linear_stretch(image, truncated_value):
    """
    Perform a truncated linear stretch on the image to enhance contrast.

    Args:
        image (numpy.ndarray): Input image.
        truncated_value (float): Percentage to truncate from both ends.

    Returns:
        numpy.ndarray: Stretched image.
    """
    def gray_process(gray):
        truncated_down = np.percentile(gray, truncated_value)
        truncated_up = np.percentile(gray, 100 - truncated_value)
        max_out = np.max(gray)
        min_out = np.min(gray)
        gray = (gray - truncated_down) / (truncated_up - truncated_down) * (max_out - min_out) + min_out
        gray[gray < min_out] = min_out
        gray[gray > max_out] = max_out
        return gray

    if len(image.shape) == 3:
        image_stretch = [gray_process(band) for band in image]
        image_stretch = np.array(image_stretch)
    else:
        image_stretch = gray_process(image)
    return image_stretch.astype(np.float32)


def calculate_average_patch_area(img):
    """
    Calculate the average patch area of an image using k-means clustering.

    Args:
        img (numpy.ndarray): Input image.

    Returns:
        float: Average patch area.
    """
    img_mean = np.mean(img, 0)
    data = img_mean.reshape((-1, 1)).astype(np.float32)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 2.0)
    flags = cv2.KMEANS_RANDOM_CENTERS
    _, labels, _ = cv2.kmeans(data, 5, None, criteria, 10, flags)
    k_img = labels.reshape(img_mean.shape).astype(np.uint8)

    area_all = []
    for x in range(5):
        t = np.where(k_img == x, 1, 0).astype(np.uint8)
        num_labels, _, stats, _ = cv2.connectedComponentsWithStats(t, connectivity=4)
        areas = [stats[i][-1] for i in range(num_labels)]
        area_all.extend(areas)
    return np.mean(area_all)


def gaussian_downscale(image, av_area):
    """
    Downscale the image using Gaussian pyramid.

    Args:
        image (numpy.ndarray): Input image.
        av_area (float): Average patch area.

    Returns:
        numpy.ndarray: Downscaled image.
    """
    def creat_gauss_kernel(kernel_size, sigma=1):
        if sigma == 0:
            sigma = ((kernel_size - 1) * 0.5 - 1) * 0.3 + 0.8
        else:
            sigma_3 = 3 * sigma
            X = np.linspace(-sigma_3, sigma_3, kernel_size)
            Y = np.linspace(-sigma_3, sigma_3, kernel_size)
            x, y = np.meshgrid(X, Y)
            gauss_1 = 1 / (2 * np.pi * sigma ** 2) * np.exp(- (x ** 2 + y ** 2) / (2 * sigma ** 2))
            Z = gauss_1.sum() 
            gauss_2 = (1 / Z) * gauss_1
        return gauss_2

    D = int(np.log2(img.shape[1] / np.sqrt(av_area)))
    Guass = img.transpose((1, 2, 0))
    kernel_size = 11
    COL = [img.shape[2]]
    ROW = [img.shape[1]]
    for j in range(D):
        Guass = cv2.GaussianBlur(Guass, (kernel_size, kernel_size), 0)
        Guass = Guass[0::2, 0::2]
        COL.append(Guass.shape[1])
        ROW.append(Guass.shape[0])
    Guass = Guass.transpose((2, 0, 1))

    for j in range(D):
        array_zero = np.zeros((3, ROW[D - j - 1], COL[D - j - 1]))
        for a in range(ROW[D - j]):
            for b in range(COL[D - j]):
                array_zero[:, 2 * a, 2 * b] = Guass[:, a, b]
        Guass = array_zero
        kernel = creat_gauss_kernel(kernel_size, sigma=1)
        kernel = 4 * kernel
        Guass = Guass.transpose((1, 2, 0))
        Guass = cv2.filter2D(Guass, -1, kernel)
        Guass = Guass.transpose((2, 0, 1))
    return Guass


def kmeans_clustering(image, num_clusters):
    """
    Perform k-means clustering on the image.

    Args:
        image (numpy.ndarray): Input image.
        num_clusters (int): Number of clusters.

    Returns:
        numpy.ndarray: Clustered image.
    """
    img_mean = np.mean(image, 0)
    data = img_mean.reshape((-1, 1)).astype(np.float32)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 2.0)
    flags = cv2.KMEANS_RANDOM_CENTERS
    _, labels, _ = cv2.kmeans(data, num_clusters, None, criteria, 10, flags)
    return labels.reshape(img_mean.shape).astype(np.uint8)


def seeds_superpixel(image, av_area):
    """
    Perform SEEDS superpixel segmentation on the image.

    Args:
        image (numpy.ndarray): Input image.
        av_area (float): Average patch area.

    Returns:
        numpy.ndarray: Superpixel labels.
    """
    D = int(np.log2(image.shape[1] / np.sqrt(av_area)))
    image = image.transpose((2, 1, 0))
    scale = 2 ** D
    seeds = cv2.ximgproc.createSuperpixelSEEDS(image.shape[1], image.shape[0], image.shape[2], scale, 15, 7, 5, True)
    seeds.iterate(cv2.convertScaleAbs(image), 10)
    return seeds.getLabels().transpose((1, 0))


def boundary_refinement(initial_boundary, fine_boundary):
    """
    Refine the boundary using initial and fine segmentation results.

    Args:
        initial_boundary (numpy.ndarray): Initial boundary labels.
        fine_boundary (numpy.ndarray): Fine boundary labels.

    Returns:
        numpy.ndarray: Refined boundary labels.
    """
    n = np.max(fine_boundary)
    for x in range(n + 1):
        region_values = initial_boundary[fine_boundary == x]
        unique_values = np.unique(region_values)
        if len(unique_values) == 1:
            fine_boundary[fine_boundary == x] = unique_values[0]
        else:
            fine_boundary[fine_boundary == x] = np.argmax(np.bincount(region_values))
    return fine_boundary


def morphological_processing(image, kernel_size=7):
    """
    Perform morphological operations (opening and closing) on the image.

    Args:
        image (numpy.ndarray): Input image.
        kernel_size (int): Size of the kernel.

    Returns:
        numpy.ndarray: Processed image.
    """
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    processed_image = np.zeros_like(image)
    for region_value in range(np.max(image) + 1):
        binary_mask = (image == region_value).astype(np.uint8)
        binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
        binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
        processed_image[binary_mask == 1] = region_value
    return processed_image


# Main workflow
if __name__ == "__main__":
    input_path = r'E:\scene boundary identification\scale test\data\data1.tif'
    output_path = r'E:\scene boundary identification\scale test\results\data1_SBVP.tif'
    num_clusters = 2

    # Read image
    proj, geotrans, img, row, column = read_img(input_path)

    # Calculate average patch area
    av_area = calculate_average_patch_area(img)

    # Downscale image
    img_downscaled = gaussian_downscale(img, av_area)

    # Perform linear stretch
    img_stretched = truncated_linear_stretch(img_downscaled, 2)

    # Perform initial boundary segmentation
    initial_boundary = kmeans_clustering(img_stretched, num_clusters)

    # Perform fine boundary segmentation
    fine_boundary = seeds_superpixel(img, av_area)

    # Refine boundaries
    refined_boundary = boundary_refinement(initial_boundary, fine_boundary)

    # Save result
    write_img(output_path, proj, geotrans, refined_boundary)