In [1]:
import numpy as np
import pandas as pd 
from pathlib import Path
import tokenizers

import tensorflow as tf
import tensorflow.keras
import tensorflow.keras.backend as K
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model 
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras import initializers, regularizers, constraints
from sklearn.preprocessing import LabelEncoder
import sklearn

import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
import seaborn as sns

from utils.att_layer import Attention 
from utils.data_loader import load_and_prepare

from Bio import SeqIO
import os
import logomaker


%matplotlib inline
sns.set()

In [11]:
def sort_based_on_pred_score(pred_list, asc_or_desc):
    reverse = asc_or_desc
    return(sorted(pred_list, key = lambda x: x[2], reverse=reverse))


def replace_ts_and_same_len(list_of_seqs, min_length, same_len=True):
    for i, el in enumerate(list_of_seqs):
        if same_len:
            act_len = len(el)
            mid = act_len // 2
            start = mid - min_length // 2
            stop = mid + min_length // 2
            list_of_seqs[i] = el[start:stop].lower().replace('t','u').upper()
        else:
            list_of_seqs[i] = el.lower().replace('t','u').upper()   
    return list_of_seqs

def get_important_subseqs_and_corresp_seqs(tokenizer_path, seq, cons, sec, model, subseq_len=20, number_of_seqs=50):
    tokenizer = tokenizers.Tokenizer.from_file(tokenizer_path)

    id_seq_list = []

    model_att = Model(inputs=model.input, \
                            outputs=[model.output, model.get_layer('att').output[-1]])
    
    for i in range(len(seq)):
        start_counter = subseq_len
        tokenized_sample = np.trim_zeros(seq[i])

        label_probs, attentions = model_att.predict([seq[i:i+1], cons[i:i+1], sec[i:i+1]])

        # Get decoded text and labels
        id2word = dict(map(reversed, tokenizer.get_vocab().items()))
        decoded_text = [id2word[word] for word in tokenized_sample] 

        # Get classification
        label = (label_probs>0.5).astype(int).squeeze()

        attentions_text = attentions[0,:len(tokenized_sample)]
        attentions_text = (attentions_text - np.min(attentions_text)) / (np.max(attentions_text) - np.min(attentions_text))

        tok_att_list = []
        for token, attention_score in zip(decoded_text, attentions_text):
            tok_att_list.append([token, attention_score])

        unified_att_nts = []
        for i in tok_att_list:
            tok, score = i
            nts = [nt for nt in tok]
            for nt in nts:
                unified_att_nts.append([nt, score])
        
        att_list = []
        for i in range(len(unified_att_nts)):
            _, attention = unified_att_nts[i]
            att_list.append(attention)

        maximum = 0
        max_i = 0
        for i in range(len(att_list) - start_counter + 1):
            slide = sum(att_list[i : i + start_counter])
            if slide > maximum:
                maximum = slide
                max_i = i

        max_seq = ''
        for tok, _ in unified_att_nts[max_i : max_i + start_counter]:
            max_seq += tok
            
        while len(max_seq) > subseq_len:
            
            start_counter = start_counter - 1
            max_seq = ''
            
            maximum = 0
            max_i = 0
            for i in range(len(att_list)-start_counter+ 1):
                slide = sum(att_list[i:i+start_counter])
                if slide > maximum:
                    maximum = slide
                    max_i = i  
            for nt, _ in unified_att_nts[max_i:max_i+start_counter]:
                max_seq += nt
                
        whole_seq = ''
        for nt, _ in unified_att_nts:
            whole_seq += nt

        id_seq_list.append([int(label), max_seq, float(label_probs), whole_seq])
    
    
    pos_id_seqs = []
    
    for i in id_seq_list:
        if i[0] == 1:
            pos_id_seqs.append([i[1], i[3], i[2]])
    
    pos_id_seqs = sort_based_on_pred_score(pos_id_seqs, True)
    
    if len(pos_id_seqs) >= number_of_seqs:
        pos_id_seqs = pos_id_seqs[:number_of_seqs]

    pos_pred_subseqs = [x[0] for x in pos_id_seqs]
   
    min_length = subseq_len
    
    return replace_ts_and_same_len(pos_pred_subseqs, min_length)

