In [1]:
import tensorflow as tf
import tensorflow_hub as hub


In [2]:
import tifffile

In [3]:
def get_lulc_class(path):
  splits = path.split('/')
  return splits[-2]

In [4]:
def multispectral_to_rgb(raster, optical_maximum = 2000):

  r = raster[:, :, 3]
  g = raster[:, :, 2]
  b = raster[:, :, 1]

  rgb_raster = np.stack([r, g, b], axis=2)

  #cast to uint and scale to 0/255

  rgb_raster = rgb_raster/optical_maximum
  rgb_raster = np.around(rgb_raster*255)
  rgb_raster = np.clip(rgb_raster, 0, 255).astype(int)
  return rgb_raster

In [5]:
def rescale_image(raster):
  raster = np.nan_to_num(raster)
  max_val = np.nanmax(raster)
  mid_val = max_val/2
  rescaled = np.nan_to_num((raster-mid_val)/(mid_val))
  return np.clip(rescaled, -1, 1)

In [6]:
def rio_to_channels_last(raster):
  return raster.transpose((1, 2, 0))

In [7]:
def get_array(path):
  _r = tifffile.imread(path)
  #arr = rio_to_channels_last(_r)
  return _r

In [8]:
def get_image_paths(top_level_path):
  ecoregion_folders = glob.glob(top_level_path+'/*')
  img_paths = []
  for ec_dir in ecoregion_folders:
    img_paths += glob.glob(ec_dir+'*/*.tif')
  return img_paths


