# Dependencies

Use extra UbuntuGIS repository to get _GDAL_ version 3.0.4 or higher, since _Colab_'s native version 2.2.x is too old for the pipeline.

> If after installation version of _GDAL_ at the end is still 2.2.x, then restart runtime.

In [None]:
# Check container OS version (for correct UbuntuGIS package version)
!lsb_release -a

## GDAL

> If _GDAL 3.0.4_ (library and _Python_ bindings) or above is already installed in the system, then just skip or comment out the cell below, 'cause it's intended for _Google Colab_, which has _GDAL 2.2.x_ only.

In [None]:
# Dark magic happens here: installing dependencies for GDAL 3.0.4
# build process via APT and install GDAL itself via PyPI
!add-apt-repository -y ppa:ubuntugis/ubuntugis-unstable
!apt install python3-gdal=3.0.4+dfsg-1~bionic0
!apt purge --autoremove python3-gdal
!pip install gdal==3.0.4
!apt install gdal-bin=3.0.4+dfsg-1~bionic0

from osgeo import gdal; print(f"GDAL version {(gdal.__version__)}")

## Segmentation models

![Segmentation models](https://camo.githubusercontent.com/51eea85ed59f27be0485cc5774d09b522ea8e77cd3f0753c085cacd18d4a41a0/68747470733a2f2f692e6962622e636f2f4774784753386d2f5365676d656e746174696f6e2d4d6f64656c732d56312d536964652d332d312e706e67)

_Segmentation models_, _PyTorch_ variant will be used in this pipeline. _Catalyst_ or _PyTorch Lightning_  (or any other deep learning framework) may also be used (but now it's time for _Segmentation Models_).


In [None]:
# %pip install git+https://github.com/qubvel/segmentation_models.pytorch.git@v0.2.0
%pip install segmentation-models-pytorch==0.2.0

# Common

Common part is a prerequisite for all further parts (data loading, training, inference).

## Google Drive

Mount _Google Drive_ for SAR images (recomended to store input and output images).

> If _Google Colab_ is not used, then this cell may be commented out.

In [None]:
from os import path as osp

from google.colab import drive


PATH_DRIVE = osp.join('/', 'content', 'drive')

# Do not mount if it is already attached
if not osp.exists(PATH_DRIVE):
    print("Mounting Google Drive...")
    drive.mount(PATH_DRIVE)
else:
    print("Google Drive has been already mounted!")

## Paths

Paths to be used by preprocessing steps. Use `PATH_STORAGE` as a subdirectory hierarchy to store processed GeoTIFFs/Shapefiles right into _Google Drive_ (empty string will make saving to _Google Drive_'s root into folders `input`/`output`). `PATH_STORAGE` is used with the `PATH_DRIVE` variable only.

`PATH_TEMP` is used to store intermediate GeoTIFFs while processing.

`PATH_INPUT` is used as a source of GeoTIFFs (differs from _dataset_ directory in that it have **no masks**).

`PATH_OUTPUT` is used to save images (for example, inference on a test subset).

`PATH_MODELS` is used to save model weights for later inference.

`PATH_DATASET` is a source of GeoTIFF **images** and **masks** for training and inference (for example, test subset).

`PATH_RESOURCES` is used as a source of auxiliary files such as GeoJSON search area or Shapefile cutline.

In [None]:
from os import path as osp


PATH_STORAGE = osp.join('ods', 'soc')  # arbitrary subpath in Google Drive (if any)
if 'PATH_DRIVE' in locals():
    PREFIX_DRIVE = osp.join(osp.basename(PATH_DRIVE), 'MyDrive', PATH_STORAGE)
else:
    PREFIX_DRIVE = ''

PATH_TEMP = osp.join('/', 'content', 'temp')
PATH_INPUT = osp.join('/', 'content', PREFIX_DRIVE, 'input')
PATH_OUTPUT = osp.join('/', 'content', PREFIX_DRIVE, 'output')
PATH_MODELS = osp.join('/', 'content', PREFIX_DRIVE, 'models')
PATH_DATASET = osp.join('/', 'content', PREFIX_DRIVE, 'dataset')
PATH_RESOURCES = osp.join('/', 'content', 'resources')

# FILE_SHAPEFILE = osp.join(PATH_RESOURCES, 'clustering', 'cutline',
#                           'Start_Ice_Map_UTMz40WGS84f_r.shp')

print('\n'.join((PATH_STORAGE, PATH_TEMP, PATH_INPUT, PATH_OUTPUT, PATH_MODELS,
                 PATH_DATASET, PATH_RESOURCES)))

## Functions

Auxiliary functions (for example, drawing/plotting data).

In [None]:
from matplotlib import pyplot as plt


def draw_one_row(*images, size=1024):
    try:
        size = size[:2]
    except:
        size = (size, size)
    count = len(images)
    figure, axes = plt.subplots(1, count, dpi=72,
                                figsize=(size[0] / 72, size[1] / 72))
    for i in range(count):
        axes[i].imshow(images[i])
    plt.show()

# Data Loading


## Dataset class

The most interesting part of the pipeline, 'cause it's an interface between data and the neural network models. Here is an example of a dataset class (`DatasetSAR`) which loads a pair of HH+HV images and their corresponding masks from different directories, stacks them together, applies goemetrics transformations and color augmentations (here the general _augmentations_ term was split into geometric and color in such a way).

Parameters for the `DatasetSAR` class constructor:

* `paths_images` is a list of strings (or a single string in case of only one polarization used) where each string point to an existing directory with HH or HV polarized images;
* `paths_masks` is a list of strings / single string that point(s) to where masks are located (only first directory from the list is used);
* `classes` is a list of classes used — important for multiclass masks (where they should be **expanded** into one-hot encoded tensor);
* `items` is a list of filenames to be used as the dataset items (WARNING: passed list of filenames is acceted as-is with no checks) — useful for creating subsets (for example, train/valid) from some superset (list of all items of the dataset);
* `transformations` is a [torchvision transforms](https://pytorch.org/vision/stable/transforms.html) class (or any other compatible callable class) **without** things like `ToTensor` (it is added separately to the images only) that is applied to both images and masks;
* `augmentations` is a callable like `transformations`, but is applied to images only;
* `size` is an integer or 2x tuple of integers for final image/mask output size (for example, 1024x1024 as input image for Unet with EfficientNet-B2 encoder);
* `expand` is a boolean flag which true value enables expanding a mask into one-hot encoded tensor (that is required for some loss functions and their modes, default is `False` — do not expand);
* `flat` is a boolean flag which true value produces squeezed mask (`(H, W)` instead of `(C, H, W)`, for example, for Focal Loss target) — takes precedence over `expand` flag (if `flat` then it's never `expand`ed).

> DEBUG: commented line with `np.clip` is for testing binary segmentation.

> TODO: add loading and geometric transformations for RGB images.

In [None]:
from collections import OrderedDict
from glob import glob

import cv2 as cv
import numpy as np
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import Compose, Lambda, RandomResizedCrop, Resize, ToTensor
from torchvision.transforms import InterpolationMode


class DatasetSAR(Dataset):
    def __init__(self, paths_images, paths_masks, classes, items=None,
                 transformations=None, augmentations=None, size=None,
                 expand=False, flat=False):
        # Assert images paths
        if isinstance(paths_images, str):
            paths_images = (paths_images, path_images)
        elif not isinstance(paths_images, (tuple, list, set)):
            raise TypeError("first argument must be of type str or list!")
        else:
            paths_images = tuple(paths_images)
        # Assert masks paths
        if isinstance(paths_masks, str):
            paths_masks = (paths_masks,)
        elif not isinstance(paths_masks, (tuple, list, set)):
            raise TypeError("second argument must be of type str or list!")
        else:
            paths_masks = tuple(paths_masks)
        # Default output image/mask size
        if size is None:
            size = (1024, 1024)
        # Default transformations (geometric - image + mask)
        mode_interpolation = InterpolationMode.NEAREST
        if not isinstance(transformations, (Compose, torch.nn.Module)):
            self.transformations = Compose([
                Resize(size, mode_interpolation),
            ])
        else:
            self.transformations = transformations
        # Default augmentations (color - image only)
        if not isinstance(augmentations, (Compose, torch.nn.Module)):
            self.augmentations = Compose([
                Lambda(lambda x: x),
            ])
        else:
            self.augmentations = augmentations
        # Class attributes
        self.to_tensor = ToTensor()
        self.paths_images = paths_images
        self.paths_masks = paths_masks
        if items is None:
            items = []
            for path in paths_images[:2] + paths_masks[:1]:
                items.append({osp.basename(item) for item \
                            in glob(osp.join(path, '*.tiff'))})
            self.items = sorted(items[0].intersection(*items))
        else:
            self.items = items
        self.classes = len(classes)
        self.items_class = OrderedDict({c: i for c, i \
                                        in zip(classes,
                                               range(1, self.classes + 1))})
        self.expand = expand
        self.flat = flat
        return None

    def __getitem__(self, item):
        # TODO: load RGBs
        image_hh = cv.imread(osp.join(self.paths_images[0], self.items[item]),
                             cv.IMREAD_LOAD_GDAL | cv.IMREAD_GRAYSCALE)
        h, w = image_hh.shape[:2]
        image_hv = cv.imread(osp.join(self.paths_images[1], self.items[item]),
                             cv.IMREAD_LOAD_GDAL | cv.IMREAD_GRAYSCALE)
        h, w = tuple(map(min, zip(image_hv.shape[:2], (h, w))))
        if self.paths_masks:
            mask = cv.imread(osp.join(self.paths_masks[0], self.items[item]),
                            cv.IMREAD_LOAD_GDAL | cv.IMREAD_GRAYSCALE)
        else:
            mask = image_hh
        h, w = tuple(map(min, zip(mask.shape[:2], (h, w))))
        # Use minimum height and width 'cause image/mask dimension may not
        # always fit (difference during mask conversion)
        image = Image.fromarray(np.dstack(
            [image_hh[:h, :w], image_hv[:h, :w], mask[:h, :w]]
            # [image_hh[:h, :w], image_hv[:h, :w], np.clip(mask[:h, :w], 0, 1)]
        ))
        del image_hh, image_hv, mask
        image = self.transformations(image)
        image = np.array(image)
        image, mask = image[..., :2], image[..., 2]
        image = self.to_tensor(np.dstack((image[..., 1], image)))
        image = self.augmentations(image)
        # Convert to int64 'cause OHE requires index tensor
        if self.flat:
            mask = torch.tensor(mask)
        elif self.expand:
            mask = torch.nn.functional.one_hot(torch.tensor(mask,
                                                            dtype=torch.int64),
                                               self.classes).to(torch.int8)\
                                               .permute(2, 0, 1)
        else:
            mask = torch.tensor(mask).unsqueeze_(0).to(torch.int64)
        return image, mask

    def __len__(self):
        return len(self.items)

    def debug(self):
        # Debug function to show some dataset stats
        print(f"DEBUG: paths images = {self.paths_images}")
        print(f"DEBUG: paths masks = {self.paths_masks}")
        print(f"DEBUG: items ({len(self.items)}):")
        print('\n'.join(self.items))
        print(f"DEBUG: classes ({self.classes}):")
        print(self.items_class)

## Datasets and visualization

This part produces datasets to be used in data loaders for training and validation. Here are three datasets: `dataset` — grand superset with all annotated images; `dataset_train` — training subset of grand dataset to be used during training (calculating gradients); `dataset_valid` — validation subset of grand dataset to be used during validation (calculating metrics).

> Training/validation items (`items_train`/`items_valid`) are selected evenly across the grand dataset items (feel free to split it in any other way).

Transformations are also in two variants (train/valid): **random crop and resize** for training in order to learn features from different scales and simple **resize** for validation (for simplicity and closer approach to later inference scenarios).

> All transformations suppose output shape of images/masks to be 1024x1024 (set by `size` valiable).

Images and masks from train/valid dataset classes are visualized in the last two loops. As for _MaritimeAI Sentinel-1 Dataset 1920_, there will be displayed 124 images and corresponding masks (neither too much nor too little).

Variables `EXPAND` and `FLAT` are used for top-level control of mask output format. For example, for Dice loss from _Segmentation Models_ _utils_ module set `EXPAND=True` and `FLAT=False`, for Dice loss from _Segmentation Models_ _losses_ module set all two variables `False`, for Focal loss from _losses_ module set `FLAT=True`.

Variable `DEBUG` sets debug mode for overfitting on one image.

Class list used for two-class segmentation (from _MaritimeAI Sentinel-1 Dataset 1920_ `2-class` directory) includes: `nodata` (zero values), `water` (ones) and `ice` (twos).

> Formally, there are only `water` and `ice` classes of interest, but technically, all three classes must be specified.

> DEBUG: for true binary segmentation there should be one class (for example, `data`).

> Yet another variable `NUM_SAMPLES` had been added to control the number of output samples from training and validation datasets (preview data loading transformations and augmentations).

In [None]:
from random import sample


CLASSES = ['nodata', 'water', 'ice']
# CLASSES = ['data']  # do np.clip(mask, 0, 1) for binary mode
NUM_SAMPLES = 6  # draw samples from train/valid datasets

DEBUG = False
EXPAND = True
FLAT = True

# Nearest interpolation mode is mandatory for masks
# and optional (but recommended) for images
mode_interpolation = InterpolationMode.NEAREST

# Separate transformations for train and valid
size = (1024, 1024)  # size for network input
transformations_train = Compose([
    RandomResizedCrop(size, (0.25, 0.95), (3 / 4, 4 / 3),
                      mode_interpolation),
])

transformations_valid = Compose([
    Resize(size, mode_interpolation),
])

# Images/masks paths
path_images_hh = osp.join(PATH_DATASET, 'images', 'hh')
path_images_hv = osp.join(PATH_DATASET, 'images', 'hv')
path_masks = osp.join(PATH_DATASET, 'masks', '2-class')

# Grand dataset
dataset = DatasetSAR((path_images_hh, path_images_hv), path_masks,
                     classes=CLASSES)

# Split the dataset into train/valid datasets
items_total = len(dataset)
fraction = 100 / 15

items_train = tuple(dataset.items[i] for i in range(items_total) \
                    if not i or i % int(round(fraction)))

items_valid = tuple(dataset.items[i] for i in range(items_total) \
                    if i and not i % int(round(fraction)))

# Debug: set default transformations for training and validation
if DEBUG:
    transformations_train = None
    transformations_valid = None

dataset_train = DatasetSAR((path_images_hh, path_images_hv), path_masks,
                           items=items_train, classes=CLASSES, expand=EXPAND,
                           flat=FLAT, transformations=transformations_train)

dataset_valid = DatasetSAR((path_images_hh, path_images_hv), path_masks,
                           items=items_valid, classes=CLASSES, expand=EXPAND,
                           flat=FLAT, transformations=transformations_valid)

# Debug: make datasets of one item
if DEBUG:
    dataset_train.items = dataset.items[:1]
    dataset_valid.items = dataset.items[:1]

nomasks = []
size_train = len(dataset_train)
print(f"Training dataset ({size_train}):")
for i in sample(range(size_train), k=min(NUM_SAMPLES, size_train)):
    try:
        image, mask = dataset_train[i]
        draw_one_row(image.permute(1, 2, 0), mask if FLAT else mask.argmax(0) \
                     if EXPAND else mask.squeeze(0) if not FLAT else mask)
    except AttributeError:
        nomask = dataset_train.items[i]
        nomasks.append(nomask)
        print(f"ERROR: failed to load {nomask}")
print(f"Read errors = {len(nomasks)}")

nomasks = []
size_valid = len(dataset_valid)
print(f"Validation dataset ({size_valid}):")
for i in sample(range(size_valid), k=min(NUM_SAMPLES, size_valid)):
    try:
        image, mask = dataset_valid[i]
        draw_one_row(image.permute(1, 2, 0), mask if FLAT else mask.argmax(0) \
                     if EXPAND else mask.squeeze(0))
    except AttributeError:
        nomask = dataset_valid.items[i]
        nomasks.append(nomask)
        print(f"ERROR: failed to load {nomask}")
print(f"Read errors = {len(nomasks)}")

`DatasetSAR` class has optional `debug` method to show filenames of items and some other info.

In [None]:
# dataset.debug()

## Dataloaders

_Google Colab_ settings. Batch size can be increased, also number of workers, if neccessary.

In [None]:
from torch.utils.data import DataLoader


loader_train = DataLoader(dataset_train, batch_size=1, shuffle=True,
                          num_workers=2)

loader_valid = DataLoader(dataset_valid, batch_size=1, shuffle=False,
                          num_workers=2)

# Training

## Model

**Unet** with **EfficientNet-B2** encoder is used as an example 'cause it does fit into _Google Colab_ GPU RAM. EfficientNet-B3 may also fit (larger encoders like EfficientNet-B7 were tested on local devboxes and gave the best results).

For some losses from _Segmentation Models_ it may require to change mask output format (squeezed or one-hot expanded, see `EXPAND` and `FLAT` variables for datasets).

Experimenting with **loss functions** may give interesting results. Compare the same scenario with different loss functions:

* DiceLoss (with classes 1 and 2)

![DiceLoss](https://raw.githubusercontent.com/MaritimeAI/ODS-SoC/1b066dd63aedd364fda4ad0d08ac3ae6286f2289/assets/images/dice_12.png)

* FocalLoss

![FocalLoss](https://raw.githubusercontent.com/MaritimeAI/ODS-SoC/1b066dd63aedd364fda4ad0d08ac3ae6286f2289/assets/images/focal_all.png)

This is a very basic example, so classic **metrics** such as **IoU** and **optimizer** such as **Adam** are being used.

Training and validation **runners** are killer feature of _Segmentation Models_.

In [None]:
import segmentation_models_pytorch as smp


device = ['cpu', 'cuda'][torch.cuda.is_available()]

# Model
model = smp.Unet(encoder_name='efficientnet-b2', encoder_weights='imagenet',
                 in_channels=3, classes=len(CLASSES))

# Loss functions
# loss = smp.utils.losses.DiceLoss()
mode = smp.losses.MULTICLASS_MODE
# loss = smp.losses.DiceLoss(mode=mode, classes=[1, 2])
# setattr(loss, '__name__', 'dice_loss')
loss = smp.losses.FocalLoss(mode=mode)
setattr(loss, '__name__', 'focal_loss')

# Metrics
metrics = [
           smp.utils.metrics.IoU(threshold=0.5)
]

# Optimizer
optimizer = torch.optim.Adam([{
    'params': model.parameters(),
    'lr': 2e-4
}])

# Runners
train_one_epoch = smp.utils.train.TrainEpoch(model, loss=loss, metrics=metrics,
                                             optimizer=optimizer, device=device,
                                             verbose=True)
valid_one_epoch = smp.utils.train.ValidEpoch(model, loss=loss, metrics=metrics,
                                             device=device, verbose=True)

## Train loop

All the models are saved to `PATH_MODELS` under a separate subdirectory with timestamp in its name.

The best score and the **best** corresponding **weights** are saved after `EPOCHS_MIN` (10th by default) epoch. Final **last weights** are saved at the end (if everything goes well or after `EPOCHS_MIN` epochs if an error occurs).

**Learning rate** is decreased by 5% every epoch.

> `TRAIN=False` allows to skip training and just load a saved model. In order to start training from a scratch, just set `TRAIN=True` and `NAME_PRELOAD=''` (both variables must not evaluate to `False`).

In [None]:
import os

from datetime import datetime


NAME_PRELOAD = ''
FORMAT_DATE = '%Y-%m-%d-%H-%M-%S'
EPOCHS_MIN = 10
EPOCHS = 50
TRAIN = True

timestamp = datetime.utcnow().strftime(FORMAT_DATE)
path_model_best = osp.join(PATH_MODELS, timestamp, 'best.pth')
path_model_last = osp.join(PATH_MODELS, timestamp, 'last.pth')
path_model_onnx = osp.join(PATH_MODELS, timestamp, 'model.onnx')
score_max = 0

if NAME_PRELOAD:
    try:
        path_model_resume = path_model_last.replace(timestamp, NAME_PRELOAD)
        model = torch.load(path_model_resume, map_location=device)
        print(f"The model has been successfully loaded from {NAME_PRELOAD}!")
    finally:
        pass

if TRAIN:
    os.makedirs(osp.dirname(path_model_best), exist_ok=True)
    os.makedirs(osp.dirname(path_model_last), exist_ok=True)

    try:
        for i in range(EPOCHS):
            print(f"Epoch = {i:3d}, learning rate =",
                f"{optimizer.param_groups[0]['lr']:0.8f}")

            logs_train = train_one_epoch.run(loader_train)
            logs_valid = valid_one_epoch.run(loader_valid)

            if score_max < logs_valid['iou_score'] and i >= EPOCHS_MIN:
                # Save only after 10 epochs
                score_max = logs_valid['iou_score']
                torch.save(model, path_model_best)
                print(f"Saved best model with score = {score_max:0.4f}!")

            # Unconditional model saving
            torch.save(model, path_model_last)
            print(f"Saved latest model with score =",
                  f"{logs_valid['iou_score']:0.4f}!")

            optimizer.param_groups[0]['lr'] *= 0.95  # step down
            print()
    except KeyboardInterrupt:
        pass

## Export to ONNX

Exporting a model to _ONNX_ format enables it to be run on any platform optimized for CPU or/and GPU.

> Exporting a model to _ONNX_ format with _PyTorch_ does not require _ONNX_ dependencies.

In [None]:
if 'model' in locals() and hasattr(model, 'predict'):
    os.makedirs(osp.dirname(path_model_onnx), exist_ok=True)
    model = model.cpu()  # inference on CPU

    # Images to use during export to ONNX
    images, masks_true = [], []
    for image, mask_true in loader_valid:
        images.append(image)
        masks_true.append(mask_true)
        break

    # PyTorch prediction for later comparison with ONNX model
    masks_predict = []
    for image in images:
        # masks_predict.append(model.predict(image.to(device)).cpu())
        masks_predict.append(model.predict(image.cpu()))

    # Make EfficientNet TorchScript-friendly
    model.encoder.set_swish(memory_efficient=False)

    # Script and export model to ONNX
    # without dynamic_axes argument resulting model will have
    # the same dimension size as input during export
    torch.onnx.export(model, tuple(image.to(device) for image in images),
                    path_model_onnx, export_params=True,
                    opset_version=11, #do_constant_folding=True,
                    input_names=['input'], output_names=['output'],
                    dynamic_axes={
                        #   'input': {0: 'batch_size'},
                        #   'output': {0, 'batch_size'}
                    }, # operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
                    )

It may be ok to clean memory after each training (if the same model is not going to be trained twice).

In [None]:
try:
    del model
except:
    pass

try:
    torch.cuda.empty_cache()
except:
    pass

# Inference


## Validation subset

Inference on validation subset makes sense in combination with visualization.

### Inference and visualization (PyTorch)

Use last or best model. After each inference memory is going to be cleaned.

In [None]:
VALID = True

if VALID and osp.exists(path_model_last):
    model_eval = torch.load(path_model_last)

    for image, mask_true in loader_valid:
        mask_predict = model_eval.predict(image.to(device)).cpu()
        for i, p, t in zip(image, mask_predict, mask_true):
            draw_one_row(i.permute(1, 2, 0), p.argmax(0), t.squeeze(0))

try:
    del model_eval
except:
    pass

try:
    torch.cuda.empty_cache()
except:
    pass

## Test subset (using ONNX runtime)

Test subset for images without annotations. So noone knows the ground truth masks.

### Prepare ONNX

Install from _PyPI_.

In [None]:
%pip install onnx onnxruntime

Check the model with _ONNX checker_.

In [None]:
import onnx

if osp.exists(path_model_onnx):
    model_onnx = onnx.load(path_model_onnx)
    onnx.checker.check_model(model_onnx)

Check the model with _ONNX Runtime_.

In [None]:
import onnxruntime

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad \
           else tensor.cpu().numpy()

if osp.exists(path_model_onnx):
    session_ort = onnxruntime.InferenceSession(path_model_onnx)

    inputs_ort = {session_ort.get_inputs()[0].name: to_numpy(images[0])}
    outputs_ort = session_ort.run(None, inputs_ort)

    np.testing.assert_allclose(to_numpy(masks_predict[0]), outputs_ort[0],
                            rtol=1e-3, atol=9e-2)

    print("OK: ONNX model verification complete!")

Draw the inference result for the exported model.

In [None]:
if (len(set(('images', 'masks_predict', 'outputs_ort'))\
        .intersection(set(locals())))) == 3:
    draw_one_row(images[0][0].permute(1, 2, 0),
                masks_predict[0][0].argmax(0),
                outputs_ort[0][0].argmax(0))

### Data loading

Reuse variables `path_images_hh`, `path_images_hv` and `path_masks` from previous cells.

In [None]:
items_images = set(os.listdir(path_images_hh))
items_masks = set(os.listdir(path_masks))
items_test = sorted(items_images - items_masks)
dict(enumerate(items_test))

### Inference and visualization (ONNX)

Use vectorized cutline from repository to crop `NoData` area.

In [None]:
!git clone https://github.com/MaritimeAI/resources.git

> Cutline shape may also be checked with `ogrinfo` utility from _GDAL_ library.

In [None]:
FILE_SHAPEFILE = osp.join(PATH_RESOURCES, 'clustering', 'cutline',
                          'Start_Ice_Map_UTMz40WGS84f_r.shp')

try:
    if osp.isfile(FILE_SHAPEFILE):
        shape = osp.abspath(osp.realpath(FILE_SHAPEFILE))
    else:
        raise FileNotFoundError
except (TypeError, FileNotFoundError) as e:
    print(f"Shapefile '{FILE_SHAPEFILE}' does not exist!")
    shape = None
print(f"Available shape is {shape}")
# !ogrinfo "{shape}"

Inference part itself that does not depend on _PyTorch_ or any model code — just _ONNX_ exported model and _ONNX Runtime_.

> HH + HV polarizations are being combined during inference, just like in `DatasetSAR` class.

> _Python_'s garbage collector is being used intensively here, 'cause images are really large.

In [None]:
import gc

import onnxruntime

from tempfile import TemporaryDirectory

from osgeo import gdal

PLOT = True

# Target directory for test inference
path_test = osp.join(PATH_OUTPUT, timestamp, 'test')
os.makedirs(path_test, exist_ok=True)

# Batch size is supposed to be 1
for i, item_test in enumerate(items_test):
    source = osp.join(path_images_hh, item_test)
    target = osp.join(path_test, item_test)
    with TemporaryDirectory() as path_temp:
        temp = osp.join(path_temp, item_test)
        gdal.Translate(temp, source,
                    options=['-b', '1', '-colorinterp', 'gray',
                                '-co', 'COMPRESS=DEFLATE'])

        # Input images must be 8-bit GeoTIFFs
        image_hh = cv.imread(osp.join(path_images_hh, item_test), cv.IMREAD_LOAD_GDAL)
        image_hv = cv.imread(osp.join(path_images_hv, item_test), cv.IMREAD_LOAD_GDAL)

        # Make HH and HV sizes match (sizes mismatch should never happen)
        # image_hv = cv.resize(image_hv, image_hh.shape[::-1], cv.INTER_NEAREST)
        image = np.dstack((image_hv, image_hh, image_hv)) / np.float32(255)
        del image_hh, image_hv
        gc.collect()

        image = cv.resize(image, (1024, 1024), cv.INTER_LINEAR)
        # Output image (GeoTIFF copy of image_hh)
        dataset = gdal.Open(temp, gdal.GA_Update)
        band = dataset.GetRasterBand(1)
        h, w = band.ReadAsArray().shape
        inputs_ort = {
            session_ort.get_inputs()[0].name: (np.moveaxis(image, -1, 0)[None, ...])
        }
        outputs_ort = session_ort.run(None, inputs_ort)
        # WARNING: this resize part works just because there are 3 classes
        mask = cv.resize(np.moveaxis(outputs_ort[0][0], 0, -1),
                        (w, h), cv.INTER_NEAREST)
        mask[..., 2] *= 2  # change 'ice' class weight
        mask = mask.argmax(-1).clip(0, 255).astype('uint8')
        band.WriteArray(mask)
        dataset.FlushCache()
        del band, dataset
        gc.collect()

        # Assume NoData value is always zero
        gdal.Warp(target, temp, dstNodata=0, xRes=40, yRes=40,
                  cutlineDSName=f"{shape}",
                  cropToCutline=(True if shape else False),
                  creationOptions=['COMPRESS=DEFLATE'])
        if PLOT:
            mask = cv.imread(target, cv.IMREAD_LOAD_GDAL)
            draw_one_row(cv.resize(image, (w, h), cv.INTER_LINEAR), mask)
        del image, mask
    gc.collect()