In [None]:
# %pip install git+https://github.com/qubvel/segmentation_models.pytorch.git@v0.2.0
%pip install segmentation-models-pytorch==0.2.0

In [None]:
!pip uninstall torchtext torchaudio fastai -y

In [None]:
!pip install -U torchvision=='0.10.0'

In [None]:
import torchvision
torchvision.__version__ == '0.10.0+cu102'

# Common

Common part is a prerequisite for all further parts (data loading, training, inference).

## Add dataset and save model weights
File -> Add or upload data

## Paths

`PATH_OUTPUT` is used to save images (for example, inference on a test subset).

`PATH_MODELS` is used to save model weights for later inference.

`PATH_UPLOAD` is used to copy model weights for later inference.

`PATH_DATASET` is a source of GeoTIFF **images** and **masks** for training and inference (for example, test subset).

In [None]:
from os import path as osp
import os 
import shutil

NOTEBOOK_NAME = ''

PATH_MODELS = osp.join('/', 'kaggle', 'working', 'models')
PATH_UPLOAD = osp.join('/', 'kaggle', 'input', NOTEBOOK_NAME, 'models')
PATH_DATASET = osp.join('/', 'kaggle', 'input', 'sentinel1', 'dataset')
PATH_OUTPUT  = osp.join('/', 'kaggle', 'working', 'output')

print('\n'.join((NOTEBOOK_NAME, PATH_MODELS, PATH_DATASET, PATH_UPLOAD, PATH_OUTPUT)))

In [43]:
if osp.exists(PATH_UPLOAD) and not osp.exists(PATH_MODELS):
    shutil.copytree(PATH_UPLOAD, 'models')
elif not osp.exists(PATH_MODELS):
    os.mkdir(PATH_MODELS)

## Functions

Auxiliary functions (for example, drawing/plotting data).

In [None]:
import os

from matplotlib import pyplot as plt


def draw_one_row(*images, size=1024, output=None):
    try:
        size = size[:2]
    except:
        size = (size, size)
    count = len(images)
    figure, axes = plt.subplots(1, count, dpi=72,
                                figsize=(size[0] / 72, size[1] / 72))
    for i in range(count):
        axes[i].imshow(images[i])
    if output is not None:
        try:
            os.makedirs(osp.dirname(output), exist_ok=True)
            plt.savefig(output)
        except:
            pass
    plt.show()

# Config

Configuration dictionary to store training parameters. This config dictionary will be updated in subsequent cells.

In [None]:
config = {
    'classes': ['nodata', 'water', 'ice'],
    'batch_size_train': 1,
    'batch_size_valid': 1,
    'num_workers_train': 1,
    'num_workers_valid': 1,
    'model_encoder': 'EfficientNet-B2',
    'model_pretrain': 'ImageNet',
    'model_channels': 3,
    'data_split': 2,
    'expand': True,
    'debug': False,
    'flat': True,
}

# Data Loading


## Dataset class

The most interesting part of the pipeline, 'cause it's an interface between data and the neural network models. Here is an example of a dataset class (`DatasetSAR`) which loads a pair of HH+HV images and their corresponding masks from different directories, stacks them together, applies goemetrics transformations and color augmentations (here the general _augmentations_ term was split into geometric and color in such a way).

Parameters for the `DatasetSAR` class constructor:

