# Comparison of ensemble model predictions: satellite shoreline example

This notebook walks through loading several models each trained for the same task, making ensemble prediction with an image, and inspecting the outputs


by Daniel Buscombe, December 2022

## Contents

* Part 1: load libraries
* Part 2: select and load models
    * we use several models trained for the same task: finding the coastal shoreline
    * we repurpose a snippet of code from the Gym script `seg_images_in_folder.py` to load in and make each model and apply the model weights from h5 files
* Part 3: Application of each model for 2-classes
    * apply all models and aggregate softmax scores for an ensemble model prediction
    * Otsu (adaptive) thresholding versus normal thresholding
    * TTA versus no TTA (test-time augmentation)
* Part 4: Application of each model for 4-classes remapped to 2-classes
    * apply all models and aggregate softmax scores for an ensemble model prediction
    * Otsu (adaptive) thresholding versus normal thresholding
    * TTA versus no TTA (test-time augmentation)

## Part 1: load libraries

In [26]:
import sys,os, json
from tkinter import filedialog
from tkinter import *
import requests
from glob import glob
from tqdm import tqdm

from doodleverse_utils.prediction_imports import seg_file2tensor_3band, seg_file2tensor_ND

def download_url(url, save_path, chunk_size=128):
    r = requests.get(url, stream=True)
    with open(save_path, 'wb') as fd:
        for chunk in r.iter_content(chunk_size=chunk_size):
            fd.write(chunk)


## Part 2: select and load models

In [95]:
choice = 'ENSEMBLE'

#### choose zenodo release
models = ['sat_RGB_2class_7384255', 'sat_5band_2class_7388008', 
            'sat_RGB_4class_6950472', 'sat_5band_4class_7344606']


In [96]:
# Select directory of images (or npzs) to segment
sample_direc_3band = '/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_3band'

sample_direc_5band = '/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band'

In [97]:
for dataset_id in models:
    
    print("Dataset ID : {}".format(dataset_id))

    zenodo_id = dataset_id.split('_')[-1]
    print("Zenodo ID : {}".format(zenodo_id))
    
    try:
        os.mkdir('../downloaded_models')
    except:
        pass

    try:
        os.mkdir('../downloaded_models/'+dataset_id)
    except:
        pass

    model_direc = '../downloaded_models/'+dataset_id

    root_url = 'https://zenodo.org/api/records/'+zenodo_id

    r = requests.get(root_url)

    js = json.loads(r.text)
    files = js['files']
    
    
    # get list of all models
    all_models = [f for f in files if f['key'].endswith('.h5')]

    # download all weights
    for a in all_models:
        outfile = model_direc + os.sep + a['links']['self'].split('/')[-1]
        if not os.path.isfile(outfile):
            print("Downloading file to {}".format(outfile))
            download_url(a['links']['self'], outfile)

    # download all con fig
    for a in all_models:
        outfile = model_direc + os.sep + a['links']['self'].split('/')[-1]
        outfile = outfile.replace('_fullmodel.h5','.json')
        if not os.path.isfile(outfile):
            print("Downloading file to {}".format(outfile))
            download_url(a['links']['self'].replace('_fullmodel.h5','.json'), outfile)
    

Dataset ID : sat_RGB_2class_7384255
Zenodo ID : 7384255
Dataset ID : sat_5band_2class_7388008
Zenodo ID : 7388008
Dataset ID : sat_RGB_4class_6950472
Zenodo ID : 6950472
Dataset ID : sat_5band_4class_7344606
Zenodo ID : 7344606
Downloading file to ../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v1_fullmodel.h5
Downloading file to ../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v2_fullmodel.h5
Downloading file to ../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v3_fullmodel.h5
Downloading file to ../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v4_fullmodel.h5
Downloading file to ../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v1.json
Downloading file to ../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v2.json
Downloading file to ../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v3.json
Downloading file to ../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v4.json


## Part 3: Application of each model for 2-classes

In [64]:

