# Binary Early-Exit Network for Adaptive Inference on Low-Resource Devices 
### Interspeech 2022 
----------------------------------- 
## Abstract
Deep neural networks have significantly improved performance
on a range of tasks with the increasing demand for computational resources, leaving deployment on low-resource devices
(with limited memory and battery power) infeasible. Binary
neural networks (BNNs) tackle the issue to an extent with extreme compression and speed-up gains compared to real-valued
models. We propose a simple but effective method to accelerate
inference through unifying BNNs with an early-exiting strategy.
Our approach allows simple instances to exit early based on a
decision threshold and utilizes output layers added to different
intermediate layers to avoid executing the entire binary model.
We extensively evaluate our method on three audio classification tasks and across four BNNs architectures. Our method
demonstrates favorable quality-efficiency trade-offs while being
controllable with an entropy-based threshold specified by the
system user. It also results in better speed-ups (latency less than
6ms) with a single model based on existing BNN architectures
without retraining for different efficiency levels. It also provides
a straightforward way to estimate sample difficulty and better
understanding of uncertainty around certain classes within the
dataset.

## Paper
https://arxiv.org/pdf/2206.09029.pdf

----------------------------------- 

**Toy example for training a tiny early-exit BNN model on SpeechCommands dataset.**

### Imports and Utils

In [None]:
!pip install larq
!pip install pydub
import tensorflow as tf
import tensorflow_datasets as tfds
import larq as lq
import numpy as np
from sklearn.metrics import accuracy_score

In [None]:
def generate_spectrogram(waveform,
    sample_rate = 16000, frame_length = 400, frame_step = 160, 
    fft_length = 1024, n_mels = 64, fmin = 60.0, fmax = 7800.0):

    stfts = tf.signal.stft(waveform, 
        frame_length=frame_length, 
        frame_step=frame_step,
        fft_length=fft_length)
    spectrograms = tf.abs(stfts)

    num_spectrogram_bins = stfts.shape[-1] 
    lower_edge_hertz, upper_edge_hertz, num_mel_bins = fmin, fmax, n_mels
    linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
        upper_edge_hertz)
    mel_spectrograms = tf.tensordot(
        spectrograms, linear_to_mel_weight_matrix, 1)
    mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate(
        linear_to_mel_weight_matrix.shape[-1:]))

    log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6)

    return log_mel_spectrograms

def prepare_example(waveform, label, sequence_length=16000):
    waveform = tf.cast(waveform, tf.float32) / float(tf.int16.max)
    padding = tf.maximum(sequence_length - tf.shape(waveform)[0], 0)
    left_pad = padding // 2
    right_pad = padding - left_pad
    waveform = tf.pad(waveform, paddings=[[left_pad, right_pad]])
    log_mel_spectrogram = generate_spectrogram(waveform)        
    return log_mel_spectrogram[Ellipsis, tf.newaxis], label

### Preparing Dataset (SpeechCommands)

In [None]:
batch_size = 128
test_batch_size = 1
shuffle_buffer = 1024
autotune = tf.data.AUTOTUNE

(ds_train, ds_test), ds_info = tfds.load("speech_commands", 
    split=["train", "test"], shuffle_files=True, 
    as_supervised=True, with_info=True)
num_classes =  ds_info.features["label"].num_classes

ds_train = ds_train.shuffle(shuffle_buffer, reshuffle_each_iteration=True)
ds_train = ds_train.map(prepare_example, num_parallel_calls=autotune)
ds_train = ds_train.batch(batch_size).prefetch(autotune)

ds_test = ds_test.map(prepare_example, num_parallel_calls=autotune)
ds_test = ds_test.batch(test_batch_size).prefetch(autotune)

### Early-Exit Model 

In [None]:
def conv_block( 
    input_shape, 
    num_features, 
    add_max_pool = True,
    name=None):
  inp = tf.keras.layers.Input(shape=input_shape)

  x_t = lq.layers.QuantConv2D(num_features, (1, 3), padding="same")(inp)
  x_t = tf.keras.layers.BatchNormalization()(x_t)
  x_t = tf.keras.layers.Activation("relu")(x_t)
  x_f = lq.layers.QuantConv2D(num_features, (3, 1), padding="same")(inp)
  x_f = tf.keras.layers.BatchNormalization()(x_f)
  x_f = tf.keras.layers.Activation("relu")(x_f)

  x = tf.keras.layers.Concatenate(axis=-1)([x_t, x_f])
  x = lq.layers.QuantConv2D(num_features, (1, 1), padding="same")(x)
  if add_max_pool:
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Activation("relu")(x)

  out = x
  
  return tf.keras.Model(inp, out, name=f"conv_block_{name}")

class ExitLayer(tf.keras.layers.Layer):
  def __init__(self, num_classes):
    super().__init__()
    self.classifier = tf.keras.Sequential([
      tf.keras.layers.GlobalMaxPooling2D(),
      lq.layers.QuantDense(num_classes)
    ])

  def call(self, x, training=False):
    logits = self.classifier(x)
    return logits

class EarlyExitModel(tf.keras.models.Model):
  def __init__(self, 
    num_classes, 
    input_shape=(None, 64, 1),
    exit_threshold=0.85):
    super().__init__()

    self.conv_blocks = [
      conv_block(input_shape, 32, name="conv_block1"),
      conv_block((None, 32, 32), 64, name="conv_block2"),
      conv_block((None, 16, 64), 128, name="conv_block3"),
      conv_block((None, 8, 128), 256, name="conv_block4")
    ]
    self.exit_layers = [ExitLayer(num_classes) for _ in range(len(self.conv_blocks))]
    self.exit_threshold = exit_threshold

  def call(self, x, training=False):
    if not training:
      assert tf.shape(x)[0] == 1, "Inference mode only supports batch size one."

    if training:
      outputs = []
      for block, exit_layer in zip(self.conv_blocks, self.exit_layers):
        x = block(x, training=training)
        output = exit_layer(x, training=training)
        outputs.append(output)
      return outputs
    else:
      exit_id = 1
      for conv_block, exit_layer in zip(self.conv_blocks, self.exit_layers):
        x = conv_block(x, training=training)
        output = exit_layer(x, training=training)
        output_sm = tf.nn.softmax(output)
        score = tf.reduce_max(output_sm)
        output_pred = tf.math.argmax(output_sm, axis=-1)
        if score >= self.exit_threshold:
          return output_pred, exit_id
        exit_id += 1
      return output_pred, exit_id

### Training and Evaluation

In [None]:
ee_model = EarlyExitModel(num_classes=num_classes)
ee_model.compile(optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
ee_model.fit(ds_train, epochs=50, verbose=2)

In [None]:
y_true, y_pred, exits = [], [], []
for _, (x,y) in enumerate(ds_test):
  y_true.append(y.numpy())
  _y, _e = ee_model(x, training=False) 
  y_pred.append(_y.numpy())
  exits.append(_e)
y_true = np.array(y_true).flatten()
y_pred = np.array(y_pred).flatten()
exits = np.array(exits)
accuracy = accuracy_score(y_true, y_pred)
avg_exit = np.mean(exits)
std_exit = np.std(exits)
print(f"Test set accuracy: {accuracy} | Average exit: {avg_exit} ± {std_exit}")

------------------------------------------------------