# Dependencies

Use extra UbuntuGIS repository to get GDAL version 3.0 or higher, sice Colab's native version 2.2.3 is too old for the pipeline.

If after installation version of GDAL at the end is still 2.2.3, then restart runtime.

In [None]:
!add-apt-repository -y ppa:ubuntugis/ubuntugis-unstable
!apt install python3-gdal=3.0.4+dfsg-1~bionic0

from osgeo import gdal; print(f"GDAL version {(gdal.__version__)}")

# Google Drive

Mount Google Drive to access SAR images (recomended to store input and output images). If you're not working with Google Colab, then you probably do not need this cell.

In [None]:
# Google Drive

import os

from google.colab import drive

# Do not mount if it is already attached
if not os.path.exists("/content/drive"):
    drive.mount("/content/drive")
    # print(f"Mounted Google Drive")
else:
    print(f"Google Drive is already mounted")

# Functions

Functions cell. Just like functions stored in _utils.py_ file.

In [None]:
import os
import cv2 as cv
import numpy as np

from typing import List, Tuple, Union


def gdal_callback(completed: float, message: str, args: None) -> bool:
    progress = int(completed * 100)
    if progress == 100:
        print(progress, '- done.')
    elif progress % 10 == 0:
        print(progress, end='', flush=True)
    elif progress % 2 == 0:
        print('.', end='', flush=True)
    return True


def adjust_gamma(image: np.ndarray, gamma: float = 1.1,
                 pad: Union[int, Tuple[int, int]] = 1) -> np.ndarray:
    assert image.ndim == 2, f"Input image must be grayscale!"
    if type(pad) is Tuple[int, int]:
        pad_low, pad_high = pad
    else:
        pad_low, pad_high = (pad, 0)
    lut = ((np.linspace(pad_low, 255 - pad_high, 256) / 255) ** (1 / gamma) *
           255).round().astype(np.uint8)
    return cv.LUT(image, lut).astype(np.uint8)


def apply_kmeans(image: np.ndarray, num_clusters: int, cycles: int = 10,
                 iters: int = 10, eps: float = 0.9,
                 mask: int = 255) -> np.ndarray:
    assert image.ndim == 2, f"Image must be 2D, but {image.ndim}D is given!"
    # Samples are the float32 image with the mask color dropped to zero
    samples = (np.float32(image) *
                (image != mask).astype(np.uint8)).reshape(-1, 1)
    criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, iters, eps)
    _, labels, centers = cv.kmeans(samples, NUM_CLUSTERS, None, criteria,
                                    10, cv.KMEANS_PP_CENTERS)
    spread = np.linspace(0, 255, centers.shape[0] + 1)\
                        [centers.argsort(axis=0)].round().astype(np.uint8)
    return spread[labels.flatten()].reshape(image.shape)


def extend_mask(image: np.ndarray, n: int = -5, color: int = 255) -> np.ndarray:
    lut = np.arange(256, dtype=np.uint8)
    if n > 0:
        lut[:n] = color
    elif n < 0:
        lut[n:] = color
    return cv.LUT(image, lut).astype(np.uint8)

# Processing

Main processing loop. Change the following variables for yourself:
* `PATH_INPUT` - where input SAR images are stored, that may be any level of depth, sice they are listed with `glob`;
* `PATH_TEMP` - where to store temporary dataset, and it's recommended to use locations under _/content_, when using Google Colab;
* `PATH_OUTPUT` - to write output files to (subdirectories are created automatically), and it's recommended to use Google Drive (existing files in output directory will be overwritten);
* `GDAL_MAX_RAM` (optional) - limit cache size for reading GDAL datasets (default is 10 MiB);
* `GDAL_TILE_SIZE` (optional) - limit tile size when processing tile-by-tile, e.g. for bilateral filtering of gamma correction (clustering is supposed to be done over the whole imgage).

Also note to change `glob` pattern for `files` variable to match your directory structure.

In [None]:
%%time

import os
import numpy as np
import matplotlib.pyplot as plt

from glob import glob
from matplotlib import cm
from osgeo import ogr, gdal, gdalconst


PATH_TEMP = '/content/temp'
PATH_INPUT = '/content/drive/My Drive/Colab Notebooks/SAR Processing/input'
PATH_OUTPUT = '/content/drive/My Drive/Colab Notebooks/SAR Processing/output'
FILE_SHAPEFILE = PATH_INPUT + '/Start_Ice_Map_UTMz40WGS84f_r.shp'

