In [32]:
from importlib import reload
import re
import random
random.seed(1337)
import os
import pickle
import itertools
import functools
import operator

import numpy as np
np.random.seed(1337)
import pandas as pd
import scipy.stats as stats
import ahocorasick

import matplotlib.pyplot as plt
import seaborn as sns

import keras
from keras import backend as K
from keras import losses
from keras.models import load_model
from keras.utils import Sequence
from keras.callbacks import ModelCheckpoint, EarlyStopping

from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot

import model
import utils

### Processing functions

In [80]:
class EncodingFunction:

    def __init__(self, name):
        self.name = name
    
    def __call__(self, df):
        pass

class PrecomputeFunction(EncodingFunction):
    
    def __init__(self, new_col, dims, method="stack"):
        self.new_col = new_col
        self.dims = dims
        self.method = method
        super().__init__(new_col)

In [81]:
class SeqLenExtractor(PrecomputeFunction):
    
    def __init__(self, seq_col, new_col):
        self.seq_col = seq_col
        super().__init__(new_col, dims=(1,))
        
    def __call__(self, df):
        return df[self.seq_col].str.len()

class KmerExtractor(PrecomputeFunction):
    
    def __init__(self, seq_col, new_col, k, jump=False, divide_counts=True):
        self.k = k
        self.seq_col = seq_col
        kmers = [''.join(i) for i in itertools.product(["A","C","T","G"], repeat = self.k)]
        self.n = len(kmers)
        self.kmer_dict = {kmers[k]:k for k in range(self.n)}
        self.jump = jump
        self.divide_counts = divide_counts
        super().__init__(new_col, dims=(self.n,))
    
    def extract(self, seq):
        i = 0
        arr = np.zeros(self.n)
        while i < len(seq) - (self.k - 1):
            arr[self.kmer_dict[seq[i:i+self.k]]] = arr[self.kmer_dict[seq[i:i+self.k]]] + 1
            if self.jump:
                i = i + self.k
            else:
                i += 1
        if self.divide_counts:
            arr/np.sum(arr)
        return arr

    def __call__(self, df):
        return df[self.seq_col].apply(self.extract)

class KmerAtPosExtractor(PrecomputeFunction):

    def __init__(self, seq_col, new_col, positions):
        self.seq_col = seq_col
        i = 0
        self.pos_kmer_dict = {}
        for start, stop in positions:
            k = np.abs(stop - start)
            kmers = [''.join(i) for i in itertools.product(["A","C","T","G"], repeat = k)]
            kmer_dict = {kmers[k]:k+i for k in range(len(kmers))}
            i += len(kmers)
            self.pos_kmer_dict[(start, stop)] = kmer_dict
        self.n = i
        self.pos = positions
        super().__init__(new_col, dims=(self.n,))
    
    def extract(self, seq):
        arr = np.zeros(self.n)
        for interval in self.pos:
            kmer_dict = self.pos_kmer_dict[interval]
            start, stop = interval
            arr[kmer_dict[seq[start:stop]]] = arr[kmer_dict[seq[start:stop]]] + 1
        return arr

    def __call__(self, df):
        return df[self.seq_col].apply(self.extract)
    
class GCContentExtractor(PrecomputeFunction):
    
    def __init__(self, seq_col, new_col):
        self.seq_col = seq_col
        super().__init__(new_col, dims=(1,))
        
    def __call__(self, df):
        return (df[self.seq_col].str.count("G") + 
                df[self.seq_col].str.count("C"))/df[self.seq_col].str.len()
    
