In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")
import os
os.chdir("..")
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from tensorflow.keras.models import load_model, Model
from tensorflow.keras import Input
import src.quad_model
from src.vis_data import get_vis_data
from joblib import load

In [2]:
def sum_strength(d):
    if "strength" in d.keys():
        return d["strength"]
    return sum([sum_strength(child) for child in d["children"]])

In [5]:
# Load exons
exons = [
    "GAGUCCCGCUUACCAUUGCAUUUAAGAAAGAGGCCAUACGCCGCUAAGACCCUACUCUUCAGAAUACCAG",
    "GAGUCCCGCUUACCAUUGCAUUUAAGAAAGCGGCCAUACGCCGCUAAGACCCUACUCUUCAGAAUACCAG",
    "GUCUGACAGUACUACGCUAAUACUACGUAAACCAAAGCCAUAAUCCAAUUGACCUCCUUUUCAGGAAUUC",
    "CCUUCCACGCCUCUCCCACUCGUUACACUCAGUUGCAGUAUGGUUAACACUCCACUAGGCCCCAGGAAUC",
    "GAGUCCCGCUUACCAUUGCAUUUAAGAAAGAGGCCAUACGCCUCUAAGACCCUACUCUUCAGAAUACCAG",
    "GAGUCCCGCUUACCAUUGCAUUUAAGAAAGAGGCCAUACGCCGCUAAGACCCUACUCUUCAGAAUACCAG",
]

In [6]:
# Get vis data
json_data = [None]*len(exons)
i = 0
for exon in tqdm(exons):
    json_data[i] = get_vis_data(exon=exon, threshold=0.001)
    i += 1

  0%|          | 0/6 [00:00<?, ?it/s]

 33%|███▎      | 2/6 [00:01<00:02,  1.93it/s]



100%|██████████| 6/6 [00:02<00:00,  2.26it/s]


In [34]:
def get_visualization_statistic(json_data):
    #` Min and max of delta strength
    print("Predicted range:", end=" ")
    print((
        min(exon_data["predicted_psi"] for exon_data in json_data), 
        max(exon_data["predicted_psi"] for exon_data in json_data)
    ))

    #` Min and max of delta strength
    print("Delta strenth range:", end=" ")
    print((
        min(exon_data["delta_force"] for exon_data in json_data), 
        max(exon_data["delta_force"] for exon_data in json_data)
    ))

    # Max of class strength
    print("Max class strength:", end=" ")
    print(max(
        max([exon_data["incl_strength"] for exon_data in json_data]),
        max([exon_data["skip_strength"] for exon_data in json_data])
    ))

    # Max of feature strength
    max_feature_strength = 0
    for exon_data in json_data:
        for class_feature_activations in exon_data["feature_activations"]["children"]:
            for feature_strength in class_feature_activations["children"]:
                max_feature_strength = max(
                    max_feature_strength, sum_strength(feature_strength)
                )
    print("Max feature strength:", max_feature_strength)

    # Max of feature position strength
    max_feature_position_strength = 0
    for exon_data in json_data:
        for class_feature_activations in exon_data["feature_activations"]["children"]:
            for feature_strength in class_feature_activations["children"]:
                if "children" not in feature_strength.keys():
                    continue
                for feature_position_strength in feature_strength["children"]:
                    max_feature_position_strength = max(
                        max_feature_position_strength, feature_position_strength["strength"]
                    )
    print("Max feature position strength:", max_feature_position_strength)

    # Max of nucleotide position strength
    max_nucleotide_position_strength = 0
    for exon_data in json_data:
        for class_nucleotide_activations in exon_data["nucleotide_activations"]["children"]:
            for nucleotide_strength in class_nucleotide_activations["children"]:
                for nucleotide_position_strength in nucleotide_strength["children"]:
                    max_nucleotide_position_strength = max(
                        max_nucleotide_position_strength, sum_strength(nucleotide_position_strength)
                    )
    print("Max nucleotide position strength:", max_nucleotide_position_strength)

    # Max of nucleotide position feature strength
    max_nucleotide_position_feature_strength = 0
    for exon_data in json_data:
        for class_nucleotide_activations in exon_data["nucleotide_activations"]["children"]:
            for nucleotide_strength in class_nucleotide_activations["children"]:
                for nucleotide_position_strength in nucleotide_strength["children"]:
                    for nucleotide_position_feature_strength in nucleotide_position_strength["children"]:
                        max_nucleotide_position_feature_strength = max(
                            max_nucleotide_position_feature_strength, 
                            nucleotide_position_feature_strength["strength"]
                        )
    print("Max nucleotide position feature strength:", max_nucleotide_position_feature_strength)

In [7]:
get_visualization_statistic(json_data)

