In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, TensorBoard, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.models import Sequential

import tifffile as tiff
import matplotlib.pyplot as plt
import numpy as np

import gc
import math
import random
import time
import os

In [2]:
# constants

DATA_PATH_PREFIX = '../input/satellite-data/'

N_BANDS = 8    # input channel shape
N_CLASSES = 5  # buildings, roads, trees, crops and water
CLASS_WEIGHTS = [0.2, 0.3, 0.1, 0.1, 0.3]

UPCONV = True    # True to use Up-Convolutuin (=TransposedConvolution), False to use Up-Sampling
PATCH_SZ = 160   # should divide by 16

N_EPOCHS = 150
BATCH_SIZE = 128

TRAIN_SZ = 32*BATCH_SIZE  # train size (for one epoch)
VAL_SZ = 1024    # validation size

Create a function generating network with given architecture. 

In [3]:
def unet_model(n_classes=5, im_sz=160, n_channels=8, n_filters_start=32, growth_factor=2, upconv=True,
               class_weights=[0.2, 0.3, 0.1, 0.1, 0.3]):
    droprate=0.25
    n_filters = n_filters_start
    inputs = Input((im_sz, im_sz, n_channels))
    conv1 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    n_filters *= growth_factor
    pool1 = BatchNormalization()(pool1)
    conv2 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    pool2 = Dropout(droprate)(pool2)

    n_filters *= growth_factor
    pool2 = BatchNormalization()(pool2)
    conv3 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    pool3 = Dropout(droprate)(pool3)

    n_filters *= growth_factor
    pool3 = BatchNormalization()(pool3)
    conv4_0 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(pool3)
    conv4_0 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv4_0)
    pool4_1 = MaxPooling2D(pool_size=(2, 2))(conv4_0)
    pool4_1 = Dropout(droprate)(pool4_1)

    n_filters *= growth_factor
    pool4_1 = BatchNormalization()(pool4_1)
    conv4_1 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(pool4_1)
    conv4_1 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv4_1)
    pool4_2 = MaxPooling2D(pool_size=(2, 2))(conv4_1)
    pool4_2 = Dropout(droprate)(pool4_2)

    n_filters *= growth_factor
    conv5 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(pool4_2)
    conv5 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv5)

    n_filters //= growth_factor
    if upconv:
        up6_1 = concatenate([Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same')(conv5), conv4_1])
    else:
        up6_1 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4_1])
    up6_1 = BatchNormalization()(up6_1)
    conv6_1 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(up6_1)
    conv6_1 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv6_1)
    conv6_1 = Dropout(droprate)(conv6_1)

    n_filters //= growth_factor
    if upconv:
        up6_2 = concatenate([Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same')(conv6_1), conv4_0])
    else:
        up6_2 = concatenate([UpSampling2D(size=(2, 2))(conv6_1), conv4_0])
    up6_2 = BatchNormalization()(up6_2)
    conv6_2 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(up6_2)
    conv6_2 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv6_2)
    conv6_2 = Dropout(droprate)(conv6_2)

    n_filters //= growth_factor
    if upconv:
        up7 = concatenate([Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same')(conv6_2), conv3])
    else:
        up7 = concatenate([UpSampling2D(size=(2, 2))(conv6_2), conv3])
    up7 = BatchNormalization()(up7)
    conv7 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv7)
    conv7 = Dropout(droprate)(conv7)

    n_filters //= growth_factor
    if upconv:
        up8 = concatenate([Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same')(conv7), conv2])
    else:
        up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2])
    up8 = BatchNormalization()(up8)
    conv8 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv8)
    conv8 = Dropout(droprate)(conv8)

    n_filters //= growth_factor
    if upconv:
        up9 = concatenate([Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same')(conv8), conv1])
    else:
        up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1])
    conv9 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(n_classes, (1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=conv10)

    def weighted_binary_crossentropy(y_true, y_pred):
        class_loglosses = K.mean(K.binary_crossentropy(y_true, y_pred), axis=[0, 1, 2])
        return K.sum(class_loglosses * K.constant(class_weights))

    model.compile(optimizer=Adam(), loss=weighted_binary_crossentropy)
    return model

It is recommended not to feed whole satellite images to the network. <br>
Instead, resample small patches of size $160 \times 160$ and train the model with such samples. <br>
This is a common practice in image segmentation.

Create U-Net model using function `unet_model()` and see its summary:

In [4]:
model = unet_model()
model.summary()

The output is created with 5 channels - one per segmentation class. Each channel will contain probabilities of pixel belonging to the corresponding class.

Plot the neural network architecture:

In [5]:
plot_model(model)

Before training network normalize the images, so that all values are in [-1.0,1.0]. <br>
Create function that does that.

In [6]:
def normalize(img):
    min = img.min()
    max = img.max()
    return 2.0 * (img - min) / (max - min) - 1.0

Read train data, normalize it and make channels last. Split validation data.

In [7]:
# all availiable ids: from "01" to "24":
trainIds = [str(i).zfill(2) for i in range(1, 25)]  

X_DICT_TRAIN = dict()
Y_DICT_TRAIN = dict()
X_DICT_VALIDATION = dict()
Y_DICT_VALIDATION = dict()

print('Reading images')
for img_id in trainIds:
    img_m = normalize(tiff.imread(f'{DATA_PATH_PREFIX}/mband/{img_id}.tif').transpose([1, 2, 0]))
    mask = tiff.imread(f'{DATA_PATH_PREFIX}/gt_mband/{img_id}.tif').transpose([1, 2, 0]) / 255
    train_xsz = int(0.75 * img_m.shape[0])  # use 75% of image as train and 25% for validation
    X_DICT_TRAIN[img_id] = img_m[:train_xsz, :, :]
    Y_DICT_TRAIN[img_id] = mask[:train_xsz, :, :]
    X_DICT_VALIDATION[img_id] = img_m[train_xsz:, :, :]
    Y_DICT_VALIDATION[img_id] = mask[train_xsz:, :, :]
    print(img_id + ' read')
print('Images were read')

gc.collect();

We need to prepare 160x160 patches from both train and validation data to fit the model.

The following function `get_rand_patch()` picks random patch from an image and corresponding mask.
Then function `get_patches()` resamples patches from the train or validataion set.

In [8]:
def get_rand_patch(img, mask, sz=160):
    """
    :param img: ndarray with shape (x_sz, y_sz, num_channels)
    :param mask: binary ndarray with shape (x_sz, y_sz, num_classes)
    :param sz: size of random patch
    :return: patch with shape (sz, sz, num_channels)
    """
    assert len(img.shape) == 3 and img.shape[0] > sz and img.shape[1] > sz and img.shape[0:2] == mask.shape[0:2]
    xc = random.randint(0, img.shape[0] - sz)
    yc = random.randint(0, img.shape[1] - sz)
    patch_img = img[xc:(xc + sz), yc:(yc + sz)]
    patch_mask = mask[xc:(xc + sz), yc:(yc + sz)]

    random_transformation = np.random.randint(1,8)
    if random_transformation == 1:  
        patch_img = patch_img[::-1,:,:]
        patch_mask = patch_mask[::-1,:,:]
    elif random_transformation == 2:    
        patch_img = patch_img[:,::-1,:]
        patch_mask = patch_mask[:,::-1,:]
    elif random_transformation == 3:   
        patch_img = patch_img.transpose([1,0,2])
        patch_mask = patch_mask.transpose([1,0,2])
    elif random_transformation == 4:
        patch_img = np.rot90(patch_img, 1)
        patch_mask = np.rot90(patch_mask, 1)
    elif random_transformation == 5:
        patch_img = np.rot90(patch_img, 2)
        patch_mask = np.rot90(patch_mask, 2)
    elif random_transformation == 6:
        patch_img = np.rot90(patch_img, 3)
        patch_mask = np.rot90(patch_mask, 3)
    else:
        pass

    return patch_img, patch_mask


def get_patches(x_dict, y_dict, n_patches, sz=160):
    x = list()
    y = list()
    total_patches = 0
    while total_patches < n_patches:
        img_id = random.sample(x_dict.keys(), 1)[0]
        img = x_dict[img_id]
        mask = y_dict[img_id]
        img_patch, mask_patch = get_rand_patch(img, mask, sz)
        x.append(img_patch)
        y.append(mask_patch)
        total_patches += 1
    print('Generated {} patches'.format(total_patches))
    return np.array(x), np.array(y)

# generator
def gen_patches(x_dict, y_dict, n_patches, sz=PATCH_SZ):
    while(True):
        gc.collect()
        yield get_patches(x_dict, y_dict, n_patches, sz=PATCH_SZ)

Create path for weights.

In [None]:
weights_path = 'weights'
if not os.path.exists(weights_path):
    os.makedirs(weights_path)
weights_path += '/unet_weights.hdf5'

if os.path.isfile(weights_path):
   model.load_weights(weights_path)

Train network.

In [None]:
print("start training net")

x_val, y_val = get_patches(X_DICT_VALIDATION, Y_DICT_VALIDATION, n_patches=VAL_SZ, sz=PATCH_SZ)
# callbacks = []
# callbacks.append(ModelCheckpoint('./weights/', monitor='val_loss', save_best_only=True))
# callbacks.append(CSVLogger('log_unet.csv', append=True, separator=';'))
# callbacks.append(TensorBoard(log_dir='./tensorboard_unet/', write_graph=True, write_images=True))
# callbacks.append(EarlyStopping(patience=5, verbose=1, restore_best_weights=True))
# callbacks.append(ReduceLROnPlateau(patience=3, verbose=1))

model_checkpoint = ModelCheckpoint(weights_path, monitor='val_loss', save_best_only=True)

train_gen = gen_patches(X_DICT_TRAIN, Y_DICT_TRAIN, n_patches=BATCH_SIZE, sz=PATCH_SZ)
model.fit(train_gen, steps_per_epoch = TRAIN_SZ//BATCH_SIZE, 
          batch_size=BATCH_SIZE, epochs=N_EPOCHS,
              verbose=2, shuffle=False,
              callbacks=[model_checkpoint],
              validation_data=(x_val, y_val))

In [9]:
# 2-ая модель, обучается меньше и распознает больше полей (crops)
def unet_crops(n_classes=N_CLASSES, im_sz=PATCH_SZ, n_channels=N_BANDS, 
               n_filters_start=32, depth=4, growth_factor=2, upconv=UPCONV,
               class_weights=CLASS_WEIGHTS):
    inputs = Input((im_sz, im_sz, n_channels))
    x = inputs

    conv1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x)
    conv1 = Dropout(0.2)(conv1)  
    conv1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv1)
    pool1 = MaxPooling2D((2, 2))(conv1)
    
    conv2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(pool1)
    conv2 = Dropout(0.2)(conv2)  
    conv2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv2)
    pool2 = MaxPooling2D((2, 2))(conv2)
     
    conv3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(pool2)
    conv3 = Dropout(0.2)(conv3)
    conv3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv3)
    pool3 = MaxPooling2D((2, 2))(conv3)
     
    conv4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(pool3)
    conv4 = Dropout(0.2)(conv4)
    conv4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
     
    conv5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(pool4)
    conv5 = Dropout(0.3)(conv5)
    conv5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv5)
    
    up6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv5)
    up6 = concatenate([up6, conv4])
    conv6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(up6)
    conv6 = Dropout(0.2)(conv6)
    conv6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv6)
     
    up7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv6)
    up7 = concatenate([up7, conv3])
    conv7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(up7)
    conv7 = Dropout(0.2)(conv7)
    conv7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv7)
     
    up8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv7)
    up8 = concatenate([up8, conv2])
    conv8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(up8)
    conv8 = Dropout(0.2)(conv8) 
    conv8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv8)
     
    up9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(conv8)
    up9 = concatenate([up9, conv1], axis=3)
    conv9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(up9)
    conv9 = Dropout(0.2)(conv9)
    conv9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv9)
     
    outputs = Conv2D(n_classes, (1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)

    # use weighted binary crossentropy as loss func
    def weighted_binary_crossentropy(y_true, y_pred):
        class_loglosses = K.mean(K.binary_crossentropy(y_true, y_pred), axis=[0, 1, 2])
        return K.sum(class_loglosses * K.constant(class_weights))

    model.compile(optimizer=Adam(), loss=weighted_binary_crossentropy)
    return model

In [10]:
N_EPOCHS = 100
BATCH_SIZE = 128

TRAIN_SZ = 8*BATCH_SIZE
VAL_SZ = 1024

model_crops = unet_crops()

In [11]:
model_crops.summary()

In [12]:
plot_model(model_crops)

In [None]:
print("start training net")

x_val, y_val = get_patches(X_DICT_VALIDATION, Y_DICT_VALIDATION, n_patches=VAL_SZ, sz=PATCH_SZ)
callbacks = []
callbacks.append(ModelCheckpoint(weights_path, monitor='val_loss', save_best_only=True))
callbacks.append(CSVLogger('log_unet.csv', append=True, separator=';'))
#callbacks.append(TensorBoard(log_dir='./tensorboard_unet/', write_graph=True, write_images=True))
callbacks.append(EarlyStopping(patience=5, verbose=1))
callbacks.append(ReduceLROnPlateau(patience=3, verbose=1))

train_gen = gen_patches(X_DICT_TRAIN, Y_DICT_TRAIN, n_patches=BATCH_SIZE, sz=PATCH_SZ)
model_crops.fit(train_gen, steps_per_epoch = TRAIN_SZ//BATCH_SIZE, 
          batch_size=BATCH_SIZE, epochs=N_EPOCHS,
              verbose=2, shuffle=False,
              callbacks=callbacks,
              validation_data=(x_val, y_val))

## Prediction

The following function takes: 
- image *'x'* 
- trained model 
- patch size 
- number of classes

It returns predicted probabilities of each class for every pixel of *'x'* in array with shape `(extended_height, extended_width, n_classes)`, where `extended_height` and `extended_width` are extended dimensions of `x` that make whole number of patches in the image.

In [13]:
def start_points(size, split_size, overlap=0):
    points = [0]
    stride = int(split_size * (1-overlap))
    counter = 1
    while True:
        pt = stride * counter
        if pt + split_size >= size:
            points.append(size - split_size)
            break
        else:
            points.append(pt)
        counter += 1
    return points

def predict_again(x, model, patch_sz=PATCH_SZ, n_classes=N_CLASSES):
    img_height = x.shape[0]
    img_width = x.shape[1]
    n_channels = x.shape[2]

    # make extended img so that it contains integer number of patches
    npatches_vertical = math.ceil(img_height/patch_sz)
    npatches_horizontal = math.ceil(img_width/patch_sz)
    extended_height = patch_sz * npatches_vertical
    extended_width = patch_sz * npatches_horizontal
    ext_x = np.zeros(shape=(extended_height, extended_width, n_channels), dtype=np.float32)
    # fill extended image with mirror reflections of neighbors:
    ext_x[:img_height, :img_width, :] = x
    for i in range(img_height, extended_height):
        ext_x[i, :, :] = ext_x[2*img_height - i - 1, :, :]
    for j in range(img_width, extended_width):
        ext_x[:, j, :] = ext_x[:, 2*img_width - j - 1, :]
    
    predictions = []
    for flip in [1,-1] :
        for rot in range(4) :
            ext = ext_x[::flip,:,:]
            ext = np.rot90(ext, k=rot, axes=(0,1))
            
            # now assemble all patches in one array
            patches_list = []
            for i in range(0, npatches_vertical):
                for j in range(0, npatches_horizontal):
                    x0, x1 = i * patch_sz, (i + 1) * patch_sz
                    y0, y1 = j * patch_sz, (j + 1) * patch_sz
                    patches_list.append(ext[x0:x1, y0:y1, :])
            # model.predict() needs numpy array rather than a list
            patches_array = np.asarray(patches_list)
            # predictions:
            patches_predict = model.predict(patches_array, batch_size=4)
            prediction = np.zeros(shape=(extended_height, extended_width, n_classes), dtype=np.float32)
            for k in range(patches_predict.shape[0]):
                i = k // npatches_horizontal
                j = k % npatches_vertical
                x0, x1 = i * patch_sz, (i + 1) * patch_sz
                y0, y1 = j * patch_sz, (j + 1) * patch_sz
                prediction[x0:x1, y0:y1, :] = patches_predict[k, :, :, :]

            # now assemble all patches in one array
            patches_list_overlap = []
            X_points = start_points(extended_width, patch_sz, 0.5)
            Y_points = start_points(extended_height, patch_sz, 0.5)
            for i in Y_points:
                for j in X_points:
                    split = ext[i:i+patch_sz, j:j+patch_sz]
                    patches_list_overlap.append(split)
            # model.predict() needs numpy array rather than a list
            patches_array_overlap = np.asarray(patches_list_overlap)
            # predictions:
            patches_predict_overlap = model.predict(patches_array_overlap, batch_size=4)
            for k in range(patches_predict_overlap.shape[0]):
                i = k // (2*npatches_horizontal-1)
                j = k % (2*npatches_vertical-1)
                x0 = int(i * patch_sz / 2)
                y0 = int(j * patch_sz / 2)
                x1, y1 = x0 + patch_sz, y0 + patch_sz
                prediction[x0:x1, y0:y1, :] = 0.5*patches_predict_overlap[k, :, :, :] + 0.5*prediction[x0:x1, y0:y1, :]
            
            pred = np.rot90(prediction, k=4-rot, axes=(0,1))
            pred = pred[::flip,:,:]
#             tiff.imshow(picture_from_mask(pred.transpose([2,0,1]), threshold=0.3))
            predictions.append(pred)
    predictions = np.asarray(predictions)
    prediction = np.mean(predictions, axis=0)
    return prediction[:img_height, :img_width, :]

Note that method `predict()` in the cell above uses `batch_size`. Batch prediction is usually used when data set to be predicted is very large and may not fit in memory. In such case prediction will be done batch by batch.

Show image of the created mask. <br>
On this image use color codes of the first 5 colors for the 5 classes. <br>
Create function that takes a mask created by *'predict()'* and a threshold and returns an RGB file that can be shown by *'imshow()'*. <br>

In function *'picture_from_mask()'* created below:
- Dictionary variable *'colors'* contains first 5 colors corresponding to the 5 classes of objects. Color of each class is defined as combination of 3 basic colors
- Dictionary *'z_order'* creates special order of classes in which the mask-image is created. If the same pixel has high enough probability of belonging to several classes then the pixel is marked as highest of them in *'z_order'*. Basically, this means that in the loop over *'z_order'* color of the next significant class replaces the color of the previous one.
- A class of a pixel is considered "significant" if probability of that class is greater than "threshold".

In [14]:
def picture_from_mask(mask, threshold=0):
    colors = {
        0: [150, 150, 150],  # Buildings
        1: [223, 194, 125],  # Roads & Tracks
        2: [27, 120, 55],    # Trees
        3: [166, 219, 160],  # Crops
        4: [116, 173, 209]   # Water
    }
    z_order = {
        1: 3,
        2: 4,
        3: 0,
        4: 1,
        5: 2
    }
    pict = 255*np.ones(shape=(3, mask.shape[1], mask.shape[2]), dtype=np.uint8)
    for i in range(1, 6):
        cl = z_order[i]
        for ch in range(3):
            pict[ch,:,:][mask[cl,:,:] > threshold] = colors[cl][ch]
    return pict

Read test image, normalize it and make predictions.

In [34]:
test_id = 'test'
img = normalize(tiff.imread(f'{DATA_PATH_PREFIX}/mband/{test_id}.tif').transpose([1,2,0]))   # make channels last
my_mask = predict_again(img, model, patch_sz=PATCH_SZ, n_classes=N_CLASSES).transpose([2,0,1])
mask = predict_again(img, model_crops, patch_sz=PATCH_SZ, n_classes=N_CLASSES).transpose([2,0,1])  # make channels first
map_ = picture_from_mask((mask+my_mask)/2, threshold=0.3)
tiff.imsave('result.tif', (255*(mask+my_mask)/2).astype('uint8'))
tiff.imsave('map.tif', map_.astype('uint8'))
tiff.imshow(map_)