In [None]:
from os import path as osp

from google.colab import drive


PATH_DRIVE = osp.join('/', 'content', 'drive')

# Do not mount if it is already attached
if not osp.exists(PATH_DRIVE):
    print("Mounting Google Drive...")
    drive.mount(PATH_DRIVE)
else:
    print("Google Drive has been already mounted!")

## Path

In [None]:
from os import path as osp


PATH_STORAGE = osp.join('ods', 'soc')  # arbitrary subpath in Google Drive (if any)
if 'PATH_DRIVE' in locals():
    PREFIX_DRIVE = osp.join(osp.basename(PATH_DRIVE), 'MyDrive', PATH_STORAGE)
else:
    PREFIX_DRIVE = ''

PATH_TEMP = osp.join('/', 'content', 'temp')
PATH_INPUT = osp.join('/', 'content', PREFIX_DRIVE, 'input')
PATH_OUTPUT = osp.join('/', 'content', PREFIX_DRIVE, 'output')
PATH_MODELS = osp.join('/', 'content', PREFIX_DRIVE, 'models')
PATH_DATASET = osp.join('/', 'content', PREFIX_DRIVE, 'dataset')
PATH_RESOURCES = osp.join('/', 'content', 'resources')

# FILE_SHAPEFILE = osp.join(PATH_RESOURCES, 'clustering', 'cutline',
#                           'Start_Ice_Map_UTMz40WGS84f_r.shp')

print('\n'.join((PATH_STORAGE, PATH_TEMP, PATH_INPUT, PATH_OUTPUT, PATH_MODELS,
                 PATH_DATASET, PATH_RESOURCES)))

In [None]:
import os
from matplotlib import pyplot as plt


def draw_one_row(*images, size=1024, output=None):
    try:
        size = size[:2]
    except:
        size = (size, size)
    count = len(images)
    figure, axes = plt.subplots(1, count, dpi=72,
                                figsize=(size[0] / 72, size[1] / 72))
    for i in range(count):
        axes[i].imshow(images[i])
    if output is not None:
        try:
            os.makedirs(osp.dirname(output), exist_ok=True)
            plt.savefig(output)
        except:
            pass
    plt.show()

<h2>Load the Dataset and split into folds</h2>

Delete corrupted images from dataset

In [None]:
items_exclude_a = [
    'S1B_EW_GRDM_1SDH_20200203T031613_20200203T031631_020099_0260A6_D03B.tiff',
    'S1B_EW_GRDM_1SDH_20200215T031613_20200215T031630_020274_026647_9E25.tiff',
    'S1B_EW_GRDM_1SDH_20200227T031613_20200227T031630_020449_026BE9_3282.tiff',
    'S1B_EW_GRDM_1SDH_20200310T031613_20200310T031630_020624_027178_1A36.tiff',
    'S1B_EW_GRDM_1SDH_20200322T031613_20200322T031631_020799_027702_664C.tiff',
    'S1B_EW_GRDM_1SDH_20200521T031615_20200521T031633_021674_029249_923C.tiff',
]

items_exclude_b = [
    'S1A_EW_GRDM_1SDH_20191117T031700_20191117T031800_029945_036ADD_32F2.tiff',
    'S1A_EW_GRDM_1SDH_20191129T031659_20191129T031759_030120_0370EF_D07E.tiff',
    'S1A_EW_GRDM_1SDH_20200104T031658_20200104T031758_030645_038306_DDA1.tiff',
    'S1A_EW_GRDM_1SDH_20200328T031656_20200328T031756_031870_03ADB5_D992.tiff',
    'S1A_EW_GRDM_1SDH_20200421T031657_20200421T031757_032220_03BA08_1F43.tiff',
]

items_exclude_c = [
    'S1A_EW_GRDM_1SDH_20191211T031659_20191211T031759_030295_0376F4_BE3E.tiff',
    'S1A_EW_GRDM_1SDH_20191223T031658_20191223T031758_030470_037CFD_AB38.tiff',
    'S1A_EW_GRDM_1SDH_20200221T031656_20200221T031756_031345_039B6D_927B.tiff',
    'S1A_EW_GRDM_1SDH_20200304T031656_20200304T031756_031520_03A17D_08EB.tiff',
    'S1A_EW_GRDM_1SDH_20200316T031656_20200316T031756_031695_03A78C_D3A3.tiff',
    'S1A_EW_GRDM_1SDH_20200409T031657_20200409T031757_032045_03B3E6_7A01.tiff',
    'S1A_EW_GRDM_1SDH_20200503T031658_20200503T031758_032395_03C031_950B.tiff',
]

items_exclude_d = [
    'S1A_EW_GRDM_1SDH_20191107T030034_20191107T030132_029799_0365BB_F7CF.tiff',
    'S1B_EW_GRDM_1SDH_20200601T023525_20200601T023601_021834_02971A_B08C.tiff'
]

In [None]:
from random import sample


NUM_SAMPLES = 6  # draw samples from train/valid datasets
DATA_SPLIT = 0

# Images/masks paths
path_images_hh = osp.join(PATH_DATASET, 'images', 'hh')
path_images_hv = osp.join(PATH_DATASET, 'images', 'hv')
path_masks = osp.join(PATH_DATASET, 'masks', '2-class')
items = os.listdir(path_images_hh)

