# Benchmarking notebook

This notebook contains all of the methods and analysis for performing benchmarking

# Imports and dependencies

In [None]:
import os
import glob
from pathlib import Path
from enum import Enum
import numpy as np
from dataclasses import dataclass
from typing import List, Dict
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import datetime

# Notebook parameters
These parameters are used throughout the notebook for benchmarking. These parameters include paths, tuning, and other various parameter settings.

In [None]:
FIGURE_DIR = Path("/Users/dreyceyalbin/Desktop/Phage-Enrich")
if not os.path.isdir(FIGURE_DIR): os.mkdir(FIGURE_DIR)
    
file_with_names = "./abundance.tsv"
accession2taxid_file = "../../database/krakenDB/taxonomy/nucl_gb.accession2taxid"

# Parsers
These methods are used for parsing all of the tools being benchmarked. These parsers are passed to a general parser function, giving a strategy-pattern-like method for obtaining result from tools.

In [13]:
def get_needed_ncbi_ids():
    """
    gets NCBI ids
    """
    file_with_names_opened = open(file_with_names)
    
    # grab all needed NCBI ids
    ncbi_ids = []
    line = file_with_names_opened.readline()
    line_counter = 0
    while(line):
        if (line_counter > 0): # skip header
            ncbi_id = line.split("\t")[0]
            ncbi_ids.append(ncbi_id)
        
        line = file_with_names_opened.readline()
        line_counter += 1
        
    file_with_names_opened.close()
    
    return ncbi_ids

ncbi_ids = get_needed_ncbi_ids()
print(len(ncbi_ids))

8957


In [None]:
def create_mapping_dictionary():
    """
    Creates dictionary for mapping NCBI ids to 
    taxonomy.
    
    Uses nucl_gb.accession2taxid
    
    NC_033618	NC_033618.1	1931113	1139918407
    """
    dictionary = {}

    # grab taxid mappings
    acc2tax_open = open(accession2taxid_file)
    line = acc2tax_open.readline()

    while(line):
        line = acc2tax_open.readline()
        line = line.split("\t")
        accession = line[1]
        taxid = line[2]
        if accession in ncbi_ids:
            dictionary[accession] = taxid
    acc2tax_open.close()

    return dictionary


accession2taxid = create_mapping_dictionary()

print(accession2taxid)

In [4]:
def parser_enrichseq(debExt_dir, extension="ss3", non_test=False, truth_csv=None):
    """
    This function parses the debext output directory.
    OUTPUT:
        dict = {}
    """
    prediction_outcomes = {} # taxid : abundance
    if truth_csv != None: truth_dict = parse_truth(truth_csv)
    for outputfile in glob.glob(debExt_dir+'/*'+extension): #os.listdir(debExt_dir): 
        with open(Path(outputfile)) as resultsfile:
            resultlines = resultsfile.readlines()
            outputfile = Path(outputfile)
            for line_index, line in enumerate(resultlines):
                if line[0] == ">":
                    if non_test:
                        identifier = outputfile.name.split(".")[0]
                    else:
                        identifier = line.strip(">").strip("\n") # uses file name.
                    if truth_csv != None:
                        prediction_outcomes[identifier] = truth_dict[identifier]
                    else:
                        prediction_outcomes[identifier] = resultlines[line_index+2].strip("\n")
    return prediction_outcomes

def parser_fastviromeexplorer(directory, extension="ss3", non_test=False):
    """
    This function parses the debext output directory.
    OUTPUT:
        dict = {}
    """
    prediction_outcomes = {} # taxid : abundance

    return prediction_outcomes

def ncbi2taxid(ncbi, dictionary):
    """
    
    uses nucl_gb.accession2taxid
    dictionary(NC_033618.1) -> 
    1. get the names of each NCBI
    """
    return dictionary[ncbi]

# Results Methods and structures
The methods and data structures here are used for creating a common data structure for the output of all tools being compared with one another. 

In [6]:
class ResultStruct:
    """ datastruct for holding results """
    
    def __init__(self):
        self.taxid2abundance : Dict = {}
        self.taxid2length : Dict = {}
        self.conf_matrix = np.zeros((2,2))
        self.precision = {}
        self.recall = {}
        #self.time : Dict[str, Dict[str,float]] = {}

In [7]:
def parse_dir(directory, truth_csv, parser_function):
    """
    This function parses all files within an output
    folder for DebruijnExtend.

    OUTPUT:                                                                     
        results = { "accuracy" : [[.8, .4, ..., 0.7], # percent guessed correctly  
                    "length" : [254, 223, ..., 30], # length per                
                    "conf_matrix" : [[],[],[]] # confusion matrix}
    """
    #results_dict = {test_iteration : [[], [], np.zeros((3,3))] }
    results: ResultStruct = ResultStruct()
    seq_dict = parser_function(directory) ### CHANGE
    truth_dict = parse_truth(truth_csv)
    results.time = get_times(directory)
    for identifier, prediction in seq_dict.items():
        true_ss3 = truth_dict[identifier]
        accuracy, confusion_matrix = prediction_rank(prediction, true_ss3)
        # append results to dictinoary
        results.accuracy.append(accuracy)
        results.length.append(len(true_ss3))
        results.pdb2length[identifier] = len(true_ss3)
        results.conf_matrix += confusion_matrix
        results.pdb_names.append(identifier)
    return results
    
def parse_truth(truth_csv, parsing_index=3):
    """
    parses the truth CSV to get the actual ss3
    """
    truth_dict = {}
    with open(truth_csv) as file_by_lines:
        file_lines = file_by_lines.readlines()
        for index, line in enumerate(file_lines):
            identifier = line.split(",")[parsing_index-2].strip("\n").strip("'")
            truth_dict[identifier] = line.split(",")[parsing_index].strip("\n").strip("'") #TODO: ASSUMES SAME ORDER AS FASTA!!
    return truth_dict

# Plotting Methods
These methods use the common `ResultStruct` datastructure (or arry thereof) for plotting individual tool metrics. 

In [8]:
def plot_confusion_matrix(input_results: ResultStruct, 
                          fig_path=None, 
                          title: str=None):                                                        
    """                                                                         
    This function plot the confusion matrix 
    """ 
    total_confusion = input_results.conf_matrix
    for row_ind, row in enumerate(total_confusion):
        total_confusion[row_ind] /= np.sum(total_confusion[row_ind])
    print(total_confusion)
    # turn array into pandas dataframe
    total_confusion = pd.DataFrame(total_confusion, 
                                   columns=['C', 'E', 'H'], 
                                   index=['C', 'E', 'H'])
    ## plot confusion matrix
    ax = sns.heatmap(total_confusion, annot=True, linewidths=.5, cmap='binary')
    if title: plt.title(title)
    if fig_path:
        plt.savefig(fig_path, transparent=True, dpi=300, bbox_inches='tight')
    plt.show()

# Benchmarking
This section splits the various benchmarking sections into a clear, concise set of scripts.

### Number of genomes comparison

### Number of read mutations

### Number of reads