# Evaluating predictive performance

Compare a prediction in the form of a gff-file to curated 'true' annotation (also a gff)


In [None]:
prediction_filename = "september-final-benchmarks/testing_my_ghmm_september16-day2020-09-16_101157.predicted.gff"
truth_filename = "september-final-benchmarks/testing.gff"

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def ranges_overlap(a, A, b, B):
    return a <= B and b <= A

def contains(a, A, b):
    return a <= b and b <= A

In [None]:
feature2index = {"NCS": 0, "CDS": 1, "intron": 2, "stop_codon": 3}
feature_names = ["NCS", "CDS", "intron", "stop"]

# This code is redundant with the code used in some other notebooks, but duplicated to be able to use them on their own
complement = {"A": "T", "T": "A", "G": "C", "C": "G", "N": "N"}
base_order = "TCAG"
codon_list = [a + b + c for a in base_order for b in base_order for c in base_order]

rna_codon_list = [c.replace("T", "U") for c in codon_list]

def codon2index(codon):
    result = base_order.index(codon[0]) * 16
    result += base_order.index(codon[1]) * 4
    result += base_order.index(codon[2])
    return result


def utr2stop_codon(feature):
    """
    Convert a three_prime_utr-feature into a corresponding stop codon
    """
    feature.feature = "stop_codon"
    if feature.strand:
        feature.end = feature.start + 2
    else:
        feature.start = feature.end - 2
    

def rev_comp(sequence: str):
    result = ""
    for i in range(len(sequence)-1, -1, -1):
        result += complement[sequence[i]]
    return result

class Feature:
    def __init__(self, feature: str, start: int, end: int, strand: bool):
        self.start = start
        self.end = end
        self.strand = strand
        self.feature = feature
        
class Gene:
    def __init__(self, strand: bool):
        self.strand = strand
        self.features = [] # a sorted (!) list of CDS, introns and one UTR; sorted from 5' to 3', 
        # Sorting is trusted to the outside world; maybe check in isValid!
        
    def append(self, feature: Feature):
        self.features.append(feature)
        
    def __len__(self):
        return len(self.features)
    
    def __str__(self):
        return " ".join(["%s%s %s %s%s" % ("[" if f.strand else "<", 
                                           f.start, f.feature, f.end,
                                           ">" if f.strand else "]") 
                         for f in self.features])
    def get_range(self):
        return self.features[0].start, self.features[-1].end

    def is_valid(self):
        if len(self.features) == 0:
            print("! The gene is empty")
            return False
        
        utr_index = -1 if self.strand else 0
        first_cds_index = 0 if self.strand else -1
        
        # a Gene must end in an Intron
        if self.features[utr_index].feature not in ["three_prime_UTR", "stop_codon"]:
            print("! 3' Terminal feature is not a UTR, but a %s" % self.features[utr_index].feature)
            return False
        # a Gene must start with a CDS
        if self.features[first_cds_index].feature != "CDS":
            print("! 5' terminal feature is not a CDS, but a %s" % self.features[first_cds_index].feature)
            return False
        
        # Gene must start with CDS, end with UTR (len >= 2 so far), and can have [intron, CDS]-pairs in between 
        # (no two cds adjacent, nor two introns)
        if len(self.features) % 2 != 0:
            print("! uneven number of features")
            return False
        
        num_utrs = 0
        index = 0
        while index < len(self.features):
            if self.features[index].feature in ["three_prime_UTR", "stop_codon"]:
                num_utrs += 1
            if self.features[index+1].feature in ["three_prime_UTR", "stop_codon"]:
                num_utrs += 1
                
            if self.features[index].strand != self.strand or self.features[index+1].strand != self.strand:
                print("! Strand disagreement within the gene")
                return False
            
            if self.strand:
                if not (self.features[index].feature == "CDS" \
                        and self.features[index+1].feature in ["intron", "three_prime_UTR", "stop_codon"]):
                    print("! CDS-intron-CDS-UTR-pattern violated")
                    return False
            else:
                if not (self.features[index+1].feature == "CDS" \
                        and self.features[index].feature in ["intron", "three_prime_UTR", "stop_codon"]):
                    print("! CDS-intron-CDS-UTR-pattern violated")
                    return False
            
            index +=2
            
        if num_utrs != 1:
            print("! Too many utrs: %s" % num_utrs)
            return False   
        
        return True
    
