[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adugnag/deSpeckNet-TF-GEE/blob/main/notebooks/test.ipynb)

# Setup software libraries



In [None]:
# Cloud authentication.
from google.colab import auth
auth.authenticate_user()

In [None]:
# Import, authenticate and initialize the Earth Engine library.
import ee
ee.Authenticate()
ee.Initialize()

In [None]:
#%tensorflow_version 1.x
import tensorflow as tf
import numpy as np

#tf.enable_eager_execution()
print(tf.__version__)

In [None]:
# Folium setup.
import folium
print(folium.__version__)

In [None]:
#@title Helper functions
def lin_to_db(image):
    """
    Convert backscatter from linear to dB.
    Parameters
    ----------
    image : ee.Image
        Image to convert 
    Returns
    -------
    ee.Image
        output image
    """
    bandNames = image.bandNames().remove('angle')
    db = ee.Image.constant(10).multiply(image.select(bandNames).log10()).rename(bandNames)
    return image.addBands(db, None, True)


def s1_prep(params):
    """
    Applies preprocessing to a collection of S1 images to return an analysis ready sentinel-1 data. 

    """
    
    POLARIZATION = params['POLARIZATION']
    FORMAT = params['FORMAT']
    START_DATE = params['START_DATE']
    STOP_DATE = params['STOP_DATE']
    ORBIT = params['ORBIT']
    RELATIVE_ORBIT_NUMBER = params['RELATIVE_ORBIT_NUMBER']
    ROI = params['ROI']
    CLIP_TO_ROI = params['CLIP_TO_ROI']

    ###########################################
    # 0. CHECK PARAMETERS
    ###########################################
    
    if POLARIZATION is None: POLARIZATION = 'VVVH'
    if FORMAT is None: FORMAT = 'DB' 
    if ORBIT is None: ORBIT = 'DESCENDING' 
    
    
    pol_required = ['VV', 'VH', 'VVVH']
    if (POLARIZATION not in pol_required):
        raise ValueError("ERROR!!! Parameter POLARIZATION not correctly defined")

    
    orbit_required = ['ASCENDING', 'DESCENDING', 'BOTH']
    if (ORBIT not in orbit_required):
        raise ValueError("ERROR!!! Parameter ORBIT not correctly defined")


    format_required = ['LINEAR', 'DB']
    if (FORMAT not in format_required):
        raise ValueError("ERROR!!! FORMAT not correctly defined")
        
    
    
    ###########################################
    # 1. IMPORT COLLECTION
    ###########################################
    
    s1 = ee.ImageCollection('COPERNICUS/S1_GRD_FLOAT') \
                .filter(ee.Filter.eq('instrumentMode', 'IW')) \
                .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))\
                .filter(ee.Filter.eq('resolution_meters', 10)) \
                .filter(ee.Filter.eq('platform_number', 'A')) \
                .filterDate(START_DATE, STOP_DATE) \
                .filterBounds(ROI)
    
    #########################
    # 2. SELECT POLARIZATION
    #########################:

        # select orbit
    if (ORBIT != 'BOTH'):
      s1 = s1.filter(ee.Filter.eq('orbitProperties_pass', ORBIT))

    if (RELATIVE_ORBIT_NUMBER != 'ANY'): 
      s1 =  s1.filter(ee.Filter.eq('relativeOrbitNumber_start', RELATIVE_ORBIT_NUMBER)) 
      
    
    if (POLARIZATION == 'VV'):
      s1 = s1.select(['VV','angle'])
    elif (POLARIZATION == 'VH'):
      s1 = s1.select(['VH','angle'])
    elif (POLARIZATION == 'VVVH'):
      s1 = s1.select(['VV','VH','angle'])  
      
    ########################
    # 3. CLIP TO ROI
    ####################### 
    
    # clip image to roi
    if (CLIP_TO_ROI):
        s1 = s1.map(lambda image: image.clip(ROI))
              
    ########################
    # 7. FORMAT CONVERSION
    ####################### 
    
    if (FORMAT == 'DB'):
        s1 = s1.map(lin_to_db)
        
        
    return s1

#@title Helper functions for export and testing

