# Figures for paper

In [None]:
import functools
import glob
import json
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
import tensorflow_hub as hub

from s_enformer.utils.modelling import get_shape_list

## Figure 6
To create the `evaluation/correlation_results.pkl` file, first run `evaluation/measure_correlations.py`. Use a GPU with 24GB of memory.

In [None]:
with open('evaluation/correlation_results.pkl', 'rb') as f:
    corr = pickle.load(f)

fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(14, 9), sharey=True, sharex=True)
plt.subplots_adjust(hspace=0.2)
plt.rc('axes', titlesize=14, labelsize=14)

genomic_track_types = list(corr)
genomic_track_types_titles = ['DNase & ATAC', 'Histone ChIP', 'TF ChIP', 'CAGE']

i = 0
for row in range(2):
    for col in range(2):
        ax[row, col].scatter(
            corr[genomic_track_types[i]]['enformer'],
            corr[genomic_track_types[i]]['bigbird'],
            alpha=0.2
        )

        correlation = round(np.corrcoef(
            corr[genomic_track_types[i]]['enformer'],
            corr[genomic_track_types[i]]['bigbird']
        )[0][1], 3)
        outperformance = round(np.mean(
            np.array(corr[genomic_track_types[i]]['bigbird'])
            >= np.array(corr[genomic_track_types[i]]['enformer'])
        ) * 100, 1)

        ax[row, col].annotate(f"S-Enformer outperforms: {outperformance}%", (0, 0.9), size=12)
        ax[row, col].set_title(f"{genomic_track_types_titles[i]}: r = {correlation}", fontsize=16)

        ax[row, col].axline((0, 0), (1, 1), linewidth=1)
    
        if row == 1:
            ax[row, col].set_xlabel('Pearson correlation (Enformer)')
        if col == 0:
            ax[row, col].set_ylabel('Pearson correlation (S-Enformer)')
            ax[row, col].tick_params(axis='y', labelsize=12)
            
        i += 1

## Figure 7

