<a href="https://colab.research.google.com/github/am1235le/GY7720_Dissertation_209041686/blob/main/gee_data_export_all_final_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Extract GEE Data into Google Storage Data Bucket

The following code was sourced and adapted from the Google Earth Engine (GEE) data extract and formatting code made available by Huot et al. (2022), which can be found under the following link:

https://github.com/google-research/google-research/tree/master/simulation_research/next_day_wildfire_spread

Huot F., Hu R. L., Goyal N., Sankar T., Ihme M. and Chen Y. F. 2022. Next Day Wildfire Spread: A Machine Learning Dataset to Predict Wildfire Spreading From Remote-Sensing Data. IEEE Transactions on Geoscience and Remote Sensing, vol. 60, pp. 1-13, 2022, Art no. 4412513, doi: 10.1109/TGRS.2022.3192974.

## Load libraries and connect to Google Drive

In [1]:
# load required libraries
import enum
import math
import os
import random
import json
import ee
from ee import Date
from typing import List, Text, Tuple, Dict
from google.colab import files
from absl import app
from absl import flags
from absl import logging
from absl.testing import flagsaver

In [None]:
# connect to google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
# set file path for working folder
%cd /content/drive/My\ Drive/Colab\ Notebooks/Dissertation/

In [7]:
# Define random seed for later in code
random.seed(123)

## Define Data Sources for Extract

In [8]:
# Define Source Data Types, including; Source, Bands and Time units
# adapted from Huot et al. (2022)

# define data categories
class DataType(enum.Enum):
  ELEVATION_SRTM = 1
  VEGETATION_VIIRS = 2
  DROUGHT_GRIDMET = 3
  WEATHER_RTMA = 4
  WEATHER_GRIDMET = 5
  FIRE_MODIS = 6
  POPULATIONDENSITY = 7
  LULC = 8
  MODIS_BANDS = 9

# specify category sources
DATA_SOURCES = {
    DataType.ELEVATION_SRTM: 'USGS/SRTMGL1_003',
    DataType.VEGETATION_VIIRS: 'NOAA/VIIRS/001/VNP13A1',
    DataType.DROUGHT_GRIDMET: 'GRIDMET/DROUGHT',
    DataType.WEATHER_RTMA: 'NOAA/NWS/RTMA',
    DataType.WEATHER_GRIDMET: 'IDAHO_EPSCOR/GRIDMET',
    DataType.FIRE_MODIS: 'MODIS/061/MOD14A1',
    DataType.POPULATIONDENSITY: 'CIESIN/GPWv411/GPW_Population_Density',
    DataType.LULC: 'GOOGLE/DYNAMICWORLD/V1',
    DataType.MODIS_BANDS: 'MODIS/061/MOD09GA'
}

# specify category bands from source
DATA_BANDS = {
    DataType.ELEVATION_SRTM: ['elevation'],
    DataType.VEGETATION_VIIRS: ['NDVI'],
    DataType.DROUGHT_GRIDMET: ['pdsi'],
    DataType.WEATHER_RTMA: [
        'PRES',
        'TMP',
        'UGRD',
        'VGRD',
        'SPFH',
        'WDIR',
        'WIND',
        'GUST',
    ],
    DataType.WEATHER_GRIDMET: [
        'pr',
        'sph',
        'th',
        'tmmn',
        'tmmx',
        'vs',
        'erc',
    ],
    DataType.FIRE_MODIS: ['FireMask'],
    DataType.POPULATIONDENSITY: ['populationdensity'],
    DataType.LULC: ['label'],
    DataType.MODIS_BANDS: [
        'sur_refl_b01',
        'sur_refl_b02',
        'sur_refl_b06',
        'sur_refl_b07'
    ]
}

