In [1]:
!pip install geemap -q
!pip install rasterio -q

In [2]:
import ee
from ee.batch import Export
import geemap

import rasterio
from rasterio.windows import Window

import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from datetime import datetime
from time import time, sleep

import numpy as np

import os

from pathlib import Path

In [3]:

ee.Authenticate()
ee.Initialize(project='ee-yvon444')

In [4]:

S1 = ee.ImageCollection('COPERNICUS/S1_GRD')
S2SR_H = ee.ImageCollection('COPERNICUS/S2_HARMONIZED')
S2Clouds = ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
FAO = ee.FeatureCollection('FAO/GAUL/2015/level1')


In [5]:

def collectionToList(collection):
    def accumulate(curr, prev):
        return ee.List(prev).add(curr)
    return ee.List(collection.iterate(accumulate, ee.List([])))

In [6]:


class OverlappingDataCollector:
    def __init__(self, area, start_date, end_date, unique_id, path,
                 orbit_pass='DESCENDING',
                 match_pref='CLOUDS',
                 image_size=256,
                 tile_size=2560,
                 max_elapsed_time=2 * 24 * 60 * 60 * 1000,
                 required_S1_bands=['VV', 'VH'],
                 required_S2_bands=['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12'],
                 required_S1_metadata=(['orbitProperties_pass', 'system:index', 'system:time_start'], ['S1_orbit_pass', 'S1_index', 'S1_time_start']),
                 required_S2_metadata=(['CLOUDY_PIXEL_PERCENTAGE', 'system:index', 'system:time_start'], ['S2_cloudy_percentage', 'S2_index', 'S2_time_start']),
                 max_overlap=None
                 ):

        self.area = area
        self.start_date = start_date
        self.end_date = end_date
        self.unique_id = unique_id
        self.path = Path(path)
        self.orbit_pass = orbit_pass # 'DESCENDING' or 'ASCENDING'
        self.match_pref = match_pref # match S1 with less cloudy S2 ('CLOUDS') or closest S2 in time ('TIME')
        self.tile_size = tile_size
        self.image_size = image_size
        self.max_elapsed_time = max_elapsed_time # how much time (ms) can there be between two corresponding S1 and S2 images?
        self.required_S1_bands = required_S1_bands
        self.required_S2_bands = required_S2_bands
        self.required_S1_metadata = required_S1_metadata if len(required_S1_metadata) == 2 else (required_S1_metadata, required_S1_metadata)
        self.required_S2_metadata = required_S2_metadata if len(required_S2_metadata) == 2 else (required_S2_metadata, required_S2_metadata)
        self.max_overlap = max_overlap

        self._calcCollections()
        self._calcOverlap()
        self._calcTiles()
        self._joinS1S2()
        self._prepareDownload()


    def _calcCollections(self):
        self.col_S1 = (self._getFilteredCollection(S1, self.area, self.start_date, self.end_date)
          .filter(ee.Filter.eq('instrumentMode', 'IW'))
          .filter(ee.Filter.eq('resolution', 'H'))
          .filter(ee.Filter.eq('resolution_meters', 10))
          .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
          .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
          )

        self.col_S1_asc = self.col_S1.filter(ee.Filter.eq('orbitProperties_pass', 'ASCENDING'))
        self.col_S1_desc = self.col_S1.filter(ee.Filter.eq('orbitProperties_pass', 'DESCENDING'))
        self.col_S1_pass = self.col_S1_asc if ORBIT_PASS == 'ASCENDING' else self.col_S1_desc

        self.col_S2 = self._getFilteredCollection(S2SR_H, self.area, self.start_date, self.end_date)
        self.col_S2 = ee.ImageCollection(ee.Join.saveFirst('cloud_prob').apply(**{
            'primary': self.col_S2,
            'secondary': S2Clouds,
            'condition': ee.Filter.equals(**{'leftField': 'system:index', 'rightField': 'system:index'})
            }))

    def _calcOverlap(self):
        overlap_S1 = self._getCollectionOverlap(self.col_S1)
        overlap_S2 = self._getCollectionOverlap(self.col_S2)
        self.overlap = overlap_S1.intersection(overlap_S2)
        self.overlap = self.overlap if self.max_overlap is None else self.overlap.intersection(self.max_overlap)
        self.overlap = self.overlap if self.overlap.area().getInfo() > 0 else None


    def _calcTiles(self):
        if self.overlap is not None:
            proj = self.col_S1.first().select('VV').projection()
            self.tiles = self.overlap.coveringGrid(proj, self.tile_size).filter(ee.Filter.isContained('.geo', self.overlap))
        else:
            self.tiles = None

    def _joinS1S2(self):
        filter_within_max_elapsed_time = ee.Filter.maxDifference(**{
            'difference': self.max_elapsed_time,
            'leftField': 'system:time_start',
            'rightField': 'system:time_start'})

        if self.match_pref == 'TIME':
            self.col_S1_S2 = ee.ImageCollection(ee.Join.saveBest('S2', 'elapsed_time').apply(**{
                'primary': self.col_S1_pass,
                'secondary': self.col_S2,
                'condition': filter_within_max_elapsed_time
                }))

        if self.match_pref == 'CLOUDS':
            self.col_S1_S2 = ee.ImageCollection(ee.Join.saveAll('S2').apply(**{
                'primary': self.col_S1_pass,
                'secondary': self.col_S2,
                'condition': filter_within_max_elapsed_time
                }))

            def find_least_cloudy(curr, prev):
                return ee.Algorithms.If(
                    ee.Number(ee.Image(curr).get('CLOUDY_PIXEL_PERCENTAGE')).lt(ee.Image(prev).get('CLOUDY_PIXEL_PERCENTAGE')),
                    curr,
                    prev)

            self.col_S1_S2 = self.col_S1_S2.map(
                lambda image: image.set('S2', ee.List(image.get('S2')).iterate(find_least_cloudy, ee.List(image.get('S2')).get(0)))
                )

    def _prepareDownload(self):
        proj = self.col_S1.first().select('VV').projection()
        col_download = self.col_S1_S2.map(
            lambda image:
             (image
              .select(self.required_S1_bands)
              .addBands(ee.Image(image.get('S2')), self.required_S2_bands)
              .addBands(ee.Image(ee.Image(image.get('S2')).get('cloud_prob')).rename(['s2cloudless']))
              .reproject(crs=proj, scale=10)
              .addBands(
                  image.select(['VV', 'VH']).multiply(-1000).toUint16(),
                  overwrite=True)
              ))

        fused_S1 = col_download.map(lambda image: image.select(self.required_S1_bands)).toBands()
        fused_S2 = col_download.map(lambda image: image.select(self.required_S2_bands)).toBands()
        fused_clouds = col_download.map(lambda image: image.select(['s2cloudless'])).toBands()

        self.fused_images = {
            'S1': fused_S1,
            'S2': fused_S2,
            'clouds': fused_clouds
            }

        self.metadata = ee.FeatureCollection(col_download.map(
            lambda image: (
                ee.Feature(image).select(self.required_S1_metadata[0], self.required_S1_metadata[1], retainGeometry=False)
                .copyProperties(ee.Feature(image.get('S2')).select(self.required_S2_metadata[0], self.required_S2_metadata[1], retainGeometry=False))
                .set('S1_bandnames', self.required_S1_bands)
                .set('S2_bandnames', self.required_S2_bands)
                )
            ))

    def pickTilesSubsample(self, tiles_subsample_percentage, seed=None):
        print(self.tiles.size().getInfo())
        self.tiles_subsample = self.tiles.randomColumn('p', seed).filter(ee.Filter.lt('p', tiles_subsample_percentage)).select(self.tiles.propertyNames())

    def exportToDrive(self):

        tiles_list = collectionToList(self.tiles_subsample)
        n_tiles = tiles_list.size().getInfo()

        os.mkdir(self.path / self.unique_id)

        self.tasks = []
        self.tasks.append(
            Export.table.toDrive(**{
                'collection': self.metadata,
                'description': f'{self.unique_id}_metadata',
                'folder': self.unique_id,
                'fileNamePrefix': f'{self.unique_id}_metadata',
                'fileFormat': 'CSV',
                'selectors': ['S1_bandnames', 'S2_bandnames'] + self.required_S1_metadata[1] + self.required_S2_metadata[1]
            }))

        self.tasks.append(
            Export.table.toDrive(**{
                'collection': self.tiles_subsample,
                'description': f'{self.unique_id}_tiles',
                'folder': self.unique_id,
                'fileNamePrefix': f'{self.unique_id}_tiles',
                'fileFormat': 'CSV',
                'selectors': ['.geo']
            })
        )

        for tile_i in range(n_tiles):
            tile = ee.Feature(tiles_list.get(tile_i))
            tile_geometry = tile.geometry()

            for name, image in self.fused_images.items():
                self.tasks.append(
                    Export.image.toDrive(**{
                        'image': image,
                        'description': f'{self.unique_id}_{tile_i}_{name}',
                        'folder': f'{self.unique_id}_{tile_i}',
                        'fileNamePrefix': f'{self.unique_id}_{tile_i}_{name}',
                        'region': tile_geometry,
                        'scale': 10,
                        'fileFormat': 'GeoTIFF'
                        }))

            os.mkdir(self.path / self.unique_id / f'{self.unique_id}_{tile_i}')

        sleep(10)
        for task in self.tasks:
            task.start()

    def postDownloadReshape(self):

        tiles_list = collectionToList(self.tiles_subsample)
        n_tiles = tiles_list.size().getInfo()

        for tile_i in range(n_tiles):
            path = self.path / self.unique_id / f'{self.unique_id}_{tile_i}'
            for raster in os.listdir(path):
                with rasterio.open(path / raster, 'r') as src:
                    profile = src.profile
                    h, w = src.shape
                    if (h < self.image_size) or (w < self.image_size):
                        raise Exception('Image too small')

                    profile['width'] = IMAGE_SIZE
                    profile['height'] = IMAGE_SIZE

                    clipped = src.read(window = rasterio.windows.Window(0, 0, IMAGE_SIZE, IMAGE_SIZE))

                with rasterio.open(path / raster, 'w', **profile) as src:
                    src.write(clipped)

    def plotMap(self):
        Map = geemap.Map()
        Map.addLayer(self.overlap, {'color': 'red'})
        Map.addLayer(self.tiles, {'color': 'black'})
        Map.centerObject(self.overlap, zoom=9)
        return Map

    def plotSummary(self):

        dates_S1 = self.col_S1_S2.aggregate_array('system:time_start').getInfo()
        dates_S2 = self.col_S1_S2.aggregate_array('S2').map(lambda image: ee.Image(image).get('system:time_start')).getInfo()

        vlines_S1 = [
            (mdates.date2num(datetime.fromtimestamp((x-self.max_elapsed_time)/1000)),
             mdates.date2num(datetime.fromtimestamp((x+self.max_elapsed_time)/1000))) for x in dates_S1
            ]

        dates_S1_fmt = [mdates.date2num(datetime.fromtimestamp(x / 1000)) for x in dates_S1]
        dates_S2_fmt = [mdates.date2num(datetime.fromtimestamp(x / 1000)) for x in dates_S2]

        cloud_probs = self.col_S1_S2.map(lambda image: ee.Image(image.get('S2')).get('cloud_prob'))
        cloud_probs = (collectionToList(cloud_probs)
          .map(lambda image: ee.Image(image).reduceRegion(ee.Reducer.mean(), self.tiles.geometry(), scale=150, bestEffort=True))
          .map(lambda x: ee.Dictionary(x).get('probability'))
        ).getInfo()
        cloud_probs = [x/100 for x in cloud_probs]

        fig, ax = plt.subplots(figsize=(15, 4))

        ax.set(title='Timeline of selected paired S1 S2 captures')

        ax.plot(dates_S2_fmt, np.ones_like(dates_S2_fmt) - 0.45, 's', color='k', markerfacecolor='blue', label='S2')
        ax.plot(dates_S1_fmt, np.zeros_like(dates_S1_fmt) + 0.45, 's', color='k', markerfacecolor='gray', label=f'S1 {"↑" if self.orbit_pass=="ASCENDING" else "↓"}')
        ax.plot(dates_S2_fmt, cloud_probs, '--o', color='darkblue', label='clouds', alpha=0.5)


        for x0, x1 in vlines_S1:
            ax.fill_betweenx([0.4, 0.6], x0, x1, color='gray', alpha=0.2)

        ax.set_ylim([-0.1,1.1])

        ax.yaxis.set_visible(False)
        ax.spines[["left", "top", "right"]].set_visible(False)

        ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%b %Y"))
        ax.margins(y=0.1)
        ax.legend()

        return fig, ax


    @staticmethod
    def _getFilteredCollection(collection, geometry, start_date, end_date):
        return (collection
            .filterDate(start_date, end_date)
            .filterBounds(geometry)
            .filter(ee.Filter.contains('.geo', geometry))
            .sort('system:time_start')
           )

    @staticmethod
    def _getCollectionOverlap(collection):
        image = collection.first();
        if image.getInfo() is None:
            return ee.Geometry.MultiPoint([])
        overlap =  collection.geometry().geometries().iterate(
            lambda curr, prev: ee.Geometry(prev).intersection(curr), image.geometry())
        return ee.Geometry(overlap)




