In [None]:
import pandas as pd
from utils.rna_ss_utils import one_idx2arr, sort_pairs, LocalStructureParser, make_target_pixel_bb, one_idx2arr, arr2db
from utils.inference_s1 import Predictor, Evaluator
from utils.util_global_struct import process_bb_old_to_new

In [None]:
import numpy as np
import torch
from utils.inference_s1 import DataEncoder

In [None]:
import dgutils.pandas as dgp

In [None]:
from scipy.special import softmax
from scipy.stats import entropy

In [None]:
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [None]:
def filter_by_n_proposal(df_bb, threshold):
    
    if len(df_bb) == 0:
        return df_bb
    else:
        # handle cases where there's only softmax predicted or scalar predicted
        if 'prob_other_sm' not in df_bb.columns:
            df_bb = dgp.add_column(df_bb, 'prob_sm', ['siz_x'], lambda a: [])  # hacky way to create a column of empty lists
        if 'prob_other_sl' not in df_bb.columns:
            df_bb = dgp.add_column(df_bb, 'prob_sl', ['siz_x'],
                                   lambda a: [])  # hacky way to create a column of empty lists
        df_bb = dgp.add_column(df_bb, 'n_proposal_norm_sm', ['prob_other_sm', 'siz_x', 'siz_y'],
                          lambda a, b, c: len(a)/float(b * c))
        df_bb = dgp.add_column(df_bb, 'n_proposal_norm_sl', ['prob_other_sl', 'siz_x', 'siz_y'],
                          lambda a, b, c: len(a)/float(b * c))
        return df_bb[(df_bb['n_proposal_norm_sm'] > threshold) | (df_bb['n_proposal_norm_sl'] > threshold)]


In [None]:
def pred_threshold_on_n_proposal(seq, predictor, threshold):
    stems, iloops, hloops = predictor.predict_bb(seq, threshold=0, topk=1, perc_cutoff=0)
    stems = pd.DataFrame(stems)
    iloops = pd.DataFrame(iloops)
    hloops = pd.DataFrame(hloops)

    stems = filter_by_n_proposal(stems, threshold)
    iloops = filter_by_n_proposal(iloops, threshold)
    hloops = filter_by_n_proposal(hloops, threshold / 2)  # /2 threshold due to /2 upper bound
    return stems, iloops, hloops

In [None]:
df = pd.read_pickle('../2021_03_16/data/human_transcriptome_segment_high_mfe_freq_testing_len64_100.pkl.gz')

In [None]:
model_path = '../2021_03_23/s1_training/result/run_7/model_ckpt_ep_17.pth'  # best model

predictor = Predictor(model_ckpt=model_path,
                     num_filters=[32, 32, 64, 64, 64, 128, 128],
                     filter_width=[9, 9, 9, 9, 9, 9, 9],
                     dropout=0.0)

In [None]:
# idx = 70  # "bad" example, even combined didn't work well

# idx = 1  # another one?

# idx = 2

# idx = 13

# idx = 14

idx = 15

seq = df.iloc[idx]['seq']
one_idx = df.iloc[idx]['one_idx']
bounding_boxes = df.iloc[idx]['bounding_boxes']
df_target = process_bb_old_to_new(bounding_boxes)

In [None]:
pairs, structure_arr = one_idx2arr(one_idx, len(seq), remove_lower_triangular=True)
target_stem_on, target_iloop_on, target_hloop_on, \
mask_stem_on, mask_iloop_on, mask_hloop_on, \
target_stem_location_x, target_stem_location_y, target_iloop_location_x, target_iloop_location_y, \
target_hloop_location_x, target_hloop_location_y, \
target_stem_sm_size, target_iloop_sm_size_x, target_iloop_sm_size_y, target_hloop_sm_size, \
target_stem_sl_size, target_iloop_sl_size_x, target_iloop_sl_size_y, target_hloop_sl_size, \
mask_stem_location_size, mask_iloop_location_size, \
mask_hloop_location_size = make_target_pixel_bb(structure_arr, bounding_boxes)

In [None]:
df_target

In [None]:
de = DataEncoder(seq)
yp = predictor.model(torch.tensor(de.x_torch))
yp = {k: v.detach().cpu().numpy()[0, :, :, :] for k, v in yp.items()}

In [None]:
yp.keys()

In [None]:
hard_mask = np.zeros((len(seq), len(seq)))
hard_mask[np.triu_indices(len(seq))] = 1

In [None]:
fig = make_subplots(
    rows=1, cols=2, shared_yaxes=True)
fig.add_trace(px.imshow(target_stem_on).data[0], 1, 1)
fig.add_trace(px.imshow(hard_mask * yp['stem_on'][0, :, :]).data[0], 1, 2)
fig.update_yaxes(autorange="reversed")
fig.update_layout(title='stem_on')
fig.show()

In [None]:
fig = make_subplots(
    rows=1, cols=2, shared_yaxes=True, shared_xaxes=True)
fig.add_trace(px.imshow(target_stem_location_x).data[0], 1, 1)
fig.add_trace(px.imshow(hard_mask * yp['stem_location_x'].argmax(axis=0)).data[0], 1, 2)
fig.update_yaxes(autorange="reversed")
fig.update_layout(title='stem_location_x')
fig.show()