# specify time units in days
DATA_TIME_SAMPLING = {
    DataType.VEGETATION_VIIRS: 8,
    DataType.DROUGHT_GRIDMET: 5,
    DataType.WEATHER_RTMA: 2,
    DataType.WEATHER_GRIDMET: 2,
    DataType.FIRE_MODIS: 1,
    DataType.LULC: 1,
    DataType.MODIS_BANDS: 1,
}

In [9]:
# check datasources
print(DATA_SOURCES)

{<DataType.ELEVATION_SRTM: 1>: 'USGS/SRTMGL1_003', <DataType.VEGETATION_VIIRS: 2>: 'NOAA/VIIRS/001/VNP13A1', <DataType.DROUGHT_GRIDMET: 3>: 'GRIDMET/DROUGHT', <DataType.WEATHER_RTMA: 4>: 'NOAA/NWS/RTMA', <DataType.WEATHER_GRIDMET: 5>: 'IDAHO_EPSCOR/GRIDMET', <DataType.FIRE_MODIS: 6>: 'MODIS/061/MOD14A1', <DataType.POPULATIONDENSITY: 7>: 'CIESIN/GPWv411/GPW_Population_Density', <DataType.LULC: 8>: 'GOOGLE/DYNAMICWORLD/V1', <DataType.MODIS_BANDS: 9>: 'MODIS/061/MOD09GA'}


In [10]:
# set ectravtion defaults
# adapted from Huot et al. (2022)

# default resampling scale
RESAMPLING_SCALE = {DataType.WEATHER_GRIDMET: 10000}
# fire detection band
DETECTION_BAND = 'detection'
# kernel size
DEFAULT_KERNEL_SIZE = 128
# sampling resolution
DEFAULT_SAMPLING_RESOLUTION = 1000  # Units: meters
# train/eval split
DEFAULT_EVAL_SPLIT = 0.2
# limit per GEE call
DEFAULT_LIMIT_PER_EE_CALL = 60
# default seed
DEFAULT_SEED = 123

In [12]:
# set spatial domain for image extraction
COORDINATES = {
    # Used as input to ee.Geometry.Rectangle()
    # Region 1
    'US': [[-125, 32], [-102, 42]]
    # Region 2
    #'US': [[-125, 32], [-114, 42]]
}

## Define functions for data extraction code

In [13]:
# adapted from Huot et al. (2022)

# extract individual GEE image for specified data source and bands
def get_image(data_type):
  """Gets an image corresponding to `data_type`.
  Args:
    data_type: A specifier for the type of data.
  Returns:
    The EE image correspoding to the selected `data_type`.
  """
  return ee.Image(DATA_SOURCES[data_type]).select(DATA_BANDS[data_type])

# extract GEE image collection for specified data source and bands
def get_image_collection(data_type):
  """Gets an image collection corresponding to `data_type`.
  Args:
    data_type: A specifier for the type of data.
  Returns:
    The EE image collection corresponding to `data_type`.
  """
  return ee.ImageCollection(DATA_SOURCES[data_type]).select(
      DATA_BANDS[data_type])

# extract GEE image collection for specified data source; first available image
def get_image_collection_pop(data_type):
  """Gets an image collection corresponding to `data_type`.
  Args:
    data_type: A specifier for the type of data.
  Returns:
    The EE image collection corresponding to `data_type`.
  """
  return ee.ImageCollection(DATA_SOURCES[data_type]).first()

In [14]:
# remove image mask
# sourced from Huot et al. (2022)
def remove_mask(image):
  """Removes the mask from an EE image.
  Args:
    image: The input EE image.
  Returns:
    The EE image without its mask.
  """
  mask = ee.Image(1)
  return image.updateMask(mask)