class Annotation:
    def __init__(self, name: str, start: int, end: int):
        self.name = name
        self.start = start
        self.end = end
        self.features = [] # The sorted list of features; could also separate this by strand
        self.genes = []
    
    def add(self, feature: Feature):
        index = 0
        while index < len(self.features):
            if self.features[index].start > feature.end:
                self.features.insert(index, feature)
                return
            index += 1
        self.features.append(feature)
        
    def overlaps(self, start: int, end: int, tolerance=300) -> bool:
        """
        Checks whether the given range [start, end] overlaps with any of the genes held by self, extended by
        tolerance nucleotides to both sides
        """
        for g in self.genes:
            g_start, g_end = g.get_range()
            g_start -= tolerance
            g_end += tolerance
            if g_start <= end and start <= g_end:
                return True
        return False
        
    def compile_genes(self) -> bool:
        all_valid = True
        self.genes = []
        last_end = self.features[0].start - 1
        current_gene = Gene(self.features[0].strand)
        for f in self.features:
            if f.start != last_end + 1: # there is a break
                self.genes.append(current_gene)
                all_valid = all_valid and current_gene.is_valid()
                current_gene = Gene(f.strand)
                
            current_gene.append(f)
            last_end = f.end
        
        if len(current_gene) > 0:
            self.genes.append(current_gene)
            all_valid = all_valid and current_gene.is_valid()
        
        return all_valid
        
        
    def __str__(self):
        return self.name + ("[%s,%s]:\n\t" % (self.start, self.end)) + ("\n\t".join(["[%s, %s]%s%s" % (f.start, f.end,
                                                                                                       f.feature, 
                                                                                                       "+" if f.strand else "-") 
                                                                                     for f in self.features]))

## Collect 'truth'

In [None]:
curated_annotation = {}
with open(truth_filename) as annotation_file:
    c = 0
    number_of_stops = 0
    faulty = 0
    line = annotation_file.readline()
    current = None
    while line:
        line = annotation_file.readline()[:-1]
        content = line.split("\t")
        if line.startswith("##sequence-region"):
            print(current)
            if current is not None:
                correct = current.compile_genes()
                if not correct:
                    faulty += 1
                print(correct)
                print("Genes:\n\t" + "\n\t".join([str(g) for g in current.genes]))
                curated_annotation[current.name] = current
            print("-  " * 36)
            
            current = Annotation(content[1], int(content[2]), int(content[3]))
            c += 1
        elif len(line) > 0 and not line.startswith("#"):
            f = Feature(content[2], int(content[3]), int(content[4]), content[6] == "+")
            if f.feature == "three_prime_UTR":
                utr2stop_codon(f)
            
            if f.feature == "stop_codon":
                number_of_stops += 1
            current.add(f)

print(current.compile_genes())
print("Genes:\n\t" + "\n\t".join([str(g) for g in current.genes]))
curated_annotation[current.name] = current
print(c, "=", len(curated_annotation), "of which", faulty, "were faulty")
print("\nThere are %s stop-codons (and thus genes) annotated" % number_of_stops)

### The replacement matrix
To quantify the quality of the prediction, collect a replacement matrix: its rows correspond to what ought to be (what is found in the 'truth'), while its columns capture what is (found in the prediction-file). 

The row-labels are:
 - NCS (can be computed from the others + knowledge of length)
 - CDS (frame information? -> could split this into: CDS in correct frame and CDS in wrong frame)
 - intron
 - stop_codon

In [None]:
predicted_annotation = {}
report = ""
report += prediction_filename + "\n\n"

with open(prediction_filename) as prediction:
    current_head = None
    current = None
    for r, line in enumerate(prediction):
        if line.startswith("NODE"):
            contents = line.split("\t")
            start = int(contents[3])
            end = int(contents[4])
            #if True or contents[0] in curated_annotation and curated_annotation[contents[0]].overlaps(start,end):
            if contents[0] != current_head:
                if current_head is not None:
                    correct = current.compile_genes()
                    print(current_head, "is", "valid" if correct else "invalid")
                    print(current)
                    predicted_annotation[current_head] = current
                current_head = contents[0]
                current = Annotation(current_head, 0, int(current_head.split("_")[3]))
            
            if contents[2] in ["gene", "transcript", "start_codon"]:
                continue
            current.add(Feature(contents[2], start, end, contents[6] == "+"))

    predicted_annotation[current_head] = current # the last one
    print(current)

In [None]:
what_can_be_compared = set(curated_annotation.keys()).intersection(set(predicted_annotation.keys()))
# print("\n".join(what_can_be_compared)) 