* `paths_images` is a list of strings (or a single string in case of only one polarization used) where each string point to an existing directory with HH or HV polarized images;
* `paths_masks` is a list of strings / single string that point(s) to where masks are located (only first directory from the list is used);
* `classes` is a list of classes used — important for multiclass masks (where they should be **expanded** into one-hot encoded tensor);
* `items` is a list of filenames to be used as the dataset items (WARNING: passed list of filenames is acceted as-is with no checks) — useful for creating subsets (for example, train/valid) from some superset (list of all items of the dataset);
* `transformations` is a [torchvision transforms](https://pytorch.org/vision/stable/transforms.html) class (or any other compatible callable class) **without** things like `ToTensor` (it is added separately to the images only) that is applied to both images and masks;
* `augmentations` is a callable like `transformations`, but is applied to images only;
* `size` is an integer or 2x tuple of integers for final image/mask output size (for example, 1024x1024 as input image for Unet with EfficientNet-B2 encoder);
* `expand` is a boolean flag which true value enables expanding a mask into one-hot encoded tensor (that is required for some loss functions and their modes, default is `False` — do not expand);
* `flat` is a boolean flag which true value produces squeezed mask (`(H, W)` instead of `(C, H, W)`, for example, for Focal Loss target) — takes precedence over `expand` flag (if `flat` then it's never `expand`ed).

> DEBUG: commented line with `np.clip` is for testing binary segmentation.

> TODO: add loading and geometric transformations for RGB images.

In [None]:
from collections import OrderedDict
from glob import glob

import cv2 as cv
import numpy as np
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import Compose, Lambda, RandomResizedCrop, Resize, ToTensor
from torchvision.transforms import InterpolationMode


class DatasetSAR(Dataset):
    def __init__(self, paths_images, paths_masks, classes, items=None,
                 transformations=None, augmentations=None, size=None,
                 expand=False, flat=False):
        # Assert images paths
        if isinstance(paths_images, str):
            paths_images = (paths_images, path_images)
        elif not isinstance(paths_images, (tuple, list, set)):
            raise TypeError("first argument must be of type str or list!")
        else:
            paths_images = tuple(paths_images)
        # Assert masks paths
        if isinstance(paths_masks, str):
            paths_masks = (paths_masks,)
        elif not isinstance(paths_masks, (tuple, list, set)):
            raise TypeError("second argument must be of type str or list!")
        else:
            paths_masks = tuple(paths_masks)
        # Default output image/mask size
        if size is None:
            size = (1024, 1024)
        # Default transformations (geometric - image + mask)
        mode_interpolation = InterpolationMode.NEAREST
        if not isinstance(transformations, (Compose, torch.nn.Module)):
            self.transformations = Compose([
                Resize(size, mode_interpolation),
            ])
        else:
            self.transformations = transformations
        # Default augmentations (color - image only)
        if not isinstance(augmentations, (Compose, torch.nn.Module)):
            self.augmentations = Compose([
                Lambda(lambda x: x),
            ])
        else:
            self.augmentations = augmentations
        # Class attributes
        self.to_tensor = ToTensor()
        self.paths_images = paths_images
        self.paths_masks = paths_masks
        if items is None:
            items = []
            for path in paths_images[:2] + paths_masks[:1]:
                items.append({osp.basename(item) for item \
                            in glob(osp.join(path, '*.tiff'))})
            self.items = sorted(items[0].intersection(*items))
        else:
            self.items = items
        self.classes = len(classes)
        self.items_class = OrderedDict({c: i for c, i \
                                        in zip(classes,
                                               range(1, self.classes + 1))})
        self.expand = expand
        self.flat = flat
        return None

    def __getitem__(self, item):
        # TODO: load RGBs
        image_hh = cv.imread(osp.join(self.paths_images[0], self.items[item]),
                             cv.IMREAD_LOAD_GDAL | cv.IMREAD_GRAYSCALE)
        h, w = image_hh.shape[:2]
        image_hv = cv.imread(osp.join(self.paths_images[1], self.items[item]),
                             cv.IMREAD_LOAD_GDAL | cv.IMREAD_GRAYSCALE)
        h, w = tuple(map(min, zip(image_hv.shape[:2], (h, w))))
        if self.paths_masks:
            mask = cv.imread(osp.join(self.paths_masks[0], self.items[item]),
                            cv.IMREAD_LOAD_GDAL | cv.IMREAD_GRAYSCALE)
        else:
            mask = image_hh
        h, w = tuple(map(min, zip(mask.shape[:2], (h, w))))
        # Use minimum height and width 'cause image/mask dimension may not
        # always fit (difference during mask conversion)
        image = Image.fromarray(np.dstack(
            [image_hh[:h, :w], image_hv[:h, :w], mask[:h, :w]]
            # [image_hh[:h, :w], image_hv[:h, :w], np.clip(mask[:h, :w], 0, 1)]
        ))
        del image_hh, image_hv, mask
        image = self.transformations(image)
        image = np.array(image)
        image, mask = image[..., :2], image[..., 2]
        image = self.to_tensor(np.dstack((image[..., 1], image)))
        image = self.augmentations(image)
        # Convert to int64 'cause OHE requires index tensor
        if self.flat:
            mask = torch.tensor(mask)
        elif self.expand:
            mask = torch.nn.functional.one_hot(torch.tensor(mask,
                                                            dtype=torch.int64),
                                               self.classes).to(torch.int8)\
                                               .permute(2, 0, 1)
        else:
            mask = torch.tensor(mask).unsqueeze_(0).to(torch.int64)
        return image, mask

    def __len__(self):
        return len(self.items)

    def debug(self):
        # Debug function to show some dataset stats
        print(f"DEBUG: paths images = {self.paths_images}")
        print(f"DEBUG: paths masks = {self.paths_masks}")
        print(f"DEBUG: items ({len(self.items)}):")
        print('\n'.join(self.items))
        print(f"DEBUG: classes ({self.classes}):")
        print(self.items_class)

## Datasets and visualization

This part produces datasets to be used in data loaders for training and validation. Here are three datasets: `dataset` — grand superset with all annotated images; `dataset_train` — training subset of grand dataset to be used during training (calculating gradients); `dataset_valid` — validation subset of grand dataset to be used during validation (calculating metrics).

> Training/validation items (`items_train`/`items_valid`) are selected evenly across the grand dataset items (feel free to split it in any other way).

Transformations are also in two variants (train/valid): **random crop and resize** for training in order to learn features from different scales and simple **resize** for validation (for simplicity and closer approach to later inference scenarios).

> All transformations suppose output shape of images/masks to be 1024x1024 (set by `size` valiable).

Images and masks from train/valid dataset classes are visualized in the last two loops. As for _MaritimeAI Sentinel-1 Dataset 1920_, there will be displayed 124 images and corresponding masks (neither too much nor too little).

Variables `EXPAND` and `FLAT` are used for top-level control of mask output format. For example, for Dice loss from _Segmentation Models_ _utils_ module set `EXPAND=True` and `FLAT=False`, for Dice loss from _Segmentation Models_ _losses_ module set all two variables `False`, for Focal loss from _losses_ module set `FLAT=True`.

Variable `DEBUG` sets debug mode for overfitting on one image.

Class list used for two-class segmentation (from _MaritimeAI Sentinel-1 Dataset 1920_ `2-class` directory) includes: `nodata` (zero values), `water` (ones) and `ice` (twos).

> Formally, there are only `water` and `ice` classes of interest, but technically, all three classes must be specified.

> DEBUG: for true binary segmentation there should be one class (for example, `data`).

> Yet another variable `NUM_SAMPLES` had been added to control the number of output samples from training and validation datasets (preview data loading transformations and augmentations).

In [None]:
items_exclude_a = [
    'S1B_EW_GRDM_1SDH_20200203T031613_20200203T031631_020099_0260A6_D03B.tiff',
    'S1B_EW_GRDM_1SDH_20200215T031613_20200215T031630_020274_026647_9E25.tiff',
    'S1B_EW_GRDM_1SDH_20200227T031613_20200227T031630_020449_026BE9_3282.tiff',
    'S1B_EW_GRDM_1SDH_20200310T031613_20200310T031630_020624_027178_1A36.tiff',
    'S1B_EW_GRDM_1SDH_20200322T031613_20200322T031631_020799_027702_664C.tiff',
    'S1B_EW_GRDM_1SDH_20200521T031615_20200521T031633_021674_029249_923C.tiff',
]

items_exclude_b = [
    'S1A_EW_GRDM_1SDH_20191117T031700_20191117T031800_029945_036ADD_32F2.tiff',
    'S1A_EW_GRDM_1SDH_20191129T031659_20191129T031759_030120_0370EF_D07E.tiff',
    'S1A_EW_GRDM_1SDH_20200104T031658_20200104T031758_030645_038306_DDA1.tiff',
    'S1A_EW_GRDM_1SDH_20200328T031656_20200328T031756_031870_03ADB5_D992.tiff',
    'S1A_EW_GRDM_1SDH_20200421T031657_20200421T031757_032220_03BA08_1F43.tiff',
]

items_exclude_c = [
    'S1A_EW_GRDM_1SDH_20191211T031659_20191211T031759_030295_0376F4_BE3E.tiff',
    'S1A_EW_GRDM_1SDH_20191223T031658_20191223T031758_030470_037CFD_AB38.tiff',
    'S1A_EW_GRDM_1SDH_20200221T031656_20200221T031756_031345_039B6D_927B.tiff',
    'S1A_EW_GRDM_1SDH_20200304T031656_20200304T031756_031520_03A17D_08EB.tiff',
    'S1A_EW_GRDM_1SDH_20200316T031656_20200316T031756_031695_03A78C_D3A3.tiff',
    'S1A_EW_GRDM_1SDH_20200409T031657_20200409T031757_032045_03B3E6_7A01.tiff',
    'S1A_EW_GRDM_1SDH_20200503T031658_20200503T031758_032395_03C031_950B.tiff',
]

items_exclude_d = [
    'S1A_EW_GRDM_1SDH_20191107T030034_20191107T030132_029799_0365BB_F7CF.tiff',
    'S1B_EW_GRDM_1SDH_20200601T023525_20200601T023601_021834_02971A_B08C.tiff'
]

In [None]:
from random import sample


CLASSES = config['classes']
# CLASSES = ['data']  # do np.clip(mask, 0, 1) for binary mode
NUM_SAMPLES = 6  # draw samples from train/valid datasets
DATA_SPLIT = config['data_split'] - 1 if 'data_split' in config else None

DEBUG = config['debug']
EXPAND = config['expand']
FLAT = config['flat']

# Nearest interpolation mode is mandatory for masks
# and optional (but recommended) for images
mode_interpolation = InterpolationMode.NEAREST

# Separate transformations for train and valid
size = (1024, 1024)  # size for network input
transformations_train = Compose([
    RandomResizedCrop(size, (0.25, 0.95), (3 / 4, 4 / 3),
                      mode_interpolation),
])

transformations_valid = Compose([
    Resize(size, mode_interpolation),
])

# Images/masks paths
path_images_hh = osp.join(PATH_DATASET, 'images', 'hh')
path_images_hv = osp.join(PATH_DATASET, 'images', 'hv')
path_masks = osp.join(PATH_DATASET, 'masks', '2-class')

# Grand dataset
dataset = DatasetSAR((path_images_hh, path_images_hv), path_masks,
                     classes=CLASSES)

# Split the dataset into train/valid datasets
items_total = len(dataset)
fraction = 100 / 15

# Train/valid items (images and masks)
if DATA_SPLIT is None:
    items_train = tuple(dataset.items[i] for i in range(items_total) \
                        if not i or i % int(round(fraction)))

    items_valid = tuple(dataset.items[i] for i in range(items_total) \
                        if i and not i % int(round(fraction)))
else:
    items_all = [item for item in sorted(dataset.items) if item not in 
                 items_exclude_a + items_exclude_b + items_exclude_c +
                 items_exclude_d]
    items_split = [{
        'train': sorted(set(items_all) - set(items_all[i::5])),
        'valid': sorted(items_all[i::5])
    } for i in range(5)]
    items_train = items_split[DATA_SPLIT]['train']
    items_valid = items_split[DATA_SPLIT]['valid']

# Debug: set default transformations for training and validation
if DEBUG:
    transformations_train = None
    transformations_valid = None

dataset_train = DatasetSAR((path_images_hh, path_images_hv), path_masks,
                           items=items_train, classes=CLASSES, expand=EXPAND,
                           flat=FLAT, transformations=transformations_train)

dataset_valid = DatasetSAR((path_images_hh, path_images_hv), path_masks,
                           items=items_valid, classes=CLASSES, expand=EXPAND,
                           flat=FLAT, transformations=transformations_valid)

# Debug: make datasets of one item
if DEBUG:
    dataset_train.items = dataset.items[:1]
    dataset_valid.items = dataset.items[:1]

nomasks = []
size_train = len(dataset_train)
print(f"Training dataset ({size_train}):")
for i in sample(range(size_train), k=min(NUM_SAMPLES, size_train)):
    try:
        image, mask = dataset_train[i]
        draw_one_row(image.permute(1, 2, 0), mask if FLAT else mask.argmax(0) \
                     if EXPAND else mask.squeeze(0) if not FLAT else mask)
    except AttributeError:
        nomask = dataset_train.items[i]
        nomasks.append(nomask)
        print(f"ERROR: failed to load {nomask}")
print(f"Read errors = {len(nomasks)}")

config['items_train'] = dataset_train.items
config['errors_train'] = nomasks
config['size_train'] = size_train

nomasks = []
size_valid = len(dataset_valid)
print(f"Validation dataset ({size_valid}):")
for i in sample(range(size_valid), k=min(NUM_SAMPLES, size_valid)):
    try:
        image, mask = dataset_valid[i]
        draw_one_row(image.permute(1, 2, 0), mask if FLAT else mask.argmax(0) \
                     if EXPAND else mask.squeeze(0))
    except AttributeError:
        nomask = dataset_valid.items[i]
        nomasks.append(nomask)
        print(f"ERROR: failed to load {nomask}")
print(f"Read errors = {len(nomasks)}")

config['items_valid'] = dataset_valid.items
config['errors_valid'] = nomasks
config['size_valid'] = size_valid

`DatasetSAR` class has optional `debug` method to show filenames of items and some other info.

In [None]:
# dataset.debug()

## Dataloaders

_Google Colab_ settings. Batch size can be increased, also number of workers, if neccessary.

In [None]:
from torch.utils.data import DataLoader


loader_train = DataLoader(dataset_train, batch_size=config['batch_size_train'],
                          shuffle=True, num_workers=config['num_workers_train'])

loader_valid = DataLoader(dataset_valid, batch_size=config['batch_size_valid'],
                          shuffle=False, num_workers=config['num_workers_valid'])

# Training

## W&B

An API token is required in order to authorize in WandB:
https://wandb.ai/authorize

In [None]:
%pip install --upgrade wandb

Add-ons -> secrets -> new secrets -> label('wandb') -> value(https://wandb.ai/authorize)

In [None]:
import wandb
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")

wandb.login(key=secret_value_0)

## Model

**Unet** with **EfficientNet-B2** encoder is used as an example 'cause it does fit into _Google Colab_ GPU RAM. EfficientNet-B3 may also fit (larger encoders like EfficientNet-B7 were tested on local devboxes and gave the best results).

For some losses from _Segmentation Models_ it may require to change mask output format (squeezed or one-hot expanded, see `EXPAND` and `FLAT` variables for datasets).

Experimenting with **loss functions** may give interesting results. Compare the same scenario with different loss functions:

* DiceLoss (with classes 1 and 2)

![DiceLoss](https://raw.githubusercontent.com/MaritimeAI/ODS-SoC/1b066dd63aedd364fda4ad0d08ac3ae6286f2289/assets/images/dice_12.png)

* FocalLoss

![FocalLoss](https://raw.githubusercontent.com/MaritimeAI/ODS-SoC/1b066dd63aedd364fda4ad0d08ac3ae6286f2289/assets/images/focal_all.png)

This is a very basic example, so classic **metrics** such as **IoU** and **optimizer** such as **Adam** are being used.

Training and validation **runners** are killer feature of _Segmentation Models_.

In [None]:
import segmentation_models_pytorch as smp


device = ['cpu', 'cuda'][torch.cuda.is_available()]
config['device'] = device

# Model
model = smp.Unet(encoder_name=config['model_encoder'].lower(),
                 encoder_weights=config['model_pretrain'].lower(),
                 in_channels=config['model_channels'],
                 classes=len(CLASSES))

# This attribute will not be saved along with model's state_dict
setattr(model, 'stage', 0)

# Loss functions
# loss = smp.utils.losses.DiceLoss()
mode = smp.losses.MULTICLASS_MODE
# loss = smp.losses.DiceLoss(mode=mode, classes=[1, 2])
# setattr(loss, '__name__', 'dice_loss')
loss = smp.losses.FocalLoss(mode=mode)
setattr(loss, '__name__', 'focal_loss')

# Update config
config['model'] = model.__str__().split('(')[0]
config['loss'] = loss.__name__
config['loss_mode'] = mode

# Metrics
metrics = [
           smp.utils.metrics.IoU(threshold=0.5)
]

# Optimizer
optimizer = torch.optim.Adam([{
    'params': model.parameters(),
    'lr': 2e-4
}])

# Update config
config['optimizer'] = optimizer.__str__().split()[0]
config['learning_rate'] = optimizer.param_groups[0]['lr']

# Runners
train_one_epoch = smp.utils.train.TrainEpoch(model, loss=loss, metrics=metrics,
                                             optimizer=optimizer, device=device,
                                             verbose=False)
valid_one_epoch = smp.utils.train.ValidEpoch(model, loss=loss, metrics=metrics,
                                             device=device, verbose=False)

## Train loop

All the models are saved to `PATH_MODELS` under a separate subdirectory with timestamp in its name.

The best score and the **best** corresponding **weights** are saved after `EPOCHS_MIN` (10th by default) epoch. Final **last weights** are saved at the end (if everything goes well or after `EPOCHS_MIN` epochs if an error occurs).

**Learning rate** is decreased by 5% every epoch.

> `TRAIN=False` allows to skip training and just load a saved model. In order to start training from a scratch, just set `TRAIN=True` and `NAME_PRELOAD=''` (both variables must not evaluate to `False`).

In [None]:
%%time

import os

from datetime import datetime


NAME_PRELOAD = ''
FORMAT_DATE = '%Y-%m-%d-%H-%M-%S'
EPOCHS_MIN = 10
EPOCHS = 30
TRAIN = True

timestamp = datetime.utcnow().strftime(FORMAT_DATE)
path_model_best = osp.join(PATH_MODELS, timestamp, 'best.pth')
path_model_last = osp.join(PATH_MODELS, timestamp, 'last.pth')
path_model_onnx = osp.join(PATH_MODELS, timestamp, 'model.onnx')
score_max = 0

if NAME_PRELOAD:
    try:
        path_model_resume = path_model_last.replace(timestamp, NAME_PRELOAD)
        model = torch.load(path_model_resume, map_location=device)
        if hasattr(model, 'stage'):
            model.stage += 1
        print(f"The model has been successfully loaded from {NAME_PRELOAD}!")
    finally:
        pass

# Update config
config['epochs'] = EPOCHS
config['epochs_min'] = EPOCHS_MIN
config['timestamp'] = timestamp
config['name_preload'] = NAME_PRELOAD

# W&B setup
if hasattr(model, 'stage') and model.stage > 0:
    WANDB_STAGE = f"Stage{model.stage}"
else:
    WANDB_STAGE = config['model_pretrain'] if config['model_pretrain'] else 'Zero'
if 'data_split' in config:
    name_split = f"Fold{config['data_split']}"
else:
    name_split = 'Nofold'

WANDB_ENTITY = 'maritimeai'
WANDB_PROJECT = 'sea-ice-segmentation'
WANDB_GROUP = '/'.join([config['model'], config['model_encoder'], WANDB_STAGE])
WANDB_NAME = '/'.join(['Baseline', name_split, timestamp] +
                      (['Debug'] if DEBUG else []))

if 'wandb' in locals() and wandb is not None:
    experiment = wandb.init(entity=WANDB_ENTITY, config=config,
                            project=WANDB_PROJECT, group=WANDB_GROUP,
                            name=WANDB_NAME, notes=timestamp)
    artifact = wandb.Artifact(WANDB_NAME.replace('/', '.'), type='model',
                              metadata={'items': dataset_train.items})
else:
    experiment = None
    artifact = None
    

if TRAIN:
    os.makedirs(osp.dirname(path_model_best), exist_ok=True)
    os.makedirs(osp.dirname(path_model_last), exist_ok=True)

    try:
        if experiment is not None:
            experiment.watch(model)
        for i in range(EPOCHS):
            print(f"Epoch = {i:3d}, learning rate =",
                f"{optimizer.param_groups[0]['lr']:0.8f}")

            logs_train = train_one_epoch.run(loader_train)
            logs_valid = valid_one_epoch.run(loader_valid)

            if score_max < logs_valid['iou_score'] and i >= EPOCHS_MIN:
                # Save only after 10 epochs
                score_max = logs_valid['iou_score']
                torch.save(model, path_model_best)
                print(f"Saved best model with score = {score_max:0.4f}!")

            # Unconditional model saving
            torch.save(model, path_model_last)
            print(f"Saved latest model with score =",
                  f"{logs_valid['iou_score']:0.4f}!")

            if experiment is not None:
                experiment.log({'learning_rate': optimizer.param_groups[0]['lr']},
                               step=i)
                experiment.log({f"{k}/train": v for k, v in logs_train.items()},
                               step=i)
                experiment.log({f"{k}/valid": v for k, v in logs_valid.items()},
                               step=i)
            optimizer.param_groups[0]['lr'] *= 0.95  # step down
            print()
    except KeyboardInterrupt:
        pass
    finally:
        try:
            artifact.add_file(path_model_best)
            artifact.add_file(path_model_last)
        except:
            pass

## Export to ONNX

Exporting a model to _ONNX_ format enables it to be run on any platform optimized for CPU or/and GPU.

> Exporting a model to _ONNX_ format with _PyTorch_ does not require _ONNX_ dependencies.

In [None]:
if 'model' in locals() and hasattr(model, 'predict'):
    os.makedirs(osp.dirname(path_model_onnx), exist_ok=True)
    model = model.cpu()  # inference on CPU

    # Images to use during export to ONNX
    images, masks_true = [], []
    for image, mask_true in loader_valid:
        images.append(image)
        masks_true.append(mask_true)
        break

    # PyTorch prediction for later comparison with ONNX model
    masks_predict = []
    for image in images:
        # masks_predict.append(model.predict(image.to(device)).cpu())
        masks_predict.append(model.predict(image.cpu()))

    # Make EfficientNet TorchScript-friendly
    model.encoder.set_swish(memory_efficient=False)

    # Script and export model to ONNX
    # without dynamic_axes argument resulting model will have
    # the same dimension size as input during export
    torch.onnx.export(model, tuple(image.cpu() for image in images),
                      path_model_onnx, export_params=True,
                      opset_version=11,  # do_constant_folding=True,
                      input_names=['input'], output_names=['output'],
                      dynamic_axes={}
                     )
    try:
        artifact.add_file(path_model_onnx)
    except:
        pass

It may be ok to clean memory after each training (if the same model is not going to be trained twice).

In [None]:
try:
    del model
except:
    pass

try:
    torch.cuda.empty_cache()
except:
    pass

# Inference


## Validation subset

Inference on validation subset makes sense in combination with visualization.

### Inference and visualization (PyTorch)

Use last or best model. After each inference memory is going to be cleaned.

In [None]:
VALID = True

if VALID and osp.exists(path_model_last):
    model_eval = torch.load(path_model_last)

    path_predictions = osp.join(PATH_OUTPUT, timestamp, 'predictions')
    for name, data in zip(dataset_valid.items, loader_valid):
        image, mask_true = data
        name = osp.splitext(name)[0] + '.png'
        mask_predict = model_eval.predict(image.to(device)).cpu()
        for i, p, t in zip(image, mask_predict, mask_true):
            draw_one_row(i.permute(1, 2, 0), p.argmax(0), t.squeeze(0),
                         output=osp.join(path_predictions, name))
    try:
        artifact.add_dir(path_predictions, name='predictions')
    except:
        pass

try:
    del model_eval
except:
    pass

try:
    torch.cuda.empty_cache()
except:
    pass

> IMPORTANT: Finalize W&B initialized process.

In [None]:
try:
    experiment.log_artifact(artifact)
except:
    pass

try:
    experiment.finish()
except:
    pass