Copyright 2021 DeepMind Technologies Limited

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

This colab showcases training of the Enformer model published in

**"Effective gene expression prediction from sequence by integrating long-range interactions"**

Žiga Avsec, Vikram Agarwal, Daniel Visentin, Joseph R. Ledsam, Agnieszka Grabska-Barwinska, Kyle R. Taylor, Yannis Assael, John Jumper, Pushmeet Kohli, David R. Kelley


## Steps

- Setup tf.data.Dataset by directly accessing the Basenji2 data on GCS: `gs://basenji_barnyard/data`
- Train the model for a few steps, alternating training on human and mouse data batches
- Evaluate the model on human and mouse genomes

## Setup

**Start the colab kernel with GPU**: Runtime -> Change runtime type -> GPU

### Install dependencies

In [5]:
!pip install dm-sonnet tqdm

You should consider upgrading via the '/home/arijit_sehanobish1/enformer_performer/env/bin/python -m pip install --upgrade pip' command.[0m


In [15]:
# Get enformer source code
!wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/attention_module.py
!wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/enformer.py

### Import

In [1]:
import tensorflow as tf
# Make sure the GPU is enabled 
# assert tf.config.list_physical_devices('GPU'), 'Start the colab kernel with GPU: Runtime -> Change runtime type -> GPU'

# # Easier debugging of OOM
# %env TF_ENABLE_GPU_GARBAGE_COLLECTION=false

In [2]:
import sonnet as snt
from tqdm import tqdm
from IPython.display import clear_output
import numpy as np
import pandas as pd
import time
import os

In [3]:
import sys
sys.path.insert(1,'/home/arijit_sehanobish1/enformer_performer/performer/tf_version')
import fast_attention_sonnet

In [3]:
assert snt.__version__.startswith('2.0')

In [4]:
tf.__version__

'2.4.1'

In [5]:
# GPU colab has T4 with 16 GiB of memory
!nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



### Code

In [4]:
import enformer_performer

In [5]:
model = enformer_performer.Enformer(channels=1536 // 4,  # Use 4x fewer channels to train faster.
                          num_heads=4,
                          num_transformer_layers=4,
                          pooling_type='max')

In [6]:
model

Enformer(
    channels=384,
    num_transformer_layers=4,
    num_heads=4,
    pooling_type='max',
)

In [7]:
# @title `get_targets(organism)`
def get_targets(organism):
    targets_txt = f'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_{organism}.txt'
    return pd.read_csv(targets_txt, sep='\t')

In [8]:
# @title `get_dataset(organism, subset, num_threads=8)`
import glob
import json
import functools


def organism_path(organism):
    return os.path.join('gs://basenji_barnyard/data', organism)


def get_dataset(organism, subset, num_threads=8):
    metadata = get_metadata(organism)
    dataset = tf.data.TFRecordDataset(tfrecord_files(organism, subset),
                                    compression_type='ZLIB',
                                    num_parallel_reads=num_threads)
    dataset = dataset.map(functools.partial(deserialize, metadata=metadata),
                        num_parallel_calls=num_threads)
    return dataset


def get_metadata(organism):
  # Keys:
  # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,
  # pool_width, crop_bp, target_length
    path = os.path.join(organism_path(organism), 'statistics.json')
    with tf.io.gfile.GFile(path, 'r') as f:
        return json.load(f)


def tfrecord_files(organism, subset):
  # Sort the values by int(*).
  return sorted(tf.io.gfile.glob(os.path.join(
      organism_path(organism), 'tfrecords', f'{subset}-*.tfr'
  )), key=lambda x: int(x.split('-')[-1].split('.')[0]))


def deserialize(serialized_example, metadata):
    """Deserialize bytes stored in TFRecordFile."""
    feature_map = {
      'sequence': tf.io.FixedLenFeature([], tf.string),
      'target': tf.io.FixedLenFeature([], tf.string),
  }
    example = tf.io.parse_example(serialized_example, feature_map)
    sequence = tf.io.decode_raw(example['sequence'], tf.bool)
    sequence = tf.reshape(sequence, (metadata['seq_length'], 4))
    sequence = tf.cast(sequence, tf.float32)

    target = tf.io.decode_raw(example['target'], tf.float16)
    target = tf.reshape(target,
                      (metadata['target_length'], metadata['num_targets']))
    target = tf.cast(target, tf.float32)

    return {'sequence': sequence,
          'target': target}


## Load dataset

In [9]:
df_targets_human = get_targets('human')
df_targets_human.head()

Unnamed: 0,index,genome,identifier,file,clip,scale,sum_stat,description
0,0,0,ENCFF833POA,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:cerebellum male adult (27 years) and mal...
1,1,0,ENCFF110QGM,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:frontal cortex male adult (27 years) and...
2,2,0,ENCFF880MKD,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:chorion
3,3,0,ENCFF463ZLQ,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:Ishikawa treated with 0.02% dimethyl sul...
4,4,0,ENCFF890OGQ,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:GM03348


In [9]:
len(df_targets_human)

5313

In [10]:
df_targets_mouse = get_targets('mouse')

In [11]:
len(df_targets_mouse)

1643

In [11]:
human_dataset = get_dataset('human', 'train').batch(1).repeat()
mouse_dataset = get_dataset('mouse', 'train').batch(1).repeat()
human_mouse_dataset = tf.data.Dataset.zip((human_dataset, mouse_dataset)).prefetch(2)

In [13]:
human_data_all = get_dataset('human', 'train')

In [14]:
human_data_all

<ParallelMapDataset shapes: {sequence: (131072, 4), target: (896, 5313)}, types: {sequence: tf.float32, target: tf.float32}>

In [15]:
human_dataset

<RepeatDataset shapes: {sequence: (None, 131072, 4), target: (None, 896, 5313)}, types: {sequence: tf.float32, target: tf.float32}>

In [16]:
it = iter(mouse_dataset)
example = next(it)

In [12]:
# Example input
it = iter(human_mouse_dataset)
example = next(it)
for i in range(len(example)):
  print(['human', 'mouse'][i])
  print({k: (v.shape, v.dtype) for k,v in example[i].items()})

human
{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 5313]), tf.float32)}
mouse
{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 1643]), tf.float32)}


