**Tutorial: Training Enformer - A Step-by-Step Guide**

Welcome to our tutorial page on training Enformers! Here, we will provide you with a detailed understanding of how `Enformer` method works. You will be able to explore the code implementation for each part and observe the training, testing, and evaluation results of this method.

Before diving into the tutorial, let's go through the necessary steps to train the data:

**Steps:**
1. Set up the `tf.data.Dataset` by accessing the Basenji2 data on Google Cloud Storage (GCS) at `gs://basenji_barnyard/data`. GCS is a cloud-based storage service. You can download the data by referring to the documentation or following the instructions provided in a resource like `gsutil`.

2. Begin training the model by alternating between training on human and mouse data batches.

3. Evaluate the model's performance on human and mouse genomes.

The Enformer model utilizes a state-of-the-art architecture to predict genomic tracks from one-hot-encoded DNA sequences. This architecture consists of three main components: convolutional blocks with pooling, transformer blocks, and cropping layers. These components are followed by final pointwise convolutions that branch into two organism-specific network heads.

The input to the Enformer model is a DNA sequence of length 196,608 bp, which is one-hot-encoded and used for prediction. For the human genome, the model produces 5,313 genomic tracks, while for the mouse genome, it generates 1,643 tracks. Each of these tracks has a length of 896, corresponding to 114,688 bp, which are aggregated into 128-bp bins.

The convolutional blocks with pooling play a crucial role in reducing the spatial dimension from 196,608 bp to 1,536. This reduction ensures that each sequence position vector represents a 128 bp segment, allowing for efficient processing and analysis within the Enformer model.
To gain more insights into the data and understand how it is sent to this method, we recommend reading the report prior to accessing this page.
 .

We hope this tutorial provides you with a comprehensive understanding of Enformer training. If you have any questions or require further assistance, please don not hesitate to reach out.

The initial line of code utilizes the `pip` package manager to install two Python packages, namely `dm-sonnet` and `tqdm`.

1- `dm-sonnet` serves as a deep learning library that is constructed on `TensorFlow`. Its purpose is to simplify the process of constructing and training neural networks by providing a collection of abstractions and modules.

2- On the other hand, `tqdm` is a Python library that facilitates the creation of progress bars and the visualization of progress for iterations or tasks within the command line interface.

3- The presence of an exclamation mark (i.e., **!**) at the start of the line indicates that the code is being executed within a Jupyter Notebook or an IPython environment. In such environments, the exclamation mark is employed to directly execute shell commands from the notebook. In this particular case, it is utilized to execute the "pip install" command for the installation of the necessary packages.

