<a href="https://colab.research.google.com/github/SERVIR/flood_mapping_intercomparison/blob/main/hydrafloods/training_materials/oct_2021_hf_training/notebooks/supplementary/cloud_shadow_masking_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# mount the google drive so that we can save credentials
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# install the packages needed
!pip install hydrafloods geemap -q

In [None]:
import ee
import datetime
import hydrafloods as hf
import geemap.eefolium as geemap
import geemap.colormaps as cm

In [None]:
# initiate authentication workflow
# it will ask to authenticate if no credentials are available
# will also initialize ee session
_ = geemap.Map()

In [None]:
region = hf.country_bbox("Guatemala")
start_time = "2019-01-01"
end_time = "2020-01-01"

# get a Landsat 8 collection
lc8 = hf.Landsat8(region,start_time,end_time)

In [None]:
mndwi = lc8.apply_func(hf.mndwi)

In [None]:
optical_vis = {
    "min":50,
    "max":5500,
    "bands":"swir2,nir,green",
    "gamma":1.5,
}

In [None]:
Map = geemap.Map(center=(15.6336, -90.1208), zoom=8)

Map.addLayer(lc8.collection.median(),optical_vis, 'Landsat 8 mosaic')
Map.addLayer(mndwi.collection.median(),{"min":-0.5,"max":0.75,"palette":cm.palettes.Blues}, 'Landsat 8 MNDWI')

Map.addLayerControl()
Map

In [None]:
import math

In [None]:
@hf.decorators.keep_attrs
def cloud_mask(img):
    qa_band = img.select("pixel_qa")
    qaCloud = hf.extract_bits(qa_band, start=5, new_name="cloud_mask").eq(0)
    # qaShadow = hf.extract_bits(qa_band, start=3, new_name="shadow_mask").eq(0)
    # qaSnow = hf.extract_bits(qa_band, start=4, new_name="snow_mask").eq(0)
    # mask = qaCloud.And(qaShadow).And(qaSnow)
    return img.addBands(qaCloud)

In [None]:
def simple_TDOM2(collection, zScoreThresh, shadowSumThresh, dilatePixels):
    @hf.decorators.keep_attrs
    def darkMask(img):
        zScore = img.select(shadowSumBands).subtract(irMean).divide(irStdDev)
        irSum = img.select(shadowSumBands).reduce(ee.Reducer.sum())
        TDOMMask = zScore.lt(zScoreThresh).reduce(ee.Reducer.sum()).eq(2).And(irSum.lt(shadowSumThresh)).Not()
        TDOMMask = TDOMMask.focal_min(dilatePixels)
        img = img.addBands(TDOMMask.rename(['TDOM_mask']))
        # Combine the cloud and shadow masks
        combinedMask = img.select('cloud_mask').mask().And(img.select('TDOM_mask')) \
            .rename('cloud_shadow_mask');
        return img.addBands(combinedMask).updateMask(combinedMask)

    shadowSumBands = ['nir', 'swir1', 'swir2']
    irStdDev = collection.select(shadowSumBands).reduce(ee.Reducer.stdDev())
    irMean = collection.select(shadowSumBands).mean()

    collection = collection.map(darkMask)

    return collection