# Counts specific motifs (e.g. PolyA sites)
class MotifExtractor(PrecomputeFunction):
     
    def __init__(self, seq_col, new_col, motifs):
        self.seq_col = seq_col
        self.n = len(motifs)
        self.motifs = motifs
        self.ahoAutomat = ahocorasick.Automaton()
        for idx, key in enumerate(motifs):
            self.ahoAutomat.add_word(key, idx)
        self.ahoAutomat.make_automaton()
        super().__init__(new_col, dims=(self.n,))
    
    def extract_motives(self, seq):
        # Use ahocorasick
        arr = np.zeros(self.n)
        for end_index, idx in self.ahoAutomat.iter(seq):
            arr[idx] += 1
        return arr
            
    def __call__(self, df):
        return df[self.seq_col].apply(self.extract_motives)

In [58]:
class NodererScore(PrecomputeFunction):
    
    def __init__(self, noderer_df_aug, noderer_df_nonaug, new_col="noderer",
                utr_col="utr", cds_col="cds",
                seq_col="sequence", score_col="efficiency"):
        self.utr_col, self.cds_col = utr_col, cds_col
        # replace U with T in Noderer dataframe
        noderer_df_aug[seq_col] = noderer_df_aug[seq_col].str.replace("U", "T")
        noderer_df_nonaug[seq_col] = noderer_df_nonaug[seq_col].str.replace("U", "T")
        self.avg_score = noderer_df_nonaug[score_col].median()
        # build dictionary
        self.score_dict_aug = {k:v for k,v in zip(noderer_df_aug[seq_col], 
                                                  noderer_df_aug[score_col])}
        self.score_dict_nonaug = {k:v for k,v in zip(noderer_df_nonaug[seq_col], 
                                              noderer_df_nonaug[score_col])}
        super().__init__(new_col, dims=(1,))
    
    def score(self, tis):
        score = self.score_dict_aug.get(tis)
        if score is None:
            score = self.score_dict_nonaug.get(tis)
            if score is None:
                score = self.avg_score
        return score
    
    def __call__(self, df):
        tis = df[self.utr_col].str[-6:]
        tis = tis.str.cat(df[self.cds_col].str[:5])
        return tis.apply(self.score)

class PrecomputeEmbeddings(PrecomputeFunction):
    
    def __init__(self, new_col, model, layer_name, input_layer_names, 
                 generator_encoding_functions, node=0):
        target_obj = model.get_layer(layer_name).get_output_at(node)
        target = [target_obj]
        self.check_fn = K.function([model.get_layer(x).input for x in input_layer_names], target)
        self.generator_encoding_functions = generator_encoding_functions.copy()
        super().__init__(new_col, dims=target_obj.shape, method="concat")
    
    def __call__(self, df):
        generator = DataSequence(df, encoding_functions=self.generator_encoding_functions, 
                                 shuffle=False)
        l = [self.check_fn(x)[0] for x in generator]
        return functools.reduce(operator.concat, [np.vsplit(x, x.shape[0]) for x in l])

In [82]:
# Encoding functions
class DataFrameExtractor(EncodingFunction):
    
    def __init__(self, col, method="stack"):
        self.col = col
        super().__init__(col)
        self.method = method
        
    def __call__(self, df):
        if self.method == "stack":
            return np.stack(df[self.col], axis = 0)
        else:
            return np.concatenate(df[self.col], axis = 0)
        
class OneHotEncoder(EncodingFunction):
    
    def __init__(self, col):
        self.col = col
        super().__init__(col)
    
    def __call__(self, df):
        max_len = len(max(df[self.col], key=len))
        return np.stack([utils.encode_seq(seq, max_len) for seq in df[self.col]], axis = 0)

class LibraryEncoder(EncodingFunction):
    
    def __init__(self, col, n_libs=6):
        self.col = col
        self.n_libs=6
        super().__init__(col)
    
    def __call__(self, df):
        return utils.encode_experiment(df, col=self.col, n_libs=self.n_libs)