def produce_seq_logo(biopython_pwm, bits=True, filename='None'):
    pwm_dict = dict(biopython_pwm)
    As = list(pwm_dict['A'])
    Cs = list(pwm_dict['C'])
    Gs = list(pwm_dict['G'])
    Us = list(pwm_dict['U'])

    pwm_arr = np.zeros((len(As), 4))
    for i in range(len(As)):
        pwm_arr[i][0] = round(As[i], 2)
        pwm_arr[i][1] = round(Cs[i], 2)
        pwm_arr[i][2] = round(Gs[i], 2)
        pwm_arr[i][3] = round(Us[i], 2)

    ppm = seqlogo.Ppm(pwm_arr, alphabet_type='RNA')
    pwm = seqlogo.ppm2pwm(ppm)
    pwm = seqlogo.Pwm(pwm, alphabet_type="RNA")
    print(pwm)
    print(np.asarray(pwm))
    seqlogo.seqlogo(pwm, ic_scale = bits, format = 'png', size = 'medium', filename=filename)

def mostAbundantKmers(sequences, kmerLen):
    kmerCounter = dict()
    for seq in sequences:
        if seq not in kmerCounter.keys():
            kmerCounter[seq] = 1
        else:
            kmerCounter[seq] += 1
    
    sortedCounter = sorted(kmerCounter.items(), key=lambda x:x[1], reverse=True)
    
    return sortedCounter

def seq_mers(seq, index, kmer_len=6):
    subseqs = []
    for i in range(len(seq) - kmer_len + 1):
        subseqs.append(f'>{index}_{i}\n{seq[i:i+kmer_len]}\n')
    return subseqs

def getSequenceLogo(sequence, outputPath):
    matrix = logomaker.sequence_to_matrix(sequence)
    
    fig, ax = plt.subplots(1,1,figsize=[4,2])
    ax.set_facecolor('white')

    logo = logomaker.Logo(matrix,
                          ax=ax,
                          color_scheme='classic',
                          baseline_width=0,
                          font_name='FreeSans',
                          show_spines=False,
                          width=.90)

    ax.set_xticks([])
    ax.set_yticks([])

    logo.fig.savefig(outputPath, bbox_inches='tight', dpi=250)

## CREATE FASTA FILES WITH TOP 50 PREDS PER CLASS
### IMPORTANT SUBSEQUENCES (BASED ON ATTENTION SCORES)

In [22]:
kmer_len = 6
subseq_len = 20
number_of_seqs = 50

TOK_PATH = 'data/tokenizers/chooseYourTokenizer'
FINETUNED_MODELS_PATH = Path(f'PATH/TO/MODELS/DIR')
PATH_TO_EVAL_DATA = Path('PATH/TO/DATA')

In [23]:
prots_to_train = [os.path.basename(f) for f in os.scandir(FINETUNED_MODELS_PATH) if f.is_dir()]


In [24]:
sortedCounter = []
for prot in prots_to_train:
    prot_dir_path = FINETUNED_MODELS_PATH / prot
    for model_path in prot_dir_path.glob('*.h5'):
        model_name = model_path.stem
        
        data = PATH_TO_EVAL_DATA / prot / 'dev.tsv'
        
        seqs, cons, secs, y = load_and_prepare(data, TOK_PATH, 150)
        model = load_model(model_path, custom_objects={'Attention': Attention})
        tensorflow.keras.backend.set_value(model.optimizer.lr, 0.001)
        model.compile(loss='binary_crossentropy', metrics=['accuracy'], optimizer='Adamax')
        imp_pos = get_important_subseqs_and_corresp_seqs(TOK_PATH,
                                            seqs,
                                            cons,
                                            secs, 
                                            model, 
                                            subseq_len=subseq_len,
                                            number_of_seqs=number_of_seqs)
        outf_p = prot_dir_path / model_name
        outf_p.mkdir(exist_ok=True, parents=True)
              
        impSubSeqsPath = outf_p / f'{prot}_{number_of_seqs}_{subseq_len}.fa'
        with open(impSubSeqsPath, 'w') as outf:
            for index, el in enumerate(imp_pos):
                subseqs = seq_mers(el, index, kmer_len)
                for i in subseqs:
                    outf.write(i)
                
        seqs = []
        for seq_record in SeqIO.parse(impSubSeqsPath, "fasta"):
            seqs.append(str(seq_record.seq))
        
        sortedCounter = mostAbundantKmers(seqs, kmer_len)[:3] # most important 3 kmers

        for i, s in enumerate(sortedCounter):
            seq = s[0]
            outp = outf_p / f'logo_{i}.png'
            getSequenceLogo(seq, outp)

        plt.close('all')