In [None]:
import numpy as np
import tensorflow as tf
import os
import time
import tensorflow as tf
from dress.datasetgeneration.preprocessing.utils import tabular_file_to_genomics_df
from dress.datasetgeneration.preprocessing.gtf_cache import preprocessing
from dress.datasetgeneration.dataset import Dataset
import squid
from squid.predictor import BasePredictor, predict_in_batches
import pandas as pd
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#### Auxiliary functions

In [None]:
class SpliceAIPredictor(BasePredictor):
    """Module for handling SpliceAI model predictions.

    Parameters
    ----------
    pred_fun : function
        Function for returning model predictions.
    task_idx : list
        Position indexes where splice sites of cassette exons are located. Must be a 2-element list. 
        First index must be the acceptor position, second index must be the donor positio .
    batch_size : int
        The number of predictions per batch.
    reduce_fun : function
        Function for reducing predictions at exon  boundaries (acceptor and donor) to scalar.

    Returns
    -------
    torch.Tensor
        Batch of scalar predictions corresponding to inputs.
    """

    def __init__(self, pred_fun, task_idx, batch_size=64, reduce_fun=np.mean, save_dir=None, save_window=None, **kwargs):
        self.pred_fun = pred_fun
        self.task_idx = task_idx
        self.batch_size = batch_size
        self.reduce_fun = reduce_fun
        BasePredictor.save_dir = save_dir
        self.kwargs = kwargs

    def __call__(self, x, x_ref, save_window):
        pred = predict_in_batches(x, x_ref, self.pred_fun, batch_size=self.batch_size, save_window=save_window, **self.kwargs)
        acceptor = pred[:, self.task_idx[0], 1]
        donor = pred[:, self.task_idx[1], 2]
        return self.reduce_fun([acceptor, donor], axis=0)
  

In [None]:
def predict_on_batch_spliceai():

    from spliceai.utils import load_model, resource_filename
    paths = ("models/spliceai{}.h5".format(x) for x in range(1, 6))
    models = [
        load_model(resource_filename("spliceai", x), compile=False) for x in paths
    ]
 
    @tf.function(
        input_signature=(tf.TensorSpec(shape=[None, None, 4], dtype=tf.int32),)
    )
    def predict_batch(batch: tf.Tensor) -> tf.Tensor:
        return tf.reduce_mean([model(batch) for model in models], axis=0)

    return predict_batch

In [None]:
from tqdm import tqdm
from multiprocessing import Pool
    
ltrdict = {tuple([1, 0, 0, 0]): "A",
    tuple([0, 1, 0, 0]): "C",
    tuple([0, 0, 1, 0]): "G",
    tuple([0, 0, 0, 1]): "T",
    tuple([0, 0, 0, 0]): "N"
    }

def process_sequence(sequence):
    return ''.join([ltrdict[tuple(list(nuc.flatten()))] for nuc in np.rollaxis(sequence, 0) if np.any(nuc)])
            
def get_fasta(x: np.ndarray):

    with Pool() as p:
        seqs = list(tqdm(p.imap(process_sequence, np.rollaxis(x, 0)), total=x.shape[0]))

    return seqs

In [None]:
def get_phenotypes(mut_seqs_fasta: list) -> list:
    phenotypes = []
    original_seq = mut_seqs_fasta[0]
    for mut_seq in mut_seqs_fasta[1:]:
        phenotypes.append('|'.join([f"SNV[{i},{nuc2}]" for i, (nuc1, nuc2) in enumerate(zip(original_seq, mut_seq)) if nuc1 != nuc2]))
    return mut_seqs_fasta[1:], phenotypes

In [None]:
def get_scores(y: np.ndarray) -> list:
    '''
    Returns the scores for each mutated sequence 
    and the difference to the original sequence
    '''
    scores, deltascores = [], []
    y_original = y[0]
    for i in range(1, y.shape[0]):
        scores.append(round(y[i], 2))
        deltascores.append(round(y[i] - y_original, 2))

    return scores, deltascores

In [None]:
def get_ss_idx(mut_seqs_fasta: list, original_ss_idx: list) -> list:
    '''
    Returns the splice site indexes of the mutated sequences.
    Because only SNVs are generated, the splice site will be the same as the original seq
    '''
    flat_string = ";".join([";".join(map(str, sublist)) for sublist in original_ss_idx])
    return [flat_string] * len(mut_seqs_fasta)

#### Load data

In [None]:
data = tabular_file_to_genomics_df(
    "0_RBFOX2_ES_events.tsv",
    is_0_based=False,
    header=0,
)
seqs, ss_idx = preprocessing(
    data,
    cache_dir="../data/cache",
    genome="../data/cache/GRCh38.primary_assembly.genome.fa",
    outdir='RBFOX2_knockdown_squid',
    use_full_sequence=False,
)