In [48]:
# Data generator
class DataSequence(Sequence):
    
    def __init__(self, df, precomputations=[], 
                 encoding_functions=[],
                 input_order=None,
                 output_encoding_fn=None,
                 batch_size=128, shuffle=True):
        self.df = df.copy()
        self.encoding_functions = encoding_functions.copy()
        self.output_encoding_fn = output_encoding_fn
        self.indices = np.arange(len(self.df))
        self.batch_size = batch_size
        for fn in precomputations:
            print("Doing precomputation: " + fn.new_col)
            self.encoding_functions.append(DataFrameExtractor(fn.new_col, fn.method))
            self.df[fn.new_col] = fn(df)
        if input_order is not None:
            fn_dict = {fn.name:fn for fn in self.encoding_functions}
            self.encoding_functions = [fn_dict[name] for name in input_order]
        self.shuffle = shuffle
        super().__init__()

    def __len__(self):
        return int(np.ceil(len(self.df) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_df = self.df.iloc[self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]]
        # Feed input
        inputs = [fn(batch_df) for fn in self.encoding_functions]
        if self.output_encoding_fn is None:
            return inputs
        else:
            return (inputs, self.output_encoding_fn(batch_df))
            
    def on_epoch_end(self):
        'Updates indices after each epoch'
        self.indices = np.arange(len(self.df))
        if self.shuffle:
            np.random.shuffle(self.indices)

In [60]:
def compute_all_metrics(model, encoding_fn, mpra_df, endo_df, 
                        snv_df=None, ribo_df=None, ptr_df=None, 
                        do_test=False):
    mpra_val = mpra_df[(mpra_df["set"] == "val") & (mpra_df["library"] == "egfp_unmod_1")]
    endo_val = endo_df[endo_df["set"] == "val"]
    data_list = [mpra_val, endo_val]
    if do_test:
        mpra_test = mpra_df[(mpra_df["set"] == "test") & (mpra_df["library"] == "egfp_unmod_1")]
        endo_test = endo_df[endo_df["set"] == "test"]
        data_list += [mpra_test, endo_test]
    for df in data_list:
        pred_gen = DataSequence(df, encoding_fn)
        predictions = model.predict_generator(pred_gen)
        print(utils.rSquared(predictions.reshape(-1), df["rl"]))
        utils.print_corrs(df["rl"], predictions.reshape(-1))

### Data import

In [7]:
with open("../Data/data_dict.pkl", 'rb') as handle:
    data_dict = pickle.load(handle)

data_df = data_dict["data"]
snv_df = data_dict["snv"]
ptr_df = data_dict["ptr"]

with open("../Data/ribo_dict.pkl", 'rb') as handle:
    data_dict = pickle.load(handle)
eichhorn_df = data_dict["eichhorn"]
eichhorn_df["log_load"] = np.log(eichhorn_df["RPF_RPKM"]/eichhorn_df["RNA_RPKM"])

with open("../Data/doudna_polysome_iso_sub.pkl", 'rb') as handle:
    doudna_df = pickle.load(handle)
doudna_df["library"] = "egfp_unmod_1"
doudna_df = doudna_df.rename(index=str, columns={"rl_mean":"rl"})

noderer_df_aug = pd.read_csv("../Data/TIS/tis_efficiencies_aug.tsv", sep="\t")
noderer_df_nonaug = pd.read_csv("../Data/TIS/tis_efficiencies_nonaug.tsv", sep="\t")
noderer_df_nonaug = noderer_df_nonaug.rename(index=str, columns={"TIS_Sequence":"sequence", 
                                             "TIS_Efficiency":"efficiency"})

In [8]:
train_doudna = doudna_df[["utr", "rl", "library", "set", "cds", "3utr"]]
train_doudna = train_doudna[train_doudna["set"] == "train"]

In [9]:
val_doudna = doudna_df[["utr", "rl", "library", "set", "cds", "3utr"]]
val_doudna = val_doudna[val_doudna["set"] == "val"]

In [23]:
base_model.evaluate_generator(base_generator)

0.25135341877937317

### Prepare transfer model

In [11]:
base_model = load_model("../Models/basic_model_scaled.h5", custom_objects={'FrameSliceLayer': model.FrameSliceLayer})

In [83]:
# Encoding functions
one_hot_fn = OneHotEncoder("utr")
library_fn = LibraryEncoder("library")
one_hot_fn_cds = OneHotEncoder("cds")
one_hot_fn_3utr = OneHotEncoder("3utr")

In [84]:
# kmer inputs
# 5utr
utr5_len = SeqLenExtractor("utr", "5utr_len")
utr5_gc = GCContentExtractor("utr", "5utr_gc")
kmer5utr = [utr5_len, utr5_gc]
kmer_inputs_5utr = [model_combined.model_input(fn.dims, "input_"+ fn.new_col) for fn in kmer5utr]
# cds
cds_codon_bias = KmerExtractor("cds", "cds_codon_bias", k=3, jump=True)
cds_len = SeqLenExtractor("cds", "cds_len")
cds_gc = GCContentExtractor("cds", "cds_gc")
kmercds = [cds_codon_bias, cds_len, cds_gc]
kmer_inputs_cds = [model_combined.model_input(fn.dims, "input_"+ fn.new_col) for fn in kmercds]
# 3 utr
utr3_len = SeqLenExtractor("3utr", "3utr_len")
utr3_gc = GCContentExtractor("3utr", "3utr_gc")
kmer3utr = [utr3_len, utr3_gc]
kmer_inputs_3utr = [model_combined.model_input(fn.dims, "input_"+ fn.new_col) for fn in kmer3utr]

In [40]:
# transfer inputs
embeddings_fn = PrecomputeEmbeddings("embeddings", base_model, "fully_connected", 
                                     ["input_seq", "input_experiment"], 
                                     generator_encoding_functions=[one_hot_fn, library_fn], 
                                     node=0)
embedding_input = model_combined.model_input((64 ,), "input_embedding")

In [87]:
encoding_functions = [one_hot_fn, one_hot_fn_cds, one_hot_fn_3utr]
precomputations = kmer5utr + kmercds + kmer3utr
precomputations = [embeddings_fn] + precomputations
input_order=["utr"] + [fn.new_col for fn in kmer5utr] + \
             ["cds"] + [fn.new_col for fn in kmercds] + \
            ["3utr"] + [fn.new_col for fn in kmer3utr]
precomputations = [embeddings_fn] + precomputations
input_order = input_order + ["embeddings"]
output_encoding_fn = DataFrameExtractor("rl")

In [88]:
train_gen = DataSequence(train_doudna, precomputations=precomputations, encoding_functions=encoding_functions,
                        output_encoding_fn=output_encoding_fn,
                        input_order=input_order,
                        batch_size=32)

Doing precomputation: 5utr_len
Doing precomputation: 5utr_gc
Doing precomputation: cds_codon_bias
Doing precomputation: cds_len
Doing precomputation: cds_gc
Doing precomputation: 3utr_len
Doing precomputation: 3utr_gc


In [14]:
val_gen = DataSequence(val_doudna, precomputations=precomputations, encoding_functions=encoding_functions,
                      output_encoding_fn=output_encoding_fn,
                      input_order=input_order,
                      batch_size=32, shuffle=False)

Doing precomputation: embeddings


In [16]:
import model_combined

In [None]:
reload(model_combined)
utr_5kmer

In [28]:
reload(model)
reload(model_combined)
utr5_conv = model_combined.pooled_conv_model(n_conv_layers=1, 
                        kernel_size=[8], n_filters=64, dilations=[1],
                        padding="same", use_batchnorm=False,
                        use_inception=False, skip_connections="",
                        single_output=False,
                        n_dense_layers=0, fc_neurons=[32], fc_drop_rate=0.2,
                        prefix="5utr_")
cds_conv = model_combined.pooled_conv_model(n_conv_layers=1, 
                        kernel_size=[8], n_filters=32, dilations=[1],
                        padding="same", use_batchnorm=False,
                        use_inception=False, skip_connections="", 
                        single_output=False,
                        n_dense_layers=0, fc_neurons=[32], fc_drop_rate=0.2,
                        prefix="cds_")
utr3_conv = model_combined.pooled_conv_model(n_conv_layers=1, 
                        kernel_size=[8], n_filters=32, dilations=[1],
                        padding="same", use_batchnorm=False,
                        use_inception=False, skip_connections="", 
                        single_output=False,
                        n_dense_layers=0, fc_neurons=[32], fc_drop_rate=0.2,
                        prefix="utr3_")

In [29]:
reload(model_combined)
utr5_model = combined_conv_kmer(utr5_conv, utr5_kmer, prefix="5utr")
cds_model = combined_conv_kmer(utr5_conv, utr5_kmer, prefix="cds")
_model = combined_conv_kmer(utr5_conv, utr5_kmer, prefix="cds")
combined_model = model_combined.transfer_model(utr5_conv, cds_conv, utr3_conv,
                  transfer_inputs=[embedding_input], n_transfer_layers=0, 
                  n_combine_layers=1, combine_neurons=[32])

In [30]:
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)
mc = ModelCheckpoint("combined.h5", monitor='val_loss', mode='min', verbose=1, save_best_only=True)
combined_model.fit_generator(generator, epochs=30, verbose=2, validation_data=val_gen,
                            callbacks=[es, mc])