In [15]:
# extract feature collection from GEE
# sourced from Huot et al. (2022)
def export_feature_collection(
    feature_collection,
    description,
    bucket,
    folder,
    bands,
    file_format = 'TFRecord',
):
  """Starts an EE task to export `feature_collection` to TFRecords.
  Args:
    feature_collection: The EE feature collection to export.
    description: The filename prefix to use in the export.
    bucket: The name of the Google Cloud bucket.
    folder: The folder to export to.
    bands: The list of names of the features to export.
    file_format: The output file format. 'TFRecord' and 'GeoTIFF' are supported.
  Returns:
    The EE task associated with the export.
  """
  task = ee.batch.Export.table.toCloudStorage(
      collection=feature_collection,
      description=description,
      bucket=bucket,
      fileNamePrefix=os.path.join(folder, description),
      fileFormat=file_format,
      selectors=bands)
  task.start()
  return task

In [16]:
# convery GEE images to arrays of defined shape
# sourced from Huot et al. (2022)
def convert_features_to_arrays(
    image_list,
    kernel_size = DEFAULT_KERNEL_SIZE,
):
  """Converts a list of EE images into `(kernel_size x kernel_size)` tiles.
  Args:
    image_list: The list of EE images.
    kernel_size: The size of the tiles (kernel_size x kernel_size).
  Returns:
    An EE image made of (kernel_size x kernel_size) tiles.
  """
  feature_stack = ee.Image.cat(image_list).float()
  kernel_list = ee.List.repeat(1, kernel_size)  # pytype: disable=attribute-error
  kernel_lists = ee.List.repeat(kernel_list, kernel_size)  # pytype: disable=attribute-error
  kernel = ee.Kernel.fixed(kernel_size, kernel_size, kernel_lists)
  return feature_stack.neighborhoodToArray(kernel)

In [17]:
# identify images containing fire
# sourced from Huot et al. (2022)
def get_detection_count(
    detection_image,
    geometry,
    sampling_scale = DEFAULT_SAMPLING_RESOLUTION,
    detection_band = DETECTION_BAND,
):
  """Counts the total number of positive pixels in the detection image.
  Assumes that the pixels in the `detection_band` of `detection_image` are
  zeros and ones.
  Args:
    detection_image: An EE image with a detection band.
    geometry: The EE geometry over which to count the pixels.
    sampling_scale: The sampling scale used to count pixels.
    detection_band: The name of the image band to use.
  Returns:
    The number of positive pixel counts or -1 if EE throws an error.
  """
  detection_stats = detection_image.reduceRegion(
      reducer=ee.Reducer.sum(), geometry=geometry, scale=sampling_scale)
  try:
    detection_count = int(detection_stats.get(detection_band).getInfo())
  except ee.EEException:
    # If the number of positive pixels cannot be counted because of a server-
    # side error, return -1.
    detection_count = -1
  return detection_count


In [18]:
# extract sample images that satisfy specified detection count criteria
# sourced from Huot et al. (2022)
def extract_samples(
    image,
    detection_count,
    geometry,
    sampling_ratio,
    detection_band = DETECTION_BAND,
    sampling_limit_per_call = DEFAULT_LIMIT_PER_EE_CALL,
    resolution = DEFAULT_SAMPLING_RESOLUTION,
    seed = DEFAULT_SEED,
):
  """Samples an EE image for positive and negative samples.
  Extracts `detection_count` positive examples and (`sampling_ratio` x
  `detection_count`) negative examples. Assumes that the pixels in the
  `detection_band` of `detection_image` are zeros and ones.
  Args:
    image: The EE image to extract samples from.
    detection_count: The number of positive samples to extract.
    geometry: The EE geometry over which to sample.
    sampling_ratio: If sampling negatives examples, samples (`sampling_ratio` x
      `detection_count`) negative examples. When extracting only positive
      examples, set this to zero.
    detection_band: The name of the image band to use to determine sampling
      locations.
    sampling_limit_per_call: The limit on the size of EE calls. Can be used to
      avoid memory errors on the EE server side. To disable this limit, set it
      to `detection_count`.
    resolution: The resolution in meters at which to scale.
    seed: The number used to seed the random number generator. Used when
      sampling less than the total number of pixels.
  Returns:
    An EE feature collection with all the extracted samples.
  """
  feature_collection = ee.FeatureCollection([])
  num_per_call = sampling_limit_per_call // (sampling_ratio + 1)

  # The sequence of sampling calls is deterministic, so calling stratifiedSample
  # multiple times never returns samples with the same center pixel.
  for _ in range(math.ceil(detection_count / num_per_call)):
    samples = image.stratifiedSample(
        region=geometry,
        numPoints=0,
        classBand=detection_band,
        scale=resolution,
        seed=seed,
        classValues=[0, 1],
        classPoints=[num_per_call * sampling_ratio, num_per_call],
        dropNulls=True)
    feature_collection = feature_collection.merge(samples)
  return feature_collection