## Model training

In [13]:
def create_step_function(model, optimizer):

  @tf.function
  def train_step(batch, head, optimizer_clip_norm_global=0.2):
    with tf.GradientTape() as tape:
        outputs = model(batch['sequence'], is_training=True)[head]
        loss = tf.reduce_mean(
          tf.keras.losses.poisson(batch['target'], outputs))

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply(gradients, model.trainable_variables)

    return loss
  return train_step

In [14]:
learning_rate = tf.Variable(0., trainable=False, name='learning_rate')
optimizer = snt.optimizers.Adam(learning_rate=learning_rate)
num_warmup_steps = 5000
target_learning_rate = 0.0005

# model = enformer_performer.Enformer(channels=1536 // 4,  # Use 4x fewer channels to train faster.
#                           num_heads=4,
#                           num_transformer_layers=4,
#                           pooling_type='max')

train_step = create_step_function(model, optimizer)

In [15]:
# Train the model
steps_per_epoch = 20
num_epochs = 5

data_it = iter(human_mouse_dataset)
global_step = 0
for epoch_i in range(num_epochs):
    for i in tqdm(range(steps_per_epoch)):
        global_step += 1

    if global_step > 1:
        learning_rate_frac = tf.math.minimum(
          1.0, global_step / tf.math.maximum(1.0, num_warmup_steps))      
        learning_rate.assign(target_learning_rate * learning_rate_frac)

    batch_human, batch_mouse = next(data_it)

    loss_human = train_step(batch=batch_human, head='human')
    loss_mouse = train_step(batch=batch_mouse, head='mouse')

  # End of epoch.
    print('')
    print('loss_human', loss_human.numpy(),
        'loss_mouse', loss_mouse.numpy(),
        'learning_rate', optimizer.learning_rate.numpy()
        )

100%|██████████| 20/20 [00:00<00:00, 222509.50it/s]


