In [None]:
import pickle
import pandas as pd
from collections import namedtuple

In [None]:
from util_s2_gnn import make_dataset, roc_prc

In [None]:
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score

In [None]:

from s2_train_gnn_10 import Net, make_target, kmer_int

In [None]:
import torch
import numpy as np

In [None]:
import plotly.express as px

In [None]:
model = Net([20, 20, 40, 40, 40, 40, 40, 50, 50, 50, 50, 50, 50, 50, 100, 100, 100, 100, 100],
            k=3, embed_dim=50)
model.load_state_dict(torch.load('result/s2_gnn_run_21.model_ckpt_ep_59.pth', 
                                 map_location=torch.device('cpu')))
model.eval()

In [None]:
df = pd.read_pickle('data/s2_test_len20_200_1000_pred_stem.pkl.gz')

In [None]:
df.head(1)

In [None]:
data_test = make_dataset(df, make_target,
                                lambda x: kmer_int(x, k=3),
                                include_s1_feature=True, s1_feature_dim=6)

In [None]:
aurocs = []
auprcs = []

thresholds = np.linspace(0, 1, 50)
f1_scores = {x: [] for x in thresholds}

for data_idx, data in enumerate(data_test):
    print(data_idx)
    pred = model(data)
    seq = df.iloc[data_idx].seq
    
    idx_triu = np.triu_indices(len(seq))
    yp_triu = pred.detach().numpy() * data['m']
    yp_triu = yp_triu[idx_triu]
    yt_triu = data['y'] * data['m']
    yt_triu = yt_triu[idx_triu]
    
    aurocs.append(roc_auc_score(y_score=yp_triu, y_true=yt_triu))
    auprcs.append(average_precision_score(y_score=yp_triu, y_true=yt_triu))
    
    for threshold in thresholds:
        f1s = f1_score(y_pred=(yp_triu > threshold), y_true=yt_triu)
        f1_scores[threshold].append(f1s)
    

In [None]:
fig = px.scatter(x=aurocs, y=auprcs)
fig.update_layout(
    xaxis_title="auroc_overall",
    yaxis_title="auprc_overall",
)

In [None]:
df_f1_scores = []
for k, v in f1_scores.items():
    df_f1_scores.append({
        'threshold': k,
        'mean': np.mean(v),
        'std': np.std(v)
    })
df_f1_scores = pd.DataFrame(df_f1_scores)

In [None]:
fig = px.scatter(df_f1_scores, x="threshold", y="mean", error_y="std")
fig.update_layout(
    yaxis_title="F1 score",
)
fig.show()

In [None]:
data_idx = 456
data = data_test[data_idx]

In [None]:
# data

In [None]:
# TmpData = namedtuple("TmpData", ["x", "edge_index", "edge_attr"])

# tmp_data = TmpData(x=torch.LongTensor(data['x']),
#                   edge_index=torch.from_numpy(data['edge_index']).long(),
#                   edge_attr=torch.from_numpy(data['edge_attr']).float())
pred = model(data)

In [None]:
px.imshow(pred.detach().numpy() * data['m'])

In [None]:
px.imshow(data['y'] * data['m'])

In [None]:
seq = df.iloc[data_idx].seq

In [None]:
idx_triu = np.triu_indices(len(seq))
yp_triu = pred.detach().numpy() * data['m']
yp_triu = yp_triu[idx_triu]
yt_triu = data['y'] * data['m']
yt_triu = yt_triu[idx_triu]

In [None]:
roc_auc_score(y_score=yp_triu, y_true=yt_triu)

In [None]:
for threhold in np.linspace(0, 1, 50):
    f1s = f1_score(y_pred=(yp_triu > threhold), y_true=yt_triu)
    print(threhold, f1s)