In [None]:
def get_track_results():
    
    # Load the models
    enformer = hub.load("https://tfhub.dev/deepmind/enformer/1").model
    s_enformer = tf.saved_model.load("models/s_enformer")
    
    # Use the first testing sequence for the plot
    sequence = get_dataset('human', 'test').batch(1).prefetch(1)
    
    # Get the sequence
    for s in sequence:
        sequence = s
        break
        
    # Make the predictions with the two models
    prediction_enformer = predict(sequence['sequence'], enformer, 393216)[0]
    prediction_s_enformer = predict(sequence['sequence'], s_enformer, 393216//2)[0]
    
    # Get the results from the experiment (i.e. the "correct answers")
    experiment_results = sequence['target'].numpy()[0]
    
    # Get the results from a track from each of the four genomic track types
    tracks_experiment = {
        'DNASE:CD14-positive monocyte female': experiment_results[:, 41],
        'CHIP:CTCF:MCF-7': experiment_results[:, 684],
        'CHIP:H3K27ac:keratinocyte female': experiment_results[:, 736],
        'CAGE:Keratinocyte - epidermal': np.log10(1 + experiment_results[:, 4799])
    }

    tracks_enformer = {
        'DNASE:CD14-positive monocyte female': prediction_enformer[:, 41],
        'CHIP:CTCF:MCF-7': prediction_enformer[:, 684],
        'CHIP:H3K27ac:keratinocyte female': prediction_enformer[:, 736],
        'CAGE:Keratinocyte - epidermal': np.log10(1 + prediction_enformer[:, 4799])
    }

    tracks_s_enformer = {
        'DNASE:CD14-positive monocyte female': prediction_s_enformer[:, 41],
        'CHIP:CTCF:MCF-7': prediction_s_enformer[:, 684],
        'CHIP:H3K27ac:keratinocyte female': prediction_s_enformer[:, 736],
        'CAGE:Keratinocyte - epidermal': np.log10(1 + prediction_s_enformer[:, 4799])
    }
    
    return tracks_experiment, tracks_enformer, tracks_s_enformer


def get_dataset(organism, subset, num_threads=8):
    metadata = get_metadata(organism)
    dataset = tf.data.TFRecordDataset(tfrecord_files(organism, subset),
                                    compression_type='ZLIB',
                                    num_parallel_reads=num_threads)
    dataset = dataset.map(functools.partial(deserialize, metadata=metadata),
                        num_parallel_calls=num_threads)
    return dataset


def get_metadata(organism):
    # Keys:
    # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,
    # pool_width, crop_bp, target_length
    path = os.path.join(organism_path(organism), 'statistics.json')
    with tf.io.gfile.GFile(path, 'r') as f:
        return json.load(f)


def organism_path(organism):
    return os.path.join('gs://basenji_barnyard/data', organism)


def tfrecord_files(organism, subset):
    # Sort the values by int(*).
    return sorted(tf.io.gfile.glob(os.path.join(
        organism_path(organism), 'tfrecords', f'{subset}-*.tfr'
    )), key=lambda x: int(x.split('-')[-1].split('.')[0]))


def deserialize(serialized_example, metadata):
    """Deserialize bytes stored in TFRecordFile."""
    feature_map = {
      'sequence': tf.io.FixedLenFeature([], tf.string),
      'target': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_example(serialized_example, feature_map)
    sequence = tf.io.decode_raw(example['sequence'], tf.bool)
    sequence = tf.reshape(sequence, (metadata['seq_length'], 4))
    sequence = tf.cast(sequence, tf.float32)

    target = tf.io.decode_raw(example['target'], tf.float16)
    target = tf.reshape(target,
                      (metadata['target_length'], metadata['num_targets']))
    target = tf.cast(target, tf.float32)

    return {'sequence': sequence,
            'target': target}


@tf.function
def predict(x, model, sequence_length):

    length = get_shape_list(x)[1]
    padding_length_left = int((sequence_length - length) // 2)
    padding_length_right = sequence_length - length - padding_length_left
    paddings = tf.constant([[0, 0, ], [padding_length_left, padding_length_right], [0, 0]])

    x = tf.pad(x, paddings, "CONSTANT")

    return model.predict_on_batch(x)['human']


def plot_tracks(tracks1, tracks2, tracks3, height=1.5):
    
    plt.figure(figsize=(18., 12.))
    plt.rc('ytick', labelsize=16)
    
    num_rows = 12
    num_cols = 1

    row_height = 4
    space_height = 4

    num_sep_rows = lambda x: int((x-1)/3)
    grid = (row_height*num_rows + space_height*num_sep_rows(num_rows), num_cols)

    ax = []

    for ind_row in range(num_rows):
        grid_row = row_height*ind_row + space_height*num_sep_rows(ind_row+1)

        ax += [plt.subplot2grid(grid, (grid_row, 0), rowspan=row_height)]

    plt.subplots_adjust(bottom=.05, top=.95, hspace=.1)
    
    plt.rc('axes', titlesize=14, labelsize=13)
    
    y_limits = [0.7, 4.6, 9.9, 0.7]
    
    i = 0
    for y_limit, (title, y), y2, y3 in zip(y_limits, tracks1.items(), tracks2.values(), tracks3.values()):
        p1 = ax[i].fill_between(np.linspace(0, len(y), num=len(y)), y, color='green')
        p2 = ax[i+1].fill_between(np.linspace(0, len(y2), num=len(y2)), y2)
        p3 = ax[i+2].fill_between(np.linspace(0, len(y3), num=len(y3)), y3, color='orange')

        ax[i].set_ylim([0, y_limit])
        ax[i+1].set_ylim([0, y_limit])
        ax[i+2].set_ylim([0, y_limit])
        
        ax[i].set_title(title, fontsize=22)
        sns.despine(top=True, right=True, bottom=True)
        i += 3
    ax[0].legend(
        [p1, p2, p3],
        ['Experiment', 'Enformer', 'S-Enformer'],
        bbox_to_anchor=(0.64, 1.4, 0.5, 0.5),
        fontsize=18
    )
    plt.xlabel("Target sequence (bp)", fontsize=20)
    plt.xticks(ticks=np.arange(0, len(y), 100), 
               labels=np.arange(0, len(y)*128, 128*100),
               fontsize=16)

In [None]:
tracks_experiment, tracks_enformer, tracks_s_enformer = get_track_results()

In [None]:
plot_tracks(tracks_experiment, tracks_enformer, tracks_s_enformer)

## Figure 8
To create the `evaluation/receptive_field_results.pkl` file, first run `evaluation/receptive_field.py`. Use a GPU with 24GB of memory.

In [None]:
with open('evaluation/receptive_field_results.pkl', 'rb') as f:
    data = pickle.load(f)

fig, ax = plt.subplots(ncols=3, nrows=3, figsize=(14, 9), sharey=True, sharex=True)

plt.rc('axes', titlesize=15, labelsize=14)
plt.yscale('log')

indices = list(data)
i = 0
for row in range(3):
    for col in range(3):
        ax[row, col].plot(data[indices[i]]['enformer'], alpha=0.6)
        ax[row, col].plot(data[indices[i]]['bigbird'], alpha=0.6)
        ax[row, col].set_title(f"\nMutation location: {indices[i]:,}")
        ax[row, col].set_xticks(np.arange(0, len(data[indices[i]]['enformer']), 100))
        ax[row, col].set_xticklabels(
            np.arange(0, len(data[indices[i]]['enformer'])*128, 128*100),
            rotation=45,
            fontsize=12
        )
    
        if row == 2:
            ax[row, col].set_xlabel('Target sequence (bp)', fontsize=16)
        if col == 0:
            ax[row, col].set_ylabel('Effect size', fontsize=16)
            
        i += 1
        
ax[0, 2].legend(
        ['Enformer', 'S-Enformer'],
        bbox_to_anchor=(1.05, 0.5, 0.5, 0.5),
        fontsize=14
    );