In [19]:
# split days in specified period into ranges for train, eval, test split
# sourced from Huot et al. (2022)
def split_days_into_train_eval_test(
    start_date,
    end_date,
    split_ratio = DEFAULT_EVAL_SPLIT,
    window_length_days = 8,
):
  """Splits the days into train / eval / test sets.
  Splits the interval between  `start_date` and `end_date` into subintervals of
  duration `window_length` days, and divides them into train / eval / test sets.
  Args:
    start_date: The start date.
    end_date: The end date.
    split_ratio: The split ratio for the divide between sets, such that the
      number of eval time chunks and test time chunks are equal to the total
      number of time chunks x `split_ratio`. All the remaining time chunks are
      training time chunks.
    window_length_days: The length of the time chunks (in days).
  Returns:
    A dictionary containing the list of start day indices of each time chunk for
    each set.
  """
  num_days = int(ee.Date.difference(end_date, start_date, unit='days').getInfo())  # pytype: disable=attribute-error
  days = list(range(num_days))
  days = days[::window_length_days]
  random.shuffle(days)
  num_eval = int(len(days) * split_ratio)
  split_days = {}
  split_days['train'] = days[:-2 * num_eval]
  split_days['eval'] = days[-2 * num_eval:-num_eval]
  split_days['test'] = days[-num_eval:]
  return split_days

In [20]:
# specify the name of bands that will be extracted
# adapted from Huot et al. (2022)
def _get_all_feature_bands():
  """Returns list of all bands corresponding to features."""
  return (DATA_BANDS[DataType.ELEVATION_SRTM] +
          ['populationdensity'] +
          DATA_BANDS[DataType.DROUGHT_GRIDMET] +
          DATA_BANDS[DataType.VEGETATION_VIIRS] +
          DATA_BANDS[DataType.WEATHER_GRIDMET] +
          ['PrevFireMask'] +
          DATA_BANDS[DataType.WEATHER_RTMA] +
          DATA_BANDS[DataType.MODIS_BANDS] +
          ['nbr'] +
          ['dnbr'] +
          ['ndii'] +
          ['ndvid'] +
          DATA_BANDS[DataType.LULC] +
          ['lulc'])

In [21]:
# specify fire band
# sourced from Huot et al. (2022)
def _get_all_response_bands():
  """Returns list of all bands corresponding to labels."""
  return DATA_BANDS[DataType.FIRE_MODIS]

In [22]:
# apply index to label to each band
# sourced from Huot et al. (2022)
def _add_index(i, bands):
  """Appends the index number `i` at the end of each element of `bands`."""
  return [f'{band}_{i}' for band in bands]

