Implementation of the class

In [88]:
import math

class DotbComparator:
    def __init__(self, ref_dotb, ref_sequence=None, canonical_only=False):
        self.ref_dotb = ref_dotb
        self.ref_sequence = ref_sequence
        self.canonical_only = canonical_only
        self.ref_pairs = self._get_pairs(self.ref_dotb)

    def compare_structures(self, pred_dotb):

        if self.canonical_only:
            self.ref_dotb = self.remove_noncanonical_pairs(
                 self.ref_sequence, self.ref_dotb)
            pred_dotb = self.remove_noncanonical_pairs(
                 self.ref_sequence, pred_dotb)
            
            print(pred_dotb)

        if len(pred_dotb) != len(self.ref_dotb):
            raise ValueError(
                "Reference and prediction structures have different lengths")
        if not self._is_balanced(self.ref_dotb) or not self._is_balanced(pred_dotb):
            raise ValueError(
                "Reference and prediction structures are not balanced")

        pred_pairs = self._get_pairs(pred_dotb)

        TP = len(self.ref_pairs & pred_pairs)
        FP = len(pred_pairs - self.ref_pairs)
        FN = len(self.ref_pairs - pred_pairs)
        TN = len(self.ref_dotb) * (len(self.ref_dotb) - 1) // 2 - TP - FP - FN

        sens = self.rounded_num(TP / (TP + FN) if TP + FN > 0 else 0)
        PPV = TP / (TP + FP) if TP + FP != 0 else 0
        MCC = self.rounded_num(
            (math.sqrt(PPV * sens) if PPV * sens != 0 else 0))
        MCC_star = (TP * TN - FP * FN) / math.sqrt((TP + FP) * (TP + FN)
                                                   * (TN + FP)) if (TP + FP) * (TP + FN) * (TN + FP) != 0 else 0

        return {
            "TP": TP,
            "FP": FP,
            "FN": FN,
            "TN": TN,
            "sens": sens,
            "PPV": PPV,
            "MCC": MCC,
            "MCC_star": MCC_star,
        }

    def _get_pairs(self, dotb):
        stack = []
        pairs = set()
        open_brackets = "([{"
        close_brackets = ")]}"
        brackets = dict(zip(close_brackets, open_brackets))
        for i, c in enumerate(dotb):
            if c in open_brackets:
                stack.append(i)
            elif c in close_brackets:
                if not stack:
                    # unbalanced closing bracket
                    return None
                j = stack.pop()
                if brackets[c] != dotb[j]:
                    # mismatched brackets
                    return None
                pairs.add(frozenset({j, i}))
        if stack:
            # unbalanced opening bracket
            return None
        return pairs

    def _is_balanced(self, exp):
        stack = []
        opening_brackets = "([{<"
        closing_brackets = ")]}>"
        pairs = {"(": ")", "[": "]", "{": "}", "<": ">", ".": "."}
        for char in exp:
            if char in opening_brackets:
                stack.append(char)
            elif char in closing_brackets:
                if not stack or pairs[stack.pop()] != char:
                    return False
            elif char == ".":
                pass
            else:
                return False
        return len(stack) == 0


    def rounded_num(self, num):
        return math.floor(num * 100) / 100


    def remove_noncanonical_pairs(self, sequence, dotbracket):
      """
      Remove non-canonical base pairs from dot bracket notation.

      Parameters:
      dotbracket (str): Dot bracket notation of RNA secondary structure.

      Returns:
      str: Dot bracket notation of RNA secondary structure with non-canonical pairs removed.
      """
      pairs = self._get_pairs(dotbracket)
      canonical_pairs = {"AU", "UA", "GC", "CG", "GU", "UG"}
      new_pairs = []
      for i, j in pairs:
          pair = sequence[i] + sequence[j]
          if pair in canonical_pairs:
              new_pairs.append((i, j))
      new_dotbracket = ""
      for i, c in enumerate(dotbracket):
          if i in [pair[0] for pair in new_pairs]:
              j = [pair[1] for pair in new_pairs if pair[0] == i][0]
              pair = sequence[j] + sequence[i]
              if pair not in canonical_pairs:
                  new_dotbracket += "."
              else:
                  new_dotbracket += c
          else:
              new_dotbracket += c
      return new_dotbracket






def compare_predictions(seq, ref, predictions, canonical_only):
    comparator = DotbComparator(ref, seq, canonical_only)
    results = {}
    for name, prediction in predictions.items():
        result = comparator.compare_structures(prediction)
        results[name] = [result["TP"], result["FP"], result["FN"], result["TN"],
                         result["sens"], result["PPV"], result["MCC"], result["MCC_star"]]
    return results



**Testing**

In [89]:
# test

seq = "GGACUCGGGGUGCCCUUCUGCGUGAAGGCUGAGAAAUACCCGUAUCACCUGAUCUGGAUAAUGCCAGCGUAGGGAAGUUC"

xray = "((((((((((((((.((((...)))))))(...).).))))).(..(((((..((((......)))).).))))))))))"

prog1 = "(((((((((..((.((((.....)))))).........))))....((((((.((((((.)).))))).))))).)))))"
prog2 = ".((..((((..(((.(((.....)))))).........))))..)).((((..((((......))))..))))......."
prog3 = "((((((((((..))))((((((((..(((((.(((((((..))))))))(((((..)))))))))))))))))).)))))"

predictions = {"xray": xray, "prog1": prog1, "prog2": prog2, "prog3": prog3}
results = compare_predictions(seq, xray, predictions, False)

print("        MCC    TP   FP   FN   TN  sens  PPV")
for name, values in results.items():
    print(
        f"{name:6}  {values[6]:.2f}  {values[0]:3}  {values[1]:3}  {values[2]:3}  {values[3]:4}  {values[4]:.2f}  {values[5]:.2f}")

((((((((((((((.((((...)))))))(...).).))))).(..(((((..((((......)))).).))))))))))
(((((((((..((.((((.....)))))).........))))....((((((.((((((.)).))))).))))).)))))
.((..((((..(((.(((.....)))))).........))))..)).((((..((((......))))..)))).......
((((((((((..))))((((((((..(((((.(((((((..))))))))(((((..)))))))))))))))))).)))))
        MCC    TP   FP   FN   TN  sens  PPV
xray    1.00   29    0    0  3131  1.00  1.00
prog1   0.78   22    5    7  3126  0.75  0.81
prog2   0.70   17    3   12  3128  0.58  0.85
prog3   0.15    5   30   24  3101  0.17  0.14