# Create the byte-arrays, and fill in annotation ([0,:])

In [None]:
encoded = {}
flanking = 300
for node in curated_annotation:
    encoded[node] = np.zeros((2, int(node.split("_")[3])))
    encoded[node][0,:] = -0.5
    length = curated_annotation[node].end
    # first run: mark the NCS-areas around the features
    for feature in curated_annotation[node].features:
        encoded[node][0, feature.start-1-flanking:feature.end+flanking] = feature2index["NCS"]  
    
    for feature in curated_annotation[node].features:
        encoded[node][0, feature.start-1:feature.end] = feature2index[feature.feature]  

In [None]:
for node in predicted_annotation:
    if node in encoded:
        length = predicted_annotation[node].end
        for feature in predicted_annotation[node].features:
            encoded[node][1, feature.start-1:feature.end] = feature2index[feature.feature]  

## Compute Sensitivity and specificity (and stanke) from bytes
(this one is just much slower than the other)

In [None]:
replacement_matrix = np.zeros((4,4))
confusion_matrix = np.zeros((2,2)) # [Y-N, p-n]
FP = 0
TP = 0
FN = 0
TN = 0
"""
matrix format: row=prediction, col=truth
   p  n
Y |TP|FP|
N |FN|TN
"""
total_cds = 0
for node in what_can_be_compared:
    for i in range(predicted_annotation[node].end):
        # the prediction
        p = int(encoded[node][1,i])
        if p == 1:
            total_cds += 1

        # the truth
        t = encoded[node][0,i]

        # no truth known here
        if t < 0:
            continue

        t = int(t)
        replacement_matrix[t,p] += 1
        if p == 1:
            if t == 1:
                TP += 1
            else:
                FP += 1 
        else:
            if t==1:
                FN += 1 
            else:
                TN += 1 
        
confusion_matrix[0,0] = TP
confusion_matrix[0,1] = FP
confusion_matrix[1,0] = FN
confusion_matrix[1,1] = TN
print(replacement_matrix) # this should be the same across both
print(confusion_matrix)

In [None]:
report += "sensitivity = %s\nspecificity = %s\nstanke/precision = %s\n" % (TP / (TP + FN), TN / (TN + FP), TP / (TP + FP))
print("sensitivity:", TP / (TP + FN))
print("specificity:", TN / (TN + FP)) 
print("stanke:     ", TP / (TP + FP))

## Variant 2: Compute replacement matrix from annotations

This requires that both the annotation and the prediction are sorted: Reads through both only once from left to right. This assumption is true for Augustus-predictions, my predictions, and my annotation

Then, for every node and every gene annotated on that node, the prediction is filtered to the range of 300nt around the annotation's edges (then put into `relevant_prediction`)

For every gene in the annotation, two indices are held:
 - `pred_index`: against which predicted feature is the current feature being compared?
 - `g_index`: which of the gene's features is currently being compared?
 
Both start with the first feature (index 0). Then, all nucleotide-positions `i` are considered, and as soon as `i` reaches the end of the current predicted feature, `pred_index` is increased (if possible). **Note** that since GFF-indices are shifted by 1 and end-inclusive, this is done at the end of the loop, not the start (draw it for yourself)

In [None]:
replacement_matrix = np.zeros((4,4))
total_cds = 0

for node in what_can_be_compared: 
    truth = curated_annotation[node]
    prediction = predicted_annotation[node]
    relevant_prediction = [f for f in prediction.features if truth.overlaps(f.start, f.end)]
    # print(node, [g.get_range() for g in truth.genes])
    # print(["%s%s %s %s%s" % ("[" if f.strand else "<", 
    #                          f.start, f.feature, f.end,
    #                          ">" if f.strand else "]") for f in relevant_prediction])
    for gene in truth.genes:
        g_start, g_end = gene.get_range()
        pred_index = 0
        g_index = 0
        predicted_feature = None
        annotated_feature = feature2index["NCS"]
        for i in range(g_start - 300, g_end + 301):
            # if i is inside current predicted feature
            # print(node, pred_index, "/", len(relevant_prediction))
            if pred_index < len(relevant_prediction) and contains(relevant_prediction[pred_index].start, 
                                                                  relevant_prediction[pred_index].end, 
                                                                  i):
                predicted_feature = feature2index[relevant_prediction[pred_index].feature]
            else:
                predicted_feature = feature2index["NCS"]

            if contains(gene.features[g_index].start, gene.features[g_index].end, i):
                annotated_feature = feature2index[gene.features[g_index].feature]
            else:
                annotated_feature = feature2index["NCS"]

            replacement_matrix[annotated_feature, predicted_feature] += 1

            # Keep g_index and pred_index up to date: if have reached end of feature, move to next
            if pred_index + 1 < len(relevant_prediction) and relevant_prediction[pred_index].end <= i:
                pred_index += 1
            if gene.features[g_index].end <= i and g_index + 1 < len(gene.features):
                g_index += 1            

    
