In [2]:
%load_ext autoreload
%autoreload 2

import os
import torch
import pickle
import warnings
from sklearn.metrics import f1_score

### utils

In [3]:
import numpy as np

def convert_dotbracket_to_matrix(s):
    m = np.zeros([len(s), len(s)])
    for char_set in [['(', ')'], ['[', ']'], ['{', '}'], ['<', '>']]:
        bp1 = []
        bp2 = []
        for i, char in enumerate(s):
            if char == char_set[0]:
                bp1.append(i)
            if char == char_set[1]:
                bp2.append(i)
        for i in list(reversed(bp1)):
            for j in bp2:
                if j > i:
                    m[i, j] = 1.0
                    bp2.remove(j)
                    break
    return m + m.T


def convert_matrix_to_dotbracket(m):
    bp_list = convert_matrix_to_bp_list(m)
    return convert_bp_list_to_dotbracket(bp_list, len(m))


def convert_matrix_to_bp_list(m):
    bp_list = []  # convert adjacency matrix to adjacency list
    for i, row in enumerate(m):
        for j, is_bp in enumerate(row[i + 1:]):
            if is_bp:
                bp_list.append((i, i + 1 + j))
    return bp_list


def convert_bp_list_to_dotbracket(bp_list,seq_len):
    dotbracket = "."*seq_len
    # group into bps that are not intertwined and can use same brackets!
    groups = group_into_non_conflicting_bp_(bp_list)

    # all bp that are not intertwined get (), but all others are
    # groups to be nonconflicting and then asigned (), [], {}, <> by group
    chars_set = [("(", ")"), ("(", ")"), ("[", "]"), ("{", "}"), ("<", ">")]
    if len(groups) > len(chars_set):
        print(f"WARNING: PK too complex with {len(groups)} groups, not enough brackets to represent it.")

    for group,chars in zip(groups,chars_set):
        for bp in group:
            dotbracket = dotbracket[:bp[0]] + chars[0] + dotbracket[bp[0]+1:bp[1]] + chars[1] + dotbracket[bp[1]+1:]
    return dotbracket


def load_matrix_or_dbn(s):
    num_lines = sum(1 for line in open(s))

    if num_lines > 2:  # heuristic here
        struct = np.loadtxt(s)  # load as base pair matrix
        assert struct.shape[0] == struct.shape[1]
    else:
        try:  # load as dot-bracket string

            dbn_struct = open(s, 'r').read().rstrip()

            struct = convert_dotbracket_to_matrix(dbn_struct)
        except:
            raise ValueError('Unable to parse structure %s' % s)
    return struct


def group_into_non_conflicting_bp_(bp_list):
    ''' given a conflict list from get_list_bp_conflicts_, group basepairs into groups that do not conflict

    Args
        conflict_list: list of pairs of base_pairs that are intertwined basepairs

    Returns:
        groups of baspairs that are not intertwined
    '''
    conflict_list = get_list_bp_conflicts_(bp_list)

    non_redudant_bp_list = get_non_redudant_bp_list_(conflict_list)
    bp_with_no_conflict = [bp for bp in bp_list if bp not in non_redudant_bp_list]
    groups = [bp_with_no_conflict]
    while non_redudant_bp_list != []:
        current_bp = non_redudant_bp_list[0]
        current_bp_conflicts = []
        for conflict in conflict_list:
            if current_bp == conflict[0]:
                current_bp_conflicts.append(conflict[1])
            elif current_bp == conflict[1]:
                current_bp_conflicts.append(conflict[0])
        group = [bp for bp in non_redudant_bp_list if bp not in current_bp_conflicts]
        groups.append(group)
        non_redudant_bp_list = current_bp_conflicts
        conflict_list = [conflict for conflict in conflict_list if
                         conflict[0] not in group and conflict[1] not in group]
    return groups


def get_list_bp_conflicts_(bp_list):
    '''given a bp_list gives the list of conflicts bp-s which indicate PK structure
    Args:
        bp_list: of list of base pairs where the base pairs are list of indeces of the bp in increasing order (bp[0]<bp[1])
    returns:
        List of conflicting basepairs, where conflicting is pairs of base pairs that are intertwined.
    '''
    if len(bp_list) <= 1:
        return []
    else:
        current_bp = bp_list[0]
        conflicts = []
        for bp in bp_list[1:]:
            if (bp[0] < current_bp[1] and current_bp[1] < bp[1]):
                conflicts.append([current_bp, bp])
        return conflicts + get_list_bp_conflicts_(bp_list[1:])