def export(IMAGE_PREFIX, KERNEL_BUFFER, GEOMETRY):
  """Run the image export task.  Block until complete.
  """
  task = ee.batch.Export.image.toCloudStorage(
    image = image.select(params['BANDS']),
    description = params['IMAGE_PREFIX'],
    bucket = params['BUCKET'],
    fileNamePrefix = params['FOLDER'] + '/' + params['IMAGE_PREFIX'],
    region = params['GEOMETRY'].getInfo()['coordinates'],
    scale = 10,
    fileFormat = 'TFRecord',
    maxPixels = 1e13,
    formatOptions = {
      'patchDimensions': params['KERNEL_SHAPE'],
      'kernelSize': params['KERNEL_BUFFER'],
      'compressed': True,
      'maxFileSize': 104857600
    }
  )
  task.start()

  # Block until the task completes.
  print('Running image export to Cloud Storage...')
  import time
  while task.active():
    time.sleep(30)

  # Error condition
  if task.status()['state'] != 'COMPLETED':
    print('Error with image export.')
  else:
    print('Image export completed.')

def prediction(params):
  """Perform inference on exported imagery, upload to Earth Engine.
  """

  print('Looking for TFRecord files...')

  # Get a list of all the files in the output bucket.
  filesList = !gsutil ls 'gs://'{params['BUCKET']}'/'{params['FOLDER']}

  # Get only the files generated by the image export.
  exportFilesList = [s for s in filesList if params['IMAGE_PREFIX'] in s]

  # Get the list of image files and the JSON mixer file.
  imageFilesList = []
  jsonFile = None
  for f in exportFilesList:
    if f.endswith('.tfrecord.gz'):
      imageFilesList.append(f)
    elif f.endswith('.json'):
      jsonFile = f

  # Make sure the files are in the right order.
  imageFilesList.sort()

  from pprint import pprint
  pprint(imageFilesList)
  print(jsonFile)

  import json
  # Load the contents of the mixer file to a JSON object.
  jsonText = !gsutil cat {jsonFile}
  # Get a single string w/ newlines from the IPython.utils.text.SList
  mixer = json.loads(jsonText.nlstr)
  pprint(mixer)
  patches = mixer['totalPatches']

  # Get set up for prediction.
  x_buffer = int(params['KERNEL_BUFFER'][0] / 2)
  y_buffer = int(params['KERNEL_BUFFER'][1] / 2)

  buffered_shape = [
      params['KERNEL_SHAPE'][0] + params['KERNEL_BUFFER'][0],
      params['KERNEL_SHAPE'][1] + params['KERNEL_BUFFER'][1]]

  imageColumns = [
    tf.io.FixedLenFeature(shape=buffered_shape, dtype=tf.float32) 
      for k in BANDS
  ]

  imageFeaturesDict = dict(zip(params['BANDS'], imageColumns))

  def parse_image(example_proto):
    return tf.io.parse_single_example(example_proto, imageFeaturesDict)

  def toTupleImage(inputs):
    inputsList = [inputs.get(key) for key in params['BANDS']]
    stacked = tf.stack(inputsList, axis=0)
    stacked = tf.transpose(stacked, [1, 2, 0])
    #stacked = tf.reshape(tensor = stacked , shape = [NR_IMAGES, 32 , 32 ,len(BAND_MODE)])
    return stacked

   # Create a dataset from the TFRecord file(s) in Cloud Storage.
  imageDataset = tf.data.TFRecordDataset(imageFilesList, compression_type='GZIP')
  imageDataset = imageDataset.map(parse_image, num_parallel_calls=5)
  imageDataset = imageDataset.map(toTupleImage).batch(1)

  # Perform inference.
  print('Running predictions...')
  predictions = model.predict(imageDataset, steps=patches, verbose=1)
  predictions = predictions[0]
  #predictions = predictions.argmax(axis=3)
  print(len(predictions))
  print(predictions[0].shape)
 

  print('Writing predictions...')
  out_image_file = 'gs://' + params['BUCKET'] + '/' + params['FOLDER'] + '/' + params['MODEL_NAME'] + '.TFRecord'
  writer = tf.io.TFRecordWriter(out_image_file)
  patches = 0

  for predictionPatch in predictions:
    print('Writing patch ' + str(patches) + '...')
    predictionPatch = predictionPatch[
        x_buffer:x_buffer+params['KERNEL_SIZE'], y_buffer:y_buffer+params['KERNEL_SIZE'],:]

    if params['POLARIZATION'] == 'VVVH':
    # Create an example.
      example = tf.train.Example(
        features=tf.train.Features(
          feature={
            'VV': tf.train.Feature(
                float_list=tf.train.FloatList(
                    value=predictionPatch[:,:,0].flatten())),
            'VH': tf.train.Feature(
                float_list=tf.train.FloatList(
                    value=predictionPatch[:,:,1].flatten()))
          }
        )
      )
    else:
      example = tf.train.Example(
        features=tf.train.Features(
          feature={
            params['POLARIZATION']: tf.train.Feature(
                float_list=tf.train.FloatList(
                    value=predictionPatch.flatten()))
          }
        )
      )
    # Write the example.
    writer.write(example.SerializeToString())
    patches += 1

  writer.close()

  # Start the upload.
  out_image_asset = USER_ID + '/' + params['MODEL_NAME']
  !earthengine upload image --asset_id={out_image_asset} {out_image_file} {jsonFile}

