In [1]:
# using pre-trained Enformer model to get pooled and base level contribution scores for CD69 locus**
# code taken/adapted from Enformer authors at https://github.com/deepmind/deepmind-research/blob/master/enformer/enformer-usage.ipynb**
# correspond to bottom track for figure 1B

In [2]:
import tensorflow as tf
import tensorflow_hub as hub
import joblib
import gzip
import kipoiseq
from kipoiseq import Interval
import pyfaidx
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from deeplift import dinuc_shuffle

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [3]:
model_path = 'https://tfhub.dev/deepmind/enformer/1'
fasta_file = '../../../reference_files/hg38.fa' ## make sure points to valid GRCh38 fasta
#!samtools faidx /home/jupyter/reference/align_refs/GRCh38/Homo_sapiens.GRCh38.dna.primary_assembly.chr.fa

In [4]:
# Download targets from Basenji2 dataset 
# Cite: Kelley et al Cross-species regulatory sequence activity prediction. PLoS Comput. Biol. 16, e1008050 (2020).
targets_txt = 'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_human.txt'
df_targets = pd.read_csv(targets_txt, sep='\t')
df_targets[df_targets['description'].str.contains("Jurkat")]

Unnamed: 0,index,genome,identifier,file,clip,scale,sum_stat,description
120,120,0,ENCFF916JRN,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:Jurkat clone E61
1208,1208,0,ENCFF017AFO,/home/drk/tillage/datasets/human/chip/encode/E...,32,2,mean,CHIP:H3K4me3:Jurkat clone E61
4831,4831,0,CNhs11253,/home/drk/tillage/datasets/human/cage/fantom/C...,384,1,sum,CAGE:acute lymphoblastic leukemia (T-ALL) cell...


In [5]:
SEQUENCE_LENGTH = 393216

class Enformer:

    def __init__(self,tfhub_url):
        self._model = hub.load(tfhub_url).model


    def predict_on_batch(self, inputs):
        predictions = self._model.predict_on_batch(inputs)
        return {k: v.numpy() for k, v in predictions.items()}

    @tf.function
    def contribution_input_grad(self, input_sequence,
                                target_mask, track_index,
                                output_head='human'):
        input_sequence = input_sequence[tf.newaxis]

        target_mask_mass = tf.reduce_sum(target_mask)
        with tf.GradientTape() as tape:
            tape.watch(input_sequence)
            pred = self._model.predict_on_batch(input_sequence)[output_head][::track_index]
            #print(pred.shape)
            prediction = tf.reduce_sum(
                  target_mask[tf.newaxis] * pred) / target_mask_mass
        grad = tape.gradient(prediction, input_sequence)
        input_grad = grad * input_sequence
        input_grad = tf.squeeze(input_grad, axis=0)

        return tf.reduce_sum(input_grad, axis=-1), grad
    def vars_return(self):
        return self._model.signatures#['serving_default'].variables

# @title `variant_centered_sequences`
#with strategy.scope():
class FastaStringExtractor:

    def __init__(self, fasta_file):
        self.fasta = pyfaidx.Fasta(fasta_file)
        self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}

    def extract(self, interval: Interval, **kwargs) -> str:
        # Truncate interval if it extends beyond the chromosome lengths.
        chromosome_length = self._chromosome_sizes[interval.chrom]
        trimmed_interval = Interval(interval.chrom,
                                    max(interval.start, 0),
                                    min(interval.end, chromosome_length),
                                    )
        # pyfaidx wants a 1-based interval
        sequence = str(self.fasta.get_seq(trimmed_interval.chrom,
                                          trimmed_interval.start + 1,
                                          trimmed_interval.stop).seq).upper()
        # Fill truncated values with N's.
        pad_upstream = 'N' * max(-interval.start, 0)
        pad_downstream = 'N' * max(interval.end - chromosome_length, 0)
        return pad_upstream + sequence + pad_downstream

    def close(self):
        return self.fasta.close()
    
def one_hot_encode(sequence):
    return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)