In [7]:


class RandomPointsDataCollector:

    def __init__(self, area, n_samples, n_tiles_target, start_date, end_date, unique_id, path,
                 buffer_region=10000,
                 buffer_intersection=20000,
                 seed=12,
                 **kwargs):
        self.area = area
        self.n_samples = n_samples
        self.n_tiles_target = n_tiles_target
        self.start_date = start_date
        self.end_date = end_date
        self.unique_id = unique_id
        self.path = Path(path)

        self.buffer_region = buffer_region
        self.buffer_intersection = buffer_intersection
        self.seed = seed

        self.kwargs = kwargs

        self._samplePoints()
        self._createDataCollectors()
        self._pickTilesSubsample()

    def _samplePoints(self):

        def filter_points(curr, prev):
            curr = ee.Feature(curr)
            prev = ee.FeatureCollection(prev)
            return ee.Algorithms.If(
                curr.buffer(self.buffer_intersection).intersects(prev.geometry().buffer(self.buffer_intersection)),
                prev,
                prev.merge(ee.FeatureCollection([curr]))
                )

        self.points = ee.FeatureCollection.randomPoints(self.area.buffer(-self.buffer_region), self.n_samples, self.seed)
        self.points = ee.FeatureCollection(self.points.iterate(filter_points, ee.FeatureCollection([])))
        self.points = collectionToList(self.points).getInfo()

    def _createDataCollectors(self):

        self.data_collectors = []
        for i, point in enumerate(self.points):
            point = ee.Geometry(point['geometry'])

            data_collector = OverlappingDataCollector(point, self.start_date, self.end_date, f'{self.unique_id}_{i}', self.path / self.unique_id,
                                                      max_overlap=point.buffer(self.buffer_intersection), **self.kwargs)
            self.data_collectors.append(data_collector)

    def _pickTilesSubsample(self):

        n_tiles = sum([dc.tiles.size().getInfo() for dc in self.data_collectors])
        p = min(self.n_tiles_target / n_tiles, 1)

        for data_collector in self.data_collectors:
            # n = data_collector.tiles.size().getInfo()
            # if n == 0:
            #     continue

            data_collector.pickTilesSubsample(p, self.seed)


    def plotMap(self):
        Map = geemap.Map()
        Map.addLayer(self.area, {'color': 'orange'})

        for data_collector in self.data_collectors:
            Map.addLayer(data_collector.tiles_subsample, {'color': 'black'})

        Map.centerObject(self.area, zoom=9)
        return Map

    def plotSummary(self):

        return [dc.plotSummary()[0] for dc in self.data_collectors if dc.tiles.size().getInfo() > 0]


    def exportToDrive(self):

        os.mkdir(self.path / self.unique_id)

        for data_collector in self.data_collectors:
            if data_collector.tiles_subsample.size().getInfo() > 0:
                data_collector.exportToDrive()

    def postDownloadReshape(self):

        for data_collector in self.data_collectors:
            if data_collector.tiles_subsample.size().getInfo() > 0:
                data_collector.postDownloadReshape()