In [1]:
!pip install dm-sonnet tqdm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dm-sonnet
  Downloading dm_sonnet-2.0.1-py3-none-any.whl (268 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/268.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━[0m [32m235.5/268.4 kB[0m [31m7.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.4/268.4 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dm-sonnet
Successfully installed dm-sonnet-2.0.1


The `wget` command is utilized to download two Python source code files from specific URLs:

1- The first `wget` command downloads a source code file called `attention_module.py` from the GitHub repository: https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/attention_module.py.

2- The second `wget` command downloads a source code file named `enformer.py` from the GitHub repository: https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/enformer.py.

To run the command quietly without displaying any output or progress information, the `-q` option is used with `wget`. This approach is commonly adopted in scripts or automation tasks to maintain a clean terminal or notebook output without unnecessary clutter.

By executing these `wget` commands, the code fetches the Enformer model or module's source code files from the specified URLs. These downloaded files can be subsequently employed locally for further development or incorporated into Python code as needed.

In [None]:
# Get enformer source code
!wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/attention_module.py
!wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/enformer.py

**Note**. We comment and explain the `attention_module.py` and `enformer.py` in different Jupyter notebook files in this GitHub repository.

**Libraries**

The code begins by importing the `TensorFlow` library using the line import `tensorflow` as `tf`.

To ensure that a GPU is enabled, it includes an assertion statement. The code verifies the presence of any physical GPU devices by invoking `tf.config.list_physical_devices('GPU')`. If there are no available GPU devices, the code raises an AssertionError with the message `Start the colab kernel with GPU: Runtime -> Change runtime type -> GPU`. This assertion guarantees that the code can leverage the GPU for accelerated computations.

Next, the code sets an environment variable named `TF_ENABLE_GPU_GARBAGE_COLLECTION` to false. This variable impacts the garbage collection behavior within TensorFlow when utilizing a GPU. By assigning it a value of false, the code disables GPU-specific garbage collection optimizations. This can be beneficial for simplifying the debugging process, especially when encountering out-of-memory (OOM) errors during GPU computations.

In summary, this code snippet ensures the availability of a GPU and adjusts certain TensorFlow settings pertaining to GPU usage and memory management. These adjustments are made for the purpose of facilitating debugging activities.

In [None]:
import tensorflow as tf
# Make sure the GPU is enabled
#assert tf.config.list_physical_devices('GPU'), 'Start the colab kernel with GPU: Runtime -> Change runtime type -> GPU'

# Easier debugging of OOM
%env TF_ENABLE_GPU_GARBAGE_COLLECTION=fals

env: TF_ENABLE_GPU_GARBAGE_COLLECTION=fals


This code imports several Python libraries and modules:

1. `sonnet` from the `snt` module: `Sonnet` is a deep learning library built on top of TensorFlow, and `snt` is a sub-module within Sonnet. It provides additional functionality and abstractions for building neural networks.

2. `tqdm`: This library is used for creating progress bars and visualizing the progress of iterations or tasks in the command line interface.

3. `IPython.display` from the `clear_output` module: This module provides functionality for controlling the display in IPython environments. `clear_output` is a function that clears the output of the cell or console.

4. `numpy` as `np`: NumPy is a fundamental library for numerical computing in Python. It provides support for large, multi-dimensional arrays and matrices, along with a collection of mathematical functions to operate on these arrays efficiently.

5. `pandas` as `pd`: Pandas is a powerful library for data manipulation and analysis. It provides data structures and functions to efficiently work with structured data, such as tables or CSV files.

6. `time`: This module provides functions for working with time-related operations, such as measuring elapsed time or introducing delays in the code.

7. `os`: The os module provides a way to use operating system-dependent functionality, such as interacting with the file system, working with environment variables, and executing system commands.

By importing these libraries and modules, the code gains access to their respective functionalities, allowing for easier and more efficient development of the subsequent code.


In [None]:
import sonnet as snt
from tqdm import tqdm
from IPython.display import clear_output
import numpy as np
import pandas as pd
import time
import os

This code uses an assert statement to check the version number of the `sonnet` library (`snt`). Specifically, it verifies that the version number starts with `2.0`.

The assert statement checks if the condition provided is `True`. If the condition is `False`, it raises an `AssertionError` with an optional error message.

In this case, the code asserts that the version number of sonnet starts with `2.0`. If the version number does not meet this condition, an AssertionError will be raised. This assertion is typically used to ensure compatibility or specific features in the code that rely on a particular version of the `sonnet` library.

In [None]:
#assert snt.version.startswith('2.0')

The code `!nvidia-smi` is used in a Colab notebook or Jupyter notebook to display information about the GPU(s) available in the environment. It executes the shell command `nvidia-smi`, which is a tool provided by `NVIDIA` to monitor and manage NVIDIA GPU devices.

In [None]:
# GPU colab has T4 with 16 GiB of memory
!nvidia-smi

/bin/bash: nvidia-smi: command not found


**code**

Importing `enformer` library which is class defined by the author. The class is fully explained in another file in this GitHub page.


In [None]:
import enformer

The code defines the `targets_txt` variable, which is a formatted string holding the URL to the target information text file specific to the given organism. The URL is constructed using an f-string, where the value of the organism is inserted into the URL.

To read the contents of the target information text file into a DataFrame, the function employs `pd.read_csv()` from the `pandas` library. It assumes that the text file is `tab-separated`, specified by the `sep='\t'` parameter.

The resulting DataFrame, containing the target information, is then returned as the output of the function.

In [None]:
# @title get_targets(organism)
def get_targets(organism):
  targets_txt = f'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_{organism}.txt'
  return pd.read_csv(targets_txt, sep='\t')

Here, the author presents a function called `get_dataset`, which serves the purpose of retrieving a dataset for a designated organism and subset. The dataset is acquired from TFRecord files and processed using TensorFlow.

Here is the explanation of the code:

1. The code begins by defining the `organism_path()` function. This function generates the path to the directory containing the data pertaining to a specific organism.

2. The `get_dataset()` function is created, which accepts three arguments: organism (specifying the desired organism), subset (indicating the specific subset of the dataset, such as 'train', 'valid', or 'test'), and `num_threads` (representing the number of parallel threads utilized for reading the `TFRecord` files).

3. Within the `get_dataset()` function, a call is made to `get_metadata()`, which retrieves metadata associated with the organism's data. This information may include the count of targets, sequence lengths, and other relevant details.

4. The `tfrecord_files()` function is implemented to generate a list of TFRecord file paths corresponding to the specified organism and subset.

5. A `TFRecordDataset` object is created and assigned to the dataset variable. This object facilitates the reading of the TFRecord files, utilizing parallel reads and zlib compression.

6. The dataset is mapped using the deserialize function. This function performs the deserialization process on the byte-encoded examples found within the TFRecord files, converting them into TensorFlow tensors.

7. Finally, the processed dataset is returned as the output of the `get_dataset()` function.

In addition to the `get_dataset()` function, several auxiliary functions (`get_metadata()`, `tfrecord_files()`, and `deserialize()`) are defined to assist in obtaining metadata, generating file paths, and deserializing `TFRecord` examples, respectively.

Overall, this code snippet enables the retrieval of a dataset for a particular organism and subset from TFRecord files. The dataset can then be further processed and analyzed using TensorFlow's capabilities.

In [None]:
# @title get_dataset(organism, subset, num_threads=8)
import glob
import json
import functools


def organism_path(organism):
  return os.path.join('gs://basenji_barnyard/data', organism)


def get_dataset(organism, subset, num_threads=8):
  metadata = get_metadata(organism)
  dataset = tf.data.TFRecordDataset(tfrecord_files(organism, subset),
                                    compression_type='ZLIB',
                                    num_parallel_reads=num_threads)
  dataset = dataset.map(functools.partial(deserialize, metadata=metadata),
                        num_parallel_calls=num_threads)
  return dataset


def get_metadata(organism):
  # Keys:
  # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,
  # pool_width, crop_bp, target_length
  path = os.path.join(organism_path(organism), 'statistics.json')
  with tf.io.gfile.GFile(path, 'r') as f:
    return json.load(f)


def tfrecord_files(organism, subset):
  # Sort the values by int(*).
  return sorted(tf.io.gfile.glob(os.path.join(
      organism_path(organism), 'tfrecords', f'{subset}-*.tfr'
  )), key=lambda x: int(x.split('-')[-1].split('.')[0]))


def deserialize(serialized_example, metadata):
    """Deserialize bytes stored in TFRecordFile."""
    feature_map = {
      'sequence': tf.io.FixedLenFeature([], tf.string),
      'target': tf.io.FixedLenFeature([], tf.string),
  }
    example = tf.io.parse_example(serialized_example, feature_map)
    sequence = tf.io.decode_raw(example['sequence'], tf.bool)
    sequence = tf.reshape(sequence, (metadata['seq_length'], 4))
    sequence = tf.cast(sequence, tf.float32)

    target = tf.io.decode_raw(example['target'], tf.float16)
    target = tf.reshape(target,
                      (metadata['target_length'], metadata['num_targets']))
    target = tf.cast(target, tf.float32)

    return {'sequence': sequence,
          'target': target}

### Load the data set

The code snippet is calling the `get_targets()` function to retrieve the target information for the organism `'human'`. It assigns the returned DataFrame to the `variable df_targets_human` and then displays the first few rows using the `head()` method.

Assuming the `get_targets()` function is defined properly, it should retrieve the target information for the `'human'` organism from a specific URL. The returned DataFrame, `df_targets_human`, contains the target data, and calling head() on it displays the first few rows of the DataFrame.

In [None]:
df_targets_human = get_targets('human')
df_targets_human.head()

Unnamed: 0,index,genome,identifier,file,clip,scale,sum_stat,description
0,0,0,ENCFF833POA,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:cerebellum male adult (27 years) and mal...
1,1,0,ENCFF110QGM,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:frontal cortex male adult (27 years) and...
2,2,0,ENCFF880MKD,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:chorion
3,3,0,ENCFF463ZLQ,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:Ishikawa treated with 0.02% dimethyl sul...
4,4,0,ENCFF890OGQ,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:GM03348


Now the author use code to generates three datasets: `human_dataset`, `mouse_dataset`, and `human_mouse_dataset`. Each dataset is acquired by utilizing the `get_dataset()` function and specifying the organism `('human' or 'mouse')` and subset `('train')`.

Here is the explanation of the code:

1. The `human_dataset` is obtained by invoking `get_dataset()` with the arguments `'human'` and `'train'`. This call retrieves the dataset for the `'human'` organism and the `'train'` subset.

2. Similarly, the `mouse_dataset` is obtained by invoking `get_dataset()` with the arguments `'mouse'` and `'train'`. This call retrieves the dataset for the `'mouse'` organism and the `'train'` subset.

3. Both `human_dataset` and `mouse_dataset` are modified by applying the `.batch(1)` method, which groups the elements of the dataset into individual batches, each containing a single element. This approach ensures that each element is processed individually.

4. The `.repeat()` method is additionally applied to both datasets. This method repeats the dataset indefinitely, allowing for multiple iterations during training or evaluation.

5. The `human_dataset` and `mouse_dataset` are merged using the `tf.data.Dataset.zip()` function, creating a new dataset named `human_mouse_dataset.` This resulting dataset contains pairs of samples, with each sample originating from either the `'human'` or `'mouse'` dataset.

6. Lastly, the `.prefetch(2)` method is invoked on `human_mouse_dataset`. This operation prefetches and buffers up to 2 elements, enhancing training performance by overlapping data preprocessing and model training.

In summary, this code generates datasets for the `'human'` and `'mouse'` organisms, and then combines them into a single dataset `(human_mouse_dataset)` containing pairs of samples. The resulting dataset is suitable for further processing or training purposes.

In [None]:
human_dataset = get_dataset('human', 'train').batch(1).repeat()
mouse_dataset = get_dataset('mouse', 'train').batch(1).repeat()
human_mouse_dataset = tf.data.Dataset.zip((human_dataset, mouse_dataset)).prefetch(2)

The code snippet creates an iterator for the `mouse_dataset` using the `iter()` function and assigns it to the variable it. It then retrieves the next element from the iterator using the `next()` function and assigns it to the variable example.

Here's an explanation of the code:

- The `iter()` function is called with the mouse_dataset as an argument to create an iterator object. An iterator allows iterating over the elements of a dataset.
- The `next()` function is used to retrieve the next element from the iterator it. Each time `next()` is called, it returns the next element of the dataset.
- The retrieved element is assigned to the variable example.

After executing this code, the example variable will contain the next element from the `mouse_dataset`. You can then use this example to access and manipulate the data within that element for further processing or analysis.

In [None]:
it = iter(mouse_dataset)
example = next(it)

This step demonstrates how to iterate over the `human_mouse_dataset` and print information about the elements in each iteration.

Here's an explanation of the code:

- The `iter()` function is called with the `human_mouse_dataset` as an argument to create an iterator object, which is assigned to the variable it.
- The `next()` function is used to retrieve the next element from the iterator it. Each time the loop iterates, it retrieves the next element of the dataset.
- Within the `loop`, a `for loop` is used to iterate over the range of the length of example, which is the number of elements in the current iteration of the dataset.
- Inside the for loop, the organism `('human' or 'mouse')` is printed based on the `index i`. It uses a list comprehension `(['human', 'mouse'][i])` to choose the corresponding organism string.
- The `example[i]` represents the current element of the dataset for the organism at `index i`. It is a dictionary containing different keys and their corresponding values.
- The dictionary items are printed using a dictionary comprehension `{k: (v.shape, v.dtype) for k, v in example[i].items()}`. This displays the shape and data type of each value in the current element.

By running this code, it will iterate over the `human_mouse_dataset`, print the organism label `('human' or 'mouse')` for each element, and display the shape and data type information for each value within that element. This can be helpful for understanding the structure and characteristics of the dataset elements.

In [None]:
# Example input
it = iter(human_mouse_dataset)
example = next(it)
for i in range(len(example)):
  print(['human', 'mouse'][i])
  print({k: (v.shape, v.dtype) for k,v in example[i].items()})

human
{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 5313]), tf.float32)}
mouse
{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 1643]), tf.float32)}