def importance_scores(chrom, start, stop, target_index, mask_indices):

    target_interval = kipoiseq.Interval(chrom, int(start), int(stop))
    resized_interval = target_interval.resize(SEQUENCE_LENGTH)
    sequence_one_hot = one_hot_encode(fasta_extractor.extract(resized_interval))
    predictions = model.predict_on_batch(sequence_one_hot[np.newaxis])['human'][0]

    target_mask = np.zeros_like(predictions)
    for idx in mask_indices:
        target_mask[idx, target_index] = 1
    # This will take some time since tf.function needs to get compiled.
    contribution_scores, grad = model.contribution_input_grad(sequence_one_hot.astype(np.float32), target_mask, target_index)
    contribution_scores = contribution_scores.numpy()
    pooled_contribution_scores = tf.nn.avg_pool1d(np.abs(contribution_scores)[np.newaxis,
                                                                              :, np.newaxis],
                                                  128, 128, 'VALID')[0, :, 0].numpy()

    base_scores = (sequence_one_hot[:][:].T * [contribution_scores[:],
                                                   contribution_scores[:],
                                                   contribution_scores[:],
                                                   contribution_scores[:]]).T

    gradient = np.multiply(sequence_one_hot[:][:].T, (np.squeeze(grad).T))
    ###### dinucleotide shuffled sequences
    seq_shuffled = dinuc_shuffle.dinuc_shuffle(sequence_one_hot, 1)[0]
    
    target_mask = np.zeros_like(predictions)
    for idx in mask_indices:
        target_mask[idx, target_index] = 1
    # This will take some time since tf.function needs to get compiled.

    contribution_scores_scram, grad_scram = model.contribution_input_grad(seq_shuffled, target_mask, target_index)
    contribution_scores_scram = contribution_scores_scram.numpy()
    pooled_contribution_scores_scram = tf.nn.avg_pool1d(np.abs(contribution_scores_scram)[np.newaxis,
                                                                              :, np.newaxis],
                                                  128, 128, 'VALID')[0, :, 0].numpy()

    ## get base level matrix

    base_scores_scram = (seq_shuffled[:][:].T * [contribution_scores_scram[:],
                                        contribution_scores_scram[:],
                                        contribution_scores_scram[:],
                                        contribution_scores_scram[:]]).T

    ## get base level matri
    gradient_scram = np.multiply(seq_shuffled[:][:].T, (np.squeeze(grad_scram).T))


        
    return resized_interval,contribution_scores,pooled_contribution_scores,base_scores,np.squeeze(grad), sequence_one_hot,base_scores_scram

    
def write_out_bedgraph_pooled(pooled_contribution_scores, interval, filename_base):
    start = interval.start
    end =  interval.end
    chrom = interval.chrom
    name = '_'.join([str(chrom), str(start), str(end)])

    out_file = open(filename_base + '.pooled.bedGraph', 'w')


    for k, value in enumerate(pooled_contribution_scores):

        start_interval = k * 128 + start
        end_interval = (k+1) * 128 + start

        line = [str(chrom),
                str(start_interval), str(end_interval),
                str(value)]

        out_file.write('\t'.join(line) + '\n')
    out_file.close()
    
def write_out_bedgraph_all(contribution_scores, interval, filename_base):
    start = interval.start
    end =  interval.end
    chrom = interval.chrom
    name = '_'.join([str(chrom), str(start), str(end)])

    out_file = open(filename_base + '.all.bedGraph', 'w')


    for k, value in enumerate(contribution_scores):

        start_interval = start + k
        end_interval = start + k + 1

        line = [str(chrom),
                str(start_interval), str(end_interval),
                str(value)]

        out_file.write('\t'.join(line) + '\n')
    out_file.close()


In [6]:
model = Enformer(model_path)
fasta_extractor = FastaStringExtractor(fasta_file)

In [7]:
# @title Compute contribution scores
# our target interval is 
mask_indices=[446,447,448,449,450]
    
out = importance_scores("chr12", 9760820, 9760903,4831, mask_indices)
resized_int, scores, pooled, base_scores,grad, seq_one_hot,base_scores_shuff = out


write_out_bedgraph_pooled(pooled, resized_int, 'CD69.pooled.bedGraph')
write_out_bedgraph_all(scores, resized_int, 'CD69.all.bedGraph')