# Train/valid items (images and masks)
items_cross = [item for item in sorted(os.listdir(path_masks))if item not in 
                 items_exclude_a + items_exclude_b + items_exclude_c +
                 items_exclude_d]
items_split = [{
    'train': sorted(set(items_cross) - set(items_cross[i::5])),
    'valid': sorted(items_cross[i::5])
} for i in range(5)]

In [None]:
len(items)

In [None]:
for i in range(len(items_split)):
  print(len(items_split[i]['train']), '/', len(items_split[i]['valid']))

# W&B

In [None]:
%pip install --quiet --upgrade wandb

Folds that was used for training baseline

In [None]:
import os
from getpass import getpass

# ATTENTION: do not forget to set proper WandB token here (as string)
if os.getenv('WANDB_API_KEY', None) is None:
    os.environ['WANDB_API_KEY'] = getpass('https://wandb.ai/authorize :')

try:
    import wandb
    wandb.login()
except:
    wandb = None

In [None]:
WANDB_ENTITY = 'maritimeai'
WANDB_PROJECT = 'sea-ice-segmentation'
WANDB_GROUP = 'Masks'
WANDB_NAME = '/'.join(['Baseline', 'ONNX', 'Masks'])

if 'wandb' in locals() and wandb is not None:
    experiment = wandb.init(entity=WANDB_ENTITY,  # config=config,
                            project=WANDB_PROJECT, group=WANDB_GROUP,
                            name=WANDB_NAME, notes='Averaging masks')
else:
    experiment = None

In [None]:
from collections import  OrderedDict


artifacts_baseline = [
    'Baseline.Fold1.2021-09-13-09-15-27:latest',
    'Baseline.Fold2.2021-10-12-07-54-33:latest',
    'Baseline.Fold3.2021-09-17-10-37-26:latest',
    'Baseline.Fold4.2021-09-24-22-44-24:latest',
    'Baseline.Fold5.2021-09-23-09-45-27:latest'
]

artifacts_fpn = [
    'Adam-2e-4.Fold1.2021-10-01-09-48-28:latest',
    'Adam-2e-4.Fold2.2021-09-30-13-12-35:latest',
    'Adam-2e-4.Fold3.2021-09-30-09-30-32:latest',
    'Adam-2e-4.Fold4.2021-09-29-16-04-34:latest',
    'Adam-2e-4.Fold5.2021-09-29-14-23-53:latest'
]

if experiment is not None:
    artifacts = OrderedDict()
    for artifact_name in artifacts_baseline + artifacts_fpn:
        artifact = experiment.use_artifact(artifact_name)
        artifacts[artifact_name] = {
            'data': artifact,
            'path': osp.abspath(artifact.download())
        }

    print(f"Downloaded {len(artifacts)} artifacts")


In [None]:
try:
    experiment.finish()
except:
    pass
    
artifacts

In [None]:
%ls artifacts/*

In [None]:
%pip install onnxruntime

In [None]:
items_folds = [x for y in [i['valid'] for i in items_split] for x in y]
len(items_folds)

In [None]:
from collections import ChainMap


items_foldmap = dict(ChainMap(*[{x: i for x in v['valid']} for i, v in enumerate(items_split)]))
items_images = {item: -1 if item not in items_folds else items_foldmap[item] for item in items}
len(items_images)

In [None]:
import onnxruntime

index = 0
models ={}

for name, data in artifacts.items():
    models[index] = {}
    models[index]['session'] = onnxruntime.InferenceSession(osp.join(data['path'], 'model.onnx'))
    index += 1

In [None]:
models

Averaging masks

In [None]:
import cv2 as cv
import numpy as np
import gc


NUM_CLASSES = 3
SIZE = (1024, 1024)

for item, i in items_images.items():
    image_hh = cv.imread(osp.join(path_images_hh, item), cv.IMREAD_LOAD_GDAL)
    image_hv = cv.imread(osp.join(path_images_hv, item), cv.IMREAD_LOAD_GDAL)
    image = np.dstack((image_hv, image_hh, image_hv)) / np.float32(255)
    mask_total = None
    del image_hh, image_hv

    for k in range(5):
        session = models[k]['session']
        batch = np.moveaxis(cv.resize(image, SIZE, cv.INTER_NEAREST), -1, 0)[None, ...]
        mask = session.run(None, {session.get_inputs()[0].name: batch})[0][0] # (9200, 8600)1
        num_classes = NUM_CLASSES
        print(f"Min = {mask.min()}, Max = {mask.max()}" )
        # mask = mask.argmax(0)
        # mask = np.stack([(mask == m).astype('uint8') for m in range(num_classes)], axis=-1)  # (9200, 8600, 3)
        if mask_total is None:
            mask_total = np.zeros_like(mask)
        if k == i:
            # 7 fold weighted
            mask_total += mask * 3
        else:
            mask_total += mask
        del mask
        gc.collect()
    
    if i >= 0:
        mask_total = (mask_total / 7) #.round().astype(np.uint8)
    else:
        mask_total = (mask_total / 5) #.round().astype(np.uint8)
    #mask_total = np.moveaxis(mask_total, 0, -1).argmax(-1).astype(np.uint8)
    print(mask_total.shape)
    mask_total = mask_total.argmax(0) #.astype(np.uint8)
    draw_one_row(image, mask_total)