In [None]:
import os

os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'

import numpy as np
import rasterio

import tensorflow as tf

%env SM_FRAMEWORK=tf.keras

import segmentation_models as sm

sm.set_framework('tf.keras')
sm.framework()

import rasterio
from patchify import patchify, unpatchify

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

cm = ListedColormap(["darkturquoise", "green", "brown", "gold"])

labels_text = [
    "Other green plants",
    r"$\it{Posidonia \ oceanica}$",
    "Brown algae & rocks",
    "Sandy bottoms",
]

def predict_image(NAME, nn, arch="Linknet", im_width=256, verbose=True, save=True):
        
        ### Mean and variance of bands for normalization ###
        Mean = [373.74282191, 345.29123999, 356.49541412, 312.82175701, 268.56866458, 230.42540939, 233.93813734, 0,
        -12.42903576]
        Var = [2.97375298e+04, 4.01506329e+04, 5.67049771e+04, 6.36354140e+04, 3.65447284e+04, 2.98572781e+04,
        2.54050505e+04, 1, 4.35067553e+01] 
        
        ### Load and process satellite image ###
        if verbose:
                print("Loading %s data" % NAME)
        
        satellite_image_filename ="Data/%s.tif" % NAME

        data = rasterio.open(satellite_image_filename)

        bands = data.read()

        N_bands = bands.shape[0]

        #Change shape of bands so that band number is the last dimension
        if verbose:
                print("Processing data...")
        bands = np.transpose(bands, (1, 2, 0))

        # Scale bands
        bands_scaled = (bands - Mean) / np.sqrt(Var)

        bands_scaled[np.all(bands[:,:,0:8] == 0, axis=2)] = 0

        bands_scaled = bands_scaled[:, :, [0,1,2,3,4,5,6,8]].astype(np.float16)

        # Crop image to make equally step size patches
        bands = bands[:bands.shape[0] - (bands.shape[0] % im_width),
                                :bands.shape[1] - (bands.shape[1] % im_width), :]
        
        bands_scaled = bands_scaled[:bands_scaled.shape[0] - (bands_scaled.shape[0] % im_width),
                                :bands_scaled.shape[1] - (bands_scaled.shape[1] % im_width), :]

        images = patchify(bands_scaled, (im_width, im_width, N_bands-1), step=im_width)

        total_patches = images.shape[0] * images.shape[1]

        batch_size = total_patches // 2

        reshaped_images = np.reshape(images, (images.shape[0]*images.shape[1], im_width, im_width, N_bands-1))

        ### Load model ###
        if verbose:
                print("Loading model %s..." % nn)
        model = tf.keras.models.load_model(
                "Models/%s_%s_%s_.h5" % (arch, nn, im_width), 
                
                custom_objects={"dice_loss":sm.losses.DiceLoss(),
                "f1-score": sm.metrics.FScore(),
                "iou_score":sm.metrics.IOUScore()}
                )

        # make prediction
        if verbose:
                print("Predicting...")
        predictions = np.reshape(np.argmax(model.predict(reshaped_images, batch_size=batch_size), axis=-1), (images.shape[0], images.shape[1], im_width, im_width))

        # unpatchify
        unpatched_predictions = unpatchify(predictions, (bands_scaled.shape[0], bands_scaled.shape[1])).astype(float)

        # Mask predictions as satellite image
        unpatched_predictions[np.any(bands[:,:,0:8] == 0, axis=2)] = 1
        unpatched_predictions[np.any(np.isnan(bands[:,:,0:8]), axis=2)] = 1
        
        unpatched_predictions[unpatched_predictions == 1] = np.nan
        unpatched_predictions[unpatched_predictions == 2] = 1
        unpatched_predictions[unpatched_predictions == 3] = 2
        unpatched_predictions[unpatched_predictions == 4] = 3

        # Save predictions as tif (set nodata value to NaN)
        if save:
                if not os.path.exists("Predictions"):
                        os.mkdir("Predictions")
                
                if verbose:
                        print("Saving predictions...")
                unpatched_predictions = np.reshape(unpatched_predictions, (1, unpatched_predictions.shape[0], unpatched_predictions.shape[1]))

                with rasterio.open(satellite_image_filename) as src:
                
                        profile = src.profile
                        
                        profile["width"] = unpatched_predictions.shape[2]
                        profile["height"] = unpatched_predictions.shape[1]
                        
                        profile.update(
                                dtype=rasterio.float32,
                                count=1,
                                compress='lzw',
                                nodata=np.nan
                        )
                        
                        with rasterio.open("Predictions/%s_predictions_%s_%s.tif" % (NAME, arch, nn), 'w', **profile) as dst:
                                dst.write(unpatched_predictions.astype(rasterio.float32))
        if verbose:
                print("Done")
                
        return unpatched_predictions