def get_non_redudant_bp_list_(conflict_list):
    ''' given a conflict list get the list of nonredundant basepairs this list has

    Args:
        conflict_list: list of pairs of base_pairs that are intertwined basepairs
    returns:
        list of basepairs in conflict list without repeats
    '''
    non_redudant_bp_list = []
    for conflict in conflict_list:
        if conflict[0] not in non_redudant_bp_list:
            non_redudant_bp_list.append(conflict[0])
        if conflict[1] not in non_redudant_bp_list:
            non_redudant_bp_list.append(conflict[1])
    return non_redudant_bp_list

import re
import math
import numpy as np

def adjacency_matrix_to_bpseq(adj_matrix):
    L = len(adj_matrix)
    bpseq = [-1] * L  # 初始化bpseq列表，假设所有位置最初都没有配对

    for i in range(L):
        for j in range(L):
            if adj_matrix[i][j] == 1:
                bpseq[i] = j
                break  # 当找到配对时，跳出内层循环

    return bpseq

def compute_expected_accuracy(etp, efp, efn):
    sen = ppv = f = 0.
    if etp + efn != 0:
        sen = etp / (etp + efn)
    else:
        sen = 0.
        
    if etp + efp != 0:
        ppv = etp / (etp + efp)
    else:
        ppv = 0.
        
    if sen + ppv != 0:
        f = 2 * sen * ppv / (sen + ppv)
    else:
        f = 0.

    return (sen, ppv, f)

def compute_expected_accuracy_pk(pred, label):
    
    # L = len(label)
    # L2 = L * (L - 1) // 2
    N = 0
    pk_flag = False
    
    sump = 0.0
    etp = 0.0
    
    for i in range(len(label)):
        j = label[i]
        if i < j:
            for k in range(i + 1, j):
                l = label[k]
                lp = pred[k]
                if j < lp:
                    N += 1 # TP+FP
                if j < l:
                    sump += 1 # TP+FN
                    if pred[i] == j and pred[k] == l:
                        etp += 1 # TP
    
    efp = N - etp
    efn = sump - etp
    
    if sump > 0:
        pk_flag = True

    return compute_expected_accuracy(etp, efn, efp), pk_flag



def apc(x):
    "Perform average product correct, used for contact prediction."
    a1 = x.sum(-1, keepdims=True)
    a2 = x.sum(-2, keepdims=True)
    a12 = x.sum((-1, -2), keepdims=True)

    avg = a1 * a2
    avg.div_(a12)  # in-place to reduce memory
    normalized = x - avg
    return normalized

## bpRNA

### Uni-RNA

In [4]:
ckpt_fm = "/home/wangxi/data_workspace/pseudoknots/log/rnafm-bpRNA-TR0-Auto-_home_wangxi_develop_rnafm_transformers_RNA-FM_pretrained-fusion-null-c226a2e9-6b2f86/C3B5X2fZt/checkpoints/epoch-13.pth"
ckpt_unirna = "/home/wangxi/data_workspace/pseudoknots/log/new-model-bpRNA-TR0-Auto-_home_wangxi_develop_unirna_light_unirna_unirna_L16_E1024_DPRNA500M_STEP400K-fusion-null-a943f3b4-dbb1bd/C2LWU25Sc/checkpoints/epoch-24.pth"
data_list = "/home/wangxi/data_workspace/pseudoknots/bpRNA-PK-TS0-1K.pkl"
torch.cuda.set_device(1)

from deepprotein.runners.inferencer import LazyInferencer

if not os.path.exists("result_unirna_bpRNA.pkl"):
    infer = LazyInferencer(ckpt_unirna, batch_size=1, sequence_pretrained="/home/wangxi/develop/unirna_light/weights/unirna_L16_E1024_DPRNA500M_STEP400K")
    result_unirna = infer.run(data_list)
    with open("result_unirna_bpRNA.pkl", "wb+") as f:
        pickle.dump(result_unirna, f)