GDAL_MAX_RAM = 600 * 1024 * 1024 # 600 MiB
GDAL_TILE_SIZE = 2048 # 3200 is maximum (Google Colab)
NUM_CLUSTERS = 7 # K-means
NUM_CHANNELS = 3 # RGB output
GAMMA = 1.0 # 1.0 - unchanged

if os.path.isdir(PATH_INPUT):
    os.makedirs(PATH_OUTPUT, exist_ok=True)
else:
    raise FileNotFoundError(f"Path '{PATH_INPUT}' must exist!")

if not os.path.isfile(FILE_SHAPEFILE):
    raise FileNotFoundError(f"Shapefile '{FILE_SHAPEFILE}' must exist!")

files = glob(os.path.join(PATH_INPUT, 'HH', '*.tif'))
print(f"Source files -->\n")
print('\n'.join(files))
print(f"Source shape is {FILE_SHAPEFILE}")
# !gdalinfo "{files[0]}"

try:
    # To shake your shape like a sine wave
    if os.path.isfile(FILE_SHAPEFILE):
        shape = os.path.abspath(os.path.realpath(FILE_SHAPEFILE))
    else:
        raise FileNotFoundError
except (TypeError, FileNotFoundError) as e:
    print(f"Shapefile '{FILE_SHAPEFILE}' does not exist!")
    shape = None
print(f"Available shape is {shape}")
# !ogrinfo "{shape}"

assert int(gdal.__version__.split('.')[0]) >= 3, f"Required GDAL version >=3.0!"
gdal.UseExceptions()

# Get colormap from Matplotlib to colorize clusters
colormap = (cm.terrain(range(256)) * 255).round().astype(np.uint8)

