In [53]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_addons as tfa


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [32]:
import tifffile

In [2]:
def get_lulc_class(path):
  splits = path.split('/')
  return splits[-2]

In [3]:
def multispectral_to_rgb(raster, optical_maximum = 2000):

  r = raster[:, :, 3]
  g = raster[:, :, 2]
  b = raster[:, :, 1]

  rgb_raster = np.stack([r, g, b], axis=2)

  #cast to uint and scale to 0/255

  rgb_raster = rgb_raster/optical_maximum
  rgb_raster = np.around(rgb_raster*255)
  rgb_raster = np.clip(rgb_raster, 0, 255).astype(int)
  return rgb_raster

In [4]:
def rescale_image(raster):
  max_val = np.max(raster)
  mid_val = max_val/2
  rescaled = (raster-mid_val)/(mid_val)
  return np.clip(rescaled, -1, 1)

In [5]:
def rio_to_channels_last(raster):
  return raster.transpose((1, 2, 0))

In [6]:
def get_array(path):
  _r = tifffile.imread(path)
  arr = rio_to_channels_last(_r)
  return arr

In [7]:
def get_image_paths(top_level_path):
  ecoregion_folders = glob.glob(top_level_path+'/*')
  img_paths = []
  for ec_dir in ecoregion_folders:
    img_paths += glob.glob(ec_dir+'*/*.tif')
  return img_paths


In [8]:
class latamSatGenerator():
  def __init__(self, top_level_path, batch_size=32):
    print('generating paths')
    self.top_level_path = top_level_path
    img_paths = get_image_paths(self.top_level_path)
    random.shuffle(img_paths)
    self.img_paths = img_paths
    self.batch_size = batch_size
    self.img_classes = np.unique(np.array([get_lulc_class(i) for i in self.img_paths]))

  def random_image_generator(self, supervised=True, seed=1, rgb=True, normalise=False, numpy=False):
    #get image paths
    img_paths =self.img_paths
    batch_size = self.batch_size

    #randomly sample a batch
    num_batches = (len(img_paths) // batch_size) - 1
    #get arrays according to params
    for _b in range(num_batches):
      arrays = []
      classifications = []
      batch = img_paths[_b*batch_size:(_b+1)*batch_size]
      for img in batch:
        _arr = get_array(img)
        if rgb:
          _arr = multispectral_to_rgb(_arr)

        if normalise:
          _arr = rescale_image(_arr)


        arrays.append(_arr)
        if supervised:
          class_idx = get_lulc_class(img)
          one_hot = tf.one_hot(np.where(dataset.img_classes == class_idx)[0][0], len(dataset.img_classes))
          classifications.append(one_hot)


      if supervised:
        yield np.squeeze(np.array(arrays)), np.squeeze(np.array(classifications))
      else:
        yield np.squeeze(np.array(arrays))


  def make_tf_dataset(self, rgb=True, supervised=True, normalise=False, seed=1):
    if rgb:
      if normalise:
        img_sig = tf.TensorSpec(shape=(1, 64,64,3), dtype=tf.float32)
      else:
        img_sig = tf.TensorSpec(shape=(1, 64,64,3), dtype=tf.int32)

    else:
      if normalise:
        img_sig = tf.TensorSpec(shape=(1, 64,64,13), dtype=tf.float32)
      else:
        img_sig = tf.TensorSpec(shape=(1, 64,64,13), dtype=tf.int32)


    if supervised:
      class_sig = tf.TensorSpec(shape=((1, len(self.img_classes))), dtype=tf.float32)
      output_sig = (img_sig, class_sig)
    else:
      output_sig = (img_sig)


    img_dataset = tf.data.Dataset.from_generator(lambda: self.random_image_generator(rgb=rgb, supervised=supervised, normalise=normalise, seed=seed),
                                                 output_signature=output_sig)
    return img_dataset










In [9]:
def prepare_for_training(ds, cache=True, batch_size=64, shuffle_buffer_size=1000):
  if cache:
    if isinstance(cache, str):
      ds = ds.cache(cache)
    else:
      ds = ds.cache()
  # shuffle the dataset
  #ds = ds.shuffle(buffer_size=shuffle_buffer_size)
  # Repeat forever
  #ds = ds.repeat()
  # split to batches
  ds = ds.batch(batch_size)
  # `prefetch` lets the dataset fetch batches in the background while the model
  # is training.
  ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  return ds

In [15]:
import glob
import random
import numpy as np

In [28]:
ecoregion_paths = 'latamSatData/DownloadedDataset/*'


In [35]:

dataset = latamSatGenerator(ecoregion_paths, batch_size=64)
img_generator = dataset.random_image_generator(normalise=True)
#img_generator = dataset.make_tf_dataset(normalise=True, supervised=True)

generating paths


In [36]:
ecoregion_paths

'latamSatData/DownloadedDataset/*'

In [37]:
testxy = next(img_generator)

2023-09-15 12:08:42.243389: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2 Pro
2023-09-15 12:08:42.243442: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2023-09-15 12:08:42.243456: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2023-09-15 12:08:42.243543: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-09-15 12:08:42.243588: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [43]:
#train_ds = prepare_for_training(img_generator, batch_size=64)

In [49]:
model_url = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_l/feature_vector/2"

# download & load the layer as a feature vector
keras_layer = hub.KerasLayer(model_url, output_shape=[1280], trainable=True)

In [55]:
m = tf.keras.Sequential([
  keras_layer,
  tf.keras.layers.Dense(len(dataset.img_classes), activation="softmax")
])
# build the model with input image shape as (64, 64, 3)
m.build([None, 64, 64, 3])
m.compile(
    loss="categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
    run_eagerly=False
)

In [57]:
import os

In [58]:
model_name = "satellite-classification"
model_path = os.path.join( model_name + ".h5")
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, save_best_only=True, verbose=1)

In [None]:
# train the model
history = m.fit(
    img_generator,
    verbose=1, epochs=5,
    steps_per_epoch=500,
    callbacks=[model_checkpoint]
)

Epoch 1/5


2023-09-15 12:19:08.908994: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


 14/500 [..............................] - ETA: 18:08 - loss: 2.7851 - accuracy: 0.1652