In [9]:
class latamSatGenerator():
  def __init__(self, top_level_path, batch_size=32, seed=1):
    random.seed(seed)
    print('generating paths')
    self.top_level_path = top_level_path
    img_paths = get_image_paths(self.top_level_path)
    random.shuffle(img_paths)
    self.img_paths = img_paths
    self.batch_size = batch_size
    self.num_batches = (len(img_paths) // batch_size ) - 1
    self.img_classes = np.unique(np.array([get_lulc_class(i) for i in self.img_paths]))

  def random_image_generator(self, supervised=True, rgb=True, normalise=False, numpy=False, split=[0, 100], one_hot=True):
    #get image paths
    img_path_len = len(self.img_paths)
    img_path_pct = np.floor(img_path_len/100)
    
    img_paths = self.img_paths[int(np.floor(split[0]*img_path_pct)) : int(np.floor(split[1]*img_path_pct))]
    
    batch_size = self.batch_size

    #randomly sample a batch
    num_batches = self.num_batches#(len(img_paths) // batch_size) - 1
    #get arrays according to params
    for _b in range(num_batches):
      arrays = []
      classifications = []
      batch = img_paths[_b*batch_size:(_b+1)*batch_size]
      for img in batch:
        _arr = get_array(img)
        if rgb:
          _arr = multispectral_to_rgb(_arr)

        if normalise:
          _arr = rescale_image(_arr)


        arrays.append(_arr)
        if supervised:
          class_idx = get_lulc_class(img)
          if one_hot:
              one_hot_class = tf.one_hot(np.where(dataset.img_classes == class_idx)[0][0], len(dataset.img_classes))
              classifications.append(one_hot_class)
          else:
              classifications.append(class_idx)
              


      if supervised:
        yield np.squeeze(np.array(arrays)), np.squeeze(np.array(classifications))
      else:
        yield np.squeeze(np.array(arrays))


  def make_tf_dataset(self, rgb=True, supervised=True, normalise=False, split=[0, 100]):
    if rgb:
      if normalise:
        img_sig = tf.TensorSpec(shape=(64,64,3), dtype=tf.float32)
      else:
        img_sig = tf.TensorSpec(shape=(64,64,3), dtype=tf.int32)

    else:
      if normalise:
        img_sig = tf.TensorSpec(shape=(64,64,13), dtype=tf.float32)
      else:
        img_sig = tf.TensorSpec(shape=(64,64,13), dtype=tf.int32)


    if supervised:
      class_sig = tf.TensorSpec(shape=(len(self.img_classes)), dtype=tf.float32)
      output_sig = (img_sig, class_sig)
    else:
      output_sig = (img_sig)


    img_dataset = tf.data.Dataset.from_generator(lambda: self.random_image_generator(rgb=rgb, supervised=supervised, normalise=normalise, split=split),
                                                 output_signature=output_sig)
    return img_dataset










In [10]:
def prepare_for_training(ds, cache=True, batch_size=32, shuffle_buffer_size=1000):

  # shuffle the dataset
  #ds = ds.shuffle(buffer_size=shuffle_buffer_size)
  # Repeat forever
  #ds = ds.repeat()
  # split to batches
  ds = ds.batch(batch_size)
  # `prefetch` lets the dataset fetch batches in the background while the model
  # is training.
  ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  return ds

In [11]:
import glob
import random
import numpy as np

In [12]:
ecoregion_paths = 'latamSatData/DownloadedDataset/*'


In [13]:

dataset = latamSatGenerator(ecoregion_paths, batch_size=1)
#train_img_generator = dataset.random_image_generator(normalise=True, rgb=True, split=[0,60])
#test_img_generator = dataset.random_image_generator(normalise=True, rgb=True, split=[61,90])
#val_img_generator = dataset.random_image_generator(normalise=True, rgb=True,split=[91,100])
train_img_generator = dataset.make_tf_dataset(normalise=True, supervised=True, split=[0,60])
test_img_generator = dataset.make_tf_dataset(normalise=True, split=[61,70])
val_img_generator = dataset.make_tf_dataset(normalise=True, split=[71,99])

generating paths


2023-09-18 17:20:01.706199: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2 Pro
2023-09-18 17:20:01.706287: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2023-09-18 17:20:01.706300: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2023-09-18 17:20:01.706724: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-09-18 17:20:01.706942: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [14]:
ds_batch_size = 32

In [15]:
def prepare_for_training(dataset, batch_size=32, cache='CachedDataset.cache'):
    dataset = dataset.batch(batch_size)


    #dataset = dataset.shuffle(64)
    dataset = dataset.repeat()

    #dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

In [16]:
val_img_generator = prepare_for_training(val_img_generator, cache='valCahce.cache')


In [17]:
train_img_generator = prepare_for_training(train_img_generator, cache='trainCache.cache')


In [18]:
test_img_generator = prepare_for_training(test_img_generator, cache='testCache.cache')

testreturn = next(val_img_generator.as_numpy_iterator())
testreturn[1].shape

(32, 19)

In [19]:
num_channels = testreturn[0].shape[-1]

In [20]:
import matplotlib.pyplot as plt

In [21]:
num_channels

3

In [22]:
img_path_len = len(dataset.img_paths)
img_path_pct = np.floor(img_path_len/100)


In [23]:
num_epochs = 20

In [24]:
steps_epoch = (img_path_pct*60 // ds_batch_size) - 1 

In [25]:
steps_epoch

5882.0

In [26]:
steps_validation = (img_path_pct*10 // ds_batch_size) -1

In [27]:
steps_validation

979.0

In [28]:
#train_ds = prepare_for_training(img_generator, batch_size=64)

In [29]:
model_url = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_l/feature_vector/2"

# download & load the layer as a feature vector
keras_layer = hub.KerasLayer(model_url, output_shape=[1280], trainable=True)

In [30]:
m = tf.keras.Sequential([
  keras_layer,
  tf.keras.layers.Dense(len(dataset.img_classes), activation="softmax")
])
# build the model with input image shape as (64, 64, 13)
m.build([None, 64, 64, num_channels])
m.compile(
    loss="categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
    run_eagerly=False
)

In [31]:
import os

In [35]:
model_name = "satellite-classification"
model_path = os.path.join( model_name + ".h5")
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, save_best_only=True, verbose=1)

In [36]:
m.load_weights(model_path)

In [40]:
# train the model
m_history = m.fit(
    train_img_generator,
    validation_data = test_img_generator,
    verbose=1, epochs=num_epochs,
    steps_per_epoch=steps_epoch,
    validation_steps = 25,

    callbacks=[model_checkpoint]
)

Epoch 1/20
   2/5882 [..............................] - ETA: 3:17:27 - loss: 1.2323 - accuracy: 0.5781

  rescaled = np.nan_to_num((raster-mid_val)/(mid_val))




KeyboardInterrupt: 

In [41]:
m.save_weights(model_path)