In [1]:
# Run these lines only the first time you run this notebook

# !pip install tensorflow_addons gdown shap -q
# !pip install git+https://github.com/katarinagresova/DeepExperiment

# !wget https://github.com/ML-Bioinfo-CEITEC/miRBind/raw/main/Models/miRBind.h5
# !wget https://raw.githubusercontent.com/ML-Bioinfo-CEITEC/miRBind/graphs/Datasets/evaluation_set_1_1_CLASH2013_paper.tsv

In [2]:
# Modules will be reloaded every time a cell is executed
%load_ext autoreload
%autoreload 2

In [3]:
import tensorflow

import json
import random
import cv2 
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import register_keras_serializable
# import tensorflow_addons as tfa
from IPython.display import Image, display
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from textwrap import wrap
import shap
from shap.plots.colors import red_transparent_blue
from Bio import SeqIO

from deepexperiment.alignment import Attrament
from deepexperiment.visualization import plot_alignment, plot_miRNA_importance, plotbar_miRNA_importance
from deepexperiment.utils import one_hot_encoding, one_hot_encoding_batch
from deepexperiment.interpret import DeepShap

from funmirtar.models.constants import HELA_TRANSFACTION_DATA, PSILAC_DATA, TARGETSCAN_COLUMN_TO_SEQUENCE

  from .autonotebook import tqdm as notebook_tqdm


Params

In [4]:
DATASET_NAME = 'refseq_id.mirna_fc'
# DATASET_NAME = 'gene_name.psilac'
# DATASET_NAME = 'refseq_id.HeLa_transfection'

# ID_COLUMN = "refseq_mrna"
ID_COLUMN = 'RefSeq ID'

SAVE_EXPL_SCORES = True

In [5]:
mirna_name = 'hsa-miR-16-5p'

In [None]:
# TS_mirnas = list(TARGETSCAN_COLUMN_TO_SEQUENCE.items())
# mirna_index = 0
# mirna_name = TS_mirnas[mirna_index][0]
# my_miRNA = TS_mirnas[mirna_index][1]

# mirna_sequences = [my_miRNA]
# mirna_sequences, mirna_name

In [6]:
my_miRNA = TARGETSCAN_COLUMN_TO_SEQUENCE[mirna_name]

mirna_sequences = [my_miRNA]
print(f"scanning for {mirna_sequences}, {mirna_name}")

['TAGCAGCACGTAAATATTGGCG'] hsa-miR-16-5p


In [None]:
# SEQUENCE_SOURCE_PATH = "data/3utr/3utr.sequences.pkl"
# SEQUENCE_SOURCE_PATH = 'data/3utr/3utr.sequences.refseq_id.mirna_fc.pkl'
# SEQUENCE_SOURCE_PATH = f'data/3utr/3utr.sequences.{DATASET_NAME}.pkl'
# SEQUENCE_SOURCE_PATH = f'data/processed/GRCh37.p13 hg19/3utr.sequences.{DATASET_NAME}.pkl'

SEQUENCE_SOURCE_PATH = f'../../data/processed/GRCh37.p13 hg19/UCSC/3utr.sequences.{DATASET_NAME}.pkl'


sequence_source_df = pd.read_pickle(SEQUENCE_SOURCE_PATH)

In [None]:
# SAVE_EXPLAINABILITY_SCORES_PATH = "explainability_scores_{}.json".format(mirna_name)
# SAVE_EXPLAINABILITY_SCORES_PATH = "data/3utr.explainability_scores_{}.json".format(mirna_name)
# SAVE_EXPLAINABILITY_SCORES_PATH = f"data/3utr.sequences.refseq_id.mirna_fc.explainability_scores_{mirna_name}.json"
# SAVE_EXPLAINABILITY_SCORES_PATH = f"data/3utr.sequences.refseq_id.mirna_fc.explainability_scores_{mirna_name}.refseq_id.json"
# SAVE_EXPLAINABILITY_SCORES_PATH = f"debug/3utr.sequences.refseq_id.mirna_fc.explainability_scores_{mirna_name}.refseq_id.json"
# SAVE_EXPLAINABILITY_SCORES_PATH = f"data/3utr.sequences.{DATASET_NAME}.explainability_scores_{mirna_name}.json"

# SAVE_EXPLAINABILITY_SCORES_PATH = f"data/scanned/GRCh37.p13 hg19/UCSC/3utr.sequences.{DATASET_NAME}.explainability_scores_{mirna_name}.json"

