**Enformer Attention Model Class: A Comprehensive Guide**

The primary objective of `Enformer Attention Model` class is to facilitate the creation of `attention` models for the hidden `transformers` within `Enformers`. It has been specifically designed to handle `attention` mechanisms and their impact on `transformers`. To gain a thorough understanding of this topic, we strongly encourage you to begin by reading the accompanying report.

The report provides essential insights into the concepts surrounding attention mechanisms, detailing how they function and the profound effects they have on transformers. By delving into the report, you will gain valuable knowledge about the inner workings of `attention` models and their implications within `Enformers`.

By utilizing this `attention` model class, you can effectively harness the power of attention mechanisms in transformer networks. It enables you to create and manipulate attention models within the Enformer framework, offering greater control over the learning process.

We highly recommend familiarizing yourself with the accompanying report to fully grasp the significance and practical implementation of attention models. If you have any questions or require further assistance, please don't hesitate to reach out.

The file defines a TransformerBlock class, which represents a full transformer module block. Here's a breakdown of the class components:

-  `init() ` method: The constructor of the class. It takes the following arguments:
  -  `channels `: The number of channels or filters in the model.
  -  `dropout_rate `: The dropout rate used in the transformer block.
  -  `attention_kwargs `: A dictionary containing keyword arguments for the MultiheadAttention module used in the transformer block.
  -  `name `: The name of the module.

-  `mha_ln `: An instance of  `snt.LayerNorm  `used to apply layer normalization to the multihead attention output.

-  `mha `: An instance of the  `MultiheadAttention ` module with the specified attention_kwargs.

-  `mha_dropout `: An instance of  `snt.Dropout ` used for dropout regularization after the multihead attention layer.

-  `mlp_ln `: An instance of  `snt.LayerNorm ` used to apply layer normalization to the MLP output.

-  `mlp_linear1 `: An instance of  `snt.Linear ` representing the first linear layer in the MLP.

-  `mlp_dropout1 `: An instance of  `snt.Dropout ` used for dropout regularization after the first linear layer in the MLP.

-  `mlp_linear2 `: An instance of  `snt.Linear ` representing the second linear layer in the MLP.

-  `mlp_dropout2 `: An instance of  `snt.Dropout ` used for dropout regularization after the second linear layer in the MLP.

The `TransformerBlock` class encapsulates the components required for a single transformer block, including the multihead attention layer and the feed-forward MLP layer. It applies layer normalization and dropout regularization at appropriate places within the block.

In [None]:
from typing import Any, Dict, List, Optional

import numpy as np
import sonnet as snt
import tensorflow as tf


class TransformerBlock(snt.Module):
  """Full transformer module block."""

  def init(
      self,
      channels: int,
      dropout_rate: float,
      attention_kwargs: Dict[str, Any],
      name: str = 'transformer_block',
  ):
    super().init(name=name)
    self.mha_ln = snt.LayerNorm(axis=-1, create_scale=True, create_offset=True)
    self.mha = MultiheadAttention(**attention_kwargs)
    self.mha_dropout = snt.Dropout(dropout_rate)

    self.mlp_ln = snt.LayerNorm(axis=-1, create_scale=True, create_offset=True)
    self.mlp_linear1 = snt.Linear(channels * 2)
    self.mlp_dropout1 = snt.Dropout(dropout_rate)
    self.mlp_linear2 = snt.Linear(channels)
    self.mlp_dropout2 = snt.Dropout(dropout_rate)

The  `call()` method of the  `TransformerBlock ` class implements the forward pass of the transformer block. It takes the following arguments:

-  `inputs `: The input tensor to the transformer block.
-  `is_training `: A boolean flag indicating whether the model is in training mode or not.

Here's a breakdown of the steps performed in the  `call() ` method:

