In [2]:
import numpy as np
import pandas as pd
from pprint import pprint
from collections import namedtuple
from sklearn.metrics import f1_score
from utils.util_global_struct import process_bb_old_to_new
from utils.rna_ss_utils import arr2db, one_idx2arr, compute_fe
from utils.inference_s2 import Predictor, process_row_bb_combo, stem_bbs2arr
from utils.misc import add_column
import torch
import torch.nn as nn

  from tqdm.autonotebook import tqdm


In [3]:
import plotly.express as px

In [4]:
df = pd.read_pickle('../2021_06_15/data/data_len60_test_1000_s1_stem_bb_combos.pkl.gz')

In [6]:
TOPK=100  # debug!

In [7]:
f1s_all = []
for _, row in df.iterrows():
    seq, df_valid_combos, bb_combos, target_bbs, target_bb_inc, target_in_combo, target_in_topk = process_row_bb_combo(row, TOPK)
    
    # for now only check those with target_in_topk == True,
    # debug: sanity check this result in f1=1
    if target_in_topk:

        # extract prediction of target (won't error out since target_in_topk == True)
        idx_tgt = next(i for i, bb_combo in enumerate(bb_combos) if set(bb_combo) == set(target_bbs))
        
        # use RNAfold to compute FE
        db_str_tgt, _ = arr2db(stem_bbs2arr(bb_combos[idx_tgt], len(seq)))
        db_str_all = [arr2db(stem_bbs2arr(bb_combo, len(seq)))[0] for bb_combo in bb_combos]
        fe_tgt = compute_fe(seq, db_str_tgt)
        fe_all = [compute_fe(seq, db_str) for db_str in db_str_all]
        
        idx_best = np.nanargmin(fe_all)
        
        # sanity check for this particular case, FE should equal
        assert fe_tgt == fe_all[idx_best]
        # but structure might not!
        if idx_tgt != idx_best:
            print("FE equal but structure differ:")
            print("target:")
            pprint(target_bbs)
            print("predicted:")
            pprint(bb_combos[idx_best])
            break
        
        target_bps = stem_bbs2arr(target_bbs, len(seq))
        pred_bps = stem_bbs2arr(bb_combos[idx_best], len(seq))
        idx = np.triu_indices(len(seq))
        f1s = f1_score(y_pred=pred_bps[idx], y_true=target_bps[idx])
        print(f"Target FE {fe_tgt}, best in topk FE {fe_all[idx_best]}, f1={f1s}")
        print("")
   
    
    

Target FE -16.4, best in topk FE -16.4, f1=1.0

FE equal but structure differ:
target:
[BoundingBox(bb_x=4, bb_y=48, siz_x=3, siz_y=3),
 BoundingBox(bb_x=9, bb_y=43, siz_x=7, siz_y=7),
 BoundingBox(bb_x=21, bb_y=36, siz_x=3, siz_y=3)]
predicted:
[BoundingBox(bb_x=2, bb_y=9, siz_x=2, siz_y=2),
 BoundingBox(bb_x=4, bb_y=48, siz_x=3, siz_y=3),
 BoundingBox(bb_x=9, bb_y=43, siz_x=7, siz_y=7),
 BoundingBox(bb_x=21, bb_y=36, siz_x=3, siz_y=3)]


In [8]:
idx_tgt

85

In [9]:
idx_best

12

In [10]:
target_bbs

[BoundingBox(bb_x=4, bb_y=48, siz_x=3, siz_y=3),
 BoundingBox(bb_x=9, bb_y=43, siz_x=7, siz_y=7),
 BoundingBox(bb_x=21, bb_y=36, siz_x=3, siz_y=3)]

In [11]:
bb_combos[idx_tgt]

[BoundingBox(bb_x=4, bb_y=48, siz_x=3, siz_y=3),
 BoundingBox(bb_x=9, bb_y=43, siz_x=7, siz_y=7),
 BoundingBox(bb_x=21, bb_y=36, siz_x=3, siz_y=3)]