Epoch 1/30
 - 197s - loss: 1.9373 - val_loss: 1.8187

Epoch 00001: val_loss improved from inf to 1.81874, saving model to combined.h5
Epoch 2/30
 - 184s - loss: 1.8137 - val_loss: 1.7925

Epoch 00002: val_loss improved from 1.81874 to 1.79253, saving model to combined.h5
Epoch 3/30
 - 186s - loss: 1.7543 - val_loss: 1.8784

Epoch 00003: val_loss did not improve from 1.79253
Epoch 4/30
 - 182s - loss: 1.7027 - val_loss: 2.0157

Epoch 00004: val_loss did not improve from 1.79253
Epoch 5/30
 - 187s - loss: 1.6338 - val_loss: 1.7989

Epoch 00005: val_loss did not improve from 1.79253
Epoch 6/30
 - 191s - loss: 1.5622 - val_loss: 1.8171

Epoch 00006: val_loss did not improve from 1.79253
Epoch 7/30
 - 183s - loss: 1.4939 - val_loss: 1.7766

Epoch 00007: val_loss improved from 1.79253 to 1.77661, saving model to combined.h5
Epoch 8/30
 - 190s - loss: 1.4067 - val_loss: 1.8533

Epoch 00008: val_loss did not improve from 1.77661
Epoch 9/30
 - 183s - loss: 1.3518 - val_loss: 1.8817