# =========================================================
def do_seg_2class(
    f, M, metadatadict, sample_direc, 
    NCLASSES, N_DATA_BANDS, TARGET_SIZE, TESTTIMEAUG, WRITE_MODELMETADATA,
    OTSU_THRESHOLD
):

    if f.endswith("jpg"):
        segfile = f.replace(".jpg", "_predseg.png")
    elif f.endswith("png"):
        segfile = f.replace(".png", "_predseg.png")
    elif f.endswith("npz"):  # in f:
        segfile = f.replace(".npz", "_predseg.png")

    if WRITE_MODELMETADATA:
        metadatadict["input_file"] = f

    segfile = os.path.normpath(segfile)

    if OTSU_THRESHOLD:
        if TESTTIMEAUG:
            outfolder = "out_otsu_tta"
        else:
            outfolder = "out_otsu_notta"
    else:
        if TESTTIMEAUG:
            outfolder = "out_nootsu_tta"
        else:
            outfolder = "out_nootsu_notta"        
    
    print(outfolder)
    
    segfile = segfile.replace(
        os.path.normpath(sample_direc), os.path.normpath(sample_direc + os.sep + outfolder)
    )

    try:
        os.mkdir(os.path.normpath(sample_direc + os.sep + outfolder))
    except:
        pass

    if WRITE_MODELMETADATA:
        metadatadict["nclasses"] = NCLASSES
        metadatadict["n_data_bands"] = N_DATA_BANDS


    if N_DATA_BANDS <= 3:
        image, w, h, bigimage = seg_file2tensor_3band(f, TARGET_SIZE)
    else:
        image, w, h, bigimage = seg_file2tensor_ND(f, TARGET_SIZE)

    image = standardize(image.numpy()).squeeze()

    E0 = []
    E1 = []

    for counter, model in enumerate(M):

        try:
            est_label = model.predict(tf.expand_dims(image, 0), batch_size=8).squeeze()
        except:
            est_label = model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=8).squeeze()

        if TESTTIMEAUG == True:
            # return the flipped prediction
            est_label2 = np.flipud(
                model.predict(
                    tf.expand_dims(np.flipud(image), 0), batch_size=1
                ).squeeze()
            )
            est_label3 = np.fliplr(
                model.predict(
                    tf.expand_dims(np.fliplr(image), 0), batch_size=1
                ).squeeze()
            )
            est_label4 = np.flipud(
                np.fliplr(
                    model.predict(
                        tf.expand_dims(np.flipud(np.fliplr(image)), 0), batch_size=1
                    ).squeeze()
                )
            )

            # soft voting - sum the softmax scores to return the new TTA estimated softmax scores
            est_label = est_label + est_label2 + est_label3 + est_label4
            del est_label2, est_label3, est_label4

        E0.append(
            resize(est_label[:, :, 0], (w, h), preserve_range=True, clip=True)
        )
        E1.append(
            resize(est_label[:, :, 1], (w, h), preserve_range=True, clip=True)
        )
        del est_label

    K.clear_session()

    e0 = np.average(np.dstack(E0), axis=-1)  # , weights=np.array(MW))

    del E0

    e1 = np.average(np.dstack(E1), axis=-1)  # , weights=np.array(MW))
    del E1

    est_label = (e1 + (1 - e0)) / 2

    if WRITE_MODELMETADATA:
        metadatadict["av_prob_stack"] = est_label

    softmax_scores = np.dstack((e0,e1))
    del e0, e1

    if WRITE_MODELMETADATA:
        metadatadict["av_softmax_scores"] = softmax_scores

    if OTSU_THRESHOLD:
        thres = threshold_otsu(est_label)
        # print("Class threshold: %f" % (thres))
        est_label = (est_label > thres).astype("uint8")
        if WRITE_MODELMETADATA:
            metadatadict["otsu_threshold"] = thres

    else:
        # print("Not using Otsu threshold")
        est_label = (est_label > 0.5).astype("uint8")
        if WRITE_MODELMETADATA:
            metadatadict["otsu_threshold"] = 0.5            


    class_label_colormap = [
        "#3366CC",
        "#DC3912",
        "#FF9900",
        "#109618",
        "#990099",
        "#0099C6",
        "#DD4477",
        "#66AA00",
        "#B82E2E",
        "#316395",
    ]
    # add classes for more than 10 classes

    # if NCLASSES > 1:
    class_label_colormap = class_label_colormap[:NCLASSES]

    if WRITE_MODELMETADATA:
        metadatadict["color_segmentation_output"] = segfile

    try:
        color_label = label_to_colors(
            est_label,
            bigimage.numpy()[:, :, 0] == 0,
            alpha=128,
            colormap=class_label_colormap,
            color_class_offset=0,
            do_alpha=False,
        )
    except:
        try:
            color_label = label_to_colors(
                est_label,
                bigimage[:, :, 0] == 0,
                alpha=128,
                colormap=class_label_colormap,
                color_class_offset=0,
                do_alpha=False,
            )
        except:
            color_label = label_to_colors(
                est_label,
                bigimage == 0,
                alpha=128,
                colormap=class_label_colormap,
                color_class_offset=0,
                do_alpha=False,
            )        

    imsave(segfile, (color_label).astype(np.uint8), check_contrast=False)
    
    if WRITE_MODELMETADATA:
        metadatadict["color_segmentation_output"] = segfile

    segfile = segfile.replace("_predseg.png", "_res.npz")

    if WRITE_MODELMETADATA:
        metadatadict["grey_label"] = est_label

        np.savez_compressed(segfile, **metadatadict)

    segfile = segfile.replace("_res.npz", "_overlay.png")

    if N_DATA_BANDS <= 3:
        plt.imshow(bigimage, cmap='gray')
    else:
        plt.imshow(bigimage[:, :, :3])

    plt.imshow(color_label, alpha=0.5)
    plt.axis("off")
    # plt.show()
    plt.savefig(segfile, dpi=200, bbox_inches="tight")
    plt.close("all")

    #### image - overlay side by side
    segfile = segfile.replace("_res.npz", "_image_overlay.png")

    plt.subplot(121)
    if N_DATA_BANDS <= 3:
        plt.imshow(bigimage, cmap='gray')
    else:
        plt.imshow(bigimage[:, :, :3])
    plt.axis("off")

    plt.subplot(122)
    if N_DATA_BANDS <= 3:
        plt.imshow(bigimage, cmap='gray')
    else:
        plt.imshow(bigimage[:, :, :3])
    if NCLASSES>2:
        plt.imshow(color_label, alpha=0.5)
    elif NCLASSES==2:
        cs = plt.contour(est_label, [-99,0,99], colors='r')
    plt.axis("off")
    # plt.show()
    plt.savefig(segfile, dpi=200, bbox_inches="tight")
    plt.close("all")

    if NCLASSES==2:
        segfile = segfile.replace("_overlay.png", "_result.mat")
        p = cs.collections[0].get_paths()[0]
        v = p.vertices
        x = v[:,0]
        y = v[:,1]
        io.savemat(segfile, dict(x=x, y=y))

### 5-band, 2-class

In [75]:
# W : list containing all the weight files fill paths
W=[]
W.append('../downloaded_models/sat_5band_2class_7388008/sat2class_5d_512_v1_fullmodel.h5')
W.append('../downloaded_models/sat_5band_2class_7388008/sat2class_5d_512_v2_fullmodel.h5')
W.append('../downloaded_models/sat_5band_2class_7388008/sat2class_5d_512_v3_fullmodel.h5')
W.append('../downloaded_models/sat_5band_2class_7388008/sat2class_5d_512_v4_fullmodel.h5')
W.append('../downloaded_models/sat_5band_2class_7388008/sat2class_5d_512_v5_fullmodel.h5')

In [76]:
sample_direc = sample_direc_5band


# The following lines prepare the data to be predicted
sample_filenames = sorted(glob(sample_direc+os.sep+'*.*'))
if sample_filenames[0].split('.')[-1]=='npz':
    sample_filenames = sorted(tf.io.gfile.glob(sample_direc+os.sep+'*.npz'))
else:
    sample_filenames = sorted(tf.io.gfile.glob(sample_direc+os.sep+'*.jpg'))
    if len(sample_filenames)==0:
        sample_filenames = sorted(glob(sample_direc+os.sep+'*.png'))

print('Number of samples: %i' % (len(sample_filenames)))

Number of samples: 55


In [77]:

# For each set of weights in W load them in
M= []; C=[]; T = []
for counter,weights in enumerate(W):

    try:
        # "fullmodel" is for serving on zoo they are smaller and more portable between systems than traditional h5 files
        # gym makes a h5 file, then you use gym to make a "fullmodel" version then zoo can read "fullmodel" version
        configfile = weights.replace('_fullmodel.h5','.json').replace('weights', 'config')
        with open(configfile) as f:
            config = json.load(f)
    except:
        # Turn the .h5 file into a json so that the data can be loaded into dynamic variables        
        configfile = weights.replace('.h5','.json').replace('weights', 'config')
        with open(configfile) as f:
            config = json.load(f)
    # Dynamically creates all variables from config dict.
    # For example configs's {'TARGET_SIZE': [768, 768]} will be created as TARGET_SIZE=[768, 768]
    # This is how the program is able to use variables that have never been explicitly defined
    for k in config.keys():
        exec(k+'=config["'+k+'"]')


    if counter==0:
        #####################################
        #### hardware
        ####################################

        SET_GPU = str(SET_GPU)

        if SET_GPU != '-1':
            USE_GPU = True
            print('Using GPU')
        else:
            USE_GPU = False
            print('Using CPU')

        if len(SET_GPU.split(','))>1:
            USE_MULTI_GPU = True 
            print('Using multiple GPUs')
        else:
            USE_MULTI_GPU = False
            if USE_GPU:
                print('Using single GPU device')
            else:
                print('Using single CPU device')

        #suppress tensorflow warnings
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

        if USE_GPU == True:
            os.environ['CUDA_VISIBLE_DEVICES'] = SET_GPU

            from doodleverse_utils.prediction_imports import *
            from tensorflow.python.client import device_lib
            physical_devices = tf.config.experimental.list_physical_devices('GPU')
            print(physical_devices)

            if physical_devices:
                # Restrict TensorFlow to only use the first GPU
                try:
                    tf.config.experimental.set_visible_devices(physical_devices, 'GPU')
                except RuntimeError as e:
                    # Visible devices must be set at program startup
                    print(e)
        else:
            os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

            from doodleverse_utils.prediction_imports import *
            from tensorflow.python.client import device_lib
            physical_devices = tf.config.experimental.list_physical_devices('GPU')
            print(physical_devices)

        ### mixed precision
        from tensorflow.keras import mixed_precision
        mixed_precision.set_global_policy('mixed_float16')
        # tf.debugging.set_log_device_placement(True)

        for i in physical_devices:
            tf.config.experimental.set_memory_growth(i, True)
        print(tf.config.get_visible_devices())

        if USE_MULTI_GPU:
            # Create a MirroredStrategy.
            strategy = tf.distribute.MirroredStrategy([p.name.split('/physical_device:')[-1] for p in physical_devices], cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
            print("Number of distributed devices: {}".format(strategy.num_replicas_in_sync))


    #from imports import *
    from doodleverse_utils.imports import *
    from doodleverse_utils.model_imports import *

    #---------------------------------------------------

    #=======================================================
    # Import the architectures for following models from doodleverse_utils
    # 1. custom_resunet
    # 2. custom_unet
    # 3. simple_resunet
    # 4. simple_unet
    # 5. satunet
    # 6. custom_resunet
    # 7. custom_satunet

    # Get the selected model based on the weights file's MODEL key provided
    # create the model with the data loaded in from the weights file
    print('.....................................')
    print('Creating and compiling model {}...'.format(counter))

    if MODEL =='resunet':
        model =  custom_resunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                        FILTERS,
                        nclasses=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                        kernel_size=(KERNEL,KERNEL),
                        strides=STRIDE,
                        dropout=DROPOUT,
                        dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                        dropout_type=DROPOUT_TYPE,
                        use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                        )
    elif MODEL=='unet':
        model =  custom_unet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                        FILTERS,
                        nclasses=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                        kernel_size=(KERNEL,KERNEL),
                        strides=STRIDE,
                        dropout=DROPOUT,
                        dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                        dropout_type=DROPOUT_TYPE,
                        use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                        )

    elif MODEL =='simple_resunet':

        model = simple_resunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    elif MODEL=='simple_unet':
        model = simple_unet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    elif MODEL=='satunet':

        model = custom_satunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    else:
        print("Model must be one of 'unet', 'resunet', or 'satunet'")
        sys.exit(2)

    try:
        # Load in the model from the weights which is the location of the weights file        
        model = tf.keras.models.load_model(weights)

        M.append(model)
        C.append(configfile)
        T.append(MODEL)
        
    except:
        # Load the metrics mean_iou, dice_coef from doodleverse_utils
        # Load in the custom loss function from doodleverse_utils        
        model.compile(optimizer = 'adam', loss = dice_coef_loss(NCLASSES))#, metrics = [iou_multi(NCLASSES), dice_multi(NCLASSES)])

        model.load_weights(weights)

        M.append(model)
        C.append(configfile)
        T.append(MODEL)

# metadatadict contains the model name (T) the config file(C) and the model weights(W)
metadatadict = {}
metadatadict['model_weights'] = W
metadatadict['config_files'] = C
metadatadict['model_types'] = T


Using GPU
Using single GPU device
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
.....................................
Creating and compiling model 0...
.....................................
Creating and compiling model 1...
.....................................
Creating and compiling model 2...
.....................................
Creating and compiling model 3...
.....................................
Creating and compiling model 4...


### No TTA, No OTsu

In [78]:
TESTTIMEAUG = False
OTSU_THRESHOLD = False

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_nootsu_notta


  cs = plt.contour(est_label, [-99,0,99], colors='r')


out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2004-10-31-15-23-44_L7_DUCK_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000163.npz failed. Check config file, and check the path provided contains valid imagery
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2013-06-07-18-24-34_L8_TORREYPINES_SDS_BENCHMARK_rgb_pan_noaug_nd_data_0000004.npz failed. Check config file, and check the path provided contains valid imagery
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta


out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta


### TTA, No Otsu

In [79]:
TESTTIMEAUG = True
OTSU_THRESHOLD = False

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_nootsu_tta


  cs = plt.contour(est_label, [-99,0,99], colors='r')


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2004-10-31-15-23-44_L7_DUCK_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000163.npz failed. Check config file, and check the path provided contains valid imagery
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2013-06-07-18-24-34_L8_TORREYPINES_SDS_BENCHMARK_rgb_pan_noaug_nd_data_0000004.npz failed. Check config file, and check the path provided contains valid imagery
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2020-12-04-10-48-33_L8_TRUCVERT_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000240.npz failed. Check config file, and check the path provided contains valid imagery
out_nootsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2020-12-10-23-44-04_L8_NARRABEEN_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000301.npz failed. Check config file, and check the path provided contains valid imagery
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2021-07-21-22-52-11_L7_NARRABEEN_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000447.npz failed. Check config file, and check the path provided contains valid imagery
out_nootsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2021-09-01-18-22-46_L8_TORREYPINES_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000188.npz failed. Check config file, and check the path provided contains valid imagery
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