**Model Training**

In this step the author introduces a function called `create_step_function()`, which is responsible for producing and returning a training step function tailored to a given model and optimizer.

Here's the explanation of the code:

The `create_step_function()` function accepts two arguments: `model` (representing the model utilized for training) and `optimizer` (representing the optimizer employed to update the model's trainable variables).

Within the function, a nested function named `train_step()` is defined. This function is decorated with `@tf.function`, enabling it to be compiled and optimized using `TensorFlow's AutoGraph` functionality, leading to improved performance.

The `train_step()` function encompasses several parameters, including `batch` (representing the input batch comprising sequences and targets), `head` (signifying a specific output head of the model), and `optimizer_clip_norm_global` (an optional parameter for gradient clipping).

Inside the `train_step` function, a gradient tape `(tf.GradientTape)` is employed to record the operations for automatic differentiation. The model's forward pass is executed by invoking model with the input sequences `(batch['sequence'])` while setting `is_training` to `True`. The desired output head `(head)` is extracted from the model's outputs.

The `loss` is determined by comparing the predicted outputs `(outputs)` with the target values `(batch['target'])` using `tf.keras.losses.poisson`. The `reduce_mean()` function is utilized to calculate the average `loss` across the batch.

To compute the gradients of the loss with respect to the model's trainable variables, the gradient tape `(tape.gradient)` is leveraged.

The computed gradients are subsequently applied to the model's trainable variables by invoking the optimizer's apply method, thereby updating the model's parameters.

The `loss` is then returned as the output of the `train_step` function.

Finally, the `create_step_function()` function returns the `train_step()` function. This `train_step()` function can be utilized to execute a single training step on the model using the specified `optimizer`. It offers a convenient and efficient approach to train the model by encapsulating the necessary operations within a TensorFlow function.

In [None]:
def create_step_function(model, optimizer):

  @tf.function
  def train_step(batch, head, optimizer_clip_norm_global=0.2):
    with tf.GradientTape() as tape:
      outputs = model(batch['sequence'], is_training=True)[head]
      loss = tf.reduce_mean(
          tf.keras.losses.poisson(batch['target'], outputs))

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply(gradients, model.trainable_variables)

    return loss
  return train_step

The code snippet sets up the `learning rate`, `optimizer`, `model`, and `training step()` function for training a model using the `Enformer` architecture.

Here's an explanation of the code:

- A `learning_rate` variable is created using `tf.Variable` and initialized to 0. It is set as non-trainable by setting `trainable=False`. This variable is used to control the learning rate during training.
- An `Adam optimizer` is created using `snt.optimizers.Adam`, and the `learning_rate` is set to the previously created variable `learning_rate`.
- The `num_warmup_steps` variable is set to `5000`, which represents the number of `warm-up` steps for the learning rate schedule.
- The `target_learning_rate` is set to `0.0005`, which represents the desired learning rate after the warm-up period.
- An instance of the Enformer model is created with specific configurations, such as the number of channels, number of heads, number of transformer layers, and pooling type.
- The `create_step_function()` function is called with the` model `and `optimizer` as arguments, and the returned `train_step` function is assigned to the `train_step()` variable.

After executing this code, you can use the `train_step()` function to perform a single training step on the model using the specified optimizer. The learning rate can be controlled by adjusting the value of the `learning_rate` variable. This setup provides the necessary components for training the `Enformer` model with the specified configurations.

In [None]:
learning_rate = tf.Variable(0., trainable=False, name='learning_rate')
optimizer = snt.optimizers.Adam(learning_rate=learning_rate)
num_warmup_steps = 5000
target_learning_rate = 0.0005

model = enformer.Enformer(channels=1536 // 4,  # Use 4x fewer channels to train faster.
                          num_heads=8,
                          num_transformer_layers=11,
                          pooling_type='max')

train_step = create_step_function(model, optimizer)

In the following, the code trains the model based on the specified configuration and outputs the loss and learning rate at the conclusion of each epoch. You can see the loss for `human` and `mouse` in this part, that would worth also if we consider learning rate also in our calculation and regards its effects on our results.

Here's the explanation of the code:

1. The `steps_per_epoch` variable is initialized to `20`, indicating the number of steps to iterate through the dataset in each epoch.

2. The `num_epochs` variable is set to `5`, representing the total number of epochs during which the model will be trained.

3. The `data_it` variable is created by invoking `iter(human_mouse_dataset)`, generating an iterator for the `human_mouse_dataset`.

4. A global step counter, `global_step`, is initialized to `0`. This variable keeps track of the overall number of training steps across all epochs.

5. The training process is executed using nested loops. The outer loop iterates over the range of `num_epochs`, while the inner loop iterates over the range of `steps_per_epoch`.

6. Within the inner loop, the `global_step` is incremented by 1 to monitor the progress.

7. Following the initial training step `(global_step > 1)`, the `learning rate` is adjusted based on the current global step and the number of warm-up steps. The learning rate fraction is computed by dividing the current global step by the `maximum of 1` and the number of warm-up steps. The `learning rate` is subsequently updated by assigning the target learning rate multiplied by the learning rate fraction to the learning_rate variable.

8. The `next(data_it)` function is employed to retrieve the subsequent batch of data from the iterator. This function returns a tuple of batches, with each batch corresponding to the `'human'` and `'mouse'` `organisms`, respectively.

9. The `train_step()` function is invoked twice to execute the training step on both the `'human'` and `'mouse'` batches individually. The losses `(loss_human and loss_mouse)` are obtained.

10. At the conclusion of each epoch, the loss values for both organisms and the current learning rate are printed.

By executing this code, the model undergoes training for the specified number of epochs, and the `loss` and `learning rate` are displayed at the conclusion of each epoch. The training process entails iterating over the dataset, updating the model parameters, and adjusting the learning rate based on the global step and warm-up schedule.

In [None]:
# Train the model
steps_per_epoch = 20
num_epochs = 5

data_it = iter(human_mouse_dataset)
global_step = 0
for epoch_i in range(num_epochs):
  for i in tqdm(range(steps_per_epoch)):
    global_step += 1

    if global_step > 1:
      learning_rate_frac = tf.math.minimum(
          1.0, global_step / tf.math.maximum(1.0, num_warmup_steps))
      learning_rate.assign(target_learning_rate * learning_rate_frac)

    batch_human, batch_mouse = next(data_it)

    loss_human = train_step(batch=batch_human, head='human')
    loss_mouse = train_step(batch=batch_mouse, head='mouse')

  # End of epoch.
  print('')
  print('loss_human', loss_human.numpy(),
        'loss_mouse', loss_mouse.numpy(),
        'learning_rate', optimizer.learning_rate.numpy()
        )

100%|██████████| 20/20 [00:24<00:00,  1.25s/it]
  0%|          | 0/20 [00:00<?, ?it/s]


loss_human 1.774059 loss_mouse 0.94303024 learning_rate 2.0000002e-06


100%|██████████| 20/20 [00:17<00:00,  1.13it/s]
  0%|          | 0/20 [00:00<?, ?it/s]


loss_human 1.0067647 loss_mouse 0.8752468 learning_rate 4.0000004e-06


100%|██████████| 20/20 [00:17<00:00,  1.13it/s]
  0%|          | 0/20 [00:00<?, ?it/s]


loss_human 1.0471998 loss_mouse 0.89318746 learning_rate 6e-06


100%|██████████| 20/20 [00:17<00:00,  1.14it/s]
  0%|          | 0/20 [00:00<?, ?it/s]


loss_human 1.010262 loss_mouse 1.02991 learning_rate 8.000001e-06


100%|██████████| 20/20 [00:17<00:00,  1.14it/s]


loss_human 1.111991 loss_mouse 0.84773445 learning_rate 1.0000001e-05





**Evaluation (pearsonR)**

In this part, the code snippet defines a function `evaluate_model()` that evaluates the performance of a model on a given dataset and a specific output head.

Here's an explanation of the code:

- The `evaluate_model()` function takes three arguments: `model` (the model to be evaluated), `dataset` (the dataset on which evaluation will be performed), and `head` (the specific output head of the model).
- Inside the function, a `MetricDict` object is created with the initial metric of `'PearsonR'`. This object is used to store and compute evaluation metrics.
- A nested function predict is defined using the `@tf.function` decorator. This function takes an `input x` and returns the predicted outputs of the model for the specified `head`. It sets `is_training` to `False` to ensure evaluation mode.
- A loop is performed over the dataset using `enumerate(dataset)`. The loop iterates over the batches of the dataset, and the index is stored in i and the batch data is stored in batch.
- If a `max_steps` value is provided and the current iteration exceeds that `value`, the `loop` breaks.
- Inside the loop, the `metric.update_state` method is called to update the metric with the ground truth targets `(batch['target'])` and the predicted outputs obtained by calling the predict function on the input sequences `(batch['sequence'])`.
- After iterating over the dataset, the evaluation result is obtained by calling `metric.result()`, which returns the computed evaluation metric.

The `evaluate_model()` function allows for evaluating the performance of a model on a given dataset and specific output head. It computes the specified evaluation metric by comparing the model's predictions with the ground truth targets.

In [None]:
def evaluate_model(model, dataset, head, max_steps=None):
  metric = MetricDict({'PearsonR': PearsonR(reduce_axis=(0,1))})
  @tf.function
  def predict(x):
    return model(x, is_training=False)[head]

  for i, batch in tqdm(enumerate(dataset)):
    if max_steps is not None and i > max_steps:
      break
    metric.update_state(batch['target'], predict(batch['sequence']))

  return metric.result()

Now the code evaluates the model's performance on the `'human'` dataset using the `'human'` output head and computes the mean value of the evaluation metrics.

Here's an explanation of the code:

- The `evaluate_model()` function is called with the following arguments:
  - `model`: The model to be evaluated.
  - `dataset`: The dataset for evaluation, obtained by calling `get_dataset('human', 'valid').batch(1).prefetch(2)`. It batches the data with a batch size of 1 and prefetches 2 batches for improved performance.
  - `head`: The specific output head of the model to evaluate, which is set to 'human'.
  - `max_steps`: The maximum number of steps to perform evaluation, which is set to `100`.
- The evaluation metrics for the `'human'` dataset are stored in the `metrics_human` variable.
- The computed mean values of the evaluation metrics are printed using a dictionary comprehension, where the key is the metric name and the value is the mean value obtained by calling `.numpy().mean()` on each metric.

By running this code, the model will be evaluated on the 'human' dataset using the specified output head, and the mean values of the evaluation metrics will be displayed.

In [None]:
metrics_human = evaluate_model(model,
                               dataset=get_dataset('human', 'valid').batch(1).prefetch(2),
                               head='human',
                               max_steps=100)
print('')
print({k: v.numpy().mean() for k, v in metrics_human.items()})

Also the code snippet evaluates the model's performance on the `'mouse'` dataset using the `'mouse'` output head and computes the mean value of the evaluation metrics.

Here's an explanation of the code:

- The `evaluate_model()` function is called with the following arguments:
  - `model`: The model to be evaluated.
  - `dataset`: The dataset for evaluation, obtained by `calling get_dataset('mouse', 'valid').batch(1).prefetch(2)`. It batches the data with a batch size of 1 and prefetches 2 batches for improved performance.
  - `head`: The specific output head of the model to evaluate, which is set to `'mouse'`.
  -` max_steps`: The maximum number of steps to perform evaluation, which is set to` 100`.
- The evaluation metrics for the `'mouse'` dataset are stored in the `metrics_mouse` variable.
- The computed mean values of the evaluation metrics are printed using a dictionary comprehension, where the key is the metric name and the value is the mean value obtained by calling `.numpy().mean()` on each metric.

By running this code, the model will be evaluated on the `'mouse'` dataset using the specified output head, and the mean values of the evaluation metrics will be displayed.

In [None]:
metrics_mouse = evaluate_model(model,
                               dataset=get_dataset('mouse', 'valid').batch(1).prefetch(2),
                               head='mouse',
                               max_steps=100)
print('')
print({k: v.numpy().mean() for k, v in metrics_mouse.items()})

101it [00:21,  6.54it/s]


{'PearsonR': 0.005183698}


Similarly, lets see the results also on `human` when we run the code

In [None]:
metrics_human = evaluate_model(model,
                               dataset=get_dataset('human', 'valid').batch(1).prefetch(2),
                               head='human',
                               max_steps=100)
print('')
print({k: v.numpy().mean() for k, v in metrics_human.items()})

101it [00:23,  6.27it/s]


{'PearsonR': 0.0028573992}