#### Generate datasets with SQUID

In [None]:
def write_output(dataset: pd.DataFrame, run_id: str, seq_id: str, seed: int, time_elapsed: float):
    '''
    Write DRESS compatible output files
    '''
    # Dataset
    dataset_obj = Dataset(dataset)
    metrics = dataset_obj.metrics
    dataset.to_csv(f"RBFOX2_knockdown_squid/{run_id}_seed_{seed}_dataset.csv.gz", compression='gzip', index=False)

    # Archive logger
    header = [
    "Run_id",
    "Seed",
    "Seq_id",
    "Generation",
    "Execution_time",
    "Archive_quality",
    "Archive_size",
    "Archive_diversity",
    "Archive_avg_diversity_per_bin",
    "Archive_empty_bin_ratio",
    "Archive_low_count_bin_ratio",
    "Archive_avg_number_diff_units",
    "Archive_avg_edit_distance",
    ]

    archive_logger = [run_id, seed, seq_id, 0, round(time_elapsed,4), round(dataset_obj.quality,4), len(dataset_obj), metrics['Diversity'], metrics['Avg_Diversity_per_bin'], metrics['Empty_bin_ratio'], metrics['Low_count_bin_ratio'], metrics['Avg_number_diff_units'], metrics['Avg_edit_distance']]
    pd.DataFrame([archive_logger], columns=header).to_csv(f"RBFOX2_knockdown_squid/{run_id}_seed_{seed}_archive_logger.csv.gz", compression='gzip', index=False)

In [None]:
# Test with 1 sequence only
seqs = {'chr1:972761-973740(+)_ENST00000379410': seqs['chr1:972761-973740(+)_ENST00000379410']}
ss_idx = {'chr1:972761-973740(+)_ENST00000379410': ss_idx['chr1:972761-973740(+)_ENST00000379410']}

In [None]:
alphabet = ['A','C','G','T']
predict_spliceai = predict_on_batch_spliceai()

for i, (seq_id, original_seq) in enumerate(seqs.items()):

    print(f"Processing {seq_id}")
    run_id = seq_id.replace(':', '_').replace('(+)', '').replace('(-)', '').replace('-', '_')

    _ss_idx = ss_idx[seq_id]
    seq_len = len(original_seq)
    x = squid.utils.seq2oh(original_seq, alphabet)
    npad = ((5000, 5000), (0, 0))
    x = np.pad(x, pad_width=npad)

    # Cassette exon positions
    task_idx = [_ss_idx[1][0], _ss_idx[1][1]]
    
    # Account for the padding
    mut_window = [5000, 5000 + seq_len]

    # Initialize the predictor
    pred_generator = SpliceAIPredictor(
    pred_fun=predict_spliceai, task_idx=task_idx, batch_size=64, save_dir=None
    )
    
    # Set up mutagenizer class (define a mutation rate such that the mean edit distance
    # of the generated sequences will be similar to the best GP configuration)
    mut_generator = squid.mutagenizer.RandomMutagenesis(mut_rate=11.3/seq_len, uniform=False)

    # Generate in silico MAVE
    mave = squid.mave.InSilicoMAVE(mut_generator, pred_generator, seq_len, mut_window=mut_window)

    # Generate datasets across 5 seeds
    for seed in range(5):
        print(f"..Seed {seed}")
        start = time.time()

        # Generate 5001 because original seq is the first one
        x_mut, y_mut = mave.generate(x, num_sim=5001, seed=seed)
        end = time.time()
        time_elapsed = end - start
        print(f"..Time elapsed: {time_elapsed // 60:.2f} minutes")
        print("Getting fasta")
        mut_seqs_fasta = get_fasta(x_mut)
        mut_seqs_fasta, phenotypes = get_phenotypes(mut_seqs_fasta)
        scores, deltascores = get_scores(y_mut)
        all_ss_idx = get_ss_idx(mut_seqs_fasta, _ss_idx)
        
        dataset = pd.DataFrame({
            'Run_id': [run_id] * len(mut_seqs_fasta),
            'Seed': [seed] * len(mut_seqs_fasta),
            'Seq_id': [seq_id] * len(mut_seqs_fasta),
            'Phenotype': phenotypes,
            'Sequence': mut_seqs_fasta,
            'Splice_site_positions': all_ss_idx,
            'Score': scores,
            'Delta_score': deltascores
        })

        write_output(dataset, run_id, seq_id, seed, time_elapsed)