else:
    with open("result_unirna_bpRNA.pkl", "rb") as f:
        result_unirna = pickle.load(f)
        
if not os.path.exists("result_fm_bpRNA.pkl"):
    infer = LazyInferencer(ckpt_fm, batch_size=1)
    result_fm = infer.run(data_list)
    with open("result_fm_bpRNA.pkl", "wb+") as f:
        pickle.dump(result_fm, f)
else:
    with open("result_fm_bpRNA.pkl", "rb") as f:
        result_fm = pickle.load(f)

warnings.filterwarnings("ignore")
data_label = pickle.load(open(data_list, "rb"))

all_pred_unirna = []
unirna_score = []
unirna_score_cb = []
unirna_score_pk = []

all_pred_fm = []
fm_score = []
fm_score_cb = []
fm_score_pk = []


all_label = []

methods_pred = []
methods_score = []
methods_score_pk = []
methods_score_pk_cb = []

unirna_pk_index = []

unirna_length_list = []
unirna_length_list_pk = []

for i in range(len(data_label)):
    id = data_label[i]['id']
    label = data_label[i]['label']
    all_label.extend(label.reshape(-1).tolist())
    unirna_length_list.append(len(data_label[i]['seq']))
    pred = np.array(result_unirna["label"][i])
    all_pred_unirna.extend(pred.reshape(-1).tolist())
    pred = np.where(pred > 0.5, 1, 0)
    f1 = f1_score(label.reshape(-1), pred.reshape(-1))
    
    unirna_score.append(f1)
    pk_metrics = compute_expected_accuracy_pk(adjacency_matrix_to_bpseq(pred), adjacency_matrix_to_bpseq(label))
                                              
    if pk_metrics[-1]:
        unirna_length_list_pk.append(len(data_label[i]['seq']))
        unirna_score_pk.append(f1)
        unirna_score_cb.append(pk_metrics[0][-1])
        unirna_pk_index.append(i)

    pred_fm = np.array(result_fm["label"][i])
    pred_fm = np.where(pred_fm > 0.5, 1, 0)
    f1 = f1_score(label.reshape(-1), pred_fm.reshape(-1))
    fm_score.append(f1)
    pk_metrics = compute_expected_accuracy_pk(adjacency_matrix_to_bpseq(pred_fm), adjacency_matrix_to_bpseq(label))
    
    if pk_metrics[-1]:
        fm_score_pk.append(f1)
        fm_score_cb.append(pk_metrics[0][-1])

  self.experiment_id = get_git_hash() or defaults.DEFAULT_EXPERIMENT_ID
  return torch.tensor(vals).squeeze(1).int()
100%|██████████| 2914/2914 [00:25<00:00, 115.04it/s]


### IPknot

In [None]:
data_list = "/home/siduanmiao/ipknot_run/ipknot_wx_test/sorted_file/bpRNA-PK-TS0_sorted.pkl"
pred_ipknot = "/home/siduanmiao/ipknot_run/ipknot_wx_test/sorted_file/bpRNA-PK-TS0_predict_sorted.pkl"

data_label = pickle.load(open(data_list, "rb"))
pred_all = pickle.load(open(pred_ipknot, "rb"))

score = []
score_pk = []
score_cb = []
length_list = []
length_list_pk = []

for i in range(len(data_label)):
    id = data_label[i]['id']
    label = data_label[i]['label']
    pred = pred_all[i]["label"]
    length_list.append(len(data_label[i]['seq']))
    pred = np.where(pred > 0.5, 1, 0)
    f1 = f1_score(label.reshape(-1), pred.reshape(-1))
    
    score.append(f1)
    pk_metrics = compute_expected_accuracy_pk(adjacency_matrix_to_bpseq(pred), adjacency_matrix_to_bpseq(label))
                                              
    if pk_metrics[-1]:
        length_list_pk.append(len(data_label[i]['seq']))
        score_pk.append(f1)
        score_cb.append(pk_metrics[0][-1])
        
