In [5]:
# Run these lines only the first time you run this notebook

!pip install tensorflow_addons gdown shap -q
!pip install git+https://github.com/katarinagresova/DeepExperiment

!wget https://github.com/ML-Bioinfo-CEITEC/miRBind/raw/main/Models/miRBind.h5
!wget https://raw.githubusercontent.com/ML-Bioinfo-CEITEC/miRBind/graphs/Datasets/evaluation_set_1_1_CLASH2013_paper.tsv

Collecting git+https://github.com/katarinagresova/DeepExperiment
  Cloning https://github.com/katarinagresova/DeepExperiment to /tmp/pip-req-build-tmes3tda
  Running command git clone --filter=blob:none --quiet https://github.com/katarinagresova/DeepExperiment /tmp/pip-req-build-tmes3tda
  Resolved https://github.com/katarinagresova/DeepExperiment to commit fe5a48e3057cc8635c99bfa7828e6da0c2190bca
  Preparing metadata (setup.py) ... [?25ldone
--2023-06-13 14:08:00--  https://github.com/ML-Bioinfo-CEITEC/miRBind/raw/main/Models/miRBind.h5
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/ML-Bioinfo-CEITEC/miRBind/main/Models/miRBind.h5 [following]
--2023-06-13 14:08:00--  https://raw.githubusercontent.com/ML-Bioinfo-CEITEC/miRBind/main/Models/miRBind.h5
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:

In [6]:
import tensorflow

In [7]:
tensorflow.__version__

'2.11.0'

### Load and preprocess the transcript data

In [8]:
from Bio import SeqIO
import re

def extract_gene_symbol(record):
    tmp = [re.findall("\((.*?)\)", x) for x in record.description.split()]
    while([] in tmp):
        tmp.remove([])
    return tmp[0][0]

gene_symbol_to_seq = {}
file_path = 'data/GCF_000001405.40/rna.fna'

for record in SeqIO.parse(file_path, "fasta"):
    description = record.description.split()
    gene_id = next((item.split('=')[1] for item in description if item.startswith('gene=')), None)
    
    symbol = extract_gene_symbol(record)
    # Do not need the locus tag (== record.id) for now
    # gene_symbol_to_seq[symbol] = [record.seq, record.id]
    gene_symbol_to_seq[symbol] = record.seq    

FileNotFoundError: [Errno 2] No such file or directory: 'data/GCF_000001405.40/rna.fna'

In [None]:
len(gene_symbol_to_seq.items())

### Get our deep learning model and miRNA data

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import register_keras_serializable
import tensorflow_addons as tfa
import cv2 
from IPython.display import Image, display
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from textwrap import wrap
import random
import shap
from shap.plots.colors import red_transparent_blue

from deepexperiment.alignment import Attrament
from deepexperiment.visualization import plot_alignment, plot_miRNA_importance, plotbar_miRNA_importance
from deepexperiment.utils import one_hot_encoding, one_hot_encoding_batch
from deepexperiment.interpret import DeepShap

In [None]:
random.seed(42)

In [None]:
"""# Loading model and the data"""

model = keras.models.load_model("miRBind.h5")
# model.summary()

samples = pd.read_csv('evaluation_set_1_1_CLASH2013_paper.tsv', sep='\t')
samples.head()

# get the miRNAs
#     CHANGE choose the miRNA
# miRNA = samples['miRNA'].value_counts().index[6]
# miRNA


### Use and evaluate the model

In [None]:
# n = 50 is number of samples used as a background image to compare the input with during the shap method
rand_samples = samples.sample(n=50, replace=False, random_state=42).reset_index(drop=True)
background, _ = one_hot_encoding_batch(rand_samples)
deepShap = DeepShap(model, background)

In [None]:
# from numpy import random

def score_sequence_attribution_minimal(gene, input_miRNA, model, draw_plot=False, step=10, length=50):
    # TODO quick fix: tensor_dim argument in one_hot_encoding_batch() and the whole model expects RNA of len 20
    miRNA = input_miRNA[0:20]
    
    miRNAs = []
    genes = []

    for i in range(0, len(gene) - length, step):
        start = max(i, 0)
        end = min(i+50, len(gene))
        miRNAs.append(miRNA)
        genes.append(gene[start:end])

    labels = np.zeros(len(genes))

    df = pd.DataFrame(
        {'miRNA': miRNAs,
         'gene': genes,
         'label': labels
        })
    data, _ = one_hot_encoding_batch(df, tensor_dim=(50, 20, 1))
    preds = model(data)

    attribution = np.zeros((len(gene), len(miRNA)))
    shap_indices = []
    pred_indices = []
    preds_indices = []
    shap_data = []

  # counter == index to predictions for windows
    counter = 0
    for i in range(0, len(gene) - length, step):
        if preds[counter][1] > 0.5:
            # shap_indices == indices of positively predicted sliding windows (indexes preds array)
            shap_indices.append(counter)
            pred_indices.append(i)
            shap_data.append(data[counter])
        counter+=1
    
    # TODO quick fix: if at this point the 'shap_data' is empty array [], it means no prediction was > 0.5, therefore (for now) let's predict 0 binding affinity  
    if len(shap_data) == 0:
        return []
    
    shap_data = np.stack(shap_data)
    # TODO: batching?
    # using only pos_shap cos we only care about positives, pos_shap and neg_shap are opposites for binary classification
    neg_shap, pos_shap = deepShap(shap_data)

    for i in range(0, len(shap_indices)):

        normalized_shap = pos_shap[i,:,:,0] * preds[shap_indices[i]][1]

        # padding current shap-value-matrix to sum it up to the cumulative attribution matrix
        newrows = np.zeros((pred_indices[i], normalized_shap.shape[1]))
        normalized_shap = np.vstack([newrows, normalized_shap])
        newrows = np.zeros((len(gene) - pred_indices[i] - length, normalized_shap.shape[1]))
        normalized_shap = np.vstack([normalized_shap, newrows])

        attribution += normalized_shap

    attribution = attribution.T.max(axis=0)
    # Fix `counts = ...` for genes shorter than 98 nucleotides
    # pad the `attribution` or truncate `counts`
    counts = list(range(1, 50)) + (len(attribution) - 98) * [50] + list(range(49, 0, -1))
    if len(attribution) < 98:
        counts = counts[:len(attribution)]
        # plt.figure(num = random.randint(0, 1000)) # num is a unique identifier for the figure
        # plt.plot(counts)
        # plt.plot(attribution * 1000)
        # plt.show()
    
    try:            
        normalized_scores = attribution / np.array(counts)
    except Exception as e:
        print(e)
        print(miRNA)
        print(gene)
        print(len(attribution), len(np.array(counts)))
        normalized_scores = 0
    

    if draw_plot:
        #     TODO change printing into a single plot?
        plt.figure(num = random.randint(0, 1000)) # num is a unique identifier for the figure
        plt.plot(normalized_scores);
    
    return normalized_scores

### Get the scores for FC comparison using files from Bartel  

In [None]:
mirna_FCs = pd.read_csv('modules/evaluation/mirna_fcs.csv',index_col=0, header=0, sep=',')

In [None]:
mirna_FCs

In [None]:
def rna_to_dna(rna_sample):
    rna_dic = {
        'A':'A',
        'C':'C',
        'U':'T',
        'G':'G',
    }
    converted = [rna_dic[x.upper()] for x in rna_sample]
    new = ""
    # traverse in the string
    for x in converted:
        new += x
    # return string
    return new

In [None]:
mirna_sequences = ['UAGCAGCACGUAAAUAUUGGCG', 'UAAAGUGCUGACAGUGCAGAU', 'UAACACUGUCUGGUAACGAUGU', 'UAAUACUGCCUGGUAAUGAUGA', 'AUGACCUAUGAAUUGACAGAC', 'UGAGGUAGUAGGUUGUAUGGUU', 'AGCAGCAUUGUACAGGGCUAUGA']
mirna_sequences = [rna_to_dna(x) for x in mirna_sequences]
print(mirna_sequences)

In [None]:
miRNA_names = ['hsa-miR-16-5p', 'hsa-miR-106b-5p', 'hsa-miR-200a-3p', 'hsa-miR-200b-3p', 'hsa-miR-215-5p', 'hsa-let-7c-5p', 'hsa-miR-103a-3p']
miRNA_name_to_seq = {}
for i in range(len(miRNA_names)):
    miRNA_name_to_seq[miRNA_names[i]] = mirna_sequences[i]
miRNA_name_to_seq

#### Only for 1 miRNA now 

In [None]:
mirna_sequences = ['hsa-miR-106b-5p']
mirna_sequences
mirna_name = 'hsa-miR-106b-5p'
# my_miRNA = 'TAGCAGCACGTAAATATTGGCG'
my_miRNA = 'TAAAGTGCTGACAGTGCAGAT'

In [None]:
gene_symbols = mirna_FCs[mirna_FCs[mirna_name].isnull() == False]['Gene symbol'].values
print(len(gene_symbols))

#### Set `threshold_len` to filter out longer genes

In [None]:
# threshold_len = 0

# if threshold_len > 0:
#     count = 0
#     for key,value in gene_symbol_to_seq.items():
#         if(len(value) < threshold_len):
#             count += 1
#     count

In [42]:
# if threshold_len > 0:
#     count_too_long = 0
#     count_kept = 0
#     for key,value in gene_symbol_to_seq.items():
#         if(len(value) > threshold_len):
#             gene_symbol_to_seq[key] = []
#             count_too_long += 1
#         else:
#             count_kept += 1
#     print('count_too_long, len > ', threshold_len, " : ", count_too_long)
#     print('count_kept ', count_kept)

### Scan the transcript

In [None]:
import time

# two objects to collect results for now, will see which one is more handy
score_table = []
miRNA_to_gene_score = {}
explain_errors = []

start = time.time()

i = 0
for miRNA in mirna_sequences[:1]:
    print(miRNA)
    if miRNA not in miRNA_to_gene_score:
        miRNA_to_gene_score[miRNA] = []
    for gene_symbol, gene_sequence in gene_symbol_to_seq.items():
        if len(gene_sequence) > 0:
            # print(len(value), len(miRNA))
            if len(miRNA) > 22:
                print(miRNA)
            try:
                score = score_sequence_attribution_minimal(gene_sequence, miRNA, model, draw_plot=False)
                score_table.append([miRNA, gene_symbol, score])
                miRNA_to_gene_score[miRNA].append([gene_symbol, score])
            except AssertionError as e:
                print(miRNA, gene_symbol, e)
                explain_errors.append([miRNA, gene_symbol, e])
                
        else:
            explain_errors.append(gene_symbol)
        
        i+=1
        if i % 100 == 0:
            print(gene_symbol, " || " , end =" ")
        # if i > 1000:
        #     break
        
            
end = time.time()
print(end - start, ' seconds')

hsa-miR-106b-5p


In [None]:
score_table[:3]

In [None]:
empties = 0
for gene_n_score in miRNA_to_gene_score[my_miRNA]:
    if len(gene_n_score[1]) == 0:
        empties += 1
print(empties, ' / ', len(miRNA_to_gene_score[my_miRNA]))

#### Save the explainability scoring to a file

In [None]:
for key, value in miRNA_name_to_seq.items():
    if value == my_miRNA:
        print(key)
        my_miRNA_name = key

In [None]:
print(type(miRNA_to_gene_score), type(miRNA_to_gene_score[my_miRNA]))

In [None]:
import json

save_scores_path = "explainability_scores_{}.json".format(my_miRNA_name)

with open(save_scores_path, 'w') as file:
    data_to_save = {}
    for key in miRNA_to_gene_score.keys():
        data_to_save[key] = [[sub_key, list(sub_val)] for sub_key, sub_val in miRNA_to_gene_score[key]]

    json.dump(data_to_save, file, indent=4)

#### Load the explainability scoring

In [None]:
load_scores_path = save_scores_path

with open(load_scores_path, 'r') as file:
    miRNA_to_gene_score_loaded = json.load(file)

In [None]:
print(type(miRNA_to_gene_score_loaded), type(miRNA_to_gene_score_loaded[my_miRNA]))

### Plot

In [None]:
# gene_symbol_to_score = {}

# for i in miRNA_to_gene_score_loaded[my_miRNA]:
#     gene_symbol = i[0]
#     score = i[1]
#     gene_symbol_to_score[gene_symbol] = score

In [None]:
from modules.collect_scores.bymean import bymeanfunc
from modules.collect_scores.bysum import bysumfunc
from modules.collect_scores.bymaxvalue import bymaxvaluefunc
import math

# from modules.collect_scores.simple_methods import bymaxvaluefunc, bysumfunc, bymeanfunc

gene_score=[]
# for gene,score in gene_to_score.items():
for gene_, score_signal in miRNA_to_gene_score_loaded[my_miRNA]:
    if len(score_signal) > 0:
        prediction = bymaxvaluefunc(score_signal)
        gene_score.append([gene_symbol, prediction])
    else:
        gene_score.append([gene_symbol, 0])
        
        # "-1 * x" because we compare with fold change where lower == more impact
        # prediction = -1 * bymaxvaluefunc(gene_to_score[gene])
        # prediction = math.log(bymaxvaluefunc(gene_to_score[gene]))
        
        # gene : [real, prediction] 
        # gene_to_comparison[gene] = [gene_to_fc[gene], prediction]
        

In [None]:
results_df = pd.DataFrame(gene_score, columns=["Gene Symbol","score"])
results_df

#### (prediction) for genes
X axis is different genes, Y axis is fold change

In [None]:
import pandas as pd
import plotly.graph_objs as go
import matplotlib.pyplot as plt

###Importing dummy data
#Predictions
targetscan = pd.read_csv('Predicted_Targets_Context_Scores.default_predictions.txt',index_col=0, header=0, sep='\t')
targetscan = targetscan[["context++ score","weighted context++ score","miRNA","Gene Symbol"]]
targetscan = targetscan[targetscan['miRNA'] == "hsa-miR-215-5p"]
#Expression
mirna_FCs = pd.read_csv('mirna_fcs.csv',index_col=0, header=0, sep=',')
mirna = my_miRNA_name


predictions = {}

predictions["context++ score"] = targetscan[["context++ score","miRNA","Gene Symbol"]]
predictions["context++ score"] = predictions["context++ score"].rename(columns={'context++ score': 'score'})
predictions["weighted context++ score"] = targetscan[["weighted context++ score","miRNA","Gene Symbol"]]
predictions["weighted context++ score"] = predictions["weighted context++ score"].rename(columns={'weighted context++ score': 'score'})


In [None]:
predictions["context++ score"]

In [None]:
merged_df = pd.merge(predictions["context++ score"],
                             mirna_FCs[["Gene symbol",mirna]],
                             left_on='Gene Symbol',
                             right_on='Gene symbol',
                             how='left').sort_values("hsa-miR-16-5p",
                                                     ascending=True).dropna()
merged_df

In [None]:
results_df_merged = pd.merge(results_df,
                             mirna_FCs[["Gene symbol",mirna]],
                             left_on='Gene Symbol',
                             right_on='Gene symbol',
                             how='left').sort_values("hsa-miR-16-5p",
                                                     ascending=True)
                                                     # ascending=True).dropna()

results_df_merged

In [None]:
mirna_FCs

In [None]:
results = {}
for algorithm in predictions:
        results[algorithm] = pd.DataFrame(columns=["Targets","Mean_mRNA_FC"])
        merged_df = pd.merge(predictions[algorithm],
                             mirna_FCs[["Gene symbol",mirna]],
                             left_on='Gene Symbol',
                             right_on='Gene symbol',
                             how='left').sort_values("score",
                                                     ascending=True).dropna()
        
        
        for i,targets in enumerate([2**n for n in range(2, 11)]):
            tmp = merged_df.head(targets)
            meanFC = tmp[mirna].mean()
            results[algorithm].loc[i,"Targets"] = targets
            results[algorithm].loc[i,"Mean_mRNA_FC"] = meanFC

In [None]:
fig = go.Figure()

for algorithm in results:
    fig.add_trace(go.Scatter(x=results[algorithm]['Targets'].index.tolist(),
                                 y=results[algorithm]['Mean_mRNA_FC'],
                                 mode='lines+markers',
                                 name=algorithm)
                                 )


    # create the layout for the plot
fig.update_layout(xaxis=dict(tickvals=results[algorithm]['Targets'].index.tolist(),
                             ticktext=results[algorithm]['Targets'],
                             title = 'Top predicted targets'),
                  yaxis=dict(title='Mean_mRNA_FC'),
                  title=f'Benchmark on {mirna} dataset')
    
    # show the plot
fig.show()
# iplot(fig)               # use plotly.offline.iplot for offline plot

### Compare with FC

In [None]:
def get_scores_by_miRNA(miRNA, score_table):
    result = {}
    for row in score_table:
        if row[0] == miRNA:
            result[row[1]] = row[2] 
    return result

In [None]:
miRNA_chosen = score_table[0][0]
gene_to_score = get_scores_by_miRNA(
    miRNA_chosen, 
    score_table
)

miRNA_chosen_name = [i for i in miRNA_name_to_seq if miRNA_name_to_seq[i]==miRNA_chosen]
miRNA_chosen_name = miRNA_chosen_name[0]
print('scores for miRNA: ', miRNA_chosen_name, " : ", miRNA_chosen)

In [None]:
# genes_for_miRNA = mirna_FCs[mirna_FCs[miRNA_chosen_name].isnull() == False]['Gene symbol'].values

In [None]:
fc_for_miRNA = mirna_FCs[mirna_FCs[miRNA_chosen_name].isnull() == False][['Gene symbol', miRNA_chosen_name]].values

In [None]:
fc_for_miRNA

In [None]:
gene_to_fc = {fc_for_miRNA[i][0] : fc_for_miRNA[i][1] for i in range(len(fc_for_miRNA))}

In [None]:
gene_to_comparison

In [None]:
import numpy as np

predictions = [x[1] for x in gene_to_comparison.values()]

#normalize all values to be between 0 and 1
predictions_norm = (predictions-np.min(predictions))/(np.max(predictions)-np.min(predictions))
predictions_norm

In [None]:
pred_min = np.min(predictions)
pred_max = np.max(predictions)

gene_to_comparison_norm = {}
for key,value in gene_to_comparison.items():
    norm_value = (value[1]-pred_min)/(pred_max-pred_min)
    # if norm_value == 0:
    #     norm_value = math.log(0.001)
    # else:
    #     norm_value = math.log(norm_value)
    gene_to_comparison_norm[key] = [value[0], norm_value]

In [None]:
gene_to_comparison_norm

#### (FC and prediction) for genes
X axis is different genes, Y axis is fold change

In [None]:
fc_vis = [x[0] for x in list(gene_to_comparison.values())]
preds_vis = [x[1] for x in list(gene_to_comparison.values())]
print(fc_vis, preds_vis)

In [None]:
plt.figure(num = random.randint(0, 1000)) # num is a unique identifier for the figure
plt.scatter(fc_vis, preds_vis)

In [None]:
predictions = [x[1] for x in gene_to_comparison.values()]


plt.figure(num = random.randint(0, 1000)) # num is a unique identifier for the figure
plt.plot(predictions);