In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff

In [None]:
from sklearn.metrics import classification_report

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

In [None]:
from functools import reduce
from itertools import cycle

In [None]:
import dgutils.pandas as dgp

In [None]:
from model_utils.utils_model import Evaluator

In [None]:
import model_utils.utils_s2 as us2 # TODO merge s2 util
from model_utils.utils_nn_s2 import predict_wrapper

In [None]:
# TODO move to util
def stem2db_str(df_stem, seq_len):
    bracket_pairs = cycle([('(', ')'), ('[', ']'), ('{', '}')])
    
    db_str = ['.'] * seq_len
    for _, row in df_stem.iterrows():
        bb_x = int(row['bb_x'])
        bb_y = int(row['bb_y'])
        siz = int(row['siz_x'])
        siz_y = int(row['siz_y'])
        assert siz == siz_y
        bracket_pair = next(bracket_pairs)  # py3
        for i in range(siz):
            db_str[bb_x+i] = bracket_pair[0]
            db_str[bb_y-i] = bracket_pair[1]
    return ''.join(db_str)

In [None]:
# TODO move to util
def stem2arr(df_stem, seq_len, flatten_triu=True):
    x = np.zeros((seq_len, seq_len))
    for _, row in df_stem.iterrows():
        bb_x = int(row['bb_x'])
        bb_y = int(row['bb_y'])
        siz = int(row['siz_x'])
        siz_y = int(row['siz_y'])
        assert siz == siz_y
        for i in range(siz):
            i1 = bb_x+i
            i2 = bb_y-i
            # FIXME should not happen! s1 inference should prune these
            if i1 < 0 or i1 >= seq_len or i2 <0 or i2 >= seq_len:
                print("Skip out-of-range base-pair {}-{}".format(i1, i2))
                continue
            x[i1, i2] = 1
            x[i2, i1] = 1
    # extract upper triangle and flatten if option is set, useful for evaluation
    if flatten_triu:
        return x[np.triu_indices(seq_len)]
    else:
        return x

In [None]:
def compute_metrics(seq_len, df_pred, bounding_boxes, convert_tl_to_tr=True):
    
    def perf_measure(y_actual, y_hat):
        TP = 0
        FP = 0
        TN = 0
        FN = 0

        for i in range(len(y_hat)): 
            if y_actual[i]==y_hat[i]==1:
                TP += 1
            elif y_hat[i]==1 and y_actual[i]!=y_hat[i]:
                FP += 1
            elif y_actual[i]==y_hat[i]==0:
                TN += 1
            elif y_hat[i]==0 and y_actual[i]!=y_hat[i]:
                FN += 1
            else:
                raise ValueError

        return TP, FP, TN, FN
    
    df_target_stem, df_target_iloop, df_target_hloop = evaluator.make_target_bb_df(bounding_boxes, convert_tl_to_tr=True)
    x_pred = stem2arr(df_pred[df_pred['bb_type'] == 'stem'], seq_len, flatten_triu=True)
    x_target = stem2arr(df_target_stem, seq_len, flatten_triu=True)
    
    TP, FP, TN, FN = perf_measure(x_target, x_pred)
    
    sensitivity = float(TP)/(TP+FN)
    ppv = float(TP)/(TP+FP)
    return sensitivity, ppv
    
#     report_dict = classification_report(x_target, x_pred, output_dict=True)
#     raise ValueError
    
    # in binary classification, recall of the positive class is 
    # also known as “sensitivity”; recall of the negative class is “specificity”.
#     return report_dict[1]['recall']
#     return report_dict['weighted avg']['precision'], report_dict['weighted avg']['recall'], report_dict['weighted avg']['f1-score'], report_dict['accuracy']

In [None]:
def display_ss_graph(df_stem, seq):
    G = nx.Graph()
    nodes = []
    for i, base in enumerate(seq):
        nodes.append((i, {"label": base}))
    G.add_nodes_from(nodes)
    # backbone
    for i in range(len(seq)-1):
        G.add_edge(i, i+1)
    # hydrogen bonds
    for _, row in df_stem.iterrows():
        bb_x = int(row['bb_x'])
        bb_y = int(row['bb_y'])
        siz = int(row['siz_x'])
        siz_y = int(row['siz_y'])
        assert siz == siz_y
        for i in range(siz):
            G.add_edge(bb_x+i,bb_y-i)
    return G
