<a href="https://colab.research.google.com/github/JilinMen/WaterQualityQuickView_GEE_Colab/blob/main/WaterQualityMonitoring_GEE_GUI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Quick View of Water Quality with Google Earth Engine




Author: Jilin Men (jmen@ua.edu)
### Version-1, 2025-02-22:
Description:
*   Collection and image retrieval, Landsat-8/9, Sentinel-2A/B
*   Support atmospheric correction with ACOLITE
*   Clouds and land mask
*   Support water quality model
*   Preview RGB and water quality maps
*   2/26/2025 Search images with cloud cover less than 50%





      





# Import install initialize and clone

In [None]:
# @title Import packages
#import library
import ee
import geemap
import ipywidgets as widgets
from IPython.display import display
from ipyleaflet import WidgetControl, DrawControl, TileLayer
from geemap import Map
import os
import sys
import datetime

In [None]:
# @title Install other required packages
!pip install netCDF4

Collecting netCDF4
  Downloading netCDF4-1.7.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting cftime (from netCDF4)
  Downloading cftime-1.6.4.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Downloading netCDF4-1.7.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.3/9.3 MB[0m [31m57.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cftime-1.6.4.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m65.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: cftime, netCDF4
Successfully installed cftime-1.6.4.post1 netCDF4-1.7.2


In [None]:
# @title GEE initialize and authenticate
Project = "ee-menjilin" # @param {"type":"string","placeholder":"10"}
try:
    ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com',project=Project)
    print("GEE initialized successfully!")
except Exception as e:
    print("GEE not initialized. Authenticating...")
    ee.Authenticate()
    ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com',project=Project)
    print("GEE initialized successfully!")

GEE not initialized. Authenticating...
GEE initialized successfully!


In [None]:
# @title Clone acolite from github
if not os.path.exists('/content/acolite/acolite/gee'):
  if os.path.exists('/content/acolite'):
    import shutil
    shutil.rmtree('/content/acolite')
  !git clone https://github.com/JilinMen/acolite.git
  print("Acolite clone finished!")
else:
  print('acolite already exists')

sys.path.append('/content/acolite')
from acolite import gee

Cloning into 'acolite'...
remote: Enumerating objects: 12137, done.[K
remote: Counting objects: 100% (2110/2110), done.[K
remote: Compressing objects: 100% (234/234), done.[K
remote: Total 12137 (delta 2023), reused 1876 (delta 1876), pack-reused 10027 (from 3)[K
Receiving objects: 100% (12137/12137), 711.21 MiB | 27.88 MiB/s, done.
Resolving deltas: 100% (7825/7825), done.
Acolite clone finished!


In [None]:
# # @title OC_3S
# if not os.path.exists('/content/OC_3S'):
#   !git clone https://github.com/JilinMen/OC_3S.git
#   print("OC_3S clone finished!")
# else:
#   print('OC_3S already exists')

# sys.path.append('/content/OC_3S')
# from OC_3S.OC_3S import OC_3S_v1


In [None]:
# @title Mount google drive
# folder_name = "/content/drive" # @param {"type":"string","placeholder":"20"}
# from google.colab import drive
# drive.mount(folder_name)

In [None]:
# @title image search
## written by Quinten Vanhellemont, RBINS
def match_scenes(isodate_start, isodate_end=None, day_range=1,
                surface_reflectance=False,
                limit=None, st_lat=None, st_lon=None, filter_tiles=None,
                sensors=['L4_TM', 'L5_TM', 'L7_ETM', 'L8_OLI', 'L9_OLI', 'S2A_MSI', 'S2B_MSI']):
    import ee
    #ee.Authenticate() ## assume ee use is authenticated in current environment
    #ee.Initialize()

    import dateutil.parser, datetime

    if filter_tiles is not None:
        if type(filter_tiles) is not list:
            filter_tiles = [filter_tiles]

    ## check isodate
    if isodate_start is None:
        print('Please provide start date.')
        return()
    else:
        dstart = dateutil.parser.parse(isodate_start)
        isodate_start = dstart.isoformat()[0:10]

    ## get date range
    if isodate_start == isodate_end: isodate_end = None
    if isodate_end is None:
        dend = dstart + datetime.timedelta(days=0)
    else:
        if isodate_end in ['now', 'today']:
            dend = datetime.datetime.now()
        else:
            dend = dateutil.parser.parse(isodate_end)
    dend += datetime.timedelta(days=1) ## add one day so end date is included
    isodate_end = dend.isoformat()[0:10]

    print('Date range {} {}'.format(isodate_start, isodate_end))

    ## identify collections
    collections = []
    landsats = []
    ## MultiSpectral Scanners
    if 'L1_MSS' in sensors: landsats.append('LM01')
    if 'L2_MSS' in sensors: landsats.append('LM02')
    if 'L3_MSS' in sensors: landsats.append('LM03')
    if 'L4_MSS' in sensors: landsats.append('LM04')
    if 'L5_MSS' in sensors: landsats.append('LM05')

    ## newer sensors
    if 'L4_TM' in sensors: landsats.append('LT04')
    if 'L5_TM' in sensors: landsats.append('LT05')
    if 'L7_ETM' in sensors: landsats.append('LE07')
    if 'L8_OLI' in sensors: landsats.append('LC08')
    if 'L9_OLI' in sensors: landsats.append('LC09')
    landsat_tiers = ['T1', 'T2']
    landsat_collections = ['C02']

    for landsat in landsats:
        for tier in landsat_tiers:
            for coll in landsat_collections:
                if surface_reflectance:
                    if landsat[1] == 'M':
                        print('No SR for MSS.')
                    else:
                        collections.append('{}/{}/{}/{}_L2'.format('LANDSAT', landsat, coll, tier))
                else:
                    if landsat[1] == 'M':
                        collections.append('{}/{}/{}/{}'.format('LANDSAT', landsat, coll, tier))
                    else:
                        collections.append('{}/{}/{}/{}_TOA'.format('LANDSAT', landsat, coll, tier))

    if ('S2A_MSI' in sensors) or ('S2B_MSI' in sensors):
        ## harmonized has scenes from new processing shifted to old processing
        ## we take the offset into account in agh for >= PB4 data
        if surface_reflectance:
            #collections += ['COPERNICUS/S2_SR'] # COPERNICUS/S2_SR_HARMONIZED
            collections += ['COPERNICUS/S2_SR_HARMONIZED'] # COPERNICUS/S2_SR superseded by COPERNICUS/S2_SR_HARMONIZED in Jun 2024
        else:
            #collections.append('COPERNICUS/S2') # 'COPERNICUS/S2_HARMONIZED'
            collections.append('COPERNICUS/S2_HARMONIZED') # COPERNICUS/S2 superseded by COPERNICUS/S2_HARMONIZED in Jun 2024

    print('Checking collections {}'.format(' '.join(collections)))
    print(limit)
    ## set up region
    if limit is not None:
        region = ee.Geometry.BBox(limit[1], limit[0], limit[3], limit[2])
    elif (st_lon is not None) & (st_lat is not None):
        region = ee.Geometry.Point([st_lon, st_lat])
    else:
        print('Warning! No limit or st_lat, st_lon combination specified. Function may return too many images.')
        region = None
    ## set up ee date
    sdate=ee.Date(isodate_start)
    edate=ee.Date(isodate_end)
    ## search ee collections
    imColl = None

    for coll in collections:
        if 'LANDSAT' in coll:
            cloud_name = 'CLOUD_COVER'
        elif 'COPERNICUS' in coll:
            cloud_name = 'CLOUDY_PIXEL_PERCENTAGE'

        imC = ee.ImageCollection(coll).filterDate(sdate, edate).filter(ee.Filter.lt(cloud_name, 50))
        if region is not None: imC = imC.filterBounds(region)

        if imColl is None:
            imColl = imC
        else:
            imColl = imColl.merge(imC)
    iml = imColl.getInfo()
    nimages = len(iml['features'])
    images = []
    if nimages > 0:
        limages = imColl.toList(nimages).getInfo()
        for im in limages:
            if 'PRODUCT_ID' in im['properties']: ## Sentinel-2 image
                fkey = 'PRODUCT_ID'
                pid = im['properties'][fkey]
            elif 'LANDSAT_PRODUCT_ID' in im['properties']: ## Landsat image
                fkey = 'LANDSAT_PRODUCT_ID'
                pid = im['properties'][fkey]
            else: continue

            skip = False
            if filter_tiles is not None:
                skip = True
                for tile in filter_tiles:
                    if tile in pid: skip = False
            if skip: continue
            images.append((fkey,pid))
    return(images, imColl)

In [None]:
# @title Atmospheric correction-ACOLITE

# Atmospheric correction: update gee_settings.txt
def update_settings(limit, isodate_start, isodate_end, sensor, output, output_scale,target_scale,glint_correction,
                    store_rhot,store_rhos,store_geom,store_sr,store_st,store_sp,
                    store_output_google_drive,
                    store_output_locally,
                    output_format,
                    old_agh=False,tile_size=606606):

    params = {}
    params["limit="] = ','.join(map(str,limit))
    params["isodate_start="] = isodate_start
    params["isodate_end="] = isodate_end
    params["sensors="] = sensor
    params["output="] = output
    params["convert_output="] = False
    params["output_scale="] = output_scale
    params["target_scale="] = target_scale
    params["glint_correction="] = glint_correction
    params["surface_reflectance="] = False
    params["store_rhot="] = store_rhot
    params["store_rhos="] = store_rhos
    params["store_geom="] = store_geom
    params["store_sr="] = store_sr
    params["store_st="] = store_st
    params["store_sp="] = store_sp
    params["store_output_google_drive="] = store_output_google_drive
    params["store_output_locally="] = store_output_locally
    params["output_format="] = output_format
    params["st_crop="] = False
    # write these parameters to the acolite/gee_settings.txt
    gee_settings = os.path.join('/content/acolite',"config/gee_settings.txt")

    try:
        with open(gee_settings,'r') as file:
            lines = file.readlines()
        for i, line in enumerate(lines):
            for key,value in params.items():
                if line.startswith(key):
                    lines[i] = f"{key}{value}\n"
                    break
        with open(gee_settings, 'w') as file:
                file.writelines(lines)
        print('setting updated!')
    except Exception as e:
        import traceback
        tb = sys.exc_info()[2]
        tbinfo = traceback.format_tb(tb)[0]
        pymsg = ("PYTHON ERRORS:\nTraceback info:\n" + tbinfo +
                "\nError Info:\n" + str(sys.exc_info() [1]))
        print(pymsg)
        return

def ACOLITE_run(limit, isodate_start, isodate_end, sensor,
                output="/conetent/drive/MyDrive/ACOLITE/", output_scale=None,target_scale=None,glint_correction=False,
                store_rhot=False,store_rhos=True,store_geom=False,store_sr=False,store_st=False,store_sp=False,
                store_output_google_drive=False,
                store_output_locally=False,
                output_format=None
                ):
    update_settings(limit,
            isodate_start, isodate_end,
            sensor,
            output,
            output_scale, target_scale,
            glint_correction,
            store_rhot,store_rhos,store_geom,store_sr,store_st,store_sp,
            store_output_google_drive,
            store_output_locally,
            output_format
            )
    out_acolite = gee.agh_run(old_agh=False)
    return out_acolite

# Water quality

In [None]:
# @title preview_rgb_image
# RGB preview
def preview_rgb_image(collection,num_images = 10):
    if collection is None:
        print("No images found. Please search for images first.")
        return

    # Limit the collection to the first 'num_images' if necessary
    if collection.size().getInfo() > num_images:
        collection = collection.limit(num_images)

    # Get the list of images from the collection
    images = collection.toList(collection.size())

    # Get the collection size
    count = collection.size().getInfo()

    # # get the first image
    # first_image = ee.Image(collection.first())
    # image_date = ee.Date(first_image.get('system:time_start')).format('YYYY-MM-dd')

    # select RGB bands
    if atmospheric_correction.value == 'SR':
        if 'L8_OLI' in sensor.value[0] or 'L9_OLI' in sensor.value[0]:
            rgb_bands = ['SR_B4', 'SR_B3', 'SR_B2']
        elif 'S2A_MSI' in sensor.value[0] or 'S2B_MSI' in sensor.value[0]:
            rgb_bands = ['B4', 'B3', 'B2']
    else:
        rgb_bands = ['B4', 'B3', 'B2']

    # vislization parameters
    vis_params = {
        'bands': rgb_bands,
        'min': 0,
        'max': 0.3,  # reflectance range of  0-0.3
        'gamma': 1.4
    }

    for i in range(count):
        image = ee.Image(images.get(i))
        # image exists or not
        if image is None:
            print(f"Image at index {i} is null. Skipping.")
            continue
        image_date = ee.Date(image.get('system:time_start')).format('YYYY-MM-dd').getInfo()
        print(f"Processing image {i + 1}/{count}: {image_date}")
        # add to map
        m.addLayer(image, vis_params, f"RGB_{image_date}")



In [None]:
# @title show map
def show_map(collect,algorithm,label='Chl mg/L',vis_params=None,num_images = 10):
    '''
    collect: ee.ImageCollection
    algorithm: water quality function
    vis_params: visualization parameters (optional)
    '''
    if label == "WaterClass":
        test_lambda = np.array([412,443,490,510,555,667,680])
    else:
        # Apply the algorithm to the image collection
        algo_collection = collect.map(algorithm)

    # Limit the collection to the first 'num_images' if necessary
    # print(algo_collection.size().getInfo())
    if algo_collection.size().getInfo() > num_images:
        algo_collection = algo_collection.limit(num_images)

    # Get the list of images from the collection
    images = algo_collection.toList(algo_collection.size())


    # Get the collection size
    count = algo_collection.size().getInfo()

    # Set default visualization parameters if not provided
    if vis_params is None:
        vis_params = {
            "min": 0,
            "max": 30,
            "palette": ["blue", "cyan", "green", "yellow", "red"]
        }

    # Iterate through the images and add them to the map
    for i in range(count):
        image = ee.Image(images.get(i))
        # image exists or not
        if image is None:
            print(f"Image at index {i} is null. Skipping.")
            continue

        image_date = ee.Date(image.get('system:time_start')).getInfo()
        if image_date is None:
            print("system:time_start is None, get time_start")
            image_date = ee.Date(image.get('time_start')).format('YYYY-MM-dd').getInfo()
        else:
            image_date = ee.Date(image.get('system:time_start')).format('YYYY-MM-dd').getInfo()

        print(f"Processing image {i + 1}/{count}: {image_date}")

        # Add the image to the map
        try:
            print("Add water quality map to layer!")
            m.addLayer(image, vis_params, f"{label}_{image_date}")
        except Exception as e:
            print(f"Error adding image to the map: {e}")

    # Ensure colorbar is added only once per label
    if not hasattr(m, "added_labels") or not isinstance(m.added_labels, set):
        m.added_labels = set()

    if label not in m.added_labels:
        # Ensure 'colorbars' is a list to avoid AttributeError
        if hasattr(m, 'colorbars'):
            if isinstance(m.colorbars, set):
                m.colorbars = list(m.colorbars)
        else:
            m.colorbars = []

        m.add_colorbar(
            vis_params,
            label=label,
            orientation='horizontal',
            transparent_bg=True
        )
        m.added_labels.add(label)

    return algo_collection

In [None]:
# def load_reference_data():
#     """
#     Load reference water type data from an asset (assuming it's uploaded as an ee.FeatureCollection).
#     """
#     import numpy as np
#     import h5py
#     ref_table = h5py.File('/content/OC_3S/Water_classification_system_30c-int.h5','r')
#     up = ee.Array(ref_table['upB'].tolist()).divide(100000)      # Upper bounds for each water type
#     low = ee.Array(ref_table['lowB'].tolist()).divide(100000)      # Lower bounds for each water type
#     ref = ee.Array(ref_table['ref_cluster'].tolist()).divide(100000)   # Reference cluster spectra
#     waves = ee.Array(ref_table['waves'].tolist())       # Reference wavelengths
#     return waves, up, low, ref

# def spectral_angle_mapping(image, water_mask, waves, upB, lowB, ref):
#     """
#     Perform spectral angle mapping on an ImageCollection for water pixels.
#     """
#     def classify_pixel(pixel):
#         Rrs = ee.Array(pixel).toFloat()

#         # Normalize input spectra
#         norm_Rrs = Rrs.divide(Rrs.pow(2).reduce('sum').sqrt())

#         # Normalize reference spectra
#         norm_ref = ref.divide(ref.pow(2).reduce('sum', [1]).sqrt())

#         # Compute spectral angle (cosine similarity)
#         cos_sim = norm_ref.multiply(norm_Rrs).reduce('sum', [1])

#         # Find best match
#         max_cos = cos_sim.reduce('max')
#         cluster_id = cos_sim.argmax().add(1)  # Convert 0-based index to 1-based

#         # Compute classification confidence score
#         upB_corr = upB.arraySlice(0, cluster_id.subtract(1), cluster_id)
#         lowB_corr = lowB.arraySlice(0, cluster_id.subtract(1), cluster_id)

#         upB_diff = upB_corr.subtract(norm_Rrs).gte(0)
#         lowB_diff = norm_Rrs.subtract(lowB_corr).gte(0)
#         confidence = upB_diff.And(lowB_diff).reduce('mean', [0])

#         return ee.Array([cluster_id, max_cos, confidence])

#     classified = image.updateMask(water_mask).expression(
#         'classify_pixel(Rrs)',
#         {'Rrs': image.toArray(), 'classify_pixel': classify_pixel}
#     )

#     return classified.arrayProject([0]).arrayFlatten([['ClusterID', 'Confidence', 'TotScore']])

# def classify_water(collection, water_mask):
#     """
#     Classify water pixels in an ImageCollection using the OC_3S_v1 algorithm.
#     """
#     waves, upB, lowB, ref = load_reference_data()

#     classified_collection = collection.map(lambda img: spectral_angle_mapping(img, water_mask, waves, upB, lowB, ref))

#     return classified_collection

In [None]:
# @title bio-optical algorithms
def Chl_algorithm(image):
    '''
    John E. O'Reilly.RSE.Chlorophyll algorithms for ocean color sensors - OC4, OC5 & OC6. 2019
    '''
    print("Calculating Chlorophyll-a concentration...")
    try:
        if atmospheric_correction.value == 'SR':
            if 'S2A_MSI' in sensor.value[0] or 'S2B_MSI' in sensor.value[0]:
                blue1 = 'B1'
                blue2 = 'B2'
                green = 'B3'
            elif 'L8_OLI' in sensor.value[0] or 'L9_OLI' in sensor.value[0]:
                blue1 = 'SR_B1'
                blue2 = 'SR_B2'
                green = 'SR_B3'
            else:
                print("Unsupported sensor for chl calculation.")
                return None
        else:
            blue1 = 'B1'
            blue2 = 'B2'
            green = 'B3'

        B1 = image.select(blue1)
        B2 = image.select(blue2)
        G = image.select(green)
        X = (B1.max(B2)).divide(G).log10()

        # float to ee.Image.constant
        c0 = ee.Image.constant(0.30963)
        c1 = ee.Image.constant(-2.40052)
        c2 = ee.Image.constant(1.28932)
        c3 = ee.Image.constant(0.52802)
        c4 = ee.Image.constant(-1.33825)

        # model
        chl = ee.Image(10).pow(
              c0.add(X.multiply(c1))
              .add(X.pow(2).multiply(c2))
              .add(X.pow(3).multiply(c3))
              .add(X.pow(4).multiply(c4))
        )
        # Get the start_time and assign it to chl
        is_date_valid = image.propertyNames().contains('system:time_start')
        start_time = ee.Algorithms.If(is_date_valid, image.get("system:time_start"), image.get("time_start"))
        chl = chl.set("system:time_start", start_time)

        return chl.rename('Chl-a')
    except Exception as e:
        print(f"Error calculating Chl-a: {e}")
        return None
def TSS_algorithm(image):
    print("Calculating total suspended solid...")
    try:
        # band select
        if atmospheric_correction.value == 'SR':
            if 'S2A_MSI' in sensor.value[0] or 'S2B_MSI' in sensor.value[0]:
                green = 'B3'
                red = 'B4'
            elif 'L8_OLI' in sensor.value[0] or 'L9_OLI' in sensor.value[0]:
                green = 'SR_B3'
                red = 'SR_B4'
            else:
                print("Unsupported sensor for TSS calculation.")
                return None
        else:
            green = 'B3'
            red = 'B4'

        # bands
        G = image.select(green)
        R = image.select(red)

        # log transform
        log_G = G.log10()
        log_R = R.log10()

        # empeirical coefficients
        a = ee.Image.constant(1.5)
        b = ee.Image.constant(-1.2)
        c = ee.Image.constant(0.7)

        # TSS model
        TSS = ee.Image(10).pow(
            a.multiply(log_G)  # a * log10(G)
            .add(b.multiply(log_R))  # + b * log10(R)
            .add(c)  # + c
        )

        is_date_valid = image.propertyNames().contains('system:time_start')
        start_time = ee.Algorithms.If(is_date_valid, image.get("system:time_start"), image.get("time_start"))
        TSS = TSS.set("system:time_start", start_time)

        return TSS.rename('TSS')
    except Exception as e:
        print(f"Error calculating TSS: {e}")
        return None
def CDOM_algorithm(image):
    print("Calculating colored dissolved organic matter (CDOM)...")
    try:
        if atmospheric_correction.value == 'SR':
            if 'S2A_MSI' in sensor.value[0] or 'S2B_MSI' in sensor.value[0]:
                blue = 'B2'
                green = 'B3'
            elif 'L8_OLI' in sensor.value[0] or 'L9_OLI' in sensor.value[0]:
                blue = 'SR_B2'
                green = 'SR_B3'
            else:
                print("Unsupported sensor for CDOM calculation.")
                return None
        else:
            blue = 'B2'
            green = 'B3'

        B = image.select(blue)
        G = image.select(green)

        log_B = B.log10()
        log_G = G.log10()

        a = ee.Image.constant(1.2)
        b = ee.Image.constant(-0.8)
        c = ee.Image.constant(0.5)

        # CDOM model
        CDOM = ee.Image(10).pow(
            a.multiply(log_B)  # a * log10(B)
            .add(b.multiply(log_G))  # + b * log10(G)
            .add(c)  # + c
        )

        is_date_valid = image.propertyNames().contains('system:time_start')
        start_time = ee.Algorithms.If(is_date_valid, image.get("system:time_start"), image.get("time_start"))
        CDOM = CDOM.set("system:time_start", start_time)

        return CDOM.rename('CDOM')
    except Exception as e:
        print(f"Error calculating CDOM: {e}")
        return None

def OC_3S_algorithm():
    print("Calculating water classes with OC-3S...")
    return None

In [None]:
# @title Machine learning algorithms
def load_model(model_path):
    import torch
    model = torch.load_model(model_path)
    return model

def predict_NN(model,image):
    inputs = image.select("SR_1","SR_2").getInfo()



In [None]:
# @title mask with product flags
def extract_water_landsat(image):
    """
    extract water bodies using Landsat imagery
    """
    # # image ID
    # system_id = ee.String(image.get('system:id'))

    # # check Landsat 8/9
    # is_landsat89 = system_id.match('LANDSAT_8|LANDSAT_9').length().gt(0)

    # QA band
    qa_band = image.select('QA_PIXEL').toInt()

    water_bit = ee.Number(7) #7 is water in QA_PIXEL

    water_mask = qa_band.bitwiseAnd(ee.Number(1).leftShift(water_bit)).neq(0)
    # print(qa_band.getInfo())
    # print("water_mask: ",water_mask.propertyNames().getInfo())
    return water_mask

def extract_water_sentinel(image):
    """
    extract water bodies using Sentinel-2 imagery
    """
    # SCL
    scl = image.select('SCL')

    # extract water areas
    water_mask = scl.eq(6)
    # print("water_mask: ",water_mask.propertyNames().getInfo())
    return water_mask

def apply_cloud_mask_sentinel(image):
    """
    mask clouds and shadows with Sentinel-2
    """
    # SCL
    scl = image.select('SCL')

    # 3: cloud shadow, 8: cloud medium probability, 9: cloud high probability
    # If any of these conditions are true (cloud or shadow present), mask should be invalid
    invalid_mask = scl.eq(3).Or(scl.eq(8)).Or(scl.eq(9))
    # Invert to get clear pixels (1 for clear, 0 for cloudy/shadow)
    clear_mask = invalid_mask.Not()

    # print("water_mask: ",clear_mask.propertyNames().getInfo())
    return clear_mask

def apply_cloud_mask_landsat(image):
    """
    mask clouds and shadows with Landsat
    """
    # # Image ID
    # system_id = ee.String(image.get('system:id'))

    # # check Landsat 8/9
    # is_landsat89 = system_id.match('LANDSAT_8|LANDSAT_9').length().gt(0)

    # QA
    qa_band = image.select('QA_PIXEL').toInt()
    # print("cloud",qa_band.getInfo())
    # cloud and shadow bit
    cloud_bit = ee.Number(3)
    shadow_bit = ee.Number(4)

    # mask clouds
    cloud_mask = qa_band.bitwiseAnd(ee.Number(1).leftShift(cloud_bit)).eq(0)
    # mask shadows
    shadow_mask = qa_band.bitwiseAnd(ee.Number(1).leftShift(shadow_bit)).eq(0)

    # combine clouds and shadows
    mask = cloud_mask.Or(shadow_mask)
    # print("water_mask: ",mask.propertyNames().getInfo())
    return mask

def mask_water(image):
    """
    extract waters
    """
    if not image:
        raise ValueError("Input image is required")
    # try system:id then custom_id
    is_valid = image.propertyNames().contains('system:id')
    system_id = ee.String(ee.Algorithms.If(
        is_valid,
        image.get('system:id'),
        image.get('custom_id')
    ))

    # Landsat or Sentinel
    is_landsat = system_id.match('LANDSAT').length().gt(0)
    is_sentinel = system_id.match('COPERNICUS').length().gt(0)
    # print("is_landsat: ",is_landsat.getInfo())
    # print("is_sentinel: ",is_sentinel.getInfo())

    # water areas
    water_mask = ee.Algorithms.If(
        is_landsat,
        extract_water_landsat(image),
        ee.Algorithms.If(
            is_sentinel,
            extract_water_sentinel(image),
            image.updateMask(ee.Image.constant(0))
        )
    )

    # mask cloud and land
    cloud_mask = ee.Algorithms.If(
        is_landsat,
        apply_cloud_mask_landsat(image),
        ee.Algorithms.If(
            is_sentinel,
            apply_cloud_mask_sentinel(image),
            image.updateMask(ee.Image.constant(0))
        )
    )

    # combine clouds and land
    final_mask = ee.Image(water_mask).And(ee.Image(cloud_mask))
    # print("final_mask: ",final_mask.propertyNames().getInfo())
    # apply mask
    masked_image = image.updateMask(final_mask)
    # print("masked_image: ", masked_image.bandNames().getInfo())
    return masked_image

In [None]:
# @title NDWI mask
def mndwi_mask(image):
    """
    计算 MNDWI（Modified Normalized Difference Water Index）
    MNDWI = (Green - SWIR) / (Green + SWIR)
    """
    water_mask = image.normalizedDifference(['B3', 'B6']).gt(0)  # Landsat 8: B3=Green, B6=SWIR1
    return image.updateMask(water_mask)

In [None]:
# @title To reflectance
def scale_reflectance_landsat(image):
    """
    Notes:
        - Landsat 8/9 scale:0.0000275 offset:-0.2
        - Sentinel-2 scale:1/10000
    """
    landsat_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
    scaled_image = (image
        .select(landsat_bands)
        .multiply(0.0000275)
        .add(-0.2)
        .copyProperties(image, image.propertyNames()))
    return image.addBands(scaled_image, landsat_bands, True)
def scale_reflectance_sentinel(image):
    # bands define
    sentinel_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8',
                     'B8A', 'B9', 'B11', 'B12']

    # transfer and copy properties
    scaled_image = (image
        .select(sentinel_bands)
        .multiply(0.0001)
        .copyProperties(image, image.propertyNames()))

    # add all other bands
    return image.addBands(scaled_image, sentinel_bands, True)

In [None]:
# @title merge by day
def merge_by_day(collection):
    """
    merge satellite images by day.

    collection should include properties at least:
        - system:time_start
        - system:id

    Returns:
        ee.ImageCollection: Mosaiced images
    """
    # obtain valid time_start
    is_date_valid = collection.first().propertyNames().contains('system:time_start')
    prop_date = ee.String(ee.Algorithms.If(
        is_date_valid,
        'system:time_start',
        'time_start'
    ))
    # get unique dates
    dates = collection.aggregate_array(prop_date) \
        .removeAll([None]) \
        .map(lambda time: ee.Date(time).format('YYYY-MM-dd')) \
        .distinct()

    # merge for the same day
    def fuse_images_by_date(date):
        date_obj = ee.Date(date)
        start_date = date_obj.millis()
        end_date = date_obj.advance(1, 'day').millis()

        # get images accroding to date
        # daily_images = collection.filterDate(start_date, end_date) #use default date system:time_start
        daily_images = collection.filter(ee.Filter.gte(prop_date, start_date)) \
                    .filter(ee.Filter.lt(prop_date, end_date))
        image_count = daily_images.size()

        # obtain valid id
        is_id_valid = daily_images.first().propertyNames().contains('system:id')
        prop_id = ee.String(ee.Algorithms.If(
            is_id_valid,
            'system:id',
            'custom_id'
        ))

        # get image ID
        image_ids = daily_images.aggregate_array(prop_id)

        # get bandNames
        band_names = ee.List(ee.Algorithms.If(
            image_count.gt(0),
            ee.Image(daily_images.first()).bandNames(),
            ee.List(["default_band"])  # avoid `None` error
        ))

        # image_count = 0
        no_images = ee.Image.constant(0) \
            .rename(band_names) \
            .set('system:time_start', date_obj.millis()) \
            .set('date', date) \
            .set('image_count', 0)\
            .set('custom_id',ee.List([]))

        # image_count = 1
        single_image = ee.Image(daily_images.first()) \
            .set('system:time_start', date_obj.millis()) \
            .set('date', date) \
            .set('image_count', 1)\
            .set('custom_id',image_ids.get(0))

        # image_count = 0 > 1
        fused_image = daily_images.reduce(ee.Reducer.mean()) \
            .rename(band_names) \
            .set('system:time_start', date_obj.millis()) \
            .set('date', date) \
            .set('image_count', image_count)\
            .set('custom_id',image_ids.get(0))  #use the first ID for fused image

        return ee.Algorithms.If(
            image_count.eq(0), no_images,
            ee.Algorithms.If(image_count.eq(1), single_image, fused_image)
        )

    # map
    fused_collection = ee.ImageCollection.fromImages(dates.map(fuse_images_by_date))

    # exclude image with image_count of 0
    return fused_collection.filter(ee.Filter.gt('image_count', 0))


# 7 GUI

In [None]:
# @title GUI parameter settings
################################################################################
#              Image search
################################################################################
# create a map
m = Map(center=(35, -95), zoom=4, layout=widgets.Layout(height='1100px', width='100%'), draw_control=False)

# draw control
draw_control = m.draw_control

# lat and lon
min_lon = widgets.FloatText(description='Min Lon:', layout=widgets.Layout(width='100%'))
max_lon = widgets.FloatText(description='Max Lon:', layout=widgets.Layout(width='100%'))
min_lat = widgets.FloatText(description='Min Lat:', layout=widgets.Layout(width='100%'))
max_lat = widgets.FloatText(description='Max Lat:', layout=widgets.Layout(width='100%'))

# monitoring
def handle_draw(target, action, geo_json):
    if action == 'created' and geo_json['geometry']['type'] == 'Polygon':
        coords = geo_json['geometry']['coordinates'][0]
        lons, lats = zip(*coords)
        min_lon.value = min(lons)
        max_lon.value = max(lons)
        min_lat.value = min(lats)
        max_lat.value = max(lats)

m.draw_control.on_draw(handle_draw)

start_date = widgets.DatePicker(
    description='Start Date:', value=datetime.date.today() - datetime.timedelta(days=30)
)
end_date = widgets.DatePicker(
    description='End Date:', value=datetime.date.today()
)

sensor = widgets.SelectMultiple(                   #SelectMultiple
    options=[('Landsat-8', 'L8_OLI'),
             ('Landsat-9', 'L9_OLI'),
             ('Sentinel-2A', 'S2A_MSI'),
             ('Sentinel-2B', 'S2B_MSI'),
             ('Landsat-7', 'L7_ETM'),
             ('Landsat-5', 'L5_TM'),
             ('Landsat-4', 'L4_TM')
             ],
    value=['L8_OLI'],
    description='Sensor:',
    rows=min(7, 4)
)

# Atmospheric correction methods
atmospheric_correction = widgets.Dropdown(
                        description='Product:',
                        options=['SR', 'ACOLITE'],
                        value='SR'
                        )

# bio-optical parameters: Chl-a, TSS, CDOM
bios = widgets.SelectMultiple(
    options=[('Chl-a', 'Chl-a'),
        ('TSS', 'TSS'),
        ('CDOM', 'CDOM')
        ],
    value=['Chl-a'],
    description='Bio-optical:',
    rows=min(3, 4)
)

# search
button_process = widgets.Button(
    description="Process",
    button_style='primary',
    tooltip='Click to start processing'
)

# title
title_html = widgets.HTML(
    value='''
    <div style="
        color: white;
        background: #2196F3;
        padding: 10px;
        font-family: Arial;
        font-size: 18px;
        border-radius: 5px;
        margin: 10px 0;
        text-align: left;
    ">
        🌍 Quick View of Water Quality
    </div>
    '''
)

# create Clear button to clear all layers
clear_button = widgets.Button(description="Clear",
              button_style='warning',
              tooltip='Click to clear current layers')

# action for button
def clear_non_basemap_layers(b):
    """clear all layers"""
    layers = list(m.layers)  # read all layers
    for layer in layers:
        if layer.name != 'OpenStreetMap.Mapnik':  # only keep the open street map
            m.remove_layer(layer)
    with status_output:
      print("Clear all layers!")

# click button
clear_button.on_click(clear_non_basemap_layers)

# parameters for ACOLITE
# other_widgets_container = widgets.VBox([
#         widgets.Label("ACOLITE Settings"),
#         # output_scale,
#         # target_scale,
#         # glint_correction,
#         # widgets.HBox([store_rhot,store_rhos,store_geom]),
#         # widgets.HBox([store_sr,store_st,store_sp]),
#         # widgets.HBox([store_output_google_drive,store_output_locally]),
#         # output_path
#     ])
# down right：output
status_output = widgets.Output()
output_container = widgets.VBox([status_output], layout=widgets.Layout(height='400px', overflow_y='auto'))

param_widgets = widgets.VBox([
        title_html,
        widgets.Label("Date Selection"),
        start_date, end_date,
        widgets.Label("Geographic Boundaries"),
        min_lon, max_lon,
        min_lat, max_lat,
        widgets.Label("Sensor Selection"),
        sensor,
        widgets.Label("Atmospheric Correction"),
        atmospheric_correction,
        widgets.Label("Water Quality parameter"),
        bios,
        widgets.HBox([button_process,clear_button]),
        widgets.Label("Output information:"),
        output_container
    ], layout=widgets.Layout(width='20%', padding='10px'))

In [None]:
# @title show water quality as layers
def show_wq(collection):
    """
    show water quality
    """
    if 'Chl-a' in bios.value:
        vis_params = {"min": 0,"max": 30,"palette": ["blue", "cyan", "green", "yellow", "red"]}
        label = "Chl-a"
        show_map(collection,Chl_algorithm,label,vis_params)
    if 'TSS' in bios.value:
        vis_params = {"min": 0,"max": 10,"palette": ["blue", "cyan", "green", "yellow", "red"]}
        label = "TSS"
        show_map(collection,TSS_algorithm,label,vis_params)
    if 'CDOM' in bios.value:
        vis_params = {"min": 0,"max": 2,"palette": ["blue", "cyan", "green", "yellow", "red"]}
        label = "CDOM"
        show_map(collection,CDOM_algorithm,label,vis_params)
    # if 'WaterClass' in bios.value:
    #     vis_params = {"min": 0,"max": 2,"palette": ["blue", "cyan", "green", "yellow", "red"]}
    #     label = "WaterClass"
    #     show_map(collection,OC_3S_algorithm,label,vis_params,water_index=water_index)

In [None]:
# @title AC, RGB, WQ, map
def button_ac(b):
    global collection
    # global water_extracted_collection
    global collection_day
    if atmospheric_correction.value == 'SR':
        with status_output:
            print('Retrieve water quality maps with Surface Reflectance!')

            images, imColl = match_scenes(
                start_date.value.isoformat(), end_date.value.isoformat(), day_range=1,
                surface_reflectance=True,
                limit=[min_lat.value, min_lon.value, max_lat.value, max_lon.value],
                st_lat=None, st_lon=None, filter_tiles=None,
                sensors=", ".join(sensor.value)
            )

            print("Total images:", len(images))
            print("Image list: ",imColl.aggregate_array('system:index').getInfo())
            print("Cloud cover: ",imColl.aggregate_array('CLOUD_COVER').getInfo())
            collection = imColl

            # transfer to surface reflectance
            if sensor.value[0] in ['S2A_MSI', 'S2B_MSI']:
                print('Input S2')
                collection_scaled = collection.map(scale_reflectance_sentinel)
            elif sensor.value[0] in ['L4_TM', 'L5_TM', 'L7_ETM', 'L8_OLI', 'L9_OLI']:
                print('Input Landsat')
                collection_scaled = collection.map(scale_reflectance_landsat)
            else:
                print("Unsupported sensor for reflectance conversion.",sensor.value)
                collection_scaled = collection

        with status_output:
            # mosaic images on the same day
            # print("Band names before mosaic: ",collection_scaled.first().bandNames().getInfo())
            collection_day = merge_by_day(collection_scaled)
            # print("Band names after mosaic: ",collection_day.first().bandNames().getInfo())
            print("Total images after mosaic:", collection_day.size().getInfo())
            # print(collection_day.first().bandNames().getInfo())
            # mask clouds and land
            water_extracted_collection = collection_day.map(mask_water)
            print("Property names: ",water_extracted_collection.first().propertyNames().getInfo())
            print("Mosaic image list: ",water_extracted_collection.aggregate_array('custom_id').getInfo())
            # print("water_extracted_collection size: ",water_extracted_collection.size().getInfo())

            print("Band names after masking: ",water_extracted_collection.first().bandNames().getInfo())
            # print("First image of collection_scaled:", collection_day.first().getInfo())
            # print("First image of water_extracted_collection:", water_extracted_collection.first().getInfo())
            # print("images after mask:", water_extracted_collection.propertyNames().getInfo())

            # RGB preview
            print('start to map RGB image!')
            preview_rgb_image(collection_day)
            print('start to map water quality parameters!')
            show_wq(water_extracted_collection)
            print("Processing complete!")
    elif atmospheric_correction.value == 'ACOLITE':
        # parameters for ACOLITE
        # output_scale = widgets.Text(description='Output Scale:', value='30', placeholder='e.g., 10, 20, 30')
        # target_scale = widgets.Text(description='Target Scale:', value='30', placeholder='e.g., 10, 20, 30')
        # store_rhot = widgets.Checkbox(description='rhot', value=False)
        # store_rhos = widgets.Checkbox(description='rhos', value=True)
        # store_geom = widgets.Checkbox(description='geom', value=False)
        # store_sr = widgets.Checkbox(description='sr', value=False)
        # store_st = widgets.Checkbox(description='st', value=False)
        # store_sp = widgets.Checkbox(description='sp', value=False)
        # store_output_google_drive = widgets.Checkbox(description='Store Output in Google Drive', value=False)
        # store_output_locally  = widgets.Checkbox(description='Store Output Locally', value=False)
        # output_path = widgets.Text(description='Output Path:', value='/content/drive/MyDrive/ACOLITE')
        # output_format = widgets.Text(description='Output Format:', value='netCDF')
        # glint_correction = widgets.Dropdown(
        #                                         description='Glint:', options=['True', 'False'], value='True'
        #                                     )
        with status_output:
            print("Applying ACOLITE Atmospheric Correction...")
            collection = ACOLITE_run(
                        [min_lat.value, min_lon.value, max_lat.value, max_lon.value],
                        start_date.value.isoformat(), end_date.value.isoformat(),
                        ", ".join(sensor.value)
                        # output_path.value,
                        # output_scale.value, target_scale.value,
                        # glint_correction.value,
                        # atmospheric_correction.value,
                        # store_rhot.value,store_rhos.value,store_geom.value,store_sr.value,store_st.value,store_sp.value,
                        # store_output_google_drive.value,
                        # store_output_locally.value,
                        # output_format.value
                        )

        with status_output:
            print("Atmospheric correction complete!")
            print('out_acolite type: ',type(collection))
            print('Bands: ',collection.first().bandNames().getInfo())
            print('Number of images: ',collection.size().getInfo())
            print('Properties: ',collection.first().propertyNames().getInfo())
            print('time_start: ',collection.first().get("time_start").getInfo())
            print('id: ',collection.first().get("custom_id").getInfo())
            collection_day = merge_by_day(collection)
            print("Size:",collection_day.size().getInfo())
            print('Bands: ',collection_day.first().bandNames().getInfo())
            print('time_start: ',collection_day.first().get("time_start").getInfo())
            print('id: ',collection_day.first().get("custom_id").getInfo())
            # mask clouds and land
            water_extracted_collection = collection_day.map(mndwi_mask)
            print("Band names after masking: ",water_extracted_collection.first().bandNames().getInfo())

            # RGB preview
            print('start to map RGB image!')
            preview_rgb_image(collection_day)
            print('start to map water quality parameters!')
            show_wq(water_extracted_collection)
            print("Processing complete!")
    else:
        print("Unsupported atmospheric correction method.")
button_process.on_click(button_ac)

In [None]:
# @title Layout
################################################################################
#            Layout
################################################################################
# up right：map
top_right_panel = widgets.VBox([m])

right_panel = widgets.VBox([top_right_panel], layout=widgets.Layout(width='80%'))

# layout display
gui = widgets.HBox([param_widgets, right_panel])
display(gui)

HBox(children=(VBox(children=(HTML(value='\n    <div style="\n        color: white;\n        background: #2196…