# example how to score seqeunces with Enformer

In [None]:
import pandas as pd
import numpy as np
import kipoiseq
from kipoiseq import Interval
import pyfaidx
import gzip
import shutil
#import wget
import swifter
import os

from pyfastaq.sequences import file_reader as fasta_reader
import tensorflow as tf
import tensorflow_hub as hub
from os import listdir

In [None]:
import ssl # to be able to load model
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
SEQUENCE_LENGTH = 393_216
bases=list("ACGT")
model_path = 'https://tfhub.dev/deepmind/enformer/1'

PRED_SQ_SIZE = 114_688
BIN_SIZE = 128

In [None]:
# relevant functions Enformer

def extract_positions(prediction, experiment_ids, annotations, start_position, end_position, 
                      full=False,
                      fixed_point=None):
    """
    Extracts predicted values given the prediction, IDs of the required experiments, and the positions of interest.
    """
    if fixed_point is None:
        fixed_point = start_position
    to_pad_left = annotations[1]["to_pad_left"]
    
    start_view = fixed_point - to_pad_left
    end_view = start_view + SEQUENCE_LENGTH
    middle = (start_view + end_view) // 2
    pred_start, pred_end = middle - (PRED_SQ_SIZE // 2), middle + (PRED_SQ_SIZE // 2) 
    
    if full: # return full prediction, with start/end annotation
        return prediction[0][:, experiment_ids], pred_start, pred_end
    
    insert_start = item_position(start_position, pred_start, pred_end)[0]
    insert_end = item_position(end_position, pred_start, pred_end)[1]
    
    return prediction[0][insert_start:insert_end, experiment_ids], insert_start*128+pred_start, insert_end*128+pred_start


def item_position(position, pred_start, pred_end, pred_size=896):
    posi = np.linspace(pred_start, pred_end, pred_size)
    snp_upper = np.where(posi >= position)[0].min()
    snp_lower = np.where(posi <= position)[0].max()
    return (snp_lower, snp_upper)



class FastaStringExtractor:
    def __init__(self, fasta_file):
        self.fasta = pyfaidx.Fasta(fasta_file)
        self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}
    def extract(self, interval: Interval, **kwargs) -> str:
        # Truncate interval if it extends beyond the chromosome lengths.
        chromosome_length = self._chromosome_sizes[interval.chrom]
        trimmed_interval = Interval(interval.chrom,
                                    max(interval.start, 0),
                                    min(interval.end, chromosome_length),)
        # pyfaidx wants a 1-based interval
        sequence = str(self.fasta.get_seq(trimmed_interval.chrom,
                                          trimmed_interval.start + 1,
                                          trimmed_interval.stop).seq).upper()
        # Fill truncated values with N's.
        pad_upstream = 'N' * max(-interval.start, 0)
        pad_downstream = 'N' * max(interval.end - chromosome_length, 0)
        return pad_upstream + sequence + pad_downstream
    def close(self):
        return self.fasta.close()



class Enformer:
    def __init__(self, tfhub_url):
        self._model = hub.load(tfhub_url).model
    def predict_on_batch(self, inputs):
        predictions = self._model.predict_on_batch(inputs)
        return {k: v.numpy() for k, v in predictions.items()}
    @tf.function
    def contribution_input_grad(self, input_sequence,
                              target_mask, output_head='human'):
        input_sequence = input_sequence[tf.newaxis]
        target_mask_mass = tf.reduce_sum(target_mask)
        with tf.GradientTape() as tape:
            tape.watch(input_sequence)
            prediction = tf.reduce_sum(
              target_mask[tf.newaxis] *
              self._model.predict_on_batch(input_sequence)[output_head]) / target_mask_mass
        input_grad = tape.gradient(prediction, input_sequence) * input_sequence
        input_grad = tf.squeeze(input_grad, axis=0)
        return tf.reduce_sum(input_grad, axis=-1)


def one_hot_encode(sequence):
    return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)



def list_experiments(organism):
    """
    Returns a pandas DataFrame with all available experiments for prediction.
    :param organism: "human" or "mouse"
    :return: pandas DataFrame with all available experiments for prediction
    """
    # TODO
    return ...





In [None]:
# load model
model=Enformer(model_path) 

In [None]:
# the input fasta file contains for each modified enhancer 10 slightly shifted versions of the same seqeunce in its larger genomic context

file = '/input_path/5p7_library_seqeunces_syn_enh_lib_12bp_for_enformer_left_align_200kb_homL_10_different_shifts_seed1_of128shiftbins_collapsed_2024_09_27.fasta'

df = pd.DataFrame(
    [(entry.id, entry.seq) for entry in fasta_reader(file)], columns=["id", "seq"]
)
df.head()

import random

random.seed(1)
# depending on library either 10,12,14 bp barcodes
to_replace10 = "NNNNNNNNNN"
to_replace12 = "NNNNNNNNNNNN"
to_replace14 = "NNNNNNNNNNNNNN"

barcode_replicate_n = 3 # three random barcodes for each genotype and shift

dfs = []

for i in range(barcode_replicate_n):    
    bdf = df.copy()
    random_barcode = "".join(np.random.choice(bases, size=len(to_replace12)))
    bdf["seq"] = bdf["seq"].str.replace(to_replace12, random_barcode)
    bdf["id"] = bdf["id"] + f"_bc_{random_barcode}"
    dfs.append(bdf)

fdfs = pd.concat(dfs)
fdfs.to_csv('outpath/Info_df_Enformerseqs_lib_5p7_3BCs_seed1_10postional_variations_seed1_2024_10_21.csv')


In [None]:
## pred
exp_ind = [12, 69, 5110,688]  # heads used for predictions (all GM12878)

one_hot = [one_hot_encode(fdfs['seq'].iloc[x]) for x in range(fdfs.shape[0])]

pred=[pd.DataFrame(model.predict_on_batch(one_hot[x][np.newaxis])['human'][0][:,exp_ind])   for x in range(len(one_hot))]

predk=pd.concat(pred)

predk.to_csv('outpath/Pred_Enf_dnase2x_h3k27ac_CAGE_12_69_5110_688_lib_5p7_3BCs_seed1_10postional_variations_seed1_2024_10_21.csv')