Epoch 00009

<keras.callbacks.History at 0x7f8ad4574198>

In [302]:
combined_model.evaluate_generator(generator)

1.1153790229742846

## Evaluate

In [19]:
combined_model = load_model("combined2.h5",  custom_objects={'FrameSliceLayer': model.FrameSliceLayer,
                                                           "correlation_coefficient_loss":
                                                            model_combined.correlation_coefficient_loss})

In [31]:
generator = DataSequence(train_doudna, precomputations=precomputations, encoding_functions=encoding_functions,
                        output_encoding_fn=output_encoding_fn,
                        batch_size=32)
predictions = combined_model.predict_generator(generator)
print(utils.rSquared(predictions.reshape(-1), train_doudna["rl"]))
utils.print_corrs(train_doudna["rl"], predictions.reshape(-1))

Doing precomputation: embeddings
0.5895364006084826
Pearson: 0.792, p-val: 0.000, squared: 0.627, Spearman: 0.789, p-val: 0.000


In [21]:
predictions = combined_model.predict_generator(val_gen)
print(utils.rSquared(predictions.reshape(-1), val_doudna["rl"]))
utils.print_corrs(val_doudna["rl"], predictions.reshape(-1))

0.05038705961306167
Pearson: 0.240, p-val: 0.000, squared: 0.058, Spearman: 0.231, p-val: 0.000