Delta strenth range: (-13.908240915963916, 22.812916158010694)
Max class strength: 86.02241516113281
Max feature strength: 37.52738878960292
Max feature position strength: 11.806382848973044
Max nucleotide position strength: 3.7882057132961697
Max nucleotide position feature strength: 3.7689620256641363


In [4]:
xTe = load(f"data/xTe_ES7_HeLa_ABC.pkl.gz")

In [31]:
nts = ["A", "C", "G", "T"]
xTe_seqs = np.array([
    "".join([nts[np.where(one_hot == 1)[0].item()] for one_hot in row[10:80]]) for row in tqdm(xTe[0])
])
exons = xTe_seqs[np.random.rand(len(xTe[0])) < 0.01]

100%|██████████| 47962/47962 [00:04<00:00, 11778.84it/s]


In [32]:
json_data = [None]*len(exons)
i = 0
for exon in tqdm(exons):
    json_data[i] = get_vis_data(exon=exon, threshold=0.001, use_new_grouping=True)
    i += 1

100%|██████████| 477/477 [04:18<00:00,  1.84it/s]


In [32]:
get_visualization_statistic(json_data)

Predicted range: (6.1262649069249164e-06, 0.9865574836730957)
Delta strenth range: (-99.51547968549517, 91.05603635149214)
Max class strength: 140.66867065429688
Max feature strength: 65.12019474210683
Max feature position strength: 16.47538328600893
Max nucleotide position strength: 8.657210330939556
Max nucleotide position feature strength: 5.491794428669643


In [35]:
get_visualization_statistic(json_data)

Predicted range: (1.0487151484994683e-05, 0.9877551794052124)
Delta strenth range: (-93.7439914957979, 105.08456075029585)
Max class strength: 142.6853297932646
Max feature strength: 71.3665307011663
Max feature position strength: 17.437657018910976
Max nucleotide position strength: 7.456638666000308
Max nucleotide position feature strength: 5.812552339636992


In [38]:
# Max class strength
for exon_data in json_data:
    if exon_data["incl_strength"] > 140:
        print(exon_data["exon"])
        break

for exon_data in json_data:
    if exon_data["skip_strength"] > 140:
        print(exon_data["exon"])
        break

CCGCGACCGGAUUAAGAUGAAGGAACGAAGCAAUUGUCGAAUCUACUCUAAUCUGCUCGAAGAUCAGAAC
AGAUGUCGAUCCCCAUUAAUCAACCCCUCUCCUUAUAUUAUCCCCAUAUUCACAAAACUGUUUGCUAAAA


In [44]:
# Max of feature strength
for exon_data in json_data:
    for class_feature_activations in exon_data["feature_activations"]["children"][:1]:
        for feature_strength in class_feature_activations["children"]:
            if sum_strength(feature_strength) > 61:
                print(exon_data["exon"])
                break

for exon_data in json_data:
    for class_feature_activations in exon_data["feature_activations"]["children"][-1:]:
        for feature_strength in class_feature_activations["children"]:
            if sum_strength(feature_strength) > 65:
                print(exon_data["exon"])
                break

CCGCGACCGGAUUAAGAUGAAGGAACGAAGCAAUUGUCGAAUCUACUCUAAUCUGCUCGAAGAUCAGAAC
CCACUCACCGCCGCCGGUGUCCUGGCAUACUCAUUAUCGCAACCCCGACGCGGCCCACUUGGGUCGCGGC


In [52]:
# Max of feature position strength
for exon_data in json_data:
    for class_feature_activations in exon_data["feature_activations"]["children"][:1]:
        for feature_strength in class_feature_activations["children"]:
            if "children" not in feature_strength.keys():
                continue
            for feature_position_strength in feature_strength["children"]:
                if feature_position_strength["strength"] > 13:
                    print(exon_data["exon"])
for exon_data in json_data:
    for class_feature_activations in exon_data["feature_activations"]["children"][-1:]:
        for feature_strength in class_feature_activations["children"]:
            if "children" not in feature_strength.keys():
                continue
            for feature_position_strength in feature_strength["children"]:
                if feature_position_strength["strength"] > 17:
                    print(exon_data["exon"])

GUCUACUGCGUCACACACAGCGACCCUAAACGAACAACCCCUACGUGAAAGUUCAUCGACGACCGCGCCA
CCCUAACACAACGUACAACAAUCCAAUAACCAUGAAUAUUGGCCUAACACGGUCACCAAGCUCGUCGGUU
AGUCUCUUGGAAUCGCGCCCGACAUCUUACCAGUAAAAUCGGUGCUCCAGGGCCACGAUCUUCGACACCA
GCUCGCAACCAGCCGCCUACCUAUUAAUUGUCUGUGCUCCAAGAAUUACAGCUAGCAAUUUAGGUACCAA