### No TTA, Otsu

In [80]:
TESTTIMEAUG = False
OTSU_THRESHOLD = True

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_otsu_notta


  cs = plt.contour(est_label, [-99,0,99], colors='r')


out_otsu_notta
out_otsu_notta
out_otsu_notta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2000-04-15-23-36-00_L7_NARRABEEN_SDS_BENCHMARK_rgb_pan_noaug_nd_data_00000012.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2013-03-20-15-41-59_L8_DUCK_SDS_BENCHMARK_rgb_pan_noaug_nd_data_0000000.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2013-04-30-15-42-50_L8_DUCK_SDS_BENCHMARK_rgb_pan_noaug_nd_data_0000002.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_notta
out_otsu_notta
out_otsu_notta
/med

out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2015-11-27-23-44-16_L8_NARRABEEN_SDS_BENCHMARK_rgb_pan_noaug_nd_data_00000091.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2020-12-04-10-48-33_L8_TRUCVERT_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000240.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_notta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2020-12-10-23-44-04_L8_NARRABEEN_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000301.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
/media/marda/TWOTB1/U

### TTA, Otsu

In [81]:
TESTTIMEAUG = True
OTSU_THRESHOLD = True

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_otsu_tta


  cs = plt.contour(est_label, [-99,0,99], colors='r')


out_otsu_tta
out_otsu_tta
out_otsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2000-04-15-23-36-00_L7_NARRABEEN_SDS_BENCHMARK_rgb_pan_noaug_nd_data_00000012.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2013-03-20-15-41-59_L8_DUCK_SDS_BENCHMARK_rgb_pan_noaug_nd_data_0000000.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2013-04-30-15-42-50_L8_DUCK_SDS_BENCHMARK_rgb_pan_noaug_nd_data_0000002.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_tta
out_otsu_tta
out_otsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2013-06-07-18-24-34_L8_TORREYPINES_SDS_BENCHMARK_rgb_pan_noaug_nd_data_0000004.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2015-11-27-23-44-16_L8_NARRABEEN_SDS_BENCHMARK_rgb_pan_noaug_nd_data_00000091.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2020-12-04-10-48-33_L8_TRUCVERT_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000240.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2020-12-10-23-44-04_L8_NARRABEEN_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000301.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2021-07-21-22-52-11_L7_NARRABEEN_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000447.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_tta
/media/marda/TWOTB1/USGS/Doodleverse/github/sample_data/sat_5band/2021-09-01-18-22-46_L8_TORREYPINES_SDS_BENCHMARK_rgb_pan_noaug_nd_data_000000188.npz failed. Check config file, and check the path provided contains valid imagery
out_otsu_tta
out_otsu_tta
out_otsu_tta




### RGB / 2-class

In [47]:
# W : list containing all the weight files fill paths
W=[]
W.append('../downloaded_models/sat_RGB_2class_7384255/sat2class_rgb_512_v3_fullmodel.h5')
W.append('../downloaded_models/sat_RGB_2class_7384255/sat2class_rgb_512_v4_fullmodel.h5')
W.append('../downloaded_models/sat_RGB_2class_7384255/sat2class_rgb_512_v5_fullmodel.h5')
W.append('../downloaded_models/sat_RGB_2class_7384255/sat2class_rgb_512_v6_fullmodel.h5')
W.append('../downloaded_models/sat_RGB_2class_7384255/sat2class_rgb_512_v7_fullmodel.h5')

In [63]:
sample_direc = sample_direc_3band


# The following lines prepare the data to be predicted
sample_filenames = sorted(glob(sample_direc+os.sep+'*.*'))
if sample_filenames[0].split('.')[-1]=='npz':
    sample_filenames = sorted(tf.io.gfile.glob(sample_direc+os.sep+'*.npz'))
else:
    sample_filenames = sorted(tf.io.gfile.glob(sample_direc+os.sep+'*.jpg'))
    if len(sample_filenames)==0:
        sample_filenames = sorted(glob(sample_direc+os.sep+'*.png'))

print('Number of samples: %i' % (len(sample_filenames)))

Number of samples: 44


In [49]:

# For each set of weights in W load them in
M= []; C=[]; T = []
for counter,weights in enumerate(W):

    try:
        # "fullmodel" is for serving on zoo they are smaller and more portable between systems than traditional h5 files
        # gym makes a h5 file, then you use gym to make a "fullmodel" version then zoo can read "fullmodel" version
        configfile = weights.replace('_fullmodel.h5','.json').replace('weights', 'config')
        with open(configfile) as f:
            config = json.load(f)
    except:
        # Turn the .h5 file into a json so that the data can be loaded into dynamic variables        
        configfile = weights.replace('.h5','.json').replace('weights', 'config')
        with open(configfile) as f:
            config = json.load(f)
    # Dynamically creates all variables from config dict.
    # For example configs's {'TARGET_SIZE': [768, 768]} will be created as TARGET_SIZE=[768, 768]
    # This is how the program is able to use variables that have never been explicitly defined
    for k in config.keys():
        exec(k+'=config["'+k+'"]')


    if counter==0:
        #####################################
        #### hardware
        ####################################

        SET_GPU = str(SET_GPU)

        if SET_GPU != '-1':
            USE_GPU = True
            print('Using GPU')
        else:
            USE_GPU = False
            print('Using CPU')

        if len(SET_GPU.split(','))>1:
            USE_MULTI_GPU = True 
            print('Using multiple GPUs')
        else:
            USE_MULTI_GPU = False
            if USE_GPU:
                print('Using single GPU device')
            else:
                print('Using single CPU device')

        #suppress tensorflow warnings
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

        if USE_GPU == True:
            os.environ['CUDA_VISIBLE_DEVICES'] = SET_GPU

            from doodleverse_utils.prediction_imports import *
            from tensorflow.python.client import device_lib
            physical_devices = tf.config.experimental.list_physical_devices('GPU')
            print(physical_devices)

            if physical_devices:
                # Restrict TensorFlow to only use the first GPU
                try:
                    tf.config.experimental.set_visible_devices(physical_devices, 'GPU')
                except RuntimeError as e:
                    # Visible devices must be set at program startup
                    print(e)
        else:
            os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

            from doodleverse_utils.prediction_imports import *
            from tensorflow.python.client import device_lib
            physical_devices = tf.config.experimental.list_physical_devices('GPU')
            print(physical_devices)

        ### mixed precision
        from tensorflow.keras import mixed_precision
        mixed_precision.set_global_policy('mixed_float16')
        # tf.debugging.set_log_device_placement(True)

        for i in physical_devices:
            tf.config.experimental.set_memory_growth(i, True)
        print(tf.config.get_visible_devices())

        if USE_MULTI_GPU:
            # Create a MirroredStrategy.
            strategy = tf.distribute.MirroredStrategy([p.name.split('/physical_device:')[-1] for p in physical_devices], cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
            print("Number of distributed devices: {}".format(strategy.num_replicas_in_sync))


    #from imports import *
    from doodleverse_utils.imports import *
    from doodleverse_utils.model_imports import *

    #---------------------------------------------------

    #=======================================================
    # Import the architectures for following models from doodleverse_utils
    # 1. custom_resunet
    # 2. custom_unet
    # 3. simple_resunet
    # 4. simple_unet
    # 5. satunet
    # 6. custom_resunet
    # 7. custom_satunet

    # Get the selected model based on the weights file's MODEL key provided
    # create the model with the data loaded in from the weights file
    print('.....................................')
    print('Creating and compiling model {}...'.format(counter))

    if MODEL =='resunet':
        model =  custom_resunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                        FILTERS,
                        nclasses=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                        kernel_size=(KERNEL,KERNEL),
                        strides=STRIDE,
                        dropout=DROPOUT,
                        dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                        dropout_type=DROPOUT_TYPE,
                        use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                        )
    elif MODEL=='unet':
        model =  custom_unet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                        FILTERS,
                        nclasses=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                        kernel_size=(KERNEL,KERNEL),
                        strides=STRIDE,
                        dropout=DROPOUT,
                        dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                        dropout_type=DROPOUT_TYPE,
                        use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                        )

    elif MODEL =='simple_resunet':

        model = simple_resunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    elif MODEL=='simple_unet':
        model = simple_unet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    elif MODEL=='satunet':

        model = custom_satunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    else:
        print("Model must be one of 'unet', 'resunet', or 'satunet'")
        sys.exit(2)

    try:
        # Load in the model from the weights which is the location of the weights file        
        model = tf.keras.models.load_model(weights)

        M.append(model)
        C.append(configfile)
        T.append(MODEL)
        
    except:
        # Load the metrics mean_iou, dice_coef from doodleverse_utils
        # Load in the custom loss function from doodleverse_utils        
        model.compile(optimizer = 'adam', loss = dice_coef_loss(NCLASSES))#, metrics = [iou_multi(NCLASSES), dice_multi(NCLASSES)])

        model.load_weights(weights)

        M.append(model)
        C.append(configfile)
        T.append(MODEL)

# metadatadict contains the model name (T) the config file(C) and the model weights(W)
metadatadict = {}
metadatadict['model_weights'] = W
metadatadict['config_files'] = C
metadatadict['model_types'] = T


Using GPU
Using single GPU device
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
.....................................
Creating and compiling model 0...
.....................................
Creating and compiling model 1...
.....................................
Creating and compiling model 2...
.....................................
Creating and compiling model 3...
.....................................
Creating and compiling model 4...


### no TTA, no Otsu

In [65]:
TESTTIMEAUG = False
OTSU_THRESHOLD = False

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_nootsu_notta


  cs = plt.contour(est_label, [-99,0,99], colors='r')


out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta


out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta


### TTA, no Otsu

In [66]:
TESTTIMEAUG = True
OTSU_THRESHOLD = False

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_nootsu_tta


  cs = plt.contour(est_label, [-99,0,99], colors='r')


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


### no TTA, Otsu

In [67]:
TESTTIMEAUG = False
OTSU_THRESHOLD = True

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_otsu_notta


  cs = plt.contour(est_label, [-99,0,99], colors='r')


out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta


out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta


### TTA, Otsu

In [68]:
TESTTIMEAUG = True
OTSU_THRESHOLD = True

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_otsu_tta


  cs = plt.contour(est_label, [-99,0,99], colors='r')


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


## Part 4: Application of each model for 4-classes remapped to 2-classes

In [90]:

# =========================================================
def do_seg_4class(
    f, M, metadatadict, sample_direc, 
    NCLASSES, N_DATA_BANDS, TARGET_SIZE, TESTTIMEAUG, WRITE_MODELMETADATA,
    OTSU_THRESHOLD
):

    if f.endswith("jpg"):
        segfile = f.replace(".jpg", "_predseg.png")
    elif f.endswith("png"):
        segfile = f.replace(".png", "_predseg.png")
    elif f.endswith("npz"):  # in f:
        segfile = f.replace(".npz", "_predseg.png")

    if WRITE_MODELMETADATA:
        metadatadict["input_file"] = f

    segfile = os.path.normpath(segfile)

    if OTSU_THRESHOLD:
        if TESTTIMEAUG:
            outfolder = "out_otsu_tta"
        else:
            outfolder = "out_otsu_notta"
    else:
        if TESTTIMEAUG:
            outfolder = "out_nootsu_tta"
        else:
            outfolder = "out_nootsu_notta"        
    
    print(outfolder)
    
    segfile = segfile.replace(
        os.path.normpath(sample_direc), os.path.normpath(sample_direc + os.sep + outfolder)
    )

    try:
        os.mkdir(os.path.normpath(sample_direc + os.sep + outfolder))
    except:
        pass

    if WRITE_MODELMETADATA:
        metadatadict["nclasses"] = NCLASSES
        metadatadict["n_data_bands"] = N_DATA_BANDS


    if N_DATA_BANDS <= 3:
        image, w, h, bigimage = seg_file2tensor_3band(
            f, TARGET_SIZE
        )  # , resize=True)
        w = w.numpy()
        h = h.numpy()
    else:
        image, w, h, bigimage = seg_file2tensor_ND(f, TARGET_SIZE)

    image = standardize(image.numpy())
    
    # return the base prediction
    if N_DATA_BANDS == 1:
        image = image[:, :, 0]
        bigimage = np.dstack((bigimage, bigimage, bigimage))

    est_label = np.zeros((TARGET_SIZE[0], TARGET_SIZE[1], NCLASSES))
    for counter, model in enumerate(M):
        # heatmap = make_gradcam_heatmap(tf.expand_dims(image, 0) , model)

        est_label = model.predict(tf.expand_dims(image, 0), batch_size=1).squeeze()

        if TESTTIMEAUG == True:
            # return the flipped prediction
            est_label2 = np.flipud(
                model.predict(
                    tf.expand_dims(np.flipud(image), 0), batch_size=1
                ).squeeze()
            )
            est_label3 = np.fliplr(
                model.predict(
                    tf.expand_dims(np.fliplr(image), 0), batch_size=1
                ).squeeze()
            )
            est_label4 = np.flipud(
                np.fliplr(
                    model.predict(
                        tf.expand_dims(np.flipud(np.fliplr(image)), 0), batch_size=1
                    ).squeeze()
                )
            )

            # soft voting - sum the softmax scores to return the new TTA estimated softmax scores
            est_label = est_label + est_label2 + est_label3 + est_label4
            del est_label2, est_label3, est_label4

        K.clear_session()

    est_label /= counter + 1
    est_label = resize(est_label, (w, h))
    if WRITE_MODELMETADATA:
        metadatadict["av_prob_stack"] = est_label

    softmax_scores = est_label.copy() #np.dstack((e0,e1))

    if WRITE_MODELMETADATA:
        metadatadict["av_softmax_scores"] = softmax_scores

    est_label = np.argmax(softmax_scores, -1)


    class_label_colormap = [
        "#3366CC",
        "#DC3912",
        "#FF9900",
        "#109618",
        "#990099",
        "#0099C6",
        "#DD4477",
        "#66AA00",
        "#B82E2E",
        "#316395",
    ]
    # add classes for more than 10 classes

    # if NCLASSES > 1:
    class_label_colormap = class_label_colormap[:NCLASSES]

    if WRITE_MODELMETADATA:
        metadatadict["color_segmentation_output"] = segfile

    try:
        color_label = label_to_colors(
            est_label,
            bigimage.numpy()[:, :, 0] == 0,
            alpha=128,
            colormap=class_label_colormap,
            color_class_offset=0,
            do_alpha=False,
        )
    except:
        try:
            color_label = label_to_colors(
                est_label,
                bigimage[:, :, 0] == 0,
                alpha=128,
                colormap=class_label_colormap,
                color_class_offset=0,
                do_alpha=False,
            )
        except:
            color_label = label_to_colors(
                est_label,
                bigimage == 0,
                alpha=128,
                colormap=class_label_colormap,
                color_class_offset=0,
                do_alpha=False,
            )        

    imsave(segfile, (color_label).astype(np.uint8), check_contrast=False)
    
    if WRITE_MODELMETADATA:
        metadatadict["color_segmentation_output"] = segfile

    segfile = segfile.replace("_predseg.png", "_res.npz")

    if WRITE_MODELMETADATA:
        metadatadict["grey_label"] = est_label

        np.savez_compressed(segfile, **metadatadict)

    segfile = segfile.replace("_res.npz", "_overlay.png")

    if N_DATA_BANDS <= 3:
        plt.imshow(bigimage, cmap='gray')
    else:
        plt.imshow(bigimage[:, :, :3])

    plt.imshow(color_label, alpha=0.5)
    plt.axis("off")
    # plt.show()
    plt.savefig(segfile, dpi=200, bbox_inches="tight")
    plt.close("all")

    #### image - overlay side by side
    segfile = segfile.replace("_res.npz", "_image_overlay.png")

    if N_DATA_BANDS <= 3:
        plt.imshow(bigimage, cmap='gray')
    else:
        plt.imshow(bigimage[:, :, :3])
    plt.imshow(color_label, alpha=0.5)

    plt.axis("off")
    # plt.show()
    plt.savefig(segfile, dpi=200, bbox_inches="tight")
    plt.close("all")


### RGB / 4-class

In [86]:
# W : list containing all the weight files fill paths
W=[]
W.append('../downloaded_models/sat_RGB_4class_6950472/sat4class_rgb_512_v1_fullmodel.h5')
W.append('../downloaded_models/sat_RGB_4class_6950472/sat4class_rgb_512_v2_fullmodel.h5')
W.append('../downloaded_models/sat_RGB_4class_6950472/sat4class_rgb_512_v3_fullmodel.h5')
W.append('../downloaded_models/sat_RGB_4class_6950472/sat4class_rgb_512_v4_fullmodel.h5')

In [87]:

# For each set of weights in W load them in
M= []; C=[]; T = []
for counter,weights in enumerate(W):

    try:
        # "fullmodel" is for serving on zoo they are smaller and more portable between systems than traditional h5 files
        # gym makes a h5 file, then you use gym to make a "fullmodel" version then zoo can read "fullmodel" version
        configfile = weights.replace('_fullmodel.h5','.json').replace('weights', 'config')
        with open(configfile) as f:
            config = json.load(f)
    except:
        # Turn the .h5 file into a json so that the data can be loaded into dynamic variables        
        configfile = weights.replace('.h5','.json').replace('weights', 'config')
        with open(configfile) as f:
            config = json.load(f)
    # Dynamically creates all variables from config dict.
    # For example configs's {'TARGET_SIZE': [768, 768]} will be created as TARGET_SIZE=[768, 768]
    # This is how the program is able to use variables that have never been explicitly defined
    for k in config.keys():
        exec(k+'=config["'+k+'"]')


    if counter==0:
        #####################################
        #### hardware
        ####################################

        SET_GPU = str(SET_GPU)

        if SET_GPU != '-1':
            USE_GPU = True
            print('Using GPU')
        else:
            USE_GPU = False
            print('Using CPU')

        if len(SET_GPU.split(','))>1:
            USE_MULTI_GPU = True 
            print('Using multiple GPUs')
        else:
            USE_MULTI_GPU = False
            if USE_GPU:
                print('Using single GPU device')
            else:
                print('Using single CPU device')

        #suppress tensorflow warnings
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

        if USE_GPU == True:
            os.environ['CUDA_VISIBLE_DEVICES'] = SET_GPU

            from doodleverse_utils.prediction_imports import *
            from tensorflow.python.client import device_lib
            physical_devices = tf.config.experimental.list_physical_devices('GPU')
            print(physical_devices)

            if physical_devices:
                # Restrict TensorFlow to only use the first GPU
                try:
                    tf.config.experimental.set_visible_devices(physical_devices, 'GPU')
                except RuntimeError as e:
                    # Visible devices must be set at program startup
                    print(e)
        else:
            os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

            from doodleverse_utils.prediction_imports import *
            from tensorflow.python.client import device_lib
            physical_devices = tf.config.experimental.list_physical_devices('GPU')
            print(physical_devices)

        ### mixed precision
        from tensorflow.keras import mixed_precision
        mixed_precision.set_global_policy('mixed_float16')
        # tf.debugging.set_log_device_placement(True)

        for i in physical_devices:
            tf.config.experimental.set_memory_growth(i, True)
        print(tf.config.get_visible_devices())

        if USE_MULTI_GPU:
            # Create a MirroredStrategy.
            strategy = tf.distribute.MirroredStrategy([p.name.split('/physical_device:')[-1] for p in physical_devices], cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
            print("Number of distributed devices: {}".format(strategy.num_replicas_in_sync))


    #from imports import *
    from doodleverse_utils.imports import *
    from doodleverse_utils.model_imports import *

    #---------------------------------------------------

    #=======================================================
    # Import the architectures for following models from doodleverse_utils
    # 1. custom_resunet
    # 2. custom_unet
    # 3. simple_resunet
    # 4. simple_unet
    # 5. satunet
    # 6. custom_resunet
    # 7. custom_satunet

    # Get the selected model based on the weights file's MODEL key provided
    # create the model with the data loaded in from the weights file
    print('.....................................')
    print('Creating and compiling model {}...'.format(counter))

    if MODEL =='resunet':
        model =  custom_resunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                        FILTERS,
                        nclasses=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                        kernel_size=(KERNEL,KERNEL),
                        strides=STRIDE,
                        dropout=DROPOUT,
                        dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                        dropout_type=DROPOUT_TYPE,
                        use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                        )
    elif MODEL=='unet':
        model =  custom_unet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                        FILTERS,
                        nclasses=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                        kernel_size=(KERNEL,KERNEL),
                        strides=STRIDE,
                        dropout=DROPOUT,
                        dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                        dropout_type=DROPOUT_TYPE,
                        use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                        )

    elif MODEL =='simple_resunet':

        model = simple_resunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    elif MODEL=='simple_unet':
        model = simple_unet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    elif MODEL=='satunet':

        model = custom_satunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    else:
        print("Model must be one of 'unet', 'resunet', or 'satunet'")
        sys.exit(2)

    try:
        # Load in the model from the weights which is the location of the weights file        
        model = tf.keras.models.load_model(weights)

        M.append(model)
        C.append(configfile)
        T.append(MODEL)
        
    except:
        # Load the metrics mean_iou, dice_coef from doodleverse_utils
        # Load in the custom loss function from doodleverse_utils        
        model.compile(optimizer = 'adam', loss = dice_coef_loss(NCLASSES))#, metrics = [iou_multi(NCLASSES), dice_multi(NCLASSES)])

        model.load_weights(weights)

        M.append(model)
        C.append(configfile)
        T.append(MODEL)

# metadatadict contains the model name (T) the config file(C) and the model weights(W)
metadatadict = {}
metadatadict['model_weights'] = W
metadatadict['config_files'] = C
metadatadict['model_types'] = T


Using GPU
Using single GPU device
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
.....................................
Creating and compiling model 0...
.....................................
Creating and compiling model 1...
.....................................
Creating and compiling model 2...
.....................................
Creating and compiling model 3...


In [88]:
sample_direc = sample_direc_3band


# The following lines prepare the data to be predicted
sample_filenames = sorted(glob(sample_direc+os.sep+'*.*'))
if sample_filenames[0].split('.')[-1]=='npz':
    sample_filenames = sorted(tf.io.gfile.glob(sample_direc+os.sep+'*.npz'))
else:
    sample_filenames = sorted(tf.io.gfile.glob(sample_direc+os.sep+'*.jpg'))
    if len(sample_filenames)==0:
        sample_filenames = sorted(glob(sample_direc+os.sep+'*.png'))

print('Number of samples: %i' % (len(sample_filenames)))

Number of samples: 44


### no TTA, no Otsu

In [91]:
TESTTIMEAUG = False
OTSU_THRESHOLD = False

for f in sample_filenames:
    try:
        do_seg_4class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta


out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta


### TTA, no Otsu

In [92]:
TESTTIMEAUG = True
OTSU_THRESHOLD = False

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


### no TTA, Otsu

In [93]:
TESTTIMEAUG = False
OTSU_THRESHOLD = True

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta


out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta


### TTA, Otsu

In [94]:
TESTTIMEAUG = True
OTSU_THRESHOLD = True

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


### 5-band / 4-class

In [101]:
# W : list containing all the weight files fill paths
W=[]
W.append('../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v1_fullmodel.h5')
W.append('../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v2_fullmodel.h5')
W.append('../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v3_fullmodel.h5')
W.append('../downloaded_models/sat_5band_4class_7344606/sat4class_5d_512_v4_fullmodel.h5')

In [102]:
sample_direc = sample_direc_5band


# The following lines prepare the data to be predicted
sample_filenames = sorted(glob(sample_direc+os.sep+'*.*'))
if sample_filenames[0].split('.')[-1]=='npz':
    sample_filenames = sorted(tf.io.gfile.glob(sample_direc+os.sep+'*.npz'))
else:
    sample_filenames = sorted(tf.io.gfile.glob(sample_direc+os.sep+'*.jpg'))
    if len(sample_filenames)==0:
        sample_filenames = sorted(glob(sample_direc+os.sep+'*.png'))

print('Number of samples: %i' % (len(sample_filenames)))

Number of samples: 46


In [104]:

# For each set of weights in W load them in
M= []; C=[]; T = []
for counter,weights in enumerate(W):

    try:
        # "fullmodel" is for serving on zoo they are smaller and more portable between systems than traditional h5 files
        # gym makes a h5 file, then you use gym to make a "fullmodel" version then zoo can read "fullmodel" version
        configfile = weights.replace('_fullmodel.h5','.json').replace('weights', 'config')
        with open(configfile) as f:
            config = json.load(f)
    except:
        # Turn the .h5 file into a json so that the data can be loaded into dynamic variables        
        configfile = weights.replace('.h5','.json').replace('weights', 'config')
        with open(configfile) as f:
            config = json.load(f)
    # Dynamically creates all variables from config dict.
    # For example configs's {'TARGET_SIZE': [768, 768]} will be created as TARGET_SIZE=[768, 768]
    # This is how the program is able to use variables that have never been explicitly defined
    for k in config.keys():
        exec(k+'=config["'+k+'"]')


    if counter==0:
        #####################################
        #### hardware
        ####################################

        SET_GPU = str(SET_GPU)

        if SET_GPU != '-1':
            USE_GPU = True
            print('Using GPU')
        else:
            USE_GPU = False
            print('Using CPU')

        if len(SET_GPU.split(','))>1:
            USE_MULTI_GPU = True 
            print('Using multiple GPUs')
        else:
            USE_MULTI_GPU = False
            if USE_GPU:
                print('Using single GPU device')
            else:
                print('Using single CPU device')

        #suppress tensorflow warnings
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

        if USE_GPU == True:
            os.environ['CUDA_VISIBLE_DEVICES'] = SET_GPU

            from doodleverse_utils.prediction_imports import *
            from tensorflow.python.client import device_lib
            physical_devices = tf.config.experimental.list_physical_devices('GPU')
            print(physical_devices)

            if physical_devices:
                # Restrict TensorFlow to only use the first GPU
                try:
                    tf.config.experimental.set_visible_devices(physical_devices, 'GPU')
                except RuntimeError as e:
                    # Visible devices must be set at program startup
                    print(e)
        else:
            os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

            from doodleverse_utils.prediction_imports import *
            from tensorflow.python.client import device_lib
            physical_devices = tf.config.experimental.list_physical_devices('GPU')
            print(physical_devices)

        ### mixed precision
        from tensorflow.keras import mixed_precision
        mixed_precision.set_global_policy('mixed_float16')
        # tf.debugging.set_log_device_placement(True)

        for i in physical_devices:
            tf.config.experimental.set_memory_growth(i, True)
        print(tf.config.get_visible_devices())

        if USE_MULTI_GPU:
            # Create a MirroredStrategy.
            strategy = tf.distribute.MirroredStrategy([p.name.split('/physical_device:')[-1] for p in physical_devices], cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
            print("Number of distributed devices: {}".format(strategy.num_replicas_in_sync))


    #from imports import *
    from doodleverse_utils.imports import *
    from doodleverse_utils.model_imports import *

    #---------------------------------------------------

    #=======================================================
    # Import the architectures for following models from doodleverse_utils
    # 1. custom_resunet
    # 2. custom_unet
    # 3. simple_resunet
    # 4. simple_unet
    # 5. satunet
    # 6. custom_resunet
    # 7. custom_satunet

    # Get the selected model based on the weights file's MODEL key provided
    # create the model with the data loaded in from the weights file
    print('.....................................')
    print('Creating and compiling model {}...'.format(counter))

    if MODEL =='resunet':
        model =  custom_resunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                        FILTERS,
                        nclasses=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                        kernel_size=(KERNEL,KERNEL),
                        strides=STRIDE,
                        dropout=DROPOUT,
                        dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                        dropout_type=DROPOUT_TYPE,
                        use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                        )
    elif MODEL=='unet':
        model =  custom_unet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                        FILTERS,
                        nclasses=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                        kernel_size=(KERNEL,KERNEL),
                        strides=STRIDE,
                        dropout=DROPOUT,
                        dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                        dropout_type=DROPOUT_TYPE,
                        use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                        )

    elif MODEL =='simple_resunet':

        model = simple_resunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    elif MODEL=='simple_unet':
        model = simple_unet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    elif MODEL=='satunet':

        model = custom_satunet((TARGET_SIZE[0], TARGET_SIZE[1], N_DATA_BANDS),
                    kernel = (2, 2),
                    num_classes=NCLASSES, #[NCLASSES+1 if NCLASSES==1 else NCLASSES][0],
                    activation="relu",
                    use_batch_norm=True,
                    dropout=DROPOUT,
                    dropout_change_per_layer=DROPOUT_CHANGE_PER_LAYER,
                    dropout_type=DROPOUT_TYPE,
                    use_dropout_on_upsampling=USE_DROPOUT_ON_UPSAMPLING,
                    filters=FILTERS,
                    num_layers=4,
                    strides=(1,1))

    else:
        print("Model must be one of 'unet', 'resunet', or 'satunet'")
        sys.exit(2)

    try:
        # Load in the model from the weights which is the location of the weights file        
        model = tf.keras.models.load_model(weights)

        M.append(model)
        C.append(configfile)
        T.append(MODEL)
        
    except:
        # Load the metrics mean_iou, dice_coef from doodleverse_utils
        # Load in the custom loss function from doodleverse_utils        
        model.compile(optimizer = 'adam', loss = dice_coef_loss(NCLASSES))#, metrics = [iou_multi(NCLASSES), dice_multi(NCLASSES)])

        model.load_weights(weights)

        M.append(model)
        C.append(configfile)
        T.append(MODEL)

# metadatadict contains the model name (T) the config file(C) and the model weights(W)
metadatadict = {}
metadatadict['model_weights'] = W
metadatadict['config_files'] = C
metadatadict['model_types'] = T


Using GPU
Using single GPU device
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
.....................................
Creating and compiling model 0...
.....................................
Creating and compiling model 1...
.....................................
Creating and compiling model 2...
.....................................
Creating and compiling model 3...


### no TTA, no Otsu

In [105]:
TESTTIMEAUG = False
OTSU_THRESHOLD = False

for f in sample_filenames:
    try:
        do_seg_4class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta


out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta
out_nootsu_notta


### TTA, no Otsu

In [106]:
TESTTIMEAUG = True
OTSU_THRESHOLD = False

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta
out_nootsu_tta


### no TTA, Otsu

In [107]:
TESTTIMEAUG = False
OTSU_THRESHOLD = True

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta


out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta
out_otsu_notta


### TTA, Otsu

In [108]:
TESTTIMEAUG = True
OTSU_THRESHOLD = True

for f in sample_filenames:
    try:
        do_seg_2class(f, M, metadatadict, sample_direc,
               NCLASSES,N_DATA_BANDS,TARGET_SIZE,TESTTIMEAUG, WRITE_MODELMETADATA,OTSU_THRESHOLD)
    except:
        print("{} failed. Check config file, and check the path provided contains valid imagery".format(f))


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta


out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
out_otsu_tta