TypeError: in user code:

    <ipython-input-13-c3553c570c78>:6 train_step  *
        outputs = model(batch['sequence'], is_training=True)[head]
    /opt/conda/lib/python3.7/site-packages/sonnet/src/utils.py:151 _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/base.py:272 wrap_with_name_scope  *
        return method(*args, **kwargs)
    /home/arijit_sehanobish1/enformer_performer/enformer_performer.py:150 __call__  *
        trunk_embedding = self.trunk(inputs, is_training=is_training)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/utils.py:151 _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/base.py:272 wrap_with_name_scope  *
        return method(*args, **kwargs)
    /home/arijit_sehanobish1/enformer_performer/enformer_performer.py:201 __call__  *
        outputs = mod(outputs, **kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/utils.py:151 _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/base.py:272 wrap_with_name_scope  *
        return method(*args, **kwargs)
    /home/arijit_sehanobish1/enformer_performer/enformer_performer.py:201 __call__  *
        outputs = mod(outputs, **kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/utils.py:151 _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/base.py:272 wrap_with_name_scope  *
        return method(*args, **kwargs)
    /home/arijit_sehanobish1/enformer_performer/enformer_performer.py:201 __call__  *
        outputs = mod(outputs, **kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/utils.py:151 _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/base.py:272 wrap_with_name_scope  *
        return method(*args, **kwargs)
    /home/arijit_sehanobish1/enformer_performer/enformer_performer.py:268 __call__  *
        return inputs + self._module(inputs, is_training, *args, **kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/utils.py:151 _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/base.py:272 wrap_with_name_scope  *
        return method(*args, **kwargs)
    /home/arijit_sehanobish1/enformer_performer/enformer_performer.py:201 __call__  *
        outputs = mod(outputs, **kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/utils.py:151 _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    /opt/conda/lib/python3.7/site-packages/sonnet/src/base.py:272 wrap_with_name_scope  *
        return method(*args, **kwargs)

    TypeError: tf____call__() missing 1 required positional argument: 'training'


## Evaluate

In [60]:
# @title `PearsonR` and `R2` metrics

def _reduced_shape(shape, axis):
  if axis is None:
    return tf.TensorShape([])
  return tf.TensorShape([d for i, d in enumerate(shape) if i not in axis])


class CorrelationStats(tf.keras.metrics.Metric):
  """Contains shared code for PearsonR and R2."""

  def __init__(self, reduce_axis=None, name='pearsonr'):
    """Pearson correlation coefficient.

    Args:
      reduce_axis: Specifies over which axis to compute the correlation (say
        (0, 1). If not specified, it will compute the correlation across the
        whole tensor.
      name: Metric name.
    """
    super(CorrelationStats, self).__init__(name=name)
    self._reduce_axis = reduce_axis
    self._shape = None  # Specified in _initialize.

  def _initialize(self, input_shape):
    # Remaining dimensions after reducing over self._reduce_axis.
    self._shape = _reduced_shape(input_shape, self._reduce_axis)

    weight_kwargs = dict(shape=self._shape, initializer='zeros')
    self._count = self.add_weight(name='count', **weight_kwargs)
    self._product_sum = self.add_weight(name='product_sum', **weight_kwargs)
    self._true_sum = self.add_weight(name='true_sum', **weight_kwargs)
    self._true_squared_sum = self.add_weight(name='true_squared_sum',
                                             **weight_kwargs)
    self._pred_sum = self.add_weight(name='pred_sum', **weight_kwargs)
    self._pred_squared_sum = self.add_weight(name='pred_squared_sum',
                                             **weight_kwargs)

  def update_state(self, y_true, y_pred, sample_weight=None):
    """Update the metric state.

    Args:
      y_true: Multi-dimensional float tensor [batch, ...] containing the ground
        truth values.
      y_pred: float tensor with the same shape as y_true containing predicted
        values.
      sample_weight: 1D tensor aligned with y_true batch dimension specifying
        the weight of individual observations.
    """
    if self._shape is None:
      # Explicit initialization check.
      self._initialize(y_true.shape)
    y_true.shape.assert_is_compatible_with(y_pred.shape)
    y_true = tf.cast(y_true, 'float32')
    y_pred = tf.cast(y_pred, 'float32')

    self._product_sum.assign_add(
        tf.reduce_sum(y_true * y_pred, axis=self._reduce_axis))

    self._true_sum.assign_add(
        tf.reduce_sum(y_true, axis=self._reduce_axis))

    self._true_squared_sum.assign_add(
        tf.reduce_sum(tf.math.square(y_true), axis=self._reduce_axis))

    self._pred_sum.assign_add(
        tf.reduce_sum(y_pred, axis=self._reduce_axis))

    self._pred_squared_sum.assign_add(
        tf.reduce_sum(tf.math.square(y_pred), axis=self._reduce_axis))

    self._count.assign_add(
        tf.reduce_sum(tf.ones_like(y_true), axis=self._reduce_axis))

  def result(self):
    raise NotImplementedError('Must be implemented in subclasses.')

  def reset_states(self):
    if self._shape is not None:
      tf.keras.backend.batch_set_value([(v, np.zeros(self._shape))
                                        for v in self.variables])


class PearsonR(CorrelationStats):
  """Pearson correlation coefficient.

  Computed as:
  ((x - x_avg) * (y - y_avg) / sqrt(Var[x] * Var[y])
  """

  def __init__(self, reduce_axis=(0,), name='pearsonr'):
    """Pearson correlation coefficient.

    Args:
      reduce_axis: Specifies over which axis to compute the correlation.
      name: Metric name.
    """
    super(PearsonR, self).__init__(reduce_axis=reduce_axis,
                                   name=name)

  def result(self):
    true_mean = self._true_sum / self._count
    pred_mean = self._pred_sum / self._count

    covariance = (self._product_sum
                  - true_mean * self._pred_sum
                  - pred_mean * self._true_sum
                  + self._count * true_mean * pred_mean)

    true_var = self._true_squared_sum - self._count * tf.math.square(true_mean)
    pred_var = self._pred_squared_sum - self._count * tf.math.square(pred_mean)
    tp_var = tf.math.sqrt(true_var) * tf.math.sqrt(pred_var)
    correlation = covariance / tp_var

    return correlation


class R2(CorrelationStats):
  """R-squared  (fraction of explained variance)."""

  def __init__(self, reduce_axis=None, name='R2'):
    """R-squared metric.

    Args:
      reduce_axis: Specifies over which axis to compute the correlation.
      name: Metric name.
    """
    super(R2, self).__init__(reduce_axis=reduce_axis,
                             name=name)

  def result(self):
    true_mean = self._true_sum / self._count
    total = self._true_squared_sum - self._count * tf.math.square(true_mean)
    residuals = (self._pred_squared_sum - 2 * self._product_sum
                 + self._true_squared_sum)

    return tf.ones_like(residuals) - residuals / total


class MetricDict:
  def __init__(self, metrics):
    self._metrics = metrics

  def update_state(self, y_true, y_pred):
    for k, metric in self._metrics.items():
      metric.update_state(y_true, y_pred)

  def result(self):
    return {k: metric.result() for k, metric in self._metrics.items()}

In [61]:
def evaluate_model(model, dataset, head, max_steps=None):
  metric = MetricDict({'PearsonR': PearsonR(reduce_axis=(0,1))})
  @tf.function
  def predict(x):
    return model(x, is_training=False)[head]

  for i, batch in tqdm(enumerate(dataset)):
    if max_steps is not None and i > max_steps:
      break
    metric.update_state(batch['target'], predict(batch['sequence']))

  return metric.result()

In [66]:
metrics_human = evaluate_model(model,
                               dataset=get_dataset('human', 'valid').batch(1).prefetch(2),
                               head='human',
                               max_steps=100)
print('')
print({k: v.numpy().mean() for k, v in metrics_human.items()})

101it [00:23,  6.27it/s]


{'PearsonR': 0.0028573992}


In [63]:
metrics_mouse = evaluate_model(model,
                               dataset=get_dataset('mouse', 'valid').batch(1).prefetch(2),
                               head='mouse',
                               max_steps=100)
print('')
print({k: v.numpy().mean() for k, v in metrics_mouse.items()})

101it [00:21,  6.54it/s]


{'PearsonR': 0.005183698}