In [45]:
# Max of nucleotide position strength
for exon_data in json_data:
    for class_nucleotide_activations in exon_data["nucleotide_activations"]["children"]:
        for nucleotide_strength in class_nucleotide_activations["children"]:
            for nucleotide_position_strength in nucleotide_strength["children"]:
                if sum_strength(nucleotide_position_strength) > 17:
                    print(exon_data["exon"])

Delta strenth range: (-71.34475481672075, 101.65916478471968)       # (-120, 120)
Max class strength: 136.1594737751982                               # 160
Max feature strength: 61.950183892971836                            # 70
Max feature position strength: 15.874034881907232                   # 20
Max nucleotide position strength: 7.857531867787869                 # 10
Max nucleotide position feature strength: 5.3906272675376385        # 6

In [4]:
# Load pretrained model
model_file_name = f"model/custom_adjacency_regularizer_20210731_124_step3.h5"
model = load_model(model_file_name)

output_model = Model(outputs=model.outputs, inputs=[
    model.get_layer('activity_regularization').input,
    model.get_layer('activity_regularization_1').input,
    model.get_layer('activity_regularization_2').input,
    model.get_layer('activity_regularization_3').input
])

In [5]:
def delta_force_2_predicted_psi(delta_force):
    act = np.zeros((1, 85, 20))
    act_1 = np.zeros((1, 85, 20))
    incl_act = np.zeros((1, 85, 28))
    skip_act = np.zeros((1, 85, 28))
    incl_act[0,0,0] = delta_force - 21.80809438066695
    yTr_hat = np.reshape(output_model([act, act_1, incl_act, skip_act]), (-1,))
    return yTr_hat.item()

In [8]:
for delta_force in [-100, -20, -10, -5, 0, 5, 10, 20, 100]:
    print(f"Detla force = {delta_force: 4} | Predicted PSI = {delta_force_2_predicted_psi(delta_force):.6f}")

Detla force = -100 | Predicted PSI = 0.000006
Detla force =  -20 | Predicted PSI = 0.071174
Detla force =  -10 | Predicted PSI = 0.232515
Detla force =   -5 | Predicted PSI = 0.361840
Detla force =    0 | Predicted PSI = 0.499993
Detla force =    5 | Predicted PSI = 0.638148
Detla force =   10 | Predicted PSI = 0.756705
Detla force =   20 | Predicted PSI = 0.892360
Detla force =  100 | Predicted PSI = 0.987334


In [21]:
def get_link_midpoint(link_function, midpoint=0.5, epsilon=1e-5, lb=-100, ub=100, max_iters=50):
    """
    Assumes monotonicity and smoothness of link function
    """
    iters = 0
    while iters < max_iters:
        xx = np.linspace(lb, ub, 1000)
        yy = link_function(xx[:, None]).numpy().flatten()

        if min(np.abs(yy - midpoint)) < epsilon:
            return xx[np.abs(yy - midpoint) < epsilon][0]
        lb_idx = np.where((yy - midpoint) < 0)[0][-1]
        ub_idx = np.where((yy - midpoint) > 0)[0][0]

        lb = xx[lb_idx]
        ub = xx[ub_idx]

        iters += 1
    raise RuntimeError(f"Max iterations ({max_iters}) reached without solution...")

def get_model_midpoint(model, midpoint=0.5):    
    """ 
    Compute the midpoint using the model"s link function. This is the negation of the basal strength. 
    I.e., positive value corresponds to a skipping basal strength. 
    """
    link_input = Input(shape=(1,))
    w = model.get_layer("energy_seq_struct").w.numpy()
    b = model.get_layer("energy_seq_struct").b.numpy()
    link_output = model.get_layer("output_activation")(model.get_layer("gen_func")(w*link_input + b))
    link_function = Model(inputs=link_input, outputs=link_output)
    return get_link_midpoint(link_function, midpoint)

def predicted_psi_2_delta_force(predicted_psi):
    return get_model_midpoint(model, midpoint=predicted_psi) + 21.80809438066695

In [30]:
for predicted_psi in np.linspace(20/240, 1-20/240, 5):
    print(f"Predicted PSI = {predicted_psi:.4f} | Detla force = {predicted_psi_2_delta_force(predicted_psi): 4}")

Predicted PSI = 0.0833 | Detla force = -18.757496234973715
Predicted PSI = 0.2917 | Detla force = -7.767928088248407
Predicted PSI = 0.5000 | Detla force =  0.0
Predicted PSI = 0.7083 | Detla force =  7.819431042654276
Predicted PSI = 0.9167 | Detla force =  23.07392477562648
