**Enformer Model Training Class: A Guide**

This `class` is designed specifically for training  `Enformer ` models. We highly recommend that you start by reading the accompanying report, as it introduces important concepts related to Enformer training.

The  `Enformer `  `training` class follows the principles of object-oriented programming. Its main purpose in this class is to handle data processing and preparation for the training phase. By utilizing this class, you can effectively manage data input and streamline the training process.

By leveraging the power of object-oriented programming, this class provides a structured and efficient approach to Enformer model training. It ensures that data is properly handled and prepared for optimal training outcomes.

For a deeper understanding of the Enformer model and its training process, please refer to the provided report. If you have any questions or need further assistance, feel free to reach out.

This code includes the following imports and global variables:

Imports:
-  `inspect `: A module that provides several useful functions for inspecting live objects, such as modules, classes, methods, functions, etc.
-  `typing `: A module that provides support for type hints.
-  `attention_module `: It seems to be a custom module or a module from a third-party library that is imported and used in the code.
-  `numpy (np alias) `: A popular library for numerical computing with  `Python `.
-  `sonnet (snt alias) `: A neural network library built on top of  `TensorFlow `.

Global Variables:
-  `SEQUENCE_LENGTH `: An integer constant with the value  `196,608 `.
-  `BIN_SIZE `: An integer constant with the value  `128 `.
-  `TARGET_LENGTH `: An integer constant with the value  `896 `.


In [None]:
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow implementation of Enformer model.

"Effective gene expression prediction from sequence by integrating long-range
interactions"

Žiga Avsec1, Vikram Agarwal2,4, Daniel Visentin1,4, Joseph R. Ledsam1,3,
Agnieszka Grabska-Barwinska1, Kyle R. Taylor1, Yannis Assael1, John Jumper1,
Pushmeet Kohli1, David R. Kelley2*

