In [None]:
import numpy as np
import pandas as pd
from pprint import pprint
from collections import namedtuple
from sklearn.metrics import f1_score
from utils.util_global_struct import process_bb_old_to_new
from utils.rna_ss_utils import arr2db, one_idx2arr
from utils.inference_s2 import Predictor, process_row_bb_combo, stem_bbs2arr
from utils.misc import add_column
import torch
import torch.nn as nn

In [None]:
import plotly.express as px

In [None]:
df = pd.read_pickle('data/data_len60_test_1000_s1_stem_bb_combos.pkl.gz')

In [None]:
model = Predictor('result/run_1/model_ckpt_ep_49.pth', num_filters=[16, 16, 32, 32, 64],
                     filter_width=[3, 3, 3, 3, 3],
                     pooling_size=[1, 1, 2, 2, 2])

In [None]:
TOPK=50

In [None]:
f1s_all = []
for _, row in df.iterrows():
    seq, df_valid_combos, bb_combos, target_bbs, target_bb_inc, target_in_combo, target_in_topk = process_row_bb_combo(row, TOPK)
    
    # for now only check those with target_in_topk == True
    if target_in_topk:
        yp = model.predict_bb_combos(seq, bb_combos)
        idx_best_score = yp.argmax()
        
        print("target:")
        pprint(target_bbs)
        print("prediction (best score):")
        pprint(bb_combos[idx_best_score])
        
        target_bps = stem_bbs2arr(target_bbs, len(seq))
        pred_bps = stem_bbs2arr(bb_combos[idx_best_score], len(seq))
        idx = np.triu_indices(len(seq))
        f1s = f1_score(y_pred=pred_bps[idx], y_true=target_bps[idx])
        
        print(f"f1 score: {f1s}")
        print('')
        
        f1s_all.append(f1s)
        
#         break
        
    
    

In [None]:
px.histogram(f1s_all)

In [None]:
f1s_all = pd.DataFrame({'f1_score': f1s_all})

In [None]:
f1s_all.describe()

In [None]:

df_pred_max_tgt = []

for _, row in df.iterrows():
    seq, df_valid_combos, bb_combos, target_bbs, target_bb_inc, target_in_combo, target_in_topk = process_row_bb_combo(row, TOPK)
    
    # check those with target_in_topk == False but target_in_combo == True
    if not target_in_topk and target_in_combo:
        yp = model.predict_bb_combos(seq, bb_combos)
        ypt = model.predict_bb_combos(seq, [target_bbs])
        df_pred_max_tgt.append({
            'yp_target': float(ypt),
            'yp_topk_max': yp.max(),
        })
        
        
        # debug
        if len(df_pred_max_tgt) > 200:
            break

In [None]:
df_pred_max_tgt = pd.DataFrame(df_pred_max_tgt)

In [None]:
df_pred_max_tgt = add_column(df_pred_max_tgt, 'target > topk_max', ['yp_target', 'yp_topk_max'],
                            lambda a, b: a > b)

In [None]:
px.scatter(df_pred_max_tgt, x='yp_target', y='yp_topk_max', color='target > topk_max')