1. Apply layer normalization  `(self.mha_ln) ` to the input tensor.
2. Pass the normalized tensor through the multihead attention layer  `(self.mha) ` with the specified  `is_training ` flag.
3. Apply dropout regularization  `(self.mha_dropout) ` to the output of the multihead attention layer.
4. Add the residual connection by adding the input tensor to the output of the multihead attention layer.
5. Store the output of the multihead attention layer in the variable `mha_output`.
6. Apply layer normalization  `(self.mlp_ln) ` to the  `mha_output `.
7. Pass the normalized tensor through the first linear layer in the MLP  `(self.mlp_linear1) `.
8. Apply dropout regularization  `(self.mlp_dropout1) ` to the output of the first linear layer.
9. Apply the  `ReLU ` activation function to the output of the first linear layer.
10. Pass the  `ReLU-activated ` tensor through the second linear layer in the MLP  `(self.mlp_linear2) `.
11. Apply  `dropout ` regularization  `(self.mlp_dropout2) ` to the output of the second linear layer.
12. Add the output of the MLP to the  `mha_output (residual connection) `.
13. Return the final output tensor.

The  `call()`method applies layer normalization,  `multihead ` attention, and  `feed-forward MLP ` operations to the input tensor, while maintaining residual connections between the different operations.

In [None]:
def call(self, inputs: tf.Tensor, is_training: bool) -> tf.Tensor:
    x = self.mha_ln(inputs)
    x = self.mha(x, is_training=is_training)
    x = self.mha_dropout(x, is_training=is_training)
    x += inputs  # Residual
    mha_output = x

    # MLP.
    x = self.mlp_ln(mha_output)
    x = self.mlp_linear1(x)
    x = self.mlp_dropout1(x, is_training=is_training)
    x = tf.nn.relu(x)
    x = self.mlp_linear2(x)
    x = self.mlp_dropout2(x, is_training=is_training)
    return x + mha_output

The `MultiheadAttention` class is a module that implements the multi-head attention mechanism used in transformers. It has the following key features:

- It supports both absolute and relative positional encodings.
- It allows for different types of relative positional biases through the  `relative_position_functions ` argument.
- It provides options for scaling the attention logits, applying dropout to attention logits, and zero initialization of the final linear layer.
- The module consists of linear projection layers for  `keys `,  `queries `, and  `values `, as well as an  `embedding layer `.
- If relative positions are used, additional linear projection layers and biases are created.

Here's a breakdown of the main components and their purposes:

-  `value_size `: The size of each value embedding per head.
-  `key_size `: The size of each key and query embedding per head.
-  `num_heads `: The number of independent queries per timestep.
-  `scaling `: Whether to scale the attention logits.
-  `attention_dropout_rate `: Dropout rate for attention logits.
-  `relative_positions `: Whether to use TransformerXL style relative attention.
-  `relative_position_symmetric `: If `True`, the symmetric version of basis functions will be used. If `False`, both symmetric and asymmetric versions will be used.
-  `relative_position_functions() `: List of function names used for relative positional biases.
-  `num_relative_position_features `: Number of relative positional features to compute. If None, it defaults to `value_size * num_heads`.
-  `positional_dropout_rate `: Dropout rate for the positional encodings if relative positions are used.
-  `zero_initialize `: If `True`, the final linear layer will be initialized with zeros.
-  `initializer `: Initializer for the projection layers. If not specified, `VarianceScaling` is used with a scale of `2.0`.

The module consists of several linear projection layers  `(_q_layer, _k_layer, _v_layer, _embedding_layer, _r_k_layer) ` with specified sizes and initializers. The projection layers are used to project the input tensor into the appropriate dimensions for multi-head attention calculations.

If relative positions are used  `(_relative_positions=True) `, additional linear projection layers  `(_r_k_layer) ` and biases  `(_r_w_bias, _r_r_bias)  `are created to handle relative positional encodings.

The module implements the  `call()` method, which performs the forward pass of the multi-head attention mechanism. It takes an input tensor and returns the output tensor after applying the multi-head attention operations.


