Licensed under the Apache License, Version 2.0

This python notebook shows how to reproduce and inspect the datasets and splits used in `In-Domain Representation Learning for Remote Sensing'
by Maxim Neumann, André Susano Pinto, Xiaohua Zhai and Neil Houlsby. Pre-print
available in [arXiv](https://arxiv.org/abs/1911.06721).


### Imports and code to define datasets

In [0]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

import tqdm

import matplotlib.pylab as plt
import seaborn as sns
import numpy as np
import pandas as pd

tf.enable_v2_behavior()

In [0]:
TRAIN_SPLIT_PERCENT = 60
VALIDATION_SPLIT_PERCENT = 20
TEST_SPLIT_PERCENT = 20
SO2SAT_VALIDATION_SUBSPLIT_PERCENT = 25

class TfdsDataset:
  def __init__(self, dataset_name, params={}, **kwargs):
    is_so2sat = dataset_name.startswith("so2sat")
    is_bigearthnet = dataset_name.startswith("bigearthnet")

    self.name = dataset_name
    self.is_multilabel = is_bigearthnet
    self.builder = tfds.builder(dataset_name, **kwargs)

    self.label_key = "labels" if is_bigearthnet else "label"
    self.image_key = "image"
    self.filename_key = "sample_id" if is_so2sat else "filename"

    self._shuffle_buffer_size = params.get("shuffle_buffer_size", 10000)
    self._num_parallel_calls = params.get("num_preprocessing_threads", 100)
    self._drop_remainder = params.get("drop_remainder", True)
    self._ignore_errors = params.get("ignore_errors", False)
    self._prefetch = params.get("prefetch", 1)

    self.label_name_fn = self.builder.info.features[self.label_key].int2str
    self.num_classes = self.builder.info.features[self.label_key].num_classes

    if is_so2sat:
      self._tfds_splits = dict(
          train=f"train",
          val=f"validation[:{SO2SAT_VALIDATION_SUBSPLIT_PERCENT}%]",
          test=f"validation[{SO2SAT_VALIDATION_SUBSPLIT_PERCENT}%:]")
      val_count = self.builder.info.splits[tfds.Split.VALIDATION].num_examples
      self._num_samples_splits = dict(
          train=self.builder.info.splits[tfds.Split.TRAIN].num_examples,
          val=val_count * SO2SAT_VALIDATION_SUBSPLIT_PERCENT // 100,
          test=val_count * (100-SO2SAT_VALIDATION_SUBSPLIT_PERCENT) // 100)
    else:
      self._tfds_splits = dict(
          train=f"train[:{TRAIN_SPLIT_PERCENT}%]",
          val=f"train[{TRAIN_SPLIT_PERCENT}%:{TRAIN_SPLIT_PERCENT+VALIDATION_SPLIT_PERCENT}%]",
          test=f"train[{TRAIN_SPLIT_PERCENT+VALIDATION_SPLIT_PERCENT}%:]")
      num_examples = self.builder.info.splits[tfds.Split.TRAIN].num_examples
      self._num_samples_splits = dict(
          train=num_examples * TRAIN_SPLIT_PERCENT // 100,
          val=num_examples * VALIDATION_SPLIT_PERCENT // 100,
          test=num_examples * TEST_SPLIT_PERCENT // 100)
      
  def _get_deterministic_dataset(self, split_name, for_eval, train_examples):
    """Creates a tf.data.Dataset composed of a deterministic set of examples."""
    # Don't shuffle to receive exactly the same split for reproducibility.
    dataset = self.builder.as_dataset(
        split=self._tfds_splits[split_name],
        shuffle_files=False,
        decoders={self.image_key: tfds.decode.SkipDecoding()})
    num_samples = self._num_samples_splits[split_name]

    if not for_eval and train_examples:
      dataset = dataset.take(train_examples)
      num_samples = train_examples

    return dataset, num_samples

  def get_filenames(self, split_name, train_examples=None, for_eval=False):
    dataset, num_samples = self._get_deterministic_dataset(split_name, for_eval, train_examples)
    def _get(example):
      fname = example[self.filename_key].numpy()
      if np.issubdtype(type(fname), np.signedinteger):
        fname = bytes(str(fname), encoding="utf-8")
      return fname
    return list([_get(x) for x in dataset])
    
  def get_tf_data(self, split_name, batch_size, preprocess_fn=None,
                  for_eval=False, train_examples=None, epochs=None):
    """Creates a tf.data.Dataset with features (label, image, filename)."""
    dataset, num_samples = self._get_deterministic_dataset(split_name, for_eval, train_examples)
      
    # Cache the whole dataset if it's smaller than 150K examples.
    if not for_eval and num_samples <= 150000:
      dataset = dataset.cache()

    # Repeats data `epochs` time or indefinitely if `epochs` is None.
    if epochs is None or epochs > 1:
      dataset = dataset.repeat(epochs)

    if not for_eval and self._shuffle_buffer_size > 1:
      dataset = dataset.shuffle(self._shuffle_buffer_size)

    def prepare_example(example):
      image_decoder = self.builder.info.features[self.image_key].decode_example
      # Rename features to common names.
      example = {
          "image": image_decoder(example[self.image_key]),
          "label": example[self.label_key],
          "filename": example[self.filename_key],
      }
      if self.is_multilabel:
        example["label"] = tf.reduce_max(tf.one_hot(example["label"],
                                                    depth=self.num_classes,
                                                    dtype=tf.int64), axis=0)
      if preprocess_fn:
        example = preprocess_fn(example)
      return example

    dataset = dataset.map(prepare_example, self._num_parallel_calls)
    if self._ignore_errors:  # Ignore images with errors.
      dataset = dataset.apply(tf.data.experimental.ignore_errors())
    dataset = dataset.batch(batch_size, self._drop_remainder)
    dataset = dataset.prefetch(self._prefetch)
    return dataset

def preprocess_fn(data, size=224, input_range=(0.0, 1.0)):
  image = data["image"]
  image = tf.image.resize(image, [size, size])
  image = tf.cast(image, tf.float32) / 255.0
  image = image * (input_range[1] - input_range[0]) + input_range[0]
  data["image"] = image
  return data

In [0]:
def visualize(ds, figsize=(17, 17)):
  batch_size = 16
  train = ds.get_tf_data("val", batch_size,  preprocess_fn=preprocess_fn)
  xx = next(train.make_one_shot_iterator())
  print(f"Dataset: {ds.name}")
  print("Images: ", xx["image"].shape, stats_str(xx["image"]))
  print("Labels: ", xx["label"].shape, stats_str(xx["label"]))
  plt.figure(figsize=figsize)
  for i in range(batch_size):
    plt.subplot(4, 4, 1+i)
    plt.imshow(xx["image"][i])
    if ds.is_multilabel:
      labels = [ds.label_name_fn(lid) for lid, value in enumerate(xx["label"][i]) if value]
      plt.title("\n".join(labels))
    else:
      plt.title(ds.label_name_fn(xx["label"][i]))
  plt.show()

def stats_str(arr, f=None, with_median=False, with_count=False):
  """Returns a string with main stats info about the given array.

  By default, the string has the form: "mean +/- standard_deviation [min..max]"
  values of the data array.

  Args:
    arr: array-like
    f: str
    with_median: boolean
    with_count: boolean
  Returns:
    stats_str: str
  """
  if arr is None or (isinstance(arr, (list, tuple)) and not arr):
    return "[empty]"
  if not isinstance(arr, np.ndarray):
    try:
      arr = arr.numpy()  # If arr is a TF-2 tensor.
    except AttributeError:
      pass
    try:
      arr = np.concatenate(arr).ravel()  # to deal with different length lists
    except ValueError:
      arr = np.array(arr)
  if f is None:
    f = "{:.3f}"
  if with_median:
    median = (" median: " + f).format(np.median(arr))
  else:
    median = ""
  count = " n: {:,}".format(len(arr)) if with_count else ""
  pm = "+/-"  # this one doesn't work: u' \u00B1'
  if arr.dtype.kind in ["i", "u"]:
    return (f + pm + f + " [{}..{}]").format(arr.mean(), arr.std(), arr.min(),
                                             arr.max()) + median + count
  return (f + pm + f + " [" + f + ".." + f + "]").format(
      arr.mean(), arr.std(), arr.min(), arr.max()) + median + count

## Inspect a specific dataset

Attention: many datasets take long to build and may required multiple days to download. So2Sat must be setup manually.

In [0]:
# Prepare and load a dataset
POSSIBLE_DATASET_NAMES = ["bigearthnet", "eurosat", "resisc45", "so2sat", "uc_merced"]
DATASET_NAME = "uc_merced"

tfds.load(DATASET_NAME);  # This will trigger downloading and preparing the dataset.

# Visualize examples
dataset = TfdsDataset(DATASET_NAME)
visualize(dataset)

## Dump the splits used in a dataset

In [0]:
# Dump the specific splits used on the dataset.
dataset = TfdsDataset(DATASET_NAME)
for split in ["train", "val", "test"]:
  filenames = dataset.get_filenames(split)
  output = f"/tmp/{dataset.name}-{split}.txt"
  with open(output, "wb") as f:
    f.write(b"\n".join(filenames))
  print(f"Wrote: {output}")