In [22]:
test_doudna = doudna_df[["utr", "rl", "library", "set", "cds", "3utr"]]
test_doudna = test_doudna[test_doudna["set"] == "test"]
test_gen = DataSequence(test_doudna, precomputations=precomputations, encoding_functions=encoding_functions,
                      output_encoding_fn=output_encoding_fn,
                      batch_size=32, shuffle=False)
predictions = combined_model.predict_generator(test_gen)
print(utils.rSquared(predictions.reshape(-1), test_doudna["rl"]))
utils.print_corrs(test_doudna["rl"], predictions.reshape(-1))

Doing precomputation: embeddings
0.06082193895161725
Pearson: 0.269, p-val: 0.000, squared: 0.072, Spearman: 0.264, p-val: 0.000


In [26]:
combined_model.get_layer("transfer_dense_0").get_weights()[6]

(65, 64)

### Prepare model

In [55]:
import model_combined

In [125]:
reload(model)
reload(model_combined)
utr5_conv = model_combined.framed_pooled_conv_model(n_conv_layers=5, 
                        kernel_size=[11,3,3,3,3], n_filters=128, dilations=[1, 2, 4, 8, 16],
                        padding="same", use_batchnorm=False,
                        use_inception=False, skip_connections="residual", 
                        n_dense_layers=1, fc_neurons=[64], fc_drop_rate=0.2)
utr3_conv = model_combined.kmer_linear_model(prefix="utr3")
cds_conv = model_combined.kmer_linear_model(prefix="cds")

In [126]:
reload(model_combined)
combined_model = model_combined.combined_model_noshortcut(utr5_conv, cds_conv, utr3_conv,
                  loss=model_combined.correlation_coefficient_loss)

In [105]:
encoding_fn = gen_encodingfn(kmer_gen(), kmer_gen(), col='utr', libcol="library", n_libs=7,
                  output_col="rl", cds_col="cds", utr3_col="3utr")
train_gen = BalancedDataSequence(train_doudna, data_df[(data_df["library"] == "egfp_unmod_1") & (data_df["set"] == "train")], encoding_fn)

In [106]:
val_gen = DataSequence(val_doudna, encoding_fn)

In [127]:
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=5)
mc = ModelCheckpoint("combined2.h5", monitor='val_loss', mode='min', verbose=1, save_best_only=True)
combined_model.fit_generator(train_gen, epochs=30, verbose=2, validation_data=val_gen,
                            callbacks=[es, mc])

Epoch 1/30
 - 302s - loss: 0.5566 - val_loss: 0.9903

Epoch 00001: val_loss improved from inf to 0.99033, saving model to combined2.h5
Epoch 2/30
 - 314s - loss: 0.3522 - val_loss: 0.9907

Epoch 00002: val_loss did not improve from 0.99033
Epoch 3/30
 - 309s - loss: 0.3343 - val_loss: 0.9670

Epoch 00003: val_loss improved from 0.99033 to 0.96705, saving model to combined2.h5
Epoch 4/30
 - 315s - loss: 0.3262 - val_loss: 0.9567

Epoch 00004: val_loss improved from 0.96705 to 0.95671, saving model to combined2.h5
Epoch 5/30
 - 314s - loss: 0.3255 - val_loss: 0.9669

Epoch 00005: val_loss did not improve from 0.95671
Epoch 6/30
 - 300s - loss: 0.3175 - val_loss: 0.9602

Epoch 00006: val_loss did not improve from 0.95671
Epoch 7/30
 - 309s - loss: 0.3142 - val_loss: 0.9640