In [None]:
@hf.decorators.keep_attrs
def cloud_project(img):

    azimuthField = 'SOLAR_AZIMUTH_ANGLE'
    zenithField = 'SOLAR_ZENITH_ANGLE'

    def projectHeights(cloudHeight):
      cloudHeight = ee.Number(cloudHeight);
      shadowCastedDistance = zenR.tan().multiply(cloudHeight); #Distance shadow is cast
      x = azR.cos().multiply(shadowCastedDistance).divide(nominalScale).round(); #X distance of shadow
      y = azR.sin().multiply(shadowCastedDistance).divide(nominalScale).round(); #Y distance of shadow
      return cloud.changeProj(proj, proj.translate(x, y));

    # Get the cloud mask
    cloud = img.select(['cloud_mask']).Not();
    cloud = cloud.focal_max(dilatePixels);
    cloud = cloud.updateMask(cloud);

    # Get TDOM mask
    TDOMMask = img.select(['TDOM_mask']).Not();

    # Project the shadow finding pixels inside the TDOM mask that are dark and
    # inside the expected area given the solar geometry
    # Find dark pixels
    darkPixels = img.select(['nir','swir1','swir2'])\
      .reduce(ee.Reducer.sum()).lt(shadowSumThresh);#.gte(1);

    proj = img.select('cloud_mask').projection()

    # Get scale of image
    nominalScale = proj.nominalScale();

    #Find where cloud shadows should be based on solar geometry
    #Convert to radians
    meanAzimuth = img.get(azimuthField);
    meanZenith = img.get(zenithField);
    azR = ee.Number(meanAzimuth).multiply(math.pi).divide(180.0)\
      .add(ee.Number(0.5).multiply(math.pi));
    zenR = ee.Number(0.5).multiply(math.pi)\
      .subtract(ee.Number(meanZenith).multiply(math.pi).divide(180.0));

    # Find the shadows
    shadows = cloudHeights.map(projectHeights);

    shadow = ee.ImageCollection.fromImages(shadows).max();

    # Create shadow mask
    shadow = shadow.updateMask(cloud.mask().Not());
    shadow = shadow.focal_max(dilatePixels);
    shadow = shadow.updateMask(darkPixels.And(TDOMMask));

    # Combine the cloud and shadow masks
    combinedMask = cloud.mask().Or(shadow.mask()).eq(0);

    # Update the image's mask and return the image
    img = img.updateMask(combinedMask);
    img = img.addBands(combinedMask.rename(['cloudShadowMask']));
    return img.clip(img.geometry());


In [None]:
dilatePixels = 2;
cloudHeights = ee.List.sequence(200,5000,500);
zScoreThresh = -0.8;
shadowSumThresh = 0.35;

In [None]:
# get a Landsat 8 collection with no QA applied
lc8_noqa = hf.Dataset(region, start_time, end_time, asset_id="LANDSAT/LC08/C01/T1_SR", use_qa=False)
lc8_noqa = lc8_noqa.select(["B2", "B3", "B4", "B5", "B6", "B7","pixel_qa"],["blue", "green", "red", "nir", "swir1", "swir2","pixel_qa"])

In [None]:
lc8_cloud_masked = lc8_noqa.apply_func(cloud_mask)

In [None]:
lc8_cloud_shadow_masked = lc8_cloud_masked.copy()

lc8_cloud_shadow_masked.collection = simple_TDOM2(lc8_cloud_masked.collection, zScoreThresh, shadowSumThresh, dilatePixels)

In [None]:
lc8_projected_clouds_masked = lc8_cloud_shadow_masked.apply_func(cloud_project)

In [None]:
lc8_first = lc8.collection.first()
lc8_noqa_first = lc8_noqa.collection.first()

lc8_cloud_masked_first = lc8_cloud_masked.collection.first()
lc8_cloud_shadow_masked_first = lc8_cloud_shadow_masked.collection.first()

lc8_projected_clouds_masked_first = lc8_projected_clouds_masked.collection.first()

In [None]:
Map = geemap.Map(center=(15.6336, -90.1208), zoom=8)

Map.addLayer(lc8_first, optical_vis, 'Landsat 8 QAed')
Map.addLayer(lc8_noqa_first, optical_vis, 'Landsat 8 No QA')
Map.addLayer(lc8_cloud_masked_first, optical_vis, 'Landsat 8 Clouds Masked')
Map.addLayer(lc8_cloud_shadow_masked_first, optical_vis, 'Landsat 8 Clouds+Shadows Masked')
Map.addLayer(lc8_cloud_shadow_masked_first, {"bands":"TDOM_mask"}, 'Landsat 8 Shadows Mask')

Map.addLayer(lc8_projected_clouds_masked_first, optical_vis, 'Landsat 8 Full QA')



Map.addLayerControl()
Map