def voting_prediction(NAME, verbose=True, save=True):
        
        if not os.path.exists("Predictions"):
                raise Exception("To use the voting prediction you must first generate and save the predictions using all the models")
        
        # Load predictions filenames
        predictions_filenames = [f for f in os.listdir("Predictions") if f.startswith(NAME + "_predictions")]
        
        if len(predictions_filenames) < 10:
                raise Exception("Only %d predictions found, 10 are needed" % len(predictions_filenames))
        
        # Load predictions
        if verbose:
                print("Loading predictions...")
                
        predictions = []
        
        for i, filename in enumerate(predictions_filenames):
                with rasterio.open("Predictions/" + filename) as src:
                        predictions.append(src.read(1))
                        
        predictions = np.array(predictions)
        
        predictions[np.isnan(predictions)] = 9
        
        # Voting
        if verbose:
                print("Computing voting prediction...")
        
        voting_predictions = np.zeros(predictions[0].shape)
        
        for i in range(predictions.shape[1]):
                for j in range(predictions.shape[2]):
                        voting_predictions[i,j] = np.bincount(predictions[:,i,j].astype(int)).argmax()
                        
        #voting_predictions = np.apply_along_axis(lambda x: np.argmax(np.bincount(x)), axis=0, arr=predictions.astype(int)).astype(float) # surprisingly this is slower
                        
        voting_predictions[voting_predictions == 9] = np.nan
                        
        # Save voting predictions
        if save:
                if verbose:
                        print("Saving voting predictions...")
                        
                with rasterio.open("Predictions/" + predictions_filenames[0]) as src:
                        with rasterio.open("Predictions/%s_voting_predictions.tif" % NAME, 'w', **src.profile) as dst:
                                dst.write(voting_predictions.astype(rasterio.float32), 1)
                                
        return voting_predictions

# Benthic habitat prediction (single model)

Run the next cell to predict the underlying benthic habitat distribution of the selected image with the selected model.

You can then visualise the predictions running the cells at the bottom

In [None]:
NAME = "Pollença_21_july_2022"  # Pollença_21_july_2022, Formentera_24_july_2022
nn = "seresnext101"  # densenet201, efficientnetb7, inceptionresnetv2, inceptionv3, mobilenetv2, resnet34, resnet152, resnext101, seresnet152, seresnext101

prediction = predict_image(NAME, nn, save=False)

To generate and save the predictions with all models (necessary to use the voting method) you can run the next cell

In [None]:
for NAME in ["Pollença_21_july_2022", "Formentera_24_july_2022"]:
    for nn in ["densenet201", "efficientnetb7", "inceptionresnetv2", "inceptionv3",
               "mobilenetv2", "resnet34", "resnet152", "resnext101", "seresnet152", 
               "seresnext101"]:
        
        prediction = predict_image(NAME, nn, save=True)

# Benthic habitat prediction (voting)

To apply the voting method to predict the benthic habitats for a given image you first need to generate the predictions with all models with `save=True`

In [None]:
NAME = "Pollença_21_july_2022"  # Pollença_21_july_2022, Formentera_24_july_2022

prediction = voting_prediction(NAME)

# Plot predictions

In [None]:
np.unique(labels)

In [None]:
# Load labels for comparison
labels_filename = "Data/%s_labels.tif" % NAME

labels = rasterio.open(labels_filename).read(1)

# Process labels 
# (0: other green plants, 1: no data 2: Posidonia oceanica, 3: brown algae & rocks, 4: sandy bottoms, 5: Other (i.e. no data))
# We will use only 4 classes: 0: other green plants, 1: Posidonia oceanica, 2: brown algae & rocks, 3: sandy bottoms
labels[labels == 1] = np.nan
labels[labels == 2] = 1
labels[labels == 3] = 2
labels[labels == 4] = 3
labels[labels == 5] = np.nan

fig, ax = plt.subplot_mosaic("""AB""", figsize=(20, 20))

ax["A"].imshow(labels, cmap=cm, interpolation="none")
ax["B"].imshow(prediction, cmap=cm, interpolation="none")

ax["A"].axis("off")
ax["B"].axis("off")

ax["A"].set_title("Ground truth", fontsize=30)
ax["B"].set_title("Predictions", fontsize=30)

# Add custom legend with colors and labels at top of the figure with handles and labels
plt.legend(
    handles=[
        plt.Rectangle((0, 0), 1, 1, color="darkturquoise"),
        plt.Rectangle((0, 0), 1, 1, color="green"),
        plt.Rectangle((0, 0), 1, 1, color="brown"),
        plt.Rectangle((0, 0), 1, 1, color="gold"),
    ],
    labels=labels_text,
    loc="upper center",
    bbox_to_anchor=(-0.1, 1.2),
    ncol=4,
    fontsize=20,
)