with open("result_ipknot_bpRNA.pkl", "wb+") as f:
    pickle.dump({
        "score": score,
        "score_pk": score_pk,
        "score_cb": score_cb,
        "length_list": length_list,
        "length_list_pk": length_list_pk
    }, f)

In [5]:
with open("result_ipknot_bpRNA.pkl", "rb") as f:
    result_ipknot = pickle.load(f)

In [6]:
def filter_by_length(data, max_length=1000):
    # 提取长度小于等于1000的点
    filtered_score = [s for s, l in zip(data['score'], data['length_list']) if l <= max_length]
    filtered_score_pk = [s for s, l in zip(data['score_pk'], data['length_list_pk']) if l <= max_length]
    filtered_score_cb = [s for s, l in zip(data['score_cb'], data['length_list_pk']) if l <= max_length]

    # 返回新的字典
    return {
        "score": filtered_score,
        "score_pk": filtered_score_pk,
        "score_cb": filtered_score_cb,
        "length_list": [l for l in data['length_list'] if l <= max_length],
        "length_list_pk": [l for l in data['length_list_pk'] if l <= max_length]
    }

filter_ipknot = filter_by_length(result_ipknot)

### Alternative methods

In [None]:
data_dir = "/home/siduanmiao/benchmark/changeformatdata/"
benchmark_name = "benchmark_8.pkl"
label_path = "/mnt/siduanmiao/RNAstru_benchmark_ipknot/benchmark_8/test.pkl"
all_method = os.listdir(data_dir)
methods_list = []

for method in all_method:
    if os.path.exists(os.path.join(data_dir, method, benchmark_name)):
        methods_list.append(method)
        
methods_list

In [None]:
import os
import pickle 
import numpy as np
import warnings
from sklearn.metrics import f1_score

warnings.filterwarnings("ignore")
data_label = pickle.load(open(label_path, "rb"))

methods_pred = []
methods_score = []
methods_score_pk = []
methods_score_pk_cb = []

length_list = []
length_list_pk = []

for method in methods_list:
    with open(os.path.join(data_dir, method, benchmark_name), "rb") as f:
        pred_all = pickle.load(f)
        
    print(f"processing {method}")
    single_method_score = []
    single_method_score_pk = []
    single_method_score_cb = []
    
    for i in range(len(data_label)):
        id = data_label[i]['id']
        label = data_label[i]['label']
        length_list.append(len(data_label[i]['seq']))
        pred = np.array(pred_all[i]["label"])
        pred = np.where(pred > 0.5, 1, 0)
        
        f1 = f1_score(label.reshape(-1), pred.reshape(-1))
        
        single_method_score.append(f1)
        pk_metrics = compute_expected_accuracy_pk(adjacency_matrix_to_bpseq(pred), adjacency_matrix_to_bpseq(label))
                                                
        if pk_metrics[-1]:
            length_list_pk.append(len(data_label[i]['seq']))
            single_method_score_pk.append(f1)
            single_method_score_cb.append(pk_metrics[0][-1])
            
    methods_score.append(single_method_score)
    methods_score_pk.append(single_method_score_pk)
    methods_score_pk_cb.append(single_method_score_cb)
    
with open("data-for-bpRNA.pkl", "wb+") as f:
    pickle.dump({
        "methods_list": methods_list,
        "methods_score": methods_score,
        "methods_score_pk": methods_score_pk,
        "methods_score_pk_cb": methods_score_pk_cb,
        "length_list": length_list,
        "length_list_pk": length_list_pk
    }, f)

In [9]:
with open("data-for-bpRNA.pkl", "rb") as f:
    methods_result = pickle.load(f)

