to investigate how to zoom out

In [1]:
# %load_ext autoreload

In [1]:
# merge the inputs
# %autoreload
import sys
import os
sys.path.append("..")

from costum_arild.source import data_processing
from costum_arild.source.data_processing import TrainingImage
from costum_arild.source.utils import image_processing_utils

In [None]:
def divide_image_add_channel(image_filepath, label_filepath, image_size=512, do_overlap=False, do_crop=False, add_filter_channel=False):
    
    """
    add_filter_channel: indicate if we want to have filters as external channels
    """
    
    # Load image
    image_ds = gdal.Open(image_filepath)
    geo_transform = image_ds.GetGeoTransform()
    projection = image_ds.GetProjection()
    # check the number of rasters
    image_matrix = np.array(([image_ds.GetRasterBand(band_idx+1).ReadAsArray() for band_idx in range(image_ds.RasterCount)]))
    if len(image_matrix.shape) > 2:
        # get (x,y, channel) form
        image_matrix = np.transpose(image_matrix, axes=[1, 2, 0])
    
    # this part is different since we want to have Gray scale
    from PIL import Image
    image_matrix = np.array(Image.fromarray(image_matrix).convert('L'))
    image_ds = None
    

    # Load label
    label_ds = gdal.Open(label_filepath)
    if label_ds.GetGeoTransform() != geo_transform:
        raise Exception(f"The geo transforms of image {image_filepath} and label {label_filepath} did not match")
    label_matrix = label_ds.GetRasterBand(1).ReadAsArray()
    label_ds = None

    training_data = []
    # Make properly sized training data
    # Make sure that the whole image is covered, even if the last one has to overlap
    if do_overlap:
        shape_0_indices = list(range(image_size // 4, image_matrix.shape[0], image_size // 4))[:-4]
        shape_1_indices = list(range(image_size // 4, image_matrix.shape[1], image_size // 4))[:-4]
    else:
        shape_0_indices = list(range(0, image_matrix.shape[0], image_size))
        shape_0_indices[-1] = image_matrix.shape[0] - image_size
        shape_1_indices = list(range(0, image_matrix.shape[1], image_size))
        shape_1_indices[-1] = image_matrix.shape[1] - image_size
    # Split the images
    for shape_0 in shape_0_indices:
        for shape_1 in shape_1_indices:
            if do_crop:
                # Extract labels for the center of the image
                labels = label_matrix[shape_0 + image_size // 4:shape_0 + image_size - image_size // 4,
                         shape_1 + image_size // 4:shape_1 + image_size - image_size // 4]
            else:
                labels = label_matrix[shape_0:shape_0 + image_size, shape_1:shape_1 + image_size]
            # Check if the image has to much unknown
            if not data_processing.is_quality_image(labels):
                continue

            # Calculate the geo transform of the label
            label_geo_transform = list(geo_transform)
            if do_crop:
                label_geo_transform[0] += (shape_1 + image_size//4) * geo_transform[1]  # East
                label_geo_transform[3] += (shape_0 + image_size//4) * geo_transform[5]  # North
            else:
                label_geo_transform[0] += (shape_1) * geo_transform[1]  # East
                label_geo_transform[3] += (shape_0) * geo_transform[5]  # North

            data = image_matrix[shape_0:shape_0 + image_size, shape_1:shape_1 + image_size]
            filter_channel = image_processing_utils.qshitf_boundary(data, ratio=0.9)
            cluster_channel = image_processing_utils.laplacian_filter(data, sigma=5)
            
            data_3d = np.stack([data, filter_channel, cluster_channel], axis=-1)
            
            new_data_geo_transform = list(geo_transform)
            new_data_geo_transform[0] += shape_1 * geo_transform[1]  # East
            new_data_geo_transform[3] += shape_0 * geo_transform[5]  # North

            name = os.path.split(image_filepath)[-1].replace(".tif", "") + f"_n_{shape_0}_e_{shape_1}"
            training_data.append(TrainingImage(data_3d, labels, new_data_geo_transform, name=name, projection=projection,
                                               label_geo_transform=label_geo_transform, east_offset=shape_1,
                                               north_offset=shape_0))
    return training_data