In [None]:
class MultiheadAttention(snt.Module):
  """Multi-head attention."""

  def __init__(self,
               value_size: int,
               key_size: int,
               num_heads: int,
               scaling: bool = True,
               attention_dropout_rate: float = 0.1,
               relative_positions: bool = False,
               relative_position_symmetric: bool = False,
               relative_position_functions: Optional[List[str]] = None,
               num_relative_position_features: Optional[int] = None,
               positional_dropout_rate: float = 0.1,
               zero_initialize: bool = True,
               initializer: Optional[snt.initializers.Initializer] = None,
               name: str = None):
    """Creates a MultiheadAttention module.

    Args:
      value_size: The size of each value embedding per head.
      key_size: The size of each key and query embedding per head.
      num_heads: The number of independent queries per timestep.
      scaling: Whether to scale the attention logits.
      attention_dropout_rate: Dropout rate for attention logits.
      relative_positions: Whether to use TransformerXL style relative attention.
      relative_position_symmetric: If True, the symmetric version of basis
        functions will be used. If False, a symmetric and asymmetric versions
        will be use.
      relative_position_functions: List of function names used for relative
        positional biases.
      num_relative_position_features: Number of relative positional features
        to compute. If None, `value_size * num_heads` is used.
      positional_dropout_rate: Dropout rate for the positional encodings if
        relative positions are used.
      zero_initialize: if True, the final linear layer will be 0 initialized.
      initializer: Initializer for the projection layers. If unspecified,
        VarianceScaling is used with scale = 2.0.
      name: Name of module.
    """
    super().__init__(name=name)
    self._value_size = value_size
    self._key_size = key_size
    self._num_heads = num_heads
    self._attention_dropout_rate = attention_dropout_rate
    self._scaling = scaling
    self._relative_positions = relative_positions
    self._relative_position_symmetric = relative_position_symmetric
    self._relative_position_functions = relative_position_functions
    if num_relative_position_features is None:
      # num_relative_position_features needs to be divisible by the number of
      # relative positional functions *2 (for symmetric & asymmetric version).
      divisible_by = 2 * len(self._relative_position_functions)
      self._num_relative_position_features = (
          (self._value_size // divisible_by) * divisible_by)
    else:
      self._num_relative_position_features = num_relative_position_features
    self._positional_dropout_rate = positional_dropout_rate

    self._initializer = initializer
    if self._initializer is None:
      self._initializer = snt.initializers.VarianceScaling(scale=2.0)

    key_proj_size = self._key_size * self._num_heads
    embedding_size = self._value_size * self._num_heads

    self._q_layer = snt.Linear(
        key_proj_size,
        name='q_layer',
        with_bias=False,
        w_init=self._initializer)
    self._k_layer = snt.Linear(
        key_proj_size,
        name='k_layer',
        with_bias=False,
        w_init=self._initializer)
    self._v_layer = snt.Linear(
        embedding_size,
        name='v_layer',
        with_bias=False,
        w_init=self._initializer)
    w_init = snt.initializers.Zeros() if zero_initialize else self._initializer
    self._embedding_layer = snt.Linear(
        embedding_size,
        name='embedding_layer',
        w_init=w_init)

    # Create additional layers if using relative positions.
    if self._relative_positions:
      self._r_k_layer = snt.Linear(
          key_proj_size,
          name='r_k_layer',
          with_bias=False,
          w_init=self._initializer)
      self._r_w_bias = tf.Variable(
          self._initializer([1, self._num_heads, 1, self._key_size],
                            dtype=tf.float32),
          name='r_w_bias')
      self._r_r_bias = tf.Variable(
          self._initializer([1, self._num_heads, 1, self._key_size],
                            dtype=tf.float32),
          name='r_r_bias')

  def _multihead_output(self, linear, inputs):
    """Applies a standard linear to inputs and returns multihead output."""

    output = snt.BatchApply(linear)(inputs)  # [B, T, H * KV]
    num_kv_channels = output.shape[-1] // self._num_heads
    # Split H * Channels into separate axes.
    output = snt.reshape(output,
                         output_shape=[-1, self._num_heads, num_kv_channels])
    # [B, T, H, KV] -> [B, H, T, KV]
    return tf.transpose(output, [0, 2, 1, 3])

The  `call()`method of the `MultiheadAttention` class implements the forward pass of the multi-head attention mechanism. Here's a breakdown of the steps performed in the method:

1. Initialize the projection layers and compute the dimensions.
2. Compute the `queries`, `keys`, and `values` by applying multi-headed projections of the inputs using the `_multihead_output` method.
3. Scale the queries by the square root of the key size if scaling is enabled.
4. If relative positions are enabled, compute the positional encodings and project them to form relative keys  `(r_k) `.
5. Compute the logits by performing matrix multiplication between queries and keys.
   - If relative positions are enabled, add the shifted relative logits to the content logits.
6. Apply  `softmax ` to obtain attention weights.
7. Apply  `dropout ` to the attention weights if in training mode.
8. Compute the attended inputs by multiplying the attention weights with values.
9. Transpose and reshape the output to the desired shape.
10. Apply a final linear layer to obtain the output.


In [None]:
def _multihead_output(self, linear, inputs):
    """Applies a standard linear to inputs and returns multihead output."""

    output = snt.BatchApply(linear)(inputs)  # [B, T, H * KV]
    num_kv_channels = output.shape[-1] // self._num_heads
    # Split H * Channels into separate axes.
    output = snt.reshape(output,
                         output_shape=[-1, self._num_heads, num_kv_channels])
    # [B, T, H, KV] -> [B, H, T, KV]
    return tf.transpose(output, [0, 2, 1, 3])

  def call(self,
               inputs,
               is_training=False):
    # Initialise the projection layers.
    embedding_size = self._value_size * self._num_heads
    seq_len = inputs.shape[1]

    # Compute q, k and v as multi-headed projections of the inputs.
    q = self._multihead_output(self._q_layer, inputs)  # [B, H, T, K]
    k = self._multihead_output(self._k_layer, inputs)  # [B, H, T, K]
    v = self._multihead_output(self._v_layer, inputs)  # [B, H, T, V]

    # Scale the query by the square-root of key size.
    if self._scaling:
      q *= self._key_size**-0.5

    if self._relative_positions:
      # For relative positions, we project positions to form relative keys.
      distances = tf.range(-seq_len + 1, seq_len, dtype=tf.float32)[tf.newaxis]
      positional_encodings = positional_features_all(
          positions=distances,
          feature_size=self._num_relative_position_features,
          seq_length=seq_len,
          feature_functions=self._relative_position_functions,
          symmetric=self._relative_position_symmetric)
      # [1, 2T-1, Cr]

      if is_training:
        positional_encodings = tf.nn.dropout(
            positional_encodings, rate=self._positional_dropout_rate)

      # [1, H, 2T-1, K]
      r_k = self._multihead_output(self._r_k_layer, positional_encodings)

      # Add shifted relative logits to content logits.
      # [B, H, T', T]
      content_logits = tf.matmul(q + self._r_w_bias, k, transpose_b=True)
      # [B, H, T', 2T-1]
      relative_logits = tf.matmul(
          q + self._r_r_bias, r_k, transpose_b=True)
      #  [B, H, T', T]
      relative_logits = relative_shift(relative_logits)
      logits = content_logits + relative_logits
    else:
      # [B, H, T', T]
      logits = tf.matmul(q, k, transpose_b=True)

    weights = tf.nn.softmax(logits)

    # Dropout on the attention weights.
    if is_training:
      weights = tf.nn.dropout(weights, rate=self._attention_dropout_rate)

    # Transpose and reshape the output.
    output = tf.matmul(weights, v)  # [B, H, T', V]
    output_transpose = tf.transpose(output, [0, 2, 1, 3])  # [B, T', H, V]
    # Final linear layer.
    attended_inputs = snt.reshape(
        output_transpose, output_shape=[embedding_size], preserve_dims=2)
    output = self._embedding_layer(attended_inputs)

    return output

Now the author includes additional `helper()` functions for handling relative positions and computing positional encodings/features. Here's an explanation of these functions:

1.  `relative_shift() `: This function is used to shift the relative `logits`, following the approach used in  `TransformerXL `. It prepends zeros to the final timescale dimension of the input `x`, then performs reshaping and slicing operations to obtain the shifted logits.

2.  `get_positional_feature_function() `: This function returns the positional feature functions based on the provided name. It maintains a dictionary of available feature functions and raises an error if an invalid function name is provided.

3.  `positional_features_all() `: This function computes the relative positional encodings/features. It takes various arguments such as  `positions `  (tensor of relative positions),  `feature_size ` (total number of basis functions),  `seq_length ` (characteristic length of individual positional features),  `bin_size ` (bin size used for partitioning the sequence),  `feature_functions ` (list of feature function names), and  `symmetric ` (boolean indicating whether the resulting features should be symmetric).

   The function iterates over the `feature_functions`, retrieves the corresponding function using  `get_positional_feature_function `, and applies it to compute positional features based on the given parameters. The resulting features are concatenated along the feature axis. If symmetric is  `True `, the features are symmetrically  duplicated across the relative position of  `0 `. The final tensor shape is checked for compatibility before returning the embeddings.



In [None]:
def relative_shift(x):
  """Shift the relative logits like in TransformerXL."""
  # We prepend zeros on the final timescale dimension.
  to_pad = tf.zeros_like(x[..., :1])
  x = tf.concat([to_pad, x], -1)
  _, num_heads, t1, t2 = x.shape
  x = tf.reshape(x, [-1, num_heads, t2, t1])
  x = tf.slice(x, [0, 0, 1, 0], [-1, -1, -1, -1])
  x = tf.reshape(x, [-1, num_heads, t1, t2 - 1])
  x = tf.slice(x, [0, 0, 0, 0], [-1, -1, -1, (t2 + 1) // 2])
  return x


# Available feature functions:
def get_positional_feature_function(name):
  """Returns positional feature functions."""
  available = {
      'positional_features_exponential': positional_features_exponential,
      'positional_features_central_mask': positional_features_central_mask,
      'positional_features_gamma': positional_features_gamma,
      'positional_features_cosine': positional_features_cosine,
      'positional_features_linear_masks': positional_features_linear_masks,
      'positional_features_sin_cos': positional_features_sin_cos,
  }
  if name not in available:
    raise ValueError(f'Function {name} not available in {available.keys()}')
  return available[name]


def positional_features_all(positions: tf.Tensor,
                            feature_size: int,
                            seq_length: Optional[int] = None,
                            bin_size: Optional[int] = None,
                            feature_functions: Optional[List[str]] = None,
                            symmetric=False):
  """Compute relative positional encodings/features.

  Each positional feature function will compute/provide the same fraction of
  features, making up the total of feature_size.
  Args:
    positions: Tensor of relative positions of arbitrary shape.
    feature_size: Total number of basis functions.
    seq_length: Sequence length denoting the characteristic length that
      the individual positional features can use. This is required since the
      parametrization of the input features should be independent of positions
      while it could still require to use the total number of features.
    bin_size: Bin sized used to partition the sequence. This can be used to
      compute features on the absolute scale relative to the genome.
    feature_functions: List of different feature functions to use. Each function
      will take as argument: positions, sequence length and number of features
      to compute.
    symmetric: If True, the resulting features will be symmetric across the
      relative position of 0 (i.e. only absolute value of positions will
      matter). If false, then both the symmetric and asymmetric version
      (symmetric multiplied by sign(positions)) of the features will be used.

  Returns:
    Tensor of shape: positions.shape + (feature_size,).
  """
  if feature_functions is None:
    feature_functions = ['positional_features_exponential',
                         'positional_features_central_mask',
                         'positional_features_gamma']
  num_components = len(feature_functions)  # 1 per each basis function
  if not symmetric:
    num_components = 2 * num_components

  # For now, we do not allow odd sized embeddings.
  if feature_size % num_components != 0:
    raise ValueError(
        f'feature_size has to be divisible by {num_components}')

  feature_functions = [get_positional_feature_function(f)
                       for f in feature_functions]
  num_basis_per_class = feature_size // num_components
  embeddings = tf.concat([f(tf.abs(positions), num_basis_per_class,
                            seq_length, bin_size)
                          for f in feature_functions],
                         axis=-1)
  if not symmetric:
    embeddings = tf.concat([embeddings,
                            tf.sign(positions)[..., tf.newaxis] * embeddings],
                           axis=-1)
  tf.TensorShape(embeddings.shape).assert_is_compatible_with(
      positions.shape + [feature_size])
  return embeddings



These are the additional positional feature functions:

1.  `_prepend_dims() `: This function is a utility function that prepends dimensions to a tensor. It reshapes the tensor by adding `num_dims` dimensions with size `1` at the beginning.

2.  `positional_features_exponential() `: This function creates exponentially decaying positional weights. It takes positions (position tensor),  `feature_size  `(number of basis functions),  `seq_length ` (sequence length),  `bin_size ` (unused), and  `min_half_life ` (smallest exponential half-life) as arguments. If `seq_length` is not provided, it is computed as the maximum absolute value of positions plus 1.

   The function calculates a grid of half-lives from `min_half_life` to `max_range` on a logarithmic scale, where `max_range` is determined based on `seq_length`. The `half-lives` are used to compute exponential weights for each position in `positions`. The outputs are shaped as  `[2 * seq_length - 1, feature_size] `, representing the positional features.

3.  `positional_features_central_mask() `: This function creates positional features using a central mask, allowing only central features. It takes `positions` (position tensor), `feature_size` (number of basis functions), `seq_length` (unused), and `bin_size` (unused) as arguments.

   The function generates a series of center widths using powers of 2, and each width is adjusted by subtracting 1. It then checks if the absolute positions are within the corresponding center widths and converts the result into a float tensor. The outputs have a shape of  `[positions.shape + [feature_size]] `.

4.  `gamma_pdf() `: This function computes the probability density function (PDF) of the gamma distribution. It takes  `x ` (input values), concentration, and rate as arguments and returns the PDF values. The function uses the gamma distribution formula to calculate the logarithm of the unnormalized probability and the logarithm of the normalization constant. Finally, it exponentiates the difference to obtain the PDF values.

These functions provide different strategies for generating positional encodings/features based on the positions and other parameters.

In [None]:
def _prepend_dims(x, num_dims):
  return tf.reshape(x, shape=[1] * num_dims + x.shape)


def positional_features_exponential(positions: tf.Tensor,
                                    feature_size: int,
                                    seq_length: Optional[int] = None,
                                    bin_size: Optional[int] = None,
                                    min_half_life: Optional[float] = 3.0):
  """Create exponentially decaying positional weights.

  Args:
    positions: Position tensor (arbitrary shape).
    feature_size: Number of basis functions to use.
    seq_length: Sequence length.
    bin_size: (unused). See positional_features_all.
    min_half_life: Smallest exponential half life in the grid of half lives.

  Returns:
    A Tensor with shape [2 * seq_length - 1, feature_size].
  """
  del bin_size  # Unused.
  if seq_length is None:
    seq_length = tf.reduce_max(tf.abs(positions)) + 1
  # Grid of half lifes from [3, seq_length / 2] with feature_size
  # distributed on the log scale.
  seq_length = tf.cast(seq_length, dtype=tf.float32)
  max_range = tf.math.log(seq_length) / tf.math.log(2.0)
  half_life = tf.pow(2.0, tf.linspace(min_half_life, max_range, feature_size))
  half_life = _prepend_dims(half_life, positions.shape.rank)
  positions = tf.abs(positions)
  outputs = tf.exp(-tf.math.log(2.0) / half_life * positions[..., tf.newaxis])
  tf.TensorShape(outputs.shape).assert_is_compatible_with(
      positions.shape + [feature_size])
  return outputs
  def positional_features_central_mask(positions: tf.Tensor,
                                     feature_size: int,
                                     seq_length: Optional[int] = None,
                                     bin_size: Optional[int] = None):
  """Positional features using a central mask (allow only central features)."""
  del seq_length  # Unused.
  del bin_size  # Unused.
  center_widths = tf.pow(2.0, tf.range(1, feature_size + 1, dtype=tf.float32))
  center_widths = center_widths - 1
  center_widths = _prepend_dims(center_widths, positions.shape.rank)
  outputs = tf.cast(center_widths > tf.abs(positions)[..., tf.newaxis],
                    tf.float32)
  tf.TensorShape(outputs.shape).assert_is_compatible_with(
      positions.shape + [feature_size])
  return outputs


def gamma_pdf(x, concentration, rate):
  """Gamma probability distribution function: p(x|concentration, rate)."""
  log_unnormalized_prob = tf.math.xlogy(concentration - 1., x) - rate * x
  log_normalization = (tf.math.lgamma(concentration) -
                       concentration * tf.math.log(rate))
  return tf.exp(log_unnormalized_prob - log_normalization)


Here are two more positional feature functions:

1.  `positional_features_gamma() `: This function computes positional features using gamma distributions. It takes  `positions ` (position tensor),  `feature_size ` (number of basis functions),  `seq_length ` (sequence length),  `bin_size  `(unused), stddev (standard deviation), and  `start_mean ` (starting mean) as arguments.

   The function first computes the sequence length if it is not provided by taking the maximum absolute value of positions and adding 1. It then calculates the mean values for each basis function on a linear scale from `start_mean` to `seq_length`. The concentration and rate parameters of the gamma distribution are computed based on the mean and stddev values. The probabilities are obtained by evaluating the gamma PDF at the absolute positions. A small constant is added to ensure numerical stability, and the probabilities are normalized by dividing them by the maximum probability along the feature dimension. The resulting outputs have a shape of [positions.shape + [feature_size]].

2.  `positional_features_cosine() `: This function generates cosine positional features. It takes  `positions ` (position tensor),  `feature_size ` (number of basis functions),  `seq_length ` (unused), and  `bin_size ` (unused) as arguments.

   The function defines a periodicity value for each basis function based on a geometric series. The cosine of the positions divided by the corresponding periodicity values is computed to generate the cosine positional features. The outputs have a shape of [positions.shape + [feature_size]].

These functions provide additional options for generating positional encodings/features based on gamma distributions and cosine functions, respectively.

In [None]:
def positional_features_gamma(positions: tf.Tensor,
                              feature_size: int,
                              seq_length: Optional[int] = None,
                              bin_size: Optional[int] = None,
                              stddev=None,
                              start_mean=None):
  """Positional features computed using the gamma distributions."""
  del bin_size  # Unused.
  if seq_length is None:
    seq_length = tf.reduce_max(tf.abs(positions)) + 1
  if stddev is None:
    stddev = seq_length / (2 * feature_size)
  if start_mean is None:
    start_mean = seq_length / feature_size
  mean = tf.linspace(start_mean, seq_length, num=feature_size)
  mean = _prepend_dims(mean, positions.shape.rank)
  concentration = (mean / stddev)**2
  rate = mean / stddev**2
  probabilities = gamma_pdf(
      tf.abs(tf.cast(positions, dtype=tf.float32))[..., tf.newaxis],
      concentration, rate)
  probabilities += 1e-8  # To ensure numerical stability.
  outputs = probabilities / tf.reduce_max(probabilities,
                                          axis=1, keepdims=True)
  tf.TensorShape(outputs.shape).assert_is_compatible_with(
      positions.shape + [feature_size])
  return outputs


def positional_features_cosine(positions: tf.Tensor,
                               feature_size: int,
                               seq_length: Optional[int] = None,
                               bin_size: Optional[int] = None):
  """Cosine positional features."""
  del bin_size  # Unused.
  del seq_length  # Unused.
  periodicity = 1.25 * tf.pow(2.0, tf.range(0, feature_size, dtype=tf.float32))
  periodicity = _prepend_dims(periodicity, positions.shape.rank)

  outputs = tf.math.cos(2 * np.pi * positions[..., tf.newaxis] / periodicity)
  tf.TensorShape(outputs.shape).assert_is_compatible_with(
      positions.shape + [feature_size])
  return outputs

Here are two more positional feature functions:

1.  `positional_features_linear_masks() `: This function generates exponentially increasing point focuses. It takes  `positions ` (position tensor),  `feature_size  `(number of basis functions),  `seq_length ` (unused), and  `bin_size ` (unused) as arguments.

   The function creates a range of distances from 0  to  `feature_size ` - 1. It then compares each distance with the absolute positions and generates a binary mask where the mask value is 1 if the distance matches the absolute position and 0 otherwise. The outputs have a shape of  `[positions.shape + [feature_size]] `.

2.  `positional_features_sin_cos() `: This function generates sine/cosine positional encodings. It takes  `positions ` (position tensor),  `feature_size ` (number of basis functions),  `seq_length ` (unused),  `bin_size ` (unused), and  `max_time ` (maximum time value) as arguments.

   The function first checks if `feature_size` is divisible by 2. It then creates a range of values `i` from 0 to `feature_size - 1` with a `step` of 2. Sine and cosine positional encodings are computed by dividing the positions by `max_time` raised to the power of `(i / feature_size)`. The sine and cosine values are concatenated along the last dimension to form the outputs, which have a shape of `[positions.shape + [feature_size]]`.

These functions provide additional options for generating positional encodings/features based on linear masks and sine/cosine functions, respectively.

In [None]:
def positional_features_linear_masks(positions: tf.Tensor,
                                     feature_size: int,
                                     seq_length: Optional[int] = None,
                                     bin_size: Optional[int] = None):
  """Exponentially increasing point focuses."""
  del bin_size  # Unused.
  del seq_length  # Unused.
  distances = tf.range(0, feature_size, dtype=tf.float32)
  distances = _prepend_dims(distances, positions.shape.rank)
  outputs = tf.cast(distances == tf.abs(positions[..., tf.newaxis]),
                    dtype=tf.float32)

  tf.TensorShape(outputs.shape).assert_is_compatible_with(
      positions.shape + [feature_size])
  return outputs


def positional_features_sin_cos(positions: tf.Tensor,
                                feature_size: int,
                                seq_length: Optional[int] = None,
                                bin_size: Optional[int] = None,
                                max_time=10000.0):
  """Sine/cosine positional encodings."""
  del bin_size  # Unused.
  del seq_length  # Unused.
  if feature_size % 2 != 0:
    raise ValueError('feature_size needs to be divisible by 2.')
  i = tf.range(0, feature_size, 2, dtype=tf.float32)
  i = _prepend_dims(i, positions.shape.rank)

  # Concat sines and cosines and return.
  outputs = tf.concat([
      tf.sin(positions[..., tf.newaxis] / max_time**(i / feature_size)),
      tf.cos(positions[..., tf.newaxis] / max_time**(i / feature_size))], -1)

  tf.TensorShape(outputs.shape).assert_is_compatible_with(
      positions.shape + [feature_size])
  return outputs