# Data Prep



In [None]:
#Test image area 

#roi
geometry =  ee.Geometry.Polygon(
        [[[103.08000490033993, -2.8225068747308946],
          [103.08000490033993, -2.9521181019620673],
         [103.29217836225399, -2.9521181019620673],
         [103.29217836225399, -2.8225068747308946]]])

geometry2 =     ee.Geometry.Polygon(
        [[[103.28423388261817, -2.666639235594898],
          [103.28423388261817, -2.7983252476718885],
          [103.47786791582129, -2.7983252476718885],
          [103.47786791582129, -2.666639235594898]]])

#Parameters
params = {   # GCS bucket
           'START_DATE': '2021-12-01', 
            'STOP_DATE': '2021-12-31',        
            'ORBIT': 'DESCENDING',
            'RELATIVE_ORBIT_NUMBER':18, 
            'POLARIZATION': 'VVVH',
            'ROI':    geometry,
            'FORMAT': 'DB',
            'CLIP_TO_ROI': True,
            'EXPORT': 'GCS',
            'BUCKET' : 'senalerts_dl3',
            'DRIVE' : '/content/drive',
            'FOLDER' : 'deSpeckNet',
            'USER_ID' : 'users/adugnagirma',
            'IMAGE_PREFIX' : 'deSpeckNet_TEST_PATCH_v3_',
          # Should be the same bands selected during data prep
            'BANDS': ['VV', 'VH'],
            'RESPONSE_TR' : ['VV_median', 'VH_median'],
            'RESPONSE_TU' : ['VV', 'VH'],
            'MASK' : ['VV_mask', 'VH_mask'],
            'KERNEL_SIZE' : 256,
            'KERNEL_SHAPE' : [256, 256],
            'KERNEL_BUFFER' : [128, 128],
            'MODEL_NAME': 'tune'
            }

#process Sentinel 1 image collection
s1_processed = s1_prep(params)
bandNames = s1_processed.first().bandNames().remove('angle')
s1_processed = s1_processed.select(bandNames)
print('Number of images in the collection: ', s1_processed.size().getInfo())

image = s1_processed.first()
# Specify inputs (Sentinel-1 bands) to the model and the response variable.
BANDS = image.bandNames().getInfo()

#Visualize

In [None]:

# Use folium to visualize the imagery.#
mapid = image.getMapId({'bands':BANDS[0], 'min': -20, 'max':0})
map = folium.Map(location=[-2.6145179357243027, 103.46795961225435])

folium.TileLayer(
    tiles=mapid['tile_fetcher'].url_format,
    attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
    overlay=True,
    name='Sentinel-1 12 day mosaic composite',
  ).add_to(map)
map.add_child(folium.LayerControl())
map

#Model

In [None]:
# load the saved model
MODEL_DIR = 'gs://' + params['BUCKET'] + '/' + params['FOLDER'] + '/' + params['MODEL_NAME']
#custom_objects={'TransformerBlock': TransformerBlock}
model = tf.keras.models.load_model(MODEL_DIR)
model.summary()

#Export and Inference

In [None]:
# Run the export. (Run the export only once)
export(image,params)
# Run the prediction.
prediction(params)

#Visualize

In [None]:
out_image = ee.Image(params['USER_ID'] + '/' + params['MODEL_NAME'])

map.add_child(folium.LayerControl())
map
#out_image = out_image.arrayArgmax()
mapid = out_image.getMapId({'min':-20,'max':0})
mapid_2 = image.getMapId({'bands':BANDS[0], 'min': -20, 'max':0})
map = folium.Map(location=[-2.6145179357243027, 103.46795961225435])
folium.TileLayer(
    tiles=mapid_2['tile_fetcher'].url_format,
    attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
    overlay=True,
    name='original image',
  ).add_to(map)
folium.TileLayer(
    tiles=mapid['tile_fetcher'].url_format,
    attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
    overlay=True,
    name='Filtered image',
  ).add_to(map)
map.add_child(folium.LayerControl())
map