print(replacement_matrix)

In [None]:
spacer = " & "
print(prediction_filename)

whither_matrix = replacement_matrix / np.tile(np.sum(replacement_matrix, axis=1), 4).reshape(4,4).T
report += "\nwhither" + spacer + spacer.join(feature_names) + "\\\\\\hline\n"
for row in range(4):
    report += feature_names[row] + spacer
    report += spacer.join([str(np.round(e, 5)) for e in whither_matrix[row]])
    report += "\\\\\n"
        
report += "\n"
whence_matrix = replacement_matrix / np.tile(np.sum(replacement_matrix, axis=0), 4).reshape(4,4)
report += "whence" + spacer + spacer.join(feature_names) + "\\\\\\hline\n"
for row in range(4):
    report += feature_names[row] + spacer
    report += spacer.join([str(np.round(e, 5)) for e in whence_matrix[row]])
    report += "\\\\\n"
    
cds = feature2index["CDS"]
sensitivity = replacement_matrix[cds, cds] / np.sum(replacement_matrix[cds, :])
stanke_specificity = replacement_matrix[cds, cds] / np.sum(replacement_matrix[:, cds]) # AUGUSTUS: /total_cds, requires other computation

print("\nsensitivity = %s\nstanke-specificity = %s" % (sensitivity, stanke_specificity))

# Feature-level sensitivity and specificity

**p** is the number of *CDS* in the annotated regions
**TP** is the number of *CDS* perfectly predicted

Since **n** does not exist in a meaningful/unique way, cannot assess specificity. Then instead opt for precision:

**Y** is the number of *CDS* predicted in the annotated regions (overlapping with it)

Then have:
*sensitivity* = **TP** / **p**
*precision* = **TP** / **Y**

In [None]:
CDS_TP = 0
CDS_p = 0
CDS_Y = 0

intron_TP = 0
intron_p = 0
intron_Y = 0

for node in what_can_be_compared:
    truth = curated_annotation[node]
    prediction = predicted_annotation[node]
    relevant_prediction = [f for f in prediction.features if truth.overlaps(f.start, f.end)]
    # print(node, [g.get_range() for g in truth.genes])
    # print(["%s%s %s %s%s" % ("[" if f.strand else "<", 
    #                          f.start, f.feature, f.end,
    #                          ">" if f.strand else "]") for f in relevant_prediction])
    for feature in truth.features:
        if feature.feature == "CDS":
            CDS_p += 1
        elif feature.feature == "intron":
            intron_p += 1
    
    for feature in relevant_prediction:
        if feature.feature == "CDS":
            CDS_Y += 1
            if len([tf for tf in truth.features 
                    if tf.feature == "CDS" and tf.strand == feature.strand\
                    and tf.start == feature.start and tf.end == feature.end]) > 0:
                print("Exact match on %s: %s %s%s%s" % (node, feature.feature, feature.start, 
                                                        "+" if feature.strand else "-",
                                                        feature.end))
                CDS_TP += 1
                
        elif feature.feature == "intron":
            intron_Y += 1
            if len([tf for tf in truth.features 
                    if tf.feature == "intron" and tf.strand == feature.strand\
                    and tf.start == feature.start and tf.end == feature.end]) > 0:
                print("Exact match on %s: %s %s%s%s" % (node, feature.feature, feature.start, 
                                                        "+" if feature.strand else "-",
                                                        feature.end))
                intron_TP += 1
        
    
print(CDS_TP, CDS_p, CDS_Y, intron_TP, intron_p, intron_Y)

In [None]:
report += "\nExon-level:\nsensitivity = %s\nprecision = %s" % (CDS_TP/CDS_p, CDS_TP/CDS_Y)
report += "\nIntron-level:\nsensitivity = %s\nprecision = %s" % (intron_TP/intron_p, intron_TP/intron_Y)

print(report)
with open(prediction_filename[:prediction_filename.index(".")] + ".report.txt", "w") as outfile:
    outfile.write(report)