In [23]:
# retrieve image samples for specified train, eval test time windows
# adapted from Huot et al. (2022)
def _get_all_image_collections():
  """Gets all the image collections and corresponding time sampling."""
  image_collections = {
      'drought':
          get_image_collection(DataType.DROUGHT_GRIDMET),
      'vegetation':
          get_image_collection(DataType.VEGETATION_VIIRS),
      'weather':
          get_image_collection(DataType.WEATHER_GRIDMET),
      'fire':
          get_image_collection(DataType.FIRE_MODIS),
      'indexbands':
          get_image_collection(DataType.MODIS_BANDS),
      'lulc':
          get_image_collection(DataType.LULC),
      'weatherrtma':
          get_image_collection(DataType.WEATHER_RTMA),
  }
  time_sampling = {
      'drought':
          DATA_TIME_SAMPLING[DataType.DROUGHT_GRIDMET],
      'vegetation':
          DATA_TIME_SAMPLING[DataType.VEGETATION_VIIRS],
      'weather':
          DATA_TIME_SAMPLING[DataType.WEATHER_GRIDMET],
      'fire':
          DATA_TIME_SAMPLING[DataType.FIRE_MODIS],
      'indexbands':
          DATA_TIME_SAMPLING[DataType.MODIS_BANDS],
      'lulc':
          DATA_TIME_SAMPLING[DataType.LULC],
      'weatherrtma':
          DATA_TIME_SAMPLING[DataType.WEATHER_RTMA],
  }
  return image_collections, time_sampling

In [24]:
# verify feature collection extract
# sourced from Huot et al. (2022)
def _verify_feature_collection(
    feature_collection
):
  """Verifies the feature collection is valid.
  If the feature collection is invalid, resets the feature collection.
  Args:
    feature_collection: An EE feature collection.
  Returns:
    `(feature_collection, size)` a tuple of the verified feature collection and
    its size.
  """
  try:
    size = int(feature_collection.size().getInfo())
  except ee.EEException:
    # Reset the feature collection
    feature_collection = ee.FeatureCollection([])
    size = 0
  return feature_collection, size

In [25]:
# extract and prepare data for defined time window
# adapted from Huot et al. (2022)
def _get_time_slices(
    window_start,
    window,
    projection,  # Defer calling until called by test code
    resampling_scale,
    lag = 0,
):
  """Extracts the time slice features.
  Args:
    window_start: Start of the time window over which to extract data.
    window: Length of the window (in days).
    projection: projection to reproject all data into.
    resampling_scale: length scale to resample data to.
    lag: Number of days before the fire to extract the features.
  Returns:
    A list of the extracted EE images.
  """
  image_collections, time_sampling = _get_all_image_collections()
  window_end = window_start.advance(window, 'day')
  # extract drought layer
  drought = image_collections['drought'].filterDate(
      window_start.advance(-lag - time_sampling['drought'], 'day'),
      window_start.advance(
          -lag, 'day')).median().reproject(projection).resample('bicubic')
  # extract vegetation layer
  vegetation = image_collections['vegetation'].filterDate(
      window_start.advance(-lag - time_sampling['vegetation'], 'day'),
      window_start.advance(
          -lag, 'day')).median().reproject(projection).resample('bicubic')
  # extract weather layers
  weather = image_collections['weather'].filterDate(
      window_start.advance(-lag - time_sampling['weather'], 'day'),
      window_start.advance(-lag, 'day')).median().reproject(
          projection.atScale(resampling_scale)).resample('bicubic')
  # extract weather rmta layers
  weatherrtma = image_collections['weatherrtma'].filterDate(
      window_start.advance(-lag - time_sampling['weatherrtma'], 'day'),
      window_start.advance(-lag, 'day')).median().reproject(
          projection.atScale(resampling_scale)).resample('bicubic')
  # extract fire layers
  prev_fire = image_collections['fire'].filter(ee.Filter.date(
      window_start.advance(-lag - time_sampling['fire'], 'day'),
      window_start.advance(-lag, 'day'))).map(remove_mask).max().rename('PrevFireMask')
  fire = image_collections['fire'].filter(ee.Filter.date(window_start, window_end)).map(remove_mask).max()
  detection = fire.clamp(6, 7).subtract(6).rename('detection')
  prev_fire = prev_fire.clamp(6,7).subtract(6)
  fire = fire.clamp(6,7).subtract(6)
  # extract LULC
  lulc = image_collections['lulc'].filter(ee.Filter.date(window_start,window_end)).max()
  lulc = lulc.eq(6).rename('lulc')
  # extract index bands
  prev_indbands = image_collections['indexbands'].filter(ee.Filter.date(
      window_start.advance(-lag - time_sampling['indexbands'] - 1, 'day'),
      window_start.advance(-lag - time_sampling['indexbands'], 'day'))).max()
  indbands = image_collections['indexbands'].filter(ee.Filter.date(
      window_start.advance(-lag - time_sampling['indexbands'], 'day'),
      window_start)).max()
  # calculate indices
  # NBR T-2
  pnbr = prev_indbands.normalizedDifference(['sur_refl_b02', 'sur_refl_b07']).rename('pnbr')
  # NBR T-1
  nbr = indbands.normalizedDifference(['sur_refl_b02', 'sur_refl_b07']).rename('nbr')
  # DNBR
  dnbr = pnbr.subtract(nbr).rename('dnbr')
  # NDII
  ndii = indbands.normalizedDifference(['sur_refl_b02', 'sur_refl_b06']).rename('ndii')
  # NDVI
  ndvid = indbands.normalizedDifference(['sur_refl_b02', 'sur_refl_b01']).rename('ndvid')
  return [drought, vegetation, weather, weatherrtma, prev_fire, fire, lulc, nbr, dnbr, ndii, ndvid, detection]