In [None]:
# softmax(yp['stem_location_x'][:, 41, 58])

In [None]:
# softmax(yp['stem_location_x'][:, 40, 58])

In [None]:
fig = make_subplots(
    rows=1, cols=2, shared_yaxes=True, shared_xaxes=True)
fig.add_trace(px.imshow(target_stem_location_y).data[0], 1, 1)
fig.add_trace(px.imshow(hard_mask * yp['stem_location_y'].argmax(axis=0)).data[0], 1, 2)
fig.update_yaxes(autorange="reversed")
fig.update_layout(title='stem_location_y')
fig.show()

In [None]:
fig = make_subplots(
    rows=1, cols=2, shared_yaxes=True, shared_xaxes=True)
fig.add_trace(px.imshow(target_stem_sm_size).data[0], 1, 1)
fig.add_trace(px.imshow(hard_mask * yp['stem_sm_size'].argmax(axis=0)).data[0], 1, 2)
fig.update_yaxes(autorange="reversed")
fig.update_layout(title='stem_sm_size')
fig.show()

In [None]:
fig = make_subplots(
    rows=1, cols=2, shared_yaxes=True)
fig.add_trace(px.imshow(target_iloop_on).data[0], 1, 1)
fig.add_trace(px.imshow(hard_mask * yp['iloop_on'][0, :, :]).data[0], 1, 2)
fig.update_yaxes(autorange="reversed")
fig.update_layout(title='iloop_on')
fig.show()

In [None]:
fig = make_subplots(
    rows=1, cols=2, shared_yaxes=True, shared_xaxes=True)
fig.add_trace(px.imshow(target_iloop_location_x).data[0], 1, 1)
fig.add_trace(px.imshow(hard_mask * yp['iloop_location_x'].argmax(axis=0)).data[0], 1, 2)
fig.update_yaxes(autorange="reversed")
fig.update_layout(title='iloop_location_x')
fig.show()

In [None]:
fig = make_subplots(
    rows=1, cols=2, shared_yaxes=True)
fig.add_trace(px.imshow(hard_mask * target_hloop_on).data[0], 1, 1)
fig.add_trace(px.imshow(hard_mask * yp['hloop_on'][0, :, :]).data[0], 1, 2)
fig.update_yaxes(autorange="reversed")
fig.update_layout(title='hloop_on')
fig.show()

In [None]:
fig = make_subplots(
    rows=1, cols=2, shared_yaxes=True)
fig.add_trace(px.imshow(hard_mask * target_hloop_location_x).data[0], 1, 1)
fig.add_trace(px.imshow(hard_mask * entropy(softmax(yp['hloop_location_x'], axis=0))).data[0], 1, 2)
fig.update_yaxes(autorange="reversed")
fig.update_layout(title='hloop_location_x (left) entropy (right)')
fig.show()

In [None]:
pred_bb_stem, pred_bb_iloop, pred_bb_hloop = predictor.predict_bb(seq=seq, threshold=0.1, topk=1, perc_cutoff=0)
pred_bb_stem = pd.DataFrame(pred_bb_stem)
pred_bb_iloop = pd.DataFrame(pred_bb_iloop)
pred_bb_hloop = pd.DataFrame(pred_bb_hloop)

In [None]:
# predict bb by thresholding on n_proposal
pred_bb_stem_2, pred_bb_iloop_2, pred_bb_hloop_2 = pred_threshold_on_n_proposal(seq, predictor, threshold=0.5)

In [None]:
pd.set_option('display.max_rows', 100)

In [None]:
pred_bb_hloop

In [None]:
pred_bb_hloop_2

In [None]:
pairs, bp_arr = one_idx2arr(one_idx, len(seq), remove_lower_triangular=True)
db_str, amb = arr2db(bp_arr)
print(amb)

In [None]:
print(">test_seq\n{}\n{}".format(seq, db_str))

In [None]:
# find bbs that are missed by both methods

df_tmp = pd.concat([pred_bb_stem[['bb_x', 'bb_y', 'siz_x', 'siz_y']], 
           pred_bb_stem_2[['bb_x', 'bb_y', 'siz_x', 'siz_y']]]).drop_duplicates()
df_tmp['pred'] = 1
df_tmp = pd.merge(df_target[df_target['bb_type'] == 'stem'], df_tmp, how='left')
print(df_tmp[df_tmp['pred'].isna()])

In [None]:
df_tmp = pd.concat([pred_bb_iloop[['bb_x', 'bb_y', 'siz_x', 'siz_y']], 
           pred_bb_iloop_2[['bb_x', 'bb_y', 'siz_x', 'siz_y']]]).drop_duplicates()
df_tmp['pred'] = 1
df_tmp = pd.merge(df_target[df_target['bb_type'] == 'iloop'], df_tmp, how='left')
print(df_tmp[df_tmp['pred'].isna()])

In [None]:
df_tmp = pd.concat([pred_bb_hloop[['bb_x', 'bb_y', 'siz_x', 'siz_y']], 
           pred_bb_hloop_2[['bb_x', 'bb_y', 'siz_x', 'siz_y']]]).drop_duplicates()
df_tmp['pred'] = 1
df_tmp = pd.merge(df_target[df_target['bb_type'] == 'hloop'], df_tmp, how='left')
print(df_tmp[df_tmp['pred'].isna()])