In [12]:
bb_combos[idx_best]

[BoundingBox(bb_x=2, bb_y=9, siz_x=2, siz_y=2),
 BoundingBox(bb_x=4, bb_y=48, siz_x=3, siz_y=3),
 BoundingBox(bb_x=9, bb_y=43, siz_x=7, siz_y=7),
 BoundingBox(bb_x=21, bb_y=36, siz_x=3, siz_y=3)]

In [15]:
db_str_tgt, _ = arr2db(stem_bbs2arr(bb_combos[idx_tgt], len(seq)))
print(db_str_tgt)

....(((..(((((((.....(((..........))))))))))..)))...........


In [16]:
db_str_pred, _ = arr2db(stem_bbs2arr(bb_combos[idx_best], len(seq)))
print(db_str_pred)

..[[(((.](((((((.....(((..........))))))))))..)))...........


In [17]:
compute_fe(seq, db_str_tgt)

-7.7

In [18]:
compute_fe(seq, db_str_pred)

-7.7

In [19]:
seq

'CCCAAGGATGATTGTACCCTCAGCACTCAAGACCGCTTGCGGTTCCCCTACACACTTTTT'

In [21]:
bb_combo_debug = bb_combos[idx_best]

In [22]:
bb_combo_debug

[BoundingBox(bb_x=2, bb_y=9, siz_x=2, siz_y=2),
 BoundingBox(bb_x=4, bb_y=48, siz_x=3, siz_y=3),
 BoundingBox(bb_x=9, bb_y=43, siz_x=7, siz_y=7),
 BoundingBox(bb_x=21, bb_y=36, siz_x=3, siz_y=3)]

In [28]:
bb_combo_debug[0]

BoundingBox(bb_x=2, bb_y=9, siz_x=2, siz_y=2)

In [29]:
bb_combo_debug[2]

BoundingBox(bb_x=9, bb_y=43, siz_x=7, siz_y=7)

In [26]:
# old implementation, from https://github.com/PSI-Lab/alice-sandbox/blob/cb56ce55d65db772375aa3e8d5304afcf9aaefbe/meetings/2021_06_15/utils_s2_tree_search.py#L17

def range_overlap(r1, r2):
    if r2[0] < r1[1] <= r2[1] or r1[0] < r2[1] <= r1[1]:
        return True
    else:
        return False

def bb_conflict_old(bb1, bb2):
    r11 = (bb1.bb_x, bb1.bb_x + bb1.siz_x)
    r12 = (bb1.bb_y - bb1.siz_y, bb1.bb_y)
    r21 = (bb2.bb_x, bb2.bb_x + bb2.siz_x)
    r22 = (bb2.bb_y - bb2.siz_y, bb2.bb_y)
    print(r11, r12, r21, r22)
    if range_overlap(r11, r21) or range_overlap(r11, r22) or range_overlap(r12, r21) or range_overlap(r12, r22):
        return True
    else:
        return False

In [30]:
def bb_conflict_new(bb1, bb2):
    r11 = (bb1.bb_x, bb1.bb_x + bb1.siz_x)
    r12 = (bb1.bb_y - bb1.siz_y + 1, bb1.bb_y + 1)
    r21 = (bb2.bb_x, bb2.bb_x + bb2.siz_x)
    r22 = (bb2.bb_y - bb2.siz_y + 1, bb2.bb_y + 1)
    print(r11, r12, r21, r22)
    if range_overlap(r11, r21) or range_overlap(r11, r22) or range_overlap(r12, r21) or range_overlap(r12, r22):
        return True
    else:
        return False

In [27]:
bb_conflict_old(bb_combo_debug[0], bb_combo_debug[2])

(2, 4) (7, 9) (9, 16) (36, 43)


False

In [31]:
bb_conflict_new(bb_combo_debug[0], bb_combo_debug[2])

(2, 4) (8, 10) (9, 16) (37, 44)


True