Epoch 00007: val_loss did not improve from 0.95671
Epoch 8/30
 - 308s - loss: 0.3098 - val_loss: 0.9535

Epoch 00008: val_loss improved from 0.95671 to 0.95354, saving model to combined2.h5
Epoch 9/30
 - 294s - loss: 0

<keras.callbacks.History at 0x7fd015ebc710>

In [131]:
combined_model = load_model("combined2.h5",  custom_objects={'FrameSliceLayer': model.FrameSliceLayer,
                                                           "correlation_coefficient_loss":
                                                            model_combined.correlation_coefficient_loss})

In [94]:
doudna_train = doudna_df[doudna_df["set"] == "train"]     
pred_gen = DataSequence(doudna_train, encoding_fn)
predictions = combined_model.predict_generator(pred_gen)
print(utils.rSquared(predictions.reshape(-1), doudna_train["rl"]))
utils.print_corrs(doudna_train["rl"], predictions.reshape(-1))

-5.563278239517302
Pearson: -0.485, p-val: 0.000, squared: 0.235, Spearman: -0.502, p-val: 0.000


In [132]:
compute_all_metrics(combined_model, encoding_fn, data_df, doudna_df, do_test=True)

-1214.2984697729644
Pearson: -0.914, p-val: 0.000, squared: 0.835, Spearman: -0.889, p-val: 0.000
-333.29901437130246
Pearson: -0.197, p-val: 0.000, squared: 0.039, Spearman: -0.232, p-val: 0.000
-819.6939772577493
Pearson: -0.922, p-val: 0.000, squared: 0.850, Spearman: -0.901, p-val: 0.000
-360.0016793565845
Pearson: -0.213, p-val: 0.000, squared: 0.045, Spearman: -0.237, p-val: 0.000


In [None]:
class BalancedDataSequence(Sequence):
    
    def __init__(self, df_exo, df_endo, encoding_fn, batch_size=64, extra_keys=[], shuffle=True):
        self.df_exo, self.df_endo = df_exo, df_endo
        self.encoding_fn = encoding_fn
        self.indices_exo = np.arange(len(self.df_exo))
        self.indices_endo = np.arange(len(self.df_endo))
        self.extra_keys = extra_keys
        self.batch_size = batch_size
        self.shuffle = shuffle
        super().__init__()

    def __len__(self):
        return int(np.ceil(len(self.indices_exo) / float(self.batch_size)))

    def __getitem__(self, idx):
        # Get some exogenous data
        batch_df_exo = self.df_exo.iloc[
            self.indices_exo[idx * self.batch_size:(idx + 1) * self.batch_size]]
        # Get a matching amount of endogenous data
        idx_endo = idx
        if (idx_endo + 1) * self.batch_size >= len(self.df_endo):
            idx_endo = 0
            np.random.shuffle(self.indices_endo)
        batch_df_endo = self.df_endo.iloc[
            self.indices_endo[idx_endo * self.batch_size:(idx_endo + 1) * self.batch_size]]
        # Concatenate
        batch_df = pd.concat([batch_df_exo, batch_df_endo])
        # Prepare input data
        encoded_data = self.encoding_fn(batch_df)
        # Feed input
        inputs = [encoded_data["seq"], encoded_data["cds_seq"], encoded_data["utr3_seq"],
                  encoded_data["seqtype"]]
        for key in self.extra_keys:
            inputs.append(encoded_data[key])
        if encoded_data.get("rl") is None:
            return inputs
        else:
            return (inputs, encoded_data["rl"])
            
    def on_epoch_end(self):
        'Updates indices after each epoch'
        self.indices_exo = np.arange(len(self.df_exo))
        self.indices_endo = np.arange(len(self.df_endo))
        if self.shuffle:
            np.random.shuffle(self.indices_exo)
            np.random.shuffle(self.indices_endo)