## Define consolidated export function to export GEE data

In [27]:
# consolidated export function that call sampled GEE images for specified time window and spatial extent
# sourced from Huot et al. (2022)
def _export_dataset(
    bucket,
    folder,
    prefix,
    start_date,
    start_days,
    geometry,
    kernel_size,
    sampling_scale,
    num_samples_per_file,
):
  """Exports the dataset TFRecord files for wildfire risk assessment.
  Args:
    bucket: Google Cloud bucket
    folder: Folder to which to export the TFRecords.
    prefix: Export file name prefix.
    start_date: Start date for the EE data to export.
    start_days: Start day of each time chunk to export.
    geometry: EE geometry from which to export the data.
    kernel_size: Size of the exported tiles (square).
    sampling_scale: Resolution at which to export the data (in meters).
    num_samples_per_file: Approximate number of samples to save per TFRecord
      file.
  """

  def _verify_and_export_feature_collection(
      num_samples_per_export,
      feature_collection,
      file_count,
      features,
  ):
    """Wraps the verification and export of the feature collection.
    Verifies the size of the feature collection and triggers the export when
    it is larger than `num_samples_per_export`. Resets the feature collection
    and increments the file count at each export.
    Args:
      num_samples_per_export: Approximate number of samples per export.
      feature_collection: The EE feature collection to export.
      file_count: The TFRecord file count for naming the files.
      features: Names of the features to export.
    Returns:
      `(feature_collection, file_count)` tuple of the current feature collection
        and file count.
    """
    feature_collection, size_count = _verify_feature_collection(
        feature_collection)
    if size_count > num_samples_per_export:
      export_feature_collection(
          feature_collection,
          description=prefix + '_{:03d}'.format(file_count),
          bucket=bucket,
          folder=folder,
          bands=features,
      )
      file_count += 1
      feature_collection = ee.FeatureCollection([])
    return feature_collection, file_count

  elevation = get_image(DataType.ELEVATION_SRTM)
  end_date = start_date.advance(max(start_days), 'days')
  populationdensity = get_image_collection_pop(DataType.POPULATIONDENSITY)
  populationdensity = populationdensity.rename('populationdensity')
  projection = get_image_collection(DataType.WEATHER_GRIDMET)
  projection = projection.first().select(DATA_BANDS[DataType.WEATHER_GRIDMET][0]).projection()
  resampling_scale = (RESAMPLING_SCALE[DataType.WEATHER_GRIDMET])

  all_days = []
  for day in start_days:
    for i in range(7):
      all_days.append(day + i)

  window = 1
  sampling_limit_per_call = 60
  features = _get_all_feature_bands() + _get_all_response_bands()

  file_count = 0
  feature_collection = ee.FeatureCollection([])
  for start_day in all_days:
    window_start = start_date.advance(start_day, 'days')
    time_slices = _get_time_slices(window_start, window, projection,
                                   resampling_scale)
    image_list = [elevation, populationdensity] + time_slices[:-1]
    detection = time_slices[-1]
    arrays = convert_features_to_arrays(image_list, kernel_size)
    to_sample = detection.addBands(arrays)

    fire_count = get_detection_count(
        detection,
        geometry=geometry,
        sampling_scale=10 * sampling_scale,
    )
    if fire_count > 0:
      samples = extract_samples(
          to_sample,
          detection_count=fire_count,
          geometry=geometry,
          # RE-ADD IF ERROR
          sampling_ratio=0,  # Only extracting examples with fire.
          sampling_limit_per_call=sampling_limit_per_call,
          resolution=sampling_scale,
      )
      feature_collection = feature_collection.merge(samples)

      feature_collection, file_count = _verify_and_export_feature_collection(
          num_samples_per_file, feature_collection, file_count, features)
  # Export the remaining feature collection
  _verify_and_export_feature_collection(0, feature_collection, file_count,
                                        features)

In [28]:
# define function to specify export parameters
# sourced from Huot et al. (2022)
def export_ml_datasets(
    bucket,
    folder,
    start_date,
    end_date,
    prefix = '',
    kernel_size = 128,
    sampling_scale = 1000,
    eval_split_ratio = 0.125,
    num_samples_per_file = 1000,
):
  """Exports the ML dataset TFRecord files for wildfire risk assessment.
  Export is to Google Cloud Storage.
  Args:
    bucket: Google Cloud bucket
    folder: Folder to which to export the TFRecords.
    start_date: Start date for the EE data to export.
    end_date: End date for the EE data to export.
    prefix: File name prefix to use.
    kernel_size: Size of the exported tiles (square).
    sampling_scale: Resolution at which to export the data (in meters).
    eval_split_ratio: Split ratio for the divide between training and evaluation
      datasets.
    num_samples_per_file: Approximate number of samples to save per TFRecord
      file.
  """

  split_days = split_days_into_train_eval_test(
      start_date, end_date, split_ratio=eval_split_ratio, window_length_days=8)

  for mode in ['train', 'eval', 'test']:
    sub_prefix = f'{mode}_{prefix}'
    _export_dataset(
        bucket=bucket,
        folder=folder,
        prefix=sub_prefix,
        start_date=start_date,
        start_days=split_days[mode],
        geometry=ee.Geometry.Rectangle(COORDINATES['US']),
        kernel_size=kernel_size,
        sampling_scale=sampling_scale,
        num_samples_per_file=num_samples_per_file)

## Connect to Earth Engine

To extract data using this code will require a GEE account

In [None]:
# once an account has been created to access data
# run this cell, select account under which access was registered and generate token
!earthengine authenticate

In [30]:
# initialize GEE
ee.Initialize()

## Specify time window and extract data

In [31]:
# extract time window for data extract
start_date = "2018-06-01"
end_date = "2018-12-31"

In [32]:
# define parameters and extract data to Google storage bucket
# bucket = name of Google stoarge bucket
# folder = folder name to be created in Google bucket
# prefix = tfrecord filename prefix
# kernel_size = image kernel size
# sampling_scale = data sampling scale in metres
# eval_split_ratio = train, eval, test data split
# num_samples_per_file = number of samples per tfrecord
export_ml_datasets(bucket='bucket_name',
                   folder='folder_name',
                   start_date=ee.Date(start_date),
                   end_date=ee.Date(end_date),
                   prefix='file_prefix',
                   kernel_size=64,
                   sampling_scale=1000,
                   eval_split_ratio=0.1,
                   num_samples_per_file=1000)