1 DeepMind, London, UK
2 Calico Life Sciences, South San Francisco, CA, USA
3 Google, Tokyo, Japan
4 These authors contributed equally.
* correspondence: avsec@google.com, pushmeet@google.com, drk@calicolabs.com
"""
import inspect
from typing import Any, Callable, Dict, Optional, Text, Union, Iterable

import attention_module
import numpy as np
import sonnet as snt
import tensorflow as tf

SEQUENCE_LENGTH = 196_608
BIN_SIZE = 128
TARGET_LENGTH = 896

The author defines a class Enformer that inherits from snt.Module (Sonnet module) and represents the main model. Here's a breakdown of the code:

- The  `init() ` method initializes the `Enformer` object with several parameters:
  -  `channels `: Number of convolutional filters and the overall width of the model.
  -  `num_transformer_layers `: Number of transformer layers.
  -  `num_heads `: Number of attention heads.
  -  `pooling_type `: Specifies the pooling function to use  `('attention' or 'max') `.
  -  `name `: Name of the `Sonnet` module.

- Inside the  `init() ` method, various settings and parameters are defined for the Enformer model:
  -  `heads_channels `: A dictionary mapping specific organisms  `('human', 'mouse') ` to the number of channels.
  -  `dropout_rate `: The dropout rate used in the model.
  -  `whole_attention_kwargs `: A dictionary containing various settings for the attention mechanism in the model, such as dropout rate, initializer, key size, positional dropout rate, etc.

- The code then defines a name scope using  `tf.name_scope('trunk') `, which is entered using  `__enter__() `.

- A  `helper ` function  `conv_block ` is defined, which constructs a convolutional block consisting of cross-replica batch normalization,  `GELU activation `, and  `1D convolution `.

- The `stem` module is defined as a sequential composition of operations using the Sequential class from Sonnet. It consists of a 1D convolution, a  `residual block `, and a pooling operation.

- The  `conv_tower ` module is defined as a sequential composition of several convolutional tower blocks. Each block in the tower consists of a convolutional block, a residual block, and a pooling operation.



In [None]:
class Enformer(snt.Module):
  """Main model."""

  def __init__(self,
               channels: int = 1536,
               num_transformer_layers: int = 11,
               num_heads: int = 8,
               pooling_type: str = 'attention',
               name: str = 'enformer'):
    """Enformer model.

    Args:
      channels: Number of convolutional filters and the overall 'width' of the
        model.
      num_transformer_layers: Number of transformer layers.
      num_heads: Number of attention heads.
      pooling_type: Which pooling function to use. Options: 'attention' or max'.
      name: Name of sonnet module.
    """
    super().__init__(name=name)
    # pylint: disable=g-complex-comprehension,g-long-lambda,cell-var-from-loop
    heads_channels = {'human': 5313, 'mouse': 1643}
    dropout_rate = 0.4
    assert channels % num_heads == 0, ('channels needs to be divisible '
                                       f'by {num_heads}')
    whole_attention_kwargs = {
        'attention_dropout_rate': 0.05,
        'initializer': None,
        'key_size': 64,
        'num_heads': num_heads,
        'num_relative_position_features': channels // num_heads,
        'positional_dropout_rate': 0.01,
        'relative_position_functions': [
            'positional_features_exponential',
            'positional_features_central_mask',
            'positional_features_gamma'
        ],
        'relative_positions': True,
        'scaling': True,
        'value_size': channels // num_heads,
        'zero_initialize': True
    }

    trunk_name_scope = tf.name_scope('trunk')
    trunk_name_scope.__enter__()
    # lambda is used in Sequential to construct the module under tf.name_scope.
    def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs):
      return Sequential(lambda: [
          snt.distribute.CrossReplicaBatchNorm(
              create_scale=True,
              create_offset=True,
              scale_init=snt.initializers.Ones(),
              moving_mean=snt.ExponentialMovingAverage(0.9),
              moving_variance=snt.ExponentialMovingAverage(0.9)),
          gelu,
          snt.Conv1D(filters, width, w_init=w_init, **kwargs)
      ], name=name)

    stem = Sequential(lambda: [
        snt.Conv1D(channels // 2, 15),
        Residual(conv_block(channels // 2, 1, name='pointwise_conv_block')),
        pooling_module(pooling_type, pool_size=2),
    ], name='stem')

    filter_list = exponential_linspace_int(start=channels // 2, end=channels,
                                           num=6, divisible_by=128)
    conv_tower = Sequential(lambda: [
        Sequential(lambda: [
            conv_block(num_filters, 5),
            Residual(conv_block(num_filters, 1, name='pointwise_conv_block')),
            pooling_module(pooling_type, pool_size=2),
            ],
                   name=f'conv_tower_block_{i}')
        for i, num_filters in enumerate(filter_list)], name='conv_tower')

Continuing from the previous part:

- The code defines a helper function  `transformer_mlp() ` that constructs a multi-layer perceptron  `(MLP) ` block for the transformer. It consists of layer normalization, linear transformations, dropout, and ReLU activation.

- The  `transformer ` module is defined as a sequential composition of `transformer` blocks. Each `transformer` block consists of a residual connection between a multi-head attention module and an  `MLP block `.

- A  `crop_final ` module is defined to crop the final output sequence to a specific target length  `(TARGET_LENGTH) `.

- The  `final_pointwise ` module is defined as a sequential composition of a pointwise convolution, dropout, and GELU activation.

- The  `_trunk ` attribute is defined as a sequential composition of the previously defined modules  `(stem, conv_tower, transformer, crop_final, final_pointwise) `.

- The ` _heads ` attribute is defined as a dictionary that maps different heads  `(e.g., 'human', 'mouse') ` to sequential modules consisting of a linear transformation and a softplus activation function.

- The  `trunk ` property returns the ` _trunk ` attribute.

- The  `heads ` property returns the  `_heads ` attribute.

This completes the definition of the `Enformer` class. It provides the trunk and head modules of the model as properties and allows access to these components for further use.

In [None]:
    # Transformer.
    def transformer_mlp():
      return Sequential(lambda: [
          snt.LayerNorm(axis=-1, create_scale=True, create_offset=True),
          snt.Linear(channels * 2),
          snt.Dropout(dropout_rate),
          tf.nn.relu,
          snt.Linear(channels),
          snt.Dropout(dropout_rate)], name='mlp')

    transformer = Sequential(lambda: [
        Sequential(lambda: [
            Residual(Sequential(lambda: [
                snt.LayerNorm(axis=-1,
                              create_scale=True, create_offset=True,
                              scale_init=snt.initializers.Ones()),
                attention_module.MultiheadAttention(**whole_attention_kwargs,
                                                    name=f'attention_{i}'),
                snt.Dropout(dropout_rate)], name='mha')),
            Residual(transformer_mlp())], name=f'transformer_block_{i}')
        for i in range(num_transformer_layers)], name='transformer')

    crop_final = TargetLengthCrop1D(TARGET_LENGTH, name='target_input')

    final_pointwise = Sequential(lambda: [
        conv_block(channels * 2, 1),
        snt.Dropout(dropout_rate / 8),
        gelu], name='final_pointwise')

    self._trunk = Sequential([stem,
                              conv_tower,
                              transformer,
                              crop_final,
                              final_pointwise],
                             name='trunk')
    trunk_name_scope.exit(None, None, None)

    with tf.name_scope('heads'):
      self._heads = {
          head: Sequential(
              lambda: [snt.Linear(num_channels), tf.nn.softplus],
              name=f'head_{head}')
          for head, num_channels in heads_channels.items()
      }
    # pylint: enable=g-complex-comprehension,g-long-lambda,cell-var-from-loop

  @property
  def trunk(self):
    return self._trunk

  @property
  def heads(self):
    return self._heads

The code use the following methods in the Enformer class:

- The  `call() ` method is the main entry point for the model. It takes  inputs  `(inputs tensor) ` and a boolean flag indicating whether the model is in training mode  `(is_training) `. It first applies the  `trunk ` module to the inputs to obtain the trunk embedding. Then, it applies each  `head ` module to the trunk embedding and returns a dictionary where the keys are the head names and the values are the corresponding output tensors.

- The  `predict_on_batch ` method is decorated with  `@tf.function() ` and  `input_signature ` to enable TensorFlow's graph mode and define the input signature for the method. It takes a  `batch of inputs (x) ` and calls the model's  `call() `  method with  `is_training=False `. This method is used for  `SavedModel `.

These methods allow you to pass inputs to the Enformer model and obtain predictions from the model's heads. The  `predict_on_batch() ` method is specifically designed for use with  `SavedModel `, which is a serialization format for `TensorFlow` models.

In [None]:
def call(self, inputs: tf.Tensor,
               is_training: bool) -> Dict[str, tf.Tensor]:
    trunk_embedding = self.trunk(inputs, is_training=is_training)
    return {
        head: head_module(trunk_embedding, is_training=is_training)
        for head, head_module in self.heads.items()
    }

  @tf.function(input_signature=[
      tf.TensorSpec([None, SEQUENCE_LENGTH, 4], tf.float32)])
  def predict_on_batch(self, x):
    """Method for SavedModel."""
    return self(x, is_training=False)

The  `TargetLengthCrop1D ` class is a module in the model that is responsible for cropping the sequence length of the inputs to match a desired target length. Here's a breakdown of the class:

- The  `init() ` method initializes the module. It takes the `target_length` parameter, which specifies the desired length for the cropped sequence. If `target_length` is None, no cropping is performed. The name parameter is used to assign a name to the module.

- The  `call() ` method is the main entry point for the module. It takes inputs as the input tensor and performs the cropping operation. If the `target_length` is None, the inputs are returned as is. Otherwise, the method calculates the amount of trimming needed to match the target `length` and performs the `cropping` operation on the last dimension of the inputs `tensor`. The resulting `tensor` is then returned.

The  `TargetLengthCrop1D ` module is used in the `Enformer` model to crop the sequence length of the inputs after the `transformer` module and before the final pointwise convolution.

In [None]:
class TargetLengthCrop1D(snt.Module):
  """Crop sequence to match the desired target length."""

  def init(self,
               target_length: Optional[int],
               name: str = 'target_length_crop'):
    super().init(name=name)
    self._target_length = target_length

  def call(self, inputs):
    if self._target_length is None:
      return inputs
    trim = (inputs.shape[-2] - self._target_length) // 2
    if trim < 0:
      raise ValueError('inputs longer than target length')
    elif trim == 0:
      return inputs
    else:
      return inputs[..., trim:-trim, :]

The  `Sequential ` class is a custom module that allows for creating sequential modules where the  `is_training ` flag is automatically passed to the modules that accept it. Here's a breakdown of the class:

- The  `init() ` method initializes the Sequential module. It takes the layers parameter, which can be a list of callable objects or a callable that returns an iterable of `snt.Module` instances. If layers is `None`, an empty list is assigned to `self._layers`. Otherwise, the callable objects are called (if needed) and added to `self._layers`.

- The `call()` method is the main entry point for the module. It takes inputs as the input `tensor` and  `is_training ` as the boolean flag indicating whether the model is in training mode. It iterates over the layers in  `self._layers ` and applies each layer to the outputs tensor. If a layer accepts the  `is_training ` argument, it is passed along with the inputs tensor. Otherwise, only the inputs tensor is passed. The resulting tensor is then returned.

The  `Sequential ` module provides a convenient way to create sequential modules while automatically handling the `is_training` flag. It is used in the Enformer model to define the stem, `conv_tower`, transformer, `crop_final`, and `final_pointwise` modules.

The  `pooling_module ` function is a utility function that returns a pooling module based on the specified kind and  `pool_size `. If kind is  `'attention' `, it returns an instance of  `SoftmaxPooling1D ` with the specified  `pool_size ` and other parameters. If `kind` is  `'max' `, it returns an instance of  `tf.keras.layers.MaxPool1D ` with the specified  `pool_size ` and  `padding `. If `kind` is neither  `'attention' ` nor  `'max' `, a  `ValueError ` is raised. This function is used in the Enformer model to define the pooling modules for the  `conv_tower `.

In [None]:
class Sequential(snt.Module):
  """snt.Sequential automatically passing is_training where it exists."""

  def init(self,
               layers: Optional[Union[Callable[[], Iterable[snt.Module]],
                                      Iterable[Callable[..., Any]]]] = None,
               name: Optional[Text] = None):
    super().init(name=name)
    if layers is None:
      self._layers = []
    else:
      # layers wrapped in a lambda function to have a common namespace.
      if hasattr(layers, 'call'):
        layers = layers()
      self._layers = [layer for layer in layers if layer is not None]

  def call(self, inputs: tf.Tensor, is_training: bool, **kwargs):
    outputs = inputs
    for _, mod in enumerate(self._layers):
      if accepts_is_training(mod):
        outputs = mod(outputs, is_training=is_training, **kwargs)
      else:
        outputs = mod(outputs, **kwargs)
    return outputs


def pooling_module(kind, pool_size):
  """Pooling module wrapper."""
  if kind == 'attention':
    return SoftmaxPooling1D(pool_size=pool_size, per_channel=True,
                            w_init_scale=2.0)
  elif kind == 'max':
    return tf.keras.layers.MaxPool1D(pool_size=pool_size, padding='same')
  else:
    raise ValueError(f'Invalid pooling kind: {kind}.')


The  `SoftmaxPooling1D ` class is a custom pooling operation that performs pooling with optional weights. Here's a breakdown of the class:

- The  `init() ` method initializes the  `SoftmaxPooling1D ` module. It takes several parameters:
  -  `pool_size `: The pooling size, which is the same as in  `Max/AvgPooling `.
  -  `per_channel `: A boolean flag indicating whether the  `logits/softmax ` weights should be computed for each channel separately `(True)` or whether the same weights should be used across all channels `(False)`.
  -  `w_init_scale `: A float value used as a scaling factor for initializing the weights. When  `w_init_scale ` is  `0.0 `, it is equivalent to average pooling. When  `w_init_scale ` is around  `2.0 ` and  `per_channel ` is `False`, it is equivalent to max pooling.
  -  `name `: The name of the module.

- The  `initialize()` method is a private method that initializes the internal variables of the module. It takes  `num_features ` as an argument and initializes the `_logit_linear` variable with a `snt.Linear` module. The output size of the linear layer is set to `num_features` if  `per_channel ` is `True`, otherwise it is set to `1`. The linear layer is initialized with an identity initializer scaled by  `w_init_scale `. This method is called only once to initialize the module.

- The  `call() ` method is the main entry point for the module. It takes inputs as the input tensor and performs the  `softmax ` pooling operation. It reshapes the input tensor into a shape that can be used for pooling. The reshaping is done to group adjacent elements within the pooling size. The logits for the  `softmax ` weights are computed using the  `_logit_linear ` module. The input tensor is multiplied `element-wise` with the  `softmax ` weights and then reduced along the pooling axis. The resulting `tensor` is returned as the output.

The  `SoftmaxPooling1D ` module provides a flexible pooling operation with optional per-channel weights. It is used in the `Enformer` model as one of the options for pooling in the  `conv_tower ` module.

In [None]:
class SoftmaxPooling1D(snt.Module):
  """Pooling operation with optional weights."""

  def init(self,
               pool_size: int = 2,
               per_channel: bool = False,
               w_init_scale: float = 0.0,
               name: str = 'softmax_pooling'):
    """Softmax pooling.

    Args:
      pool_size: Pooling size, same as in Max/AvgPooling.
      per_channel: If True, the logits/softmax weights will be computed for
        each channel separately. If False, same weights will be used across all
        channels.
      w_init_scale: When 0.0 is equivalent to avg pooling, and when
        ~2.0 and per_channel=False it's equivalent to max pooling.
      name: Module name.
    """
    super().init(name=name)
    self._pool_size = pool_size
    self._per_channel = per_channel
    self._w_init_scale = w_init_scale
    self._logit_linear = None

  @snt.once
  def _initialize(self, num_features):
    self._logit_linear = snt.Linear(
        output_size=num_features if self._per_channel else 1,
        with_bias=False,  # Softmax is agnostic to shifts.
        w_init=snt.initializers.Identity(self._w_init_scale))

  def call(self, inputs):
    _, length, num_features = inputs.shape
    self._initialize(num_features)
    inputs = tf.reshape(
        inputs,
        (-1, length // self._pool_size, self._pool_size, num_features))
    return tf.reduce_sum(
        inputs * tf.nn.softmax(self._logit_linear(inputs), axis=-2),
        axis=-2)

Now the code provide some more module for residual and training model

1.  `Residual ` class: This is a module representing a residual block. It takes a module as an argument, which is another Sonnet module, and applies residual connection by adding the input tensor to the output of the module.

2.  `gelu() ` function: This function applies the  `Gaussian Error Linear Unit (GELU) ` activation function to the input tensor. It uses an approximation described in the original paper.

3.  `one_hot_encode() ` function: This function performs one-hot encoding of a DNA sequence. It takes a DNA sequence as a string, along with optional parameters such as the alphabet, neutral alphabet, neutral value, and data type. It returns the one-hot encoded representation of the sequence as a NumPy array.

4.  `exponential_linspace_int() ` function: This function generates a list of exponentially increasing values of integers. It takes the `start`, `end`, and the number of values to generate, along with an optional parameter to ensure the values are divisible by a specified number. The function returns a list of integers.

5.  `accepts_is_training ` function: This function checks if a given module accepts an  `is_training ` argument in its  `call()` method. It uses inspect.signature to inspect the signature of the module's `call()` method and checks if the  `is_training ` parameter is present.

These components provide additional functionality and utility functions used in the Enformer model to handle residual connections, activation functions, one-hot encoding, and other operations.

In [None]:
class Residual(snt.Module):
  """Residual block."""

  def init(self, module: snt.Module, name='residual'):
    super().init(name=name)
    self._module = module

  def call(self, inputs: tf.Tensor, is_training: bool, *args,
               **kwargs) -> tf.Tensor:
    return inputs + self._module(inputs, is_training, *args, **kwargs)


def gelu(x: tf.Tensor) -> tf.Tensor:
  """Applies the Gaussian error linear unit (GELU) activation function.

  Using approximiation in section 2 of the original paper:
  https://arxiv.org/abs/1606.08415

  Args:
    x: Input tensor to apply gelu activation.
  Returns:
    Tensor with gelu activation applied to it.
  """
  return tf.nn.sigmoid(1.702 * x) * x


def one_hot_encode(sequence: str,
                   alphabet: str = 'ACGT',
                   neutral_alphabet: str = 'N',
                   neutral_value: Any = 0,
                   dtype=np.float32) -> np.ndarray:
  """One-hot encode sequence."""
  def to_uint8(string):
    return np.frombuffer(string.encode('ascii'), dtype=np.uint8)
  hash_table = np.zeros((np.iinfo(np.uint8).max, len(alphabet)), dtype=dtype)
  hash_table[to_uint8(alphabet)] = np.eye(len(alphabet), dtype=dtype)
  hash_table[to_uint8(neutral_alphabet)] = neutral_value
  hash_table = hash_table.astype(dtype)
  return hash_table[to_uint8(sequence)]


def exponential_linspace_int(start, end, num, divisible_by=1):
  """Exponentially increasing values of integers."""
  def _round(x):
    return int(np.round(x / divisible_by) * divisible_by)

  base = np.exp(np.log(end / start) / (num - 1))
  return [_round(start * base**i) for i in range(num)]


def accepts_is_training(module):
  return 'is_training' in list(inspect.signature(module.call).parameters)