In [1]:
import rasterio
import os
from glob import glob
import matplotlib.pyplot as plt
%matplotlib inline

from cmixuv.satreaders import l8image, s2image
from inference import Model, save_cloud_mask

In [2]:
# Tweaks. Recommmended Tensorflow >=2.1 with a CUDA GPU (if available)
CUDA_DEVICE = "GPU"
if CUDA_DEVICE == "CPU":
    # Disable GPU - Force CPU if your GPU can not handle a full image
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
else:
    # If a GPU is not installed, these lines take no effect
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    # Flag for not allocating all memory at once and run out of memory so early
    # Allocate memory on demand
    import tensorflow as tf
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    if len(physical_devices) > 0:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
# DIRS: databases and weights
# (Internal) check if mounts are in W10 or Linux
root_path = "//erc.uv.es/databases"
if not os.path.exists(root_path):
    root_path = "/media/disk/databases"

weights_path = os.path.join(root_path, "CMIX", "setupgonzalo", "experiments")
CLOUD_DETECTION_WEIGHTS = {
    "rgbiswir": os.path.join(weights_path, "landsatbiomeRGBISWIR7.hdf5"), # TODO create checkpoints folder with model weights
    "rgbi": os.path.join(weights_path, "landsatbiomeRGBI6.hdf5"),
}

# Options: 
# FCNN input bands 
namemodels = ["rgbi", "rgbiswir"]
# Satellite image
satnames = ["L8", "S2"]

In [4]:
satname = "S2"
namemodel = "rgbiswir"

In [5]:
# Read image
if satname == "L8":
    # L-8 image
    L8_path = os.path.join(root_path, "LANDSAT8_CLOUDS", "38_CLOUDS")
    L8_prods = [f for f in sorted(glob(os.path.join(L8_path, "*")))]

    landsatimage = "C:/Users/danlo/Documents/UV/prods/LC08_L1TP_002053_20160520_20170324_01_T1"#L8_prods[0]
    satobj = l8image.L8Image(landsatimage)

else:
    # S-2 image
    S2_path = os.path.join(root_path, "SENTINEL2_CLOUDS", "SENTINEL_2_BaetensHagolle")
    S2_prods = [f for f in sorted(glob(os.path.join(S2_path, "*.SAFE")))]

    sentinelimage = "C:/Users/danlo/Documents/UV/prods/S2A_MSIL1C_20160417T110652_N0201_R137_T29RPQ_20160417T111159.SAFE"#S2_prods[0]
    satobj = s2image.S2Image(sentinelimage)

In [6]:
# LOAD MODEL
model = Model(satname=satname, namemodel=namemodel)

In [None]:
# PREDICTION
# Compute cloud mask (CM) from image with selected model
cloud_prob_bin = model.predict(satobj)

In [None]:
# Store computed CM into .TIFF in product path
save_cloud_mask(satobj, cloud_prob_bin, os.path.join(satobj.folder, "cmixuvclouds_" + namemodel + ".tif"))

In [None]:
# VISUALIZATIONS
# Load RGB bands for visualisation
rgb = satobj.load_bands(bands=[2, 3, 4])

In [None]:
# Plot RGB
plt.imshow(rgb)

In [None]:
# Plot cloud mask
plt.imshow(cloud_prob_bin)