print(f"\nProcessing files -->")
for filename in files:
    # Input images assumed to be grayscale
    print(f"\nInput file is {filename}")
    output, _ = os.path.splitext(filename.replace(PATH_INPUT, PATH_OUTPUT))
    print(f"Output is {output}")
    os.makedirs(os.path.dirname(output), exist_ok=True)
    temp = filename.replace(PATH_INPUT, PATH_TEMP)
    print(f"Temporary filename is {temp}")
    os.makedirs(os.path.dirname(temp), exist_ok=True)

    # Create temporary RGB GeoTIFF
    if 'dataset' in locals():
        del dataset
    dataset = gdal.Open(filename, gdal.GA_ReadOnly)
    if dataset.RasterCount < 1:
        # raise AttributeError(f"Source dataset has no rasters!")
        print(f"ERROR: dataset {filename} has no rasters! Skipping...")
        continue
    if dataset.RasterCount < NUM_CHANNELS:
        channels = [gdal.GCI_RedBand, gdal.GCI_GreenBand,
                    gdal.GCI_BlueBand, gdal.GCI_AlphaBand]
        tempset = gdal.GetDriverByName('MEM').CreateCopy('', dataset, 0)
        band = tempset.GetRasterBand(1)
        layer: np.ndarray = band.ReadAsArray()
        band.SetColorInterpretation(channels[0])
        for i in range(tempset.RasterCount + 1, NUM_CHANNELS + 1):
            tempset.AddBand()
            band = tempset.GetRasterBand(i)
            band.WriteArray(layer)
            band.SetColorInterpretation(channels[i - 1])
        del band
        # Change projection and resolution
        options = gdal.WarpOptions(format='GTiff', dstSRS='EPSG:32640',
                                srcNodata=255, dstNodata=255,# geoloc=False,
                                xRes=40, yRes=40, cutlineDSName=f"{shape}",
                                cropToCutline=(True if shape else False),
                                callback=gdal_callback)
        print(f"Warping source file into {temp}...")
        gdal.Warp(temp, tempset, options=options)
        del tempset
    del dataset

    # Process temporary RGB GeoTIFF
    dataset = gdal.Open(temp, gdal.GA_Update)
    try:
        if dataset.RasterCount < 1:
            raise AttributeError(f"Temp dataset has no rasters!")
        else:
            print(f"GeoTIFF rasters = {dataset.RasterCount}")
        print(f"Dataset raster size = ({dataset.RasterYSize}, {dataset.RasterXSize})")

        # Try to process RGB image tile-by-tile
        if not type(GDAL_TILE_SIZE) is int or GDAL_TILE_SIZE > 3072:
            GDAL_TILE_SIZE = 3072
        tile_x_size = GDAL_TILE_SIZE
        tile_y_size = GDAL_TILE_SIZE
        tiles = dataset.GetTiledVirtualMemArray(eAccess=gdalconst.GF_Write,
                                            tilexsize=tile_x_size,
                                            tileysize=tile_y_size,
                                            cache_size=GDAL_MAX_RAM,
                                            tile_organization=gdalconst.GTO_TIP)
        try:
            print(f"Tiles array shape = {tiles.shape}",
                  "(tilesY, tilesX, Y, X, channels)")
            # Kernels: rect, cross, ellipse (morphology)
            # kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (9, 9))
            print(f"Preprocessing...")
            for i in range(tiles.shape[0]):
                for j in range(tiles.shape[1]):
                    correction = tiles[i, j, ..., 0]
                    correction = extend_mask(correction, -32)
                    if GAMMA != 1.0:
                        correction = adjust_gamma(correction, gamma=GAMMA, pad=0)
                    correction = cv.bilateralFilter(correction, 31, 63, 63)
                    # _, correction = cv.threshold(correction, 1, 255, cv.THRESH_BINARY)
                    # correction = cv.morphologyEx(correction, cv.MORPH_CLOSE, kernel)
                    tiles[i, j, ...] = np.repeat(correction[..., None], 3, axis=-1)
                    print('.', end='', flush=True)
                print()
        finally:
            del tiles

        # Try to process the whole RGB image at once
        image = dataset.GetVirtualMemArray(eAccess=gdalconst.GF_Write,
                                           cache_size=GDAL_MAX_RAM,
                                           band_sequential=False)
        try:
            print(f"Clustering {image.shape} image...")
            correction = image[..., 0]
            plt.subplots(1, 1, figsize=(25, 10))[1].imshow(correction, cmap='gray')
            histograms = plt.subplots(1, 1, figsize=(25, 10))[1]
            histograms.hist(correction.ravel(), [127], [1, 255])
            histograms.set_xlim([0, 127])
            clustered = apply_kmeans(correction, NUM_CLUSTERS, eps=0.95)
            print(f"Done!")

            # Assemble into channels with colormap applied
            clustered = np.stack([cv.LUT(clustered, colormap[..., 0]) |
                                  (correction == 255).astype(np.uint8) * 255,
                                  cv.LUT(clustered, colormap[..., 1]) |
                                  (correction == 255).astype(np.uint8) * 255,
                                  cv.LUT(clustered, colormap[..., 2]) |
                                  (correction == 255).astype(np.uint8) * 255],
                                 axis=2)
            image[...] = clustered
            plt.subplots(1, 1, figsize=(25, 10))[1].imshow(clustered)
        finally:
            del image

        # Save temporary dataset to destination (output raster file)
        destination = os.path.join(output, 'image')
        os.makedirs(destination, exist_ok=True)
        destination = os.path.join(destination, os.path.basename(temp))
        print(f"Saving image to {destination}...")
        gdal.Translate(destination, dataset)

        # Vectorize clusters (create output shapefile set for clusters)
        # WARNING: shapefile may be quite large
        destination = os.path.join(output, 'shape')
        os.makedirs(destination, exist_ok=True)
        destination = os.path.join(destination, os.path.basename(temp))
        destination = os.path.splitext(destination)[0] + '.shp'

        tempset = ogr.GetDriverByName('ESRI Shapefile').CreateDataSource(destination)
        try:
            band_source = dataset.GetRasterBand(1)
            band_mask = band_source.GetMaskBand()
            srs = dataset.GetSpatialRef()

            layer = tempset.CreateLayer('out', geom_type=ogr.wkbPolygon, srs=srs)
            layer.CreateField(ogr.FieldDefn('DN', ogr.OFTInteger))

            options = []
            field = 0

            print(f"Saving shape to {destination}...")
            gdal.Polygonize(band_source, band_mask, layer, field, options,
                            callback=gdal_callback)
        finally:
            del tempset

        # Calculate final histogram (optional) - shall be less than
        # NUM_CLUSTERS peaks for each channel (R, G, B)
        print(f"Building histograms...")
        histograms = plt.subplots(1, 1, figsize=(25, 10))[1]
        for i, c in enumerate(['b', 'g', 'r']):
            histogram = cv.calcHist([clustered], [i], None, [127], [1, 255])
            histograms.plot(histogram, color=c)
            histograms.set_xlim([0, 127])
        plt.show()
    finally:
        del dataset

In [None]:
# !7z a -mx=9 -ms=on "/content/{os.path.basename(output)}.7z" "{output}"
# !cp /content/*.7z "{output}"

In [None]:
# !ogrinfo "{os.path.splitext(destination)[0] + '.shp'}"

In [None]:
# !uname -a
# !getconf PAGE_SIZE
# !df -h /dev/shm
# !free -h