SAVE_EXPLAINABILITY_SCORES_PATH = f"../../data/scanned/GRCh37.p13 hg19/UCSC/3utr.sequences.{DATASET_NAME}.explainability_scores_{mirna_name}.json"

SAVE_SCANNING_ERRORS_PATH = f"../../data/scanned/GRCh37.p13 hg19/UCSC/3utr.sequences.{DATASET_NAME}.explainability_scores_{mirna_name}.scanning_errors.txt"

In [None]:
PREDICTION_THRESHOLD = 0
random.seed(42)

### Load and preprocess the transcript data

In [None]:
sequence_source_df.shape[0], sequence_source_df[ID_COLUMN].nunique(), sequence_source_df[[ID_COLUMN]].dropna().shape[0]

In [None]:
gene_symbol_to_seq = sequence_source_df[[ID_COLUMN, "sequence"]].set_index(ID_COLUMN).to_dict()['sequence']

### Get our deep learning model and miRNA data

In [None]:
"""# Loading model and the data"""

model = keras.models.load_model("models/miRBind.h5")   # Old model from miRBind trained on Ago1 data
# miRNA/miRNA/models/model_miRNA.h5
# model = keras.models.load_model("models/model_miRNA.h5")   # from Vasek's paper https://github.com/ML-Bioinfo-CEITEC/HybriDetector/blob/main/ML/Models/model_miRNA.h5
# model.summary()

In [None]:
samples = pd.read_csv('evaluation_set_1_1_CLASH2013_paper.tsv', sep='\t')
samples.head()


### Use and evaluate the model

In [None]:
# n = 50 is number of samples used as a background image to compare the input with during the shap method
rand_samples = samples.sample(n=50, replace=False, random_state=42).reset_index(drop=True)
background, _ = one_hot_encoding_batch(rand_samples)
deepShap = DeepShap(model, background)

#### #DEBUG Set `threshold_len` to filter out longer genes

In [None]:
threshold_len = -1 # -1
threshold_len_high = -1 # -1

# if threshold_len > 0:
#     count = 0
#     for key,value in gene_symbol_to_seq.items():
#         if(len(value) < threshold_len):
#             count += 1
#     count

if threshold_len > 0:
    count = 0
    for row in sequence_source_df[["ensembl_gene_id", "sequence"]].itertuples():
        if(len(row.sequence) < threshold_len):
            count += 1
    print(count)

In [None]:
# if threshold_len > 0:
#     count_too_long = 0
#     count_kept = 0
#     for key,value in gene_symbol_to_seq.items():
#         if(len(value) > threshold_len):
#             gene_symbol_to_seq[key] = []
#             count_too_long += 1
#         else:
#             count_kept += 1
#     print('count_too_long, len > ', threshold_len, " : ", count_too_long)
#     print('count_kept ', count_kept)


if threshold_len > 0:
    gene_symbol_to_seq = {}
    count_too_long = 0
    count_kept = 0
    for row in sequence_source_df[["ensembl_gene_id", "sequence"]].itertuples():
        # if(len(row.sequence) < threshold_len or len(row.sequence) > threshold_len_high):
            # count_too_long += 1
        # else:
        #     gene_symbol_to_seq[row.ensembl_gene_id] = row.sequence
        #     count_kept += 1
        if(len(row.sequence) < threshold_len): 
            gene_symbol_to_seq[row.ensembl_gene_id] = row.sequence
            count_kept += 1
    print('count_too_long, len > ', threshold_len, " : ", count_too_long)
    print('count_kept ', count_kept)

### Scan the transcript

