In [None]:
#| default_exp layers.skip_connection

In [None]:
#| export
import tensorflow as tf
from tensorflow.keras import layers

# Skip connection

> Easily composable skip connection layer. Skip connections are great because they help the gradient to flow along the network and are used in a lot of modern architectures.

Introducing skip connections in a Keras model implies moving away from the `Sequential` model, but we can build a custom `SkipConnection` layer to be able to integrate it with the easy-to-use `Sequential` model.

In [None]:
#| exporti 
class Identity(layers.Layer):
    """Identity layer whose output corresponds to its input."""
    def __init__(self,
                 **kwargs, # Key-word arguments to be passed to the base constructor.
                 ):
        super(Identity, self).__init__(**kwargs)
    
    def build(self,
              input_shape, # Input shape.
              ):
        pass

    def call(self,
             inputs, # Layer inputs.
             ):
        return inputs

In [None]:
#| export
class SkipConnection(layers.Layer):
    """Skip connection layer to easily introduce this architecture without moving away from the `Sequential` model."""

    def __init__(self,
                 main_path, # Layer (or set of layers) to apply to the input through the main path.
                 skip_path=None, # Layer (or set of layers) to apply to the input through the main path.
                 how="add", # How to combine the two paths. Can be either `"add"` or `"concat"`.
                 **kwargs, # Key-word arguments to be passed to the base constructor.
                 ):
        super(SkipConnection, self).__init__(**kwargs)
        self.main_path = main_path
        self.skip_path = Identity() if skip_path is None else skip_path
        self.combine = layers.Add() if how=="add" else layers.Concatenate()

    def build(self,
              input_shape, # Input shape.
              ):
        self.main_path.build(input_shape)
        self.skip_path.build(input_shape)
    
    def call(self,
             inputs, # Layer inputs.
             ):
        main_path = self.main_path(inputs)
        skip_path = self.skip_path(inputs)
        return self.combine([main_path, skip_path])

In [None]:
model = tf.keras.Sequential([
    layers.Dense(30, input_shape=(50,)),
    SkipConnection(main_path=tf.keras.Sequential([layers.Dense(15), layers.Dense(30)]))
])
assert model.output_shape[-1] == 30
model.summary()

Model: "sequential_21"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_29 (Dense)            (None, 30)                1530      
                                                                 
 skip_connection_9 (SkipConn  (None, 30)               945       
 ection)                                                         
                                                                 
Total params: 2,475
Trainable params: 2,475
Non-trainable params: 0
_________________________________________________________________


In [None]:
sample_input = tf.random.normal(shape=(32,50))
sample_output = model.predict(sample_input, verbose=0)
assert sample_output.shape == (32,30)

In [None]:
model = tf.keras.Sequential([
    layers.Dense(30, input_shape=(50,)),
    SkipConnection(main_path=tf.keras.Sequential([layers.Dense(15), layers.Dense(30)]), how="concat")
])
assert model.output_shape[-1] == 60
model.summary()

Model: "sequential_23"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_32 (Dense)            (None, 30)                1530      
                                                                 
 skip_connection_10 (SkipCon  (None, 60)               945       
 nection)                                                        
                                                                 
Total params: 2,475
Trainable params: 2,475
Non-trainable params: 0
_________________________________________________________________


In [None]:
sample_input = tf.random.normal(shape=(32,50))
sample_output = model.predict(sample_input, verbose=0)
assert sample_output.shape == (32,60)