#     G.add_edge(1,2)
#     G.add_edge(1,3)
#     nx.draw(G, with_labels=True)
#     plt.show()

In [None]:
# predictor_s2 = us2.Predictor('v0.2')

# predictor_s2 = us2.Predictor('s2_training/result/synthetic_s2_5000/model_ckpt_ep_34.pth')

predictor_s2 = us2.Predictor('v0.3')

In [None]:
df = pd.read_pickle('../2021_01_12/data/synthetic_s1_pred_1000_t0p1_k1.pkl.gz')

In [None]:
evaluator = Evaluator(predictor=None)   # using static utils

In [None]:
def pred_row(seq, bb_stem, bb_iloop, bb_hloop):
    uniq_stem = pd.DataFrame(bb_stem)
    uniq_iloop = pd.DataFrame(bb_iloop)
    uniq_hloop = pd.DataFrame(bb_hloop)
    df_pred = predict_wrapper(uniq_stem, uniq_iloop, uniq_hloop, 
                              discard_ns_stem=True, min_hloop_size=2, 
                              seq=seq, m_factor=1, predictor=predictor_s2)
    return df_pred

In [None]:
# make prediction
df = dgp.add_column(df, 'df_pred', ['seq', 'bb_stem', 'bb_iloop', 'bb_hloop'], pred_row, pbar=True)

In [None]:
# compute metrics
df = dgp.add_columns(df, ['sensitivity', 'ppv'], 
                     ['len', 'df_pred', 'bounding_boxes'], 
                     # setting convert_tl_to_tr to True since this particular dataset's ground truth is in old format
                     lambda seq_len, df_pred, bounding_boxes: compute_metrics(seq_len, df_pred, bounding_boxes, convert_tl_to_tr=True))

In [None]:

px.scatter(df, x='sensitivity', y='ppv',
          marginal_x='violin', marginal_y='violin')

In [None]:
# for convenience

def print_db_str(seq, bounding_boxes, df_pred):
    df_target_stem, df_target_iloop, df_target_hloop = evaluator.make_target_bb_df(bounding_boxes, convert_tl_to_tr=True)
    
    print('>s1')
    print(seq)
    db_str_target = stem2db_str(df_target_stem, len(seq))
    print(db_str_target)

    print('>s2')
    print(seq)
    db_str_pred = stem2db_str(df_pred[df_pred['bb_type'] == 'stem'], len(seq))
    print(db_str_pred)


In [None]:
# find example with various metric

# top: sensitivity == ppv == 1
row = df[(df['sensitivity'] == 1) & (df['ppv']==1)].sample().iloc[0]
print_db_str(row['seq'], row['bounding_boxes'], row['df_pred'])


In [None]:
# high: sensitivity >= 0.9, ppv >= 0.9 (but != 1)
row = df[(df['sensitivity'] >= 0.9) & (df['sensitivity'] != 1) & (df['ppv'] >= 0.9) & (df['ppv'] != 1)].sample().iloc[0]
print_db_str(row['seq'], row['bounding_boxes'], row['df_pred'])

In [None]:
# mid: 0.4 - 0.6
row = df[(df['sensitivity'] >= 0.4) & (df['sensitivity'] < 0.6) & (df['ppv'] >= 0.4) & (df['ppv'] < 0.6)].sample().iloc[0]
print_db_str(row['seq'], row['bounding_boxes'], row['df_pred'])

In [None]:
# low: < 0.1
row = df[(df['sensitivity'] < 0.1) & (df['ppv'] < 0.1)].sample().iloc[0]
print_db_str(row['seq'], row['bounding_boxes'], row['df_pred'])

In [None]:
G = display_ss_graph(df_pred[df_pred['bb_type'] == 'stem'], seq)
nx.draw(G, with_labels=True)
#     plt.show()