In [11]:
def filter_dict_by_length(data, max_length=1000):
    # 提取长度小于等于1000的点
    methods_score_filter = []
    methods_score_pk_filter = []
    methods_score_pk_cb_filter = []
    methods_list = data["methods_list"]
    for index in range(len(methods_list)):
        single_method_score = []
        single_method_score_pk = []
        single_method_score_pk_cb = []
        
        single_method_score = [s for s, l in zip(data["methods_score"][index], data["length_list"]) if l <= max_length]
        single_method_score_pk = [s for s, l in zip(data["methods_score_pk"][index], data["length_list_pk"]) if l <= max_length]
        single_method_score_pk_cb = [s for s, l in zip(data["methods_score_pk_cb"][index], data["length_list_pk"]) if l <= max_length]
        
        
        methods_score_filter.append(single_method_score)
        methods_score_pk_filter.append(single_method_score_pk)
        methods_score_pk_cb_filter.append(single_method_score_pk_cb)
        
    # 返回新的字典
    return {
        "methods_list": data["methods_list"],
        "methods_score": methods_score_filter,
        "methods_score_pk": methods_score_pk_filter,
        "methods_score_pk_cb": methods_score_pk_cb_filter,
        "length_list": [l for l in data['length_list'] if l <= max_length],
        "length_list_pk": [l for l in data['length_list_pk'] if l <= max_length]
    }

methods_result_filter = filter_dict_by_length(methods_result)

In [None]:
color_dict = {
        "UniRNA": "#D95F02",
        "mxfold2": "#C06FA9",
        "RNA-FM": "#B78FB2",
        "mxfold": "#97B3C5",
        "linearFold": "#B1A471",
        "RNAfold": "#70A56E",
        "RNAstructure": "#E5ADA8",
        "ProbKnot": "#E96692",
        "contrafold": "#E4B488",
    }

## Comparsion

In [None]:
import pandas as pd
# ipknot
df = pd.DataFrame({"length": result_ipknot["length_list"], "f1": result_ipknot["score"]})
df_pk = pd.DataFrame({"length": result_ipknot["length_list_pk"], "f1": result_ipknot["score_pk"]})
df_cb = pd.DataFrame({"length": result_ipknot["length_list_pk"], "f1": result_ipknot["score_cb"]})

# 根据“length”列划分为三组
def group_by_df(df):
    bins = [0, 150, 500, 1500]
    labels = ['<=150', '151-500', '500-1500']
    df['length_group'] = pd.cut(df['length'], bins=bins, labels=labels)

    # 计算每组“f1”列的平均值
    grouped_df = df.groupby('length_group')['f1'].mean()
    return grouped_df

# Uni-RNA
df = pd.DataFrame({"length": length_list, "f1": unirna_score})
df_pk = pd.DataFrame({"length": length_list_pk, "f1": unirna_score_pk})
df_cb = pd.DataFrame({"length": length_list_pk, "f1": unirna_score_cb})

# 根据“length”列划分为三组
def group_by_df(df):
    bins = [0, 150, 500, 1000]
    labels = ['<=150', '151-500', '>500']
    df['length_group'] = pd.cut(df['length'], bins=bins, labels=labels)

    # 计算每组“f1”列的平均值
    grouped_df = df.groupby('length_group')['f1'].mean()
    return grouped_df

In [18]:
len(unirna_length_list)

2914

In [12]:
print("Uni-RNA: ", np.mean(unirna_score), np.mean(unirna_score_pk), np.mean(unirna_score_cb))
print("RNA-FM: ", np.mean(fm_score), np.mean(fm_score_pk), np.mean(fm_score_cb))
print("ipknot", np.mean(filter_ipknot["score"]), np.mean(filter_ipknot["score_pk"]), np.mean(filter_ipknot["score_cb"]))

for i in range(len(methods_result_filter["methods_list"])):
    print(methods_result_filter["methods_list"][i], np.mean(methods_result_filter["methods_score"][i]), np.mean(methods_result_filter["methods_score_pk"][i]), np.mean(methods_result_filter["methods_score_pk_cb"][i]))

Uni-RNA:  0.6437240421502223 0.5299536279544587 0.27277472472767167
RNA-FM:  0.557692691257335 0.4093522664761831 0.22080616411663256
ipknot 0.5181428947763839 0.4504619597344123 0.05438228276374786
ProbKnot 0.4899316384006009 0.42580240467625496 0.0210065269189357
mxfold 0.5183789697411939 0.4723626099453287 0.0
RNAstructure 0.4866155359567672 0.40278444067066416 0.0
SPOTRNA 0.582678901620983 0.4852052571285208 0.13525588484395434
linearFold 0.49596426509693303 0.3974546280853714 0.0
RNAfold 0.4910309607991334 0.40404207837945516 0.0
contrafold 0.4789496909254144 0.3933388544445596 0.0