In [8]:

IMAGE_SIZE = 256
TILE_SIZE = 2560
START_DATE = ee.Date('2022-01-01')
END_DATE = ee.Date('2023-01-01')
MAX_ELAPSED_TIME = 2 * 24 * 60 * 60 * 1000 # how much time (ms) can there be between a S1 and S2 image?
ORBIT_PASS = 'DESCENDING'
MATCH_S1_S2 = 'CLOUDS' # match S1 and S2 on least 'CLOUDS' or least 'TIME' elapsed
UNIQUE_ID = 'BIHAR_2022_2023_TEST'
PATH = './drive/MyDrive/deCloud/deCloudData/'

REQUIRED_S1_BANDS = ['VV', 'VH']
REQUIRED_S2_BANDS = ['B2', 'B3', 'B4', 'B8']

bihar = FAO.filter(ee.Filter.eq('ADM1_NAME', 'Bihar')).first().geometry()

n_samples = 20
n_tiles_target = 200

dc = RandomPointsDataCollector(
    area=bihar,
    n_samples=n_samples,
    n_tiles_target=n_tiles_target,
    start_date=START_DATE,
    end_date=END_DATE,
    unique_id=UNIQUE_ID,
    path=PATH,
    buffer_region=1000,
    buffer_intersection=2000,
    seed=12,
    orbit_pass=ORBIT_PASS,
    match_pref=MATCH_S1_S2,
    image_size=256,
    tile_size=2560,
    max_elapsed_time=2 * 24 * 60 * 60 * 1000,
    required_S1_bands=REQUIRED_S1_BANDS,
    required_S2_bands=REQUIRED_S2_BANDS,
    required_S1_metadata=(['orbitProperties_pass', 'system:index', 'system:time_start'], ['S1_orbit_pass', 'S1_index', 'S1_time_start']),
    required_S2_metadata=(['CLOUDY_PIXEL_PERCENTAGE', 'system:index', 'system:time_start'], ['S2_cloudy_percentage', 'S2_index', 'S2_time_start']),
)


ZeroDivisionError: division by zero

In [None]:

# visualize what we'
x = dc.plotSummary()

In [None]:
dc.plotMap()