In [None]:
def score_sequence_attribution_minimal_FIXED(gene, input_miRNA, model, draw_plot=False, step=10, length=50, prediction_threshold=0.5):
    # QUICK FIX - miRBind takes only 20 long miRNA
    miRNA = input_miRNA[0:20]

    miRNAs = []
    genes = []
    counts = np.zeros(len(gene))

    for i in range(0, len(gene) - length + 1, step):
        start = max(i, 0)
        end = min(i+length, len(gene))
        miRNAs.append(miRNA)
        genes.append(gene[start:end])
        counts[start:end] += 1

    labels = np.zeros(len(genes))

    df = pd.DataFrame(
        {'miRNA': miRNAs,
         'gene': genes,
         'label': labels
        })
    data, _ = one_hot_encoding_batch(df, tensor_dim=(50, 20, 1))
    preds = model(data)

    attribution = np.zeros((len(gene), len(miRNA)))
    shap_indices = []
    pred_indices = []
    preds_indices = []
    shap_data = []

    counter = 0
    for i in range(0, len(gene) - length + 1, step):
        if preds[counter][1] > prediction_threshold:
            # shap_indices == indices of positively predicted sliding windows (indexes preds array)
            shap_indices.append(counter)
            pred_indices.append(i)
            shap_data.append(data[counter])
        counter+=1
    
    # TODO quick fix: if at this point the 'shap_data' is empty array [], it means no prediction was > 0.5, therefore (for now) let's predict 0 binding affinity  
    if len(shap_data) == 0:
        return []
    shap_data = np.stack(shap_data)
    
    neg_shap, pos_shap = deepShap(shap_data)
    
    for i in range(0, len(shap_indices)):

        normalized_shap = pos_shap[i,:,:,0] * preds[shap_indices[i]][1]

        newrows = np.zeros((pred_indices[i], normalized_shap.shape[1]))
        normalized_shap = np.vstack([newrows, normalized_shap])
        newrows = np.zeros((len(gene) - pred_indices[i] - length, normalized_shap.shape[1]))
        normalized_shap = np.vstack([normalized_shap, newrows])

        attribution += normalized_shap

    attribution = attribution.T.max(axis=0)
 
    counts[counts == 0] = 1 # because when stepping transcript, its len might not be dividable by step and leave a few 0s at the end of counts
    # np.where(counts == 0, 1, counts)
    normalized_scores = attribution / np.array(counts)

    if draw_plot:
        #     TODO change printing into a single plot?
        plt.figure(num = random.randint(0, 1000)) # num is a unique identifier for the figure
        plt.plot(normalized_scores)
        plt.title('normalized_scores => attribution / np.array(counts)')
        plt.show()
    # plt.plot(normalized_scores);

    return normalized_scores

In [None]:
import time

# two objects to collect results for now, will see which one is more handy
score_table = []
miRNA_to_gene_score = {}
explain_errors = []

start = time.time()

i = 0
for miRNA in mirna_sequences[:1]:
    if miRNA not in miRNA_to_gene_score:
        miRNA_to_gene_score[miRNA] = []
    for gene_symbol, gene_sequence in gene_symbol_to_seq.items():
        if not isinstance(gene_sequence, float) and len(gene_sequence) > 0:
            try:
                score = score_sequence_attribution_minimal_FIXED(
                    str(gene_sequence), 
                    miRNA, 
                    model, 
                    draw_plot=False, 
                    step=10, 
                    length=50, 
                    prediction_threshold=PREDICTION_THRESHOLD
                )
                score_table.append([miRNA, gene_symbol, score])
                miRNA_to_gene_score[miRNA].append([gene_symbol, score])
            except (AssertionError, ValueError) as e:
                # print(miRNA, gene_symbol, e)
                print(e)
                explain_errors.append([mirna_name, miRNA, gene_symbol, str(e)])
                
        else:
            explain_errors.append(gene_symbol)

        i+=1
        if i % 500 == 0:
            print(gene_symbol, " |",i , "| " , end =" ")
        # if i > 1000:
            # break
    # break
        
            
end = time.time()

In [None]:
print(f'{round(end - start, 2)} seconds, {round((end - start) / 3600, 2)} hours')

In [None]:
len(gene_symbol_to_seq.items())

In [None]:
explain_errors

In [None]:
empties = 0
for gene_n_score in miRNA_to_gene_score[my_miRNA]:
    if len(gene_n_score[1]) == 0:
        empties += 1
print(empties, ' / ', len(miRNA_to_gene_score[my_miRNA]))

#### Save the explainability scoring to a file

In [None]:
print(type(miRNA_to_gene_score), type(miRNA_to_gene_score[my_miRNA]))

In [None]:
SAVE_EXPLAINABILITY_SCORES_PATH

In [None]:
if SAVE_EXPL_SCORES:
    with open(SAVE_EXPLAINABILITY_SCORES_PATH, 'w') as file:
        data_to_save = {}
        for key in miRNA_to_gene_score.keys():
            data_to_save[key] = [[sub_key, list(sub_val)] for sub_key, sub_val in miRNA_to_gene_score[key]]

        json.dump(data_to_save, file, indent=4)

In [None]:
data_to_save.keys()

In [None]:
if SAVE_EXPL_SCORES:
    with open(SAVE_SCANNING_ERRORS_PATH, 'w') as filehandle:
        json.dump(explain_errors, filehandle)

In [None]:
SAVE_SCANNING_ERRORS_PATH