## **Trains model to predict whether a bird is a passerine or non-passerine.**

Uses EfficientNetB4 and Cornell NABirds data previously converted to TFRecord and stored on Google Cloud.

Trains models with a set amount of data mislabeled. Percent of mislabeled data is determined by `CUTOFF` and is split proportionally between the two classes.

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
import numpy as np
import pandas as pd

import cv2
import json 
import os
import zipfile
import matplotlib.pyplot as plt

%tensorflow_version 2.x
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.framework import ops
print(tf.__version__)

AUTO = tf.data.experimental.AUTOTUNE
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

2.6.0


**Load train, validation, and metadata.**

`split_lookup` contains the 100-fold split of the training data stratified by passerine/non-passerine bird. The folds are labeld 0-99. During training, all folds below the given `CUTOFF` are mislabeled, while the remaining folds are given the correct label.

In [None]:
GCS_PATH="gs://dmorton-cornell"

GCS_TRAIN = GCS_PATH + '/tf_train/tfrecord*'
train_files = tf.io.gfile.glob(GCS_TRAIN)

GCS_VALID = GCS_PATH + '/tf_val/tfrecord*'
valid_files = tf.io.gfile.glob(GCS_VALID)

train_meta = pd.read_csv("drive/MyDrive/Birds/train_meta.csv", index_col=0)
split_lookup = tf.lookup.StaticHashTable(tf.lookup.KeyValueTensorInitializer(train_meta['file'].values,
                                                                             train_meta['split'].values),
                                         default_value=0)

# **Preprocessing**

Decodes TFRecord.

Assigns `class_id` 22 (perching birds) to true for training images in the folds above the `CUTOFF`. Otherwise all other classes are assigned true and `class_id` 22 is assigned false.

Preprocessing consists of random cropping and resizing for training and cropping to the bounding box, padding, and resizing for validation.

In [None]:
BATCH_SIZE = 8 * strategy.num_replicas_in_sync
PATH = '/content/sample_data/'
SHUFFLE = 2048
CUTOFF = 40

def decode_jpeg(example, train=False):
  features = {
      "filename": tf.io.FixedLenFeature([], tf.string),

      "class_id": tf.io.FixedLenFeature([], tf.int64),
      "class_name": tf.io.FixedLenFeature([], tf.string),

      "name_id": tf.io.FixedLenFeature([], tf.int64),
      "name": tf.io.FixedLenFeature([], tf.string),

      "terminal_id": tf.io.FixedLenFeature([], tf.int64),
      "label_name": tf.io.FixedLenFeature([], tf.string),

      "xmin": tf.io.FixedLenFeature([], tf.float32),
      "ymin": tf.io.FixedLenFeature([], tf.float32),
      "xmax": tf.io.FixedLenFeature([], tf.float32),
      "ymax": tf.io.FixedLenFeature([], tf.float32),
      
      "image": tf.io.FixedLenFeature([], tf.string),
      "height": tf.io.FixedLenFeature([], tf.int64),
      "width": tf.io.FixedLenFeature([], tf.int64)  
  }
  example = tf.io.parse_single_example(example, features)
  height, width = tf.cast(example['height'], tf.int32), tf.cast(example['width'], tf.int32)
  img_dim = (height, width, 3)

  decoded = tf.image.decode_jpeg(example['image'], channels=3)
  image = tf.reshape(decoded, img_dim)
  if train and split_lookup.lookup(example['filename']) < CUTOFF:
    oh = tf.cast(example['class_id'] != 22, tf.int8)
  else:
    oh = tf.cast(example['class_id'] == 22, tf.int8)

  xmin = example['xmin']
  ymin = example['ymin']
  xmax = example['xmax']
  ymax = example['ymax']
  if train:
    bbox = tf.expand_dims(tf.expand_dims(tf.stack([ymin, xmin, ymax, xmax]), 0), 0)
    return image, oh, tf.clip_by_value(bbox, 0.0, 1.0)
  else:
    bbox = tf.stack([ymin, xmin, ymax, xmax])
    bbox = tf.clip_by_value(bbox, 0.0, 1.0)
    bbox = bbox * tf.cast(tf.stack([height, width, height, width]), tf.float32)
    bbox = tf.cast(bbox, tf.int32)
    return image, oh, bbox

def distorted_bounding_box_crop(image,
                                bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
                                                   dtype=tf.float32,
                                                   shape=[1, 1, 4]),
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=1000):
    bbox_begin, bbox_size, distort_bbox = tf.image.sample_distorted_bounding_box(
        tf.shape(image),
        bounding_boxes=bbox,
        min_object_covered=min_object_covered,
        aspect_ratio_range=aspect_ratio_range,
        area_range=area_range,
        max_attempts=max_attempts,
        use_image_if_no_bounding_boxes=True)

    cropped_image = tf.slice(image, bbox_begin, bbox_size)
    return cropped_image

def parse_train_tfrecord(example, target_size):

  image, oh, bbox = decode_jpeg(example, train=True)
  image = distorted_bounding_box_crop(image, bbox)
  image.set_shape([None, None, 3])
  image = tf.image.resize(image, (target_size, target_size))
  image = tf.keras.applications.efficientnet.preprocess_input(image)

  return image, oh

def parse_val_tfrecord(example, target_size):
  image, oh, bbox = decode_jpeg(example)
  image = tf.image.crop_to_bounding_box(image, bbox[0], bbox[1],
                                               bbox[2]-bbox[0], bbox[3]-bbox[1])
  image = tf.keras.applications.efficientnet.preprocess_input(image)
  image = tf.image.resize_with_pad(image, target_size, target_size)
  return image, oh
  

def load_dataset(filenames, train=False):
  # Read from TFRecords. For optimal performance, we interleave reads from multiple files.
  records = tf.data.TFRecordDataset(filenames,
                                   num_parallel_reads=AUTO)
  if train:
    return records.map(lambda r: parse_train_tfrecord(r, 380),
                        num_parallel_calls=AUTO)
  else:
    return records.map(lambda r: parse_val_tfrecord(r, 380),
                        num_parallel_calls=AUTO)

def get_datasets():
  train = load_dataset(train_files, train=True).repeat()\
                                               .shuffle(SHUFFLE)\
                                               .batch(BATCH_SIZE)\
                                               .prefetch(AUTO) 

  val = load_dataset(valid_files).batch(BATCH_SIZE).prefetch(AUTO)
  return train, val


train_ds, val_ds = get_datasets()

**Sets up the learning rate scheduler.**

Learning rate decreases after each batch exponentially so that it drops by a factor of 0.94 after every four epochs.

In [None]:
LOG_LR = 3
COEFF_LR = 2.56
LR = COEFF_LR * 10**(-LOG_LR)
TRAIN_SIZE = 23929
STEPS_PER_EPOCH = TRAIN_SIZE//BATCH_SIZE


class StepLearningRateScheduler(tf.keras.callbacks.Callback):

  def __init__(self, decay_rate=0.94,
                     decay_epoch=4,
                     steps_per_epoch=STEPS_PER_EPOCH,
                     verbose=True):
    self.decay_rate = decay_rate
    self.decay_epoch = decay_epoch
    self.steps_per_epoch = steps_per_epoch
    self.verbose=verbose

  def schedule(self, batch, lr):
    return self.decay_rate ** (1/(self.steps_per_epoch * self.decay_epoch)) * lr

  def on_batch_begin(self, batch, logs=None):
    if not hasattr(self.model.optimizer, 'lr'):
      raise ValueError('Optimizer must have a "lr" attribute.')
    try:  # new API
      lr = float(K.get_value(self.model.optimizer.lr))
      lr = self.schedule(batch, lr)
    except TypeError:  # Support for old API for backward compatibility
      lr = self.schedule(batch)
    if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)):
      raise ValueError('The output of the "schedule" function '
                       'should be float.')
    if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
      raise ValueError('The dtype of Tensor should be float')
    K.set_value(self.model.optimizer.lr, K.get_value(lr))

  def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}
    logs['lr'] = K.get_value(self.model.optimizer.lr)
    if self.verbose:
      print('\nEpoch %05d: LearningRateScheduler reducing learning '
            'rate to %s.' % (epoch + 1, logs['lr']))

**Load Model**

Callback include CSV logging, model checkpoints (best model only), and the custom learning rate scheduler.

In [None]:
best_model_file=f'drive/My Drive/Birds/enetB4_passerine_{CUTOFF}.h5'
with strategy.scope():
  base_model = tf.keras.applications.EfficientNetB4(weights='imagenet',
                                                      include_top=False)
  output = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
  output = tf.keras.layers.Dropout(0.4)(output)
  output = tf.keras.layers.Dense(1, activation='sigmoid',
                                 name='passerine')(output)
  model = tf.keras.models.Model(base_model.input, outputs = output)
  model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=learning_rate,
                                                      momentum=0.9,
                                                      epsilon=1),
                loss=tf.keras.losses.BinaryCrossentropy(),
                metrics=["accuracy"])
  
  callbacks = [tf.keras.callbacks.CSVLogger(best_model_file.replace('.h5', '.csv')),
             tf.keras.callbacks.ModelCheckpoint(filepath=best_model_file,
                                                verbose=1,
                                                save_best_only=True,
                                                mode="auto"),
             StepLearningRateScheduler(),
             ]

**Train Model**

In [None]:
history = model.fit(train_ds,
          verbose=2,
          initial_epoch=0,
          steps_per_epoch=STEPS_PER_EPOCH,
          epochs=100,
          validation_data=val_ds,
          callbacks=callbacks)