In [1]:
from ete3 import NCBITaxa
import pandas as pd

In [2]:
# Initialize NCBITaxa object (only needs to be initialized once)
ncbi = NCBITaxa()

def get_genus_and_family(species_name):
    """
    Get the genus and family names for a given species using NCBI taxonomy.
    
    Args:
        species_name (str): The name of the species (e.g., "Salmonella enterica").
    
    Returns:
        tuple: A tuple containing the genus and family names (genus, family).
               If the species is not found, it returns (None, None).
    """
    try:
        # Get the taxid for the species
        taxid = ncbi.get_name_translator([species_name])
        
        if not taxid:
            #print(f"Species '{species_name}' not found in NCBI taxonomy database.")
            return None, None
        
        species_taxid = list(taxid.values())[0][0]  # Get the first TaxID
        
        # Get the lineage for the taxid
        lineage = ncbi.get_lineage(species_taxid)
        
        # Get the taxonomic ranks and names for the lineage
        lineage_ranks = ncbi.get_rank(lineage)
        lineage_names = ncbi.get_taxid_translator(lineage)
        
        # Extract genus and family
        genus = None
        family = None
        for taxid in lineage:
            if lineage_ranks[taxid] == "genus":
                genus = lineage_names[taxid]
            if lineage_ranks[taxid] == "family":
                family = lineage_names[taxid]
        
        return genus, family
    
    except Exception as e:
        #print(f"An error occurred: {e}")
        return None, None

In [4]:
df = pd.read_csv('ProkaryoticVirus_Species_Level_Host.csv')
df.head()

Unnamed: 0,Accession,Host
0,NC_092682.1,Salmonella enterica
1,NC_092683.1,Salmonella enterica
2,NC_092455.1,Klebsiella michiganensis
3,NC_092456.1,Klebsiella quasipneumoniae
4,NC_092457.1,Klebsiella variicola


In [5]:
label = {acc: host for acc, host in zip(df.Accession, df.Host)}

# CHERRY results

In [None]:
cherry = pd.read_csv('Results/cherry_ncbi_results.tsv', sep='\t')
cherry.head()

Unnamed: 0,Accession,Length,Host,CHERRYScore,Method,Host_NCBI_lineage,Host_GTDB_lineage
0,NC_092683.1,113461,species:Pectobacterium parmentieri,0.9,CRISPR-based (DB),d__Bacteria;p__Proteobacteria;c__Gammaproteoba...,d__Bacteria;p__Pseudomonadota;c__Gammaproteoba...
1,NC_092682.1,111322,-,-,,-,-
2,NC_092681.1,121769,species:Acinetobacter sp. WCHA45,0.91,CRISPR-based (DB),d__Bacteria;p__Proteobacteria;c__Gammaproteoba...,d__Bacteria;p__Pseudomonadota;c__Gammaproteoba...
3,NC_092586.1,37696,-,-,,-,-
4,NC_092457.1,40735,genus:Klebsiella,0.81,AAI-based,d__Bacteria;p__Proteobacteria;c__Gammaproteoba...,Not found


In [7]:
cherry_prediction = {acc: host for acc, host in zip(cherry.Accession, cherry.Host)}
cherry_prediction_species = {}
cherry_prediction_genus = {}
for acc, host in cherry_prediction.items():
    if 'species:' in host:
        cherry_prediction_species[acc] = host.split('species:')[-1]
    if 'genus:' in host:
        cherry_prediction_genus[acc] = host.split('genus:')[-1]

In [14]:
total = 0
correct = 0
for acc in label:
    try:
        pred = cherry_prediction[acc]
        if 'species' not in pred:
            continue
        elif cherry_prediction_species[acc] == label[acc]:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Species-level accuracy:", correct / total)
print("Species-level prediction rate:", total / len(label))

Species-level accuracy: 0.7724391057724391
Species-level prediction rate: 0.6379310344827587


In [16]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        if acc in cherry_prediction_species:
            pred_g, pred_f = get_genus_and_family(cherry_prediction_species[acc])
            if pred_g == genus:
                correct += 1
            total+=1
        elif acc in cherry_prediction_genus:
            if cherry_prediction_genus[acc] == genus:
                correct += 1
            total+=1
        else:
            total+=1
    except:
        pass

print("Genus-level accuracy:", correct / total)
print("Genus-level prediction rate:", total / len(label))

Genus-level accuracy: 0.8065134099616859
Genus-level prediction rate: 1.0


In [17]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        if acc in cherry_prediction_species:
            pred_g, pred_f = get_genus_and_family(cherry_prediction_species[acc])
            if pred_f == family:
                correct += 1
            total+=1
        elif acc in cherry_prediction_genus:
            pred_g, pred_f = get_genus_and_family(cherry_prediction_genus[acc])
            if pred_f == family:
                correct += 1
            total+=1
        else:
            total+=1
    except:
        pass
print("Family-level accuracy:", correct / total)
print("Family-level prediction rate:", total / len(label))

Family-level accuracy: 0.8716475095785441
Family-level prediction rate: 1.0


# iPHoP results

In [18]:
iphop = pd.read_csv('Results/iphop_ncbi_results.csv')
iphop['Host'] = iphop['Host taxonomy'].apply(lambda x: x.split(';s__')[1].split(';')[0])
iphop.head()

Unnamed: 0,Virus,Host genome,Host taxonomy,Main method,Confidence score,Additional methods,Host
0,NC_048101.1,RS_GCF_900323545.1,d__Bacteria;p__Actinomycetota;c__Actinomycetia...,CRISPR,95.6,,Actinomyces sp900323545
1,NC_048104.1,RS_GCF_000425025.1,d__Bacteria;p__Bacillota;c__Bacilli;o__Lactoba...,iPHoP-RF,93.7,,Streptococcus castoreus
2,NC_048107.1,RS_GCF_002902085.1,d__Bacteria;p__Bacillota;c__Bacilli;o__Staphyl...,iPHoP-RF,96.7,,Staphylococcus simiae
3,NC_048109.1,RS_GCF_000017205.1,d__Bacteria;p__Pseudomonadota;c__Gammaproteoba...,blast,96.9,iPHoP-RF;82.20,Pseudomonas aeruginosa_A
4,NC_048110.1,GB_GCA_009929605.1,d__Bacteria;p__Bacillota_C;c__Negativicutes;o_...,blast,93.4,,Veillonella sp009929605


In [22]:
iphop_prediction = {acc: host for acc, host in zip(iphop.Virus, iphop.Host)}

In [23]:
total = 0
correct = 0
for acc in label:
    try:
        pred = iphop_prediction[acc]
        if pred == '':
            continue
        elif pred == label[acc]:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Species-level accuracy:", correct / total)
print("Species-level prediction rate:", total / len(label))

Species-level accuracy: 0.16418992388546574
Species-level prediction rate: 0.5872711792252022


In [None]:
total = 0
correct = 0
for acc in label:
    try:
        pred = iphop_prediction[acc]
        if pred == '':
            continue
        genus, family = get_genus_and_family(label[acc])
        pred_g, pred_f = get_genus_and_family(pred)
        if genus == pred_g:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Genus-level accuracy:", correct / total)
print("Genus-level prediction rate:", total / len(label))


Genus-level accuracy: 0.47843421529539687
Genus-level prediction rate: 0.5872711792252022


In [28]:
total = 0
correct = 0
for acc in label:
    try:
        pred = iphop_prediction[acc]
        if pred == '':
            continue
        genus, family = get_genus_and_family(label[acc])
        pred_g, pred_f = get_genus_and_family(pred)
        if pred_f == None:
            total += 1
        elif pred_f == family:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Family-level accuracy:", correct / total)
print("Family-level prediction rate:", total / len(label))

Family-level accuracy: 0.5748459586806814
Family-level prediction rate: 0.5872711792252022


# PHERI results

In [29]:
pheri = pd.read_csv('Results/pheri_ncbi_results.csv')
pheri = pheri.fillna('')
pheri.head()

Unnamed: 0,phage,host,score
0,NC_021774.1,,
1,NC_028860,Mycolicibacterium,1.0
2,NC_047794,Streptomyces,1.0
3,NC_073092,Pseudomonas,1.0
4,NC_055816,Gordonia,1.0


In [30]:
pheri_prediction = {acc: host for acc, host in zip(pheri.phage, pheri.host)}

In [31]:
total = 0
correct = 0
for acc in label:
    try:
        pred = pheri_prediction[acc.split('.')[0]]
        genus, family = get_genus_and_family(label[acc])
        #if str(pred) == str("nan"):
        #    continue
        pred_g, pred_f = get_genus_and_family(pred)
        if pred_g == genus:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Genus-level accuracy:", correct / total)
print("Genus-level prediction rate:", total / len(label))


Genus-level accuracy: 0.814430314041007
Genus-level prediction rate: 0.8201362281822052


In [32]:
total = 0
correct = 0
for acc in label:
    try:
        pred = pheri_prediction[acc.split('.')[0]]
        genus, family = get_genus_and_family(label[acc])
        pred_g, pred_f = get_genus_and_family(pred)
        if pred_f == family:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Family-level accuracy:", correct / total)
print("Family-level prediction rate:", total / len(label))

Family-level accuracy: 0.8647806903711394
Family-level prediction rate: 0.8201362281822052


# VHM_net results

In [35]:
vhmnet = pd.read_csv('Results/vhmnet_ncbi_results.csv')
vhmnet_prediction = {acc: host for acc, host in zip(vhmnet.Accession, vhmnet.hostSpecies)}
vhmnet.head()

Unnamed: 0,hostNCBIName,hostSuperkingdom,hostPhylum,hostClass,hostOrder,hostFamily,hostGenus,hostSpecies,hostName,score,...,s2star_pct,posSV_pct,negSV_pct,crispr_pct,acc_phylum,acc_class,acc_order,acc_family,acc_genus,Accession
0,GCF_000010465.1,Bacteria,Firmicutes,Bacilli,Bacillales,Staphylococcaceae,Staphylococcus,Staphylococcus aureus,Staphylococcus aureus subsp. aureus str. Newman,0.9991,...,0.9999,1.0,0.9998,,0.9956,0.9956,0.9956,0.9823,0.9646,NC_024391.1
1,GCF_000267765.2,Bacteria,Proteobacteria,Gammaproteobacteria,Enterobacterales,Enterobacteriaceae,Escherichia,Escherichia coli,Escherichia coli PA28,0.9586,...,0.9968,0.9999,0.9996,,0.9735,0.9735,0.943,0.9084,0.8473,NC_027351.1
2,GCF_001281005.1,Bacteria,Proteobacteria,Gammaproteobacteria,Enterobacterales,Enterobacteriaceae,Citrobacter,Citrobacter portucalensis,Citrobacter portucalensis,0.9654,...,0.9619,1.0,0.9999,,0.9731,0.9731,0.9441,0.9089,0.8468,NC_031092.1
3,GCF_000709475.1,Bacteria,Firmicutes,Bacilli,Bacillales,Staphylococcaceae,Staphylococcus,Staphylococcus aureus,Staphylococcus aureus,0.9159,...,0.9734,1.0,0.9996,,0.9671,0.9671,0.9362,0.8956,0.8162,NC_022764.1
4,GCF_001469215.1,Bacteria,Proteobacteria,Gammaproteobacteria,Alteromonadales,Pseudoalteromonadaceae,Pseudoalteromonas,Pseudoalteromonas sp. H105,Pseudoalteromonas sp. H105,0.9957,...,0.9964,1.0,0.9999,,0.9854,0.9854,0.9592,0.9388,0.895,NC_070826.1


In [38]:
total = 0
correct = 0
for acc in label:
    try:
        pred = vhmnet_prediction[acc]
        if pred == label[acc]:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Species-level accuracy:", correct / total)
print("Species-level prediction rate:", total / len(label))

Species-level accuracy: 0.3673903788846318
Species-level prediction rate: 1.0


In [39]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        pred = vhmnet_prediction[acc]
        pred_g, pred_f = get_genus_and_family(pred)
        if pred_g == genus:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Genus-level accuracy:", correct / total)
print("Genus-level prediction rate:", total / len(label))

Genus-level accuracy: 0.4821200510855683
Genus-level prediction rate: 1.0


In [40]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        pred = vhmnet_prediction[acc]
        pred_g, pred_f = get_genus_and_family(pred)
        if family == pred_f:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Family-level accuracy:", correct / total)
print("Family-level prediction rate:", total / len(label))

Family-level accuracy: 0.6055768412090251
Family-level prediction rate: 1.0


# vHULK results

In [42]:
vhulk = pd.read_csv('Results/vhulk_ncbi_results.csv')
vhulk = vhulk.fillna('')
vhulk_prediction = {acc.split('.csv')[0]: host for acc, host in zip(vhulk.Accession, vhulk.pred_species)}
vhulk.head()

Unnamed: 0,pred_genus,score_genus,entropy_genus,pred_species,score_species,entropy_species,Accession
0,Salmonella,0.989508,0.199786,Salmonella_enterica,0.99975,1.68254,NC_048785.1.csv
1,Mycobacterium,0.99982,0.005806,Mycobacterium_smegmatis,0.999998,5.399976,NC_021061.1.csv
2,Streptococcus,0.99982,0.005261,Streptococcus_thermophilus,1.0,5.553392,NC_070700.1.csv
3,Lactococcus,1.0,2e-06,Lactococcus_lactis,1.0,5.92404,NC_049487.1.csv
4,,0.0,0.0,,0.0,0.0,NC_077159.1.csv


In [43]:
total = 0
correct = 0
for acc in label:
    try:
        pred = vhulk_prediction[acc]
        if pred == '':
            continue
        elif pred.replace('_', ' ') == label[acc]:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Species-level accuracy:", correct / total)
print("Species-level prediction rate:", total / len(label))

Species-level accuracy: 0.5818436733263231
Species-level prediction rate: 0.60727969348659


In [44]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        pred = vhulk_prediction[acc]
        if pred == '':
            continue
        pred_g, pred_f = get_genus_and_family(pred.replace('_', ' '))
        if pred_g == genus:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Genus-level accuracy:", correct / total)
print("Genus-level prediction rate:", total / len(label))

Genus-level accuracy: 0.7206449351559762
Genus-level prediction rate: 0.60727969348659


In [45]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        pred = vhulk_prediction[acc]
        if pred == '':
            continue
        pred_g, pred_f = get_genus_and_family(pred.replace('_', ' '))
        if pred_f == family:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Family-level accuracy:", correct / total)
print("Family-level prediction rate:", total / len(label))

Family-level accuracy: 0.7490361023484052
Family-level prediction rate: 0.60727969348659


# DeepHost results

In [47]:
deep = pd.read_csv('Results/DeepHost_ncbi_results.tsv', sep='\t', header=None)
deep_prediction = {acc.split(' |')[0]: host for acc, host in zip(deep[0], deep[1])}
deep.head()

Unnamed: 0,0,1
0,"NC_092681.1 |Salmonella phage BD13, complete g...",Salmonella enterica
1,"NC_092682.1 |Salmonella phage beppo, complete ...",Salmonella enterica
2,"NC_092683.1 |Salmonella phage bobsandoy, compl...",Salmonella enterica
3,NC_092586.1 |Caudoviricetes sp. vir215 genomic...,Pseudomonas aeruginosa
4,"NC_092455.1 |Klebsiella phage Oda, complete ge...",Klebsiella pneumoniae


In [50]:
total = 0
correct = 0
for acc in label:
    try:
        pred = deep_prediction[acc]
        if pred.replace('_', ' ') == label[acc]:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Species-level accuracy:", correct / total)
print("Species-level prediction rate:", total / len(label))

Species-level accuracy: 0.5951468710089399
Species-level prediction rate: 1.0


In [51]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        pred = deep_prediction[acc]
        pred_g, pred_f = get_genus_and_family(pred.replace('_', ' '))
        if pred_g == genus:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Genus-level accuracy:", correct / total)
print("Genus-level prediction rate:", total / len(label))

Genus-level accuracy: 0.6758194976585781
Genus-level prediction rate: 1.0


In [52]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        pred = deep_prediction[acc]
        pred_g, pred_f = get_genus_and_family(pred.replace('_', ' '))
        if pred_f == family:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Family-level accuracy:", correct / total)
print("Family-level prediction rate:", total / len(label))

Family-level accuracy: 0.7188165176670924
Family-level prediction rate: 1.0


# RaFAH results

In [54]:
rafah = pd.read_csv('Results/RaFAH_ncbi_results.tsv', sep='\t')
rafah_prediction = {acc: host for acc, host in zip(rafah['Variable'], rafah['Predicted_Host'])}
rafah.head()

Unnamed: 0,Variable,CDS_Count,Description,Length,Original_File,Predicted_Host,Predicted_Host_Score
0,NC_000866.4,265,"|Enterobacteria phage T4, complete genome",168903,NC_000866.4.fa,Escherichia,0.805
1,NC_000867.1,19,"|Pseudoalteromonas phage PM2, complete genome",10079,NC_000867.1.fa,Pseudoalteromonas,0.876333
2,NC_000871.1,46,"|Streptococcus phage Sfi19, complete genome",37370,NC_000871.1.fa,Streptococcus,0.987
3,NC_000872.1,54,"|Streptococcus phage Sfi21, complete genome",40739,NC_000872.1.fa,Streptococcus,0.989
4,NC_000896.1,66,"|Lactobacillus prophage phiadh, complete genome",43785,NC_000896.1.fa,Lactobacillus,0.981


In [55]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        pred = rafah_prediction[acc]
        pred_g, pred_f = get_genus_and_family(pred.replace('_', ' '))
        if pred_g == genus:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Genus-level accuracy:", correct / total)
print("Genus-level prediction rate:", total / len(label))

Genus-level accuracy: 0.8391256830601093
Genus-level prediction rate: 0.9738186462324393


In [56]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        pred = rafah_prediction[acc]
        pred_g, pred_f = get_genus_and_family(pred.replace('_', ' '))
        if pred_f == family:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Family-level accuracy:", correct / total)
print("Family-level prediction rate:", total / len(label))

Family-level accuracy: 0.9127868852459017
Family-level prediction rate: 0.9738186462324393


# PHP results

In [59]:
php = pd.read_csv('Results/PHP_ncbi_results.csv')
php_prediction = {acc.split('.fa')[0]: host for acc, host in zip(php['queryVirus'], php['species'])}
php.head()

Unnamed: 0,queryVirus,score,_maxScoreHost,species
0,NC_021774.1.fa,1448.805684,1560201,Erwinia iniecta
1,NC_028860.1.fa,1454.858953,1490483,Mycobacterium sp. SWH-M5
2,NC_047794.1.fa,1365.426114,2218666,Streptomyces sp. AC1-42W
3,NC_073092.1.fa,1430.137377,1298875,Nocardioides sp. JGI 0001009-J09
4,NC_055816.1.fa,1455.347689,2055,Gordonia terrae


In [60]:
total = 0
correct = 0
for acc in label:
    try:
        pred = php_prediction[acc]
        if pred == '':
            continue
        elif pred.replace('_', ' ') == label[acc]:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Species-level accuracy:", correct / total)
print("Species-level prediction rate:", total / len(label))

Species-level accuracy: 0.19625372498935717
Species-level prediction rate: 1.0


In [61]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        pred = php_prediction[acc]
        if pred == '':
            continue
        pred_g, pred_f = get_genus_and_family(pred.replace('_', ' '))
        if pred_g == genus:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Genus-level accuracy:", correct / total)
print("Genus-level prediction rate:", total / len(label))

Genus-level accuracy: 0.40911025968497233
Genus-level prediction rate: 1.0


In [62]:
total = 0
correct = 0
for acc in label:
    try:
        genus, family = get_genus_and_family(label[acc])
        pred = php_prediction[acc]
        if pred == '':
            continue
        pred_g, pred_f = get_genus_and_family(pred.replace('_', ' '))
        if pred_f == family:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass
print("Family-level accuracy:", correct / total)
print("Family-level prediction rate:", total / len(label))

Family-level accuracy: 0.6198382290336313
Family-level prediction rate: 1.0


# Combine results

In [63]:
union = list(set(deep_prediction) | set(cherry_prediction_species) | set(vhmnet_prediction))

## Union + Majority vote

In [64]:
# generate a majority vote for union prediction
union_prediction = {}
for item in union:
    if item in deep_prediction and item in cherry_prediction_species and item in vhmnet_prediction:
        votes = [deep_prediction[item], cherry_prediction_species[item], vhmnet_prediction[item]]
        union_prediction[item] = max(set(votes), key=votes.count)
    elif item in deep_prediction and item in cherry_prediction_species:
        votes = [deep_prediction[item], cherry_prediction_species[item]]
        union_prediction[item] = max(set(votes), key=votes.count)
    elif item in deep_prediction and item in vhmnet_prediction:
        votes = [deep_prediction[item], vhmnet_prediction[item]]
        union_prediction[item] = max(set(votes), key=votes.count)
    elif item in cherry_prediction_species and item in vhmnet_prediction:
        votes = [cherry_prediction_species[item], vhmnet_prediction[item]]
        union_prediction[item] = max(set(votes), key=votes.count)
    elif item in deep_prediction:
        votes = [deep_prediction[item]]
        union_prediction[item] = max(set(votes), key=votes.count)
    elif item in cherry_prediction_species:
        votes = [cherry_prediction_species[item]]
        union_prediction[item] = max(set(votes), key=votes.count)
    elif item in vhmnet_prediction:
        votes = [vhmnet_prediction[item]]
        union_prediction[item] = max(set(votes), key=votes.count)


In [66]:
total = 0
correct = 0
for acc in label:
    try:
        pred = union_prediction[acc]
        if pred == '':
            continue
        elif pred.replace('_', ' ') == label[acc]:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass

print("Union + majority vote accuracy:", correct / total)
print("Union + majority vote prediction rate:", total / len(label))

Union + majority vote accuracy: 0.5796083439761601
Union + majority vote prediction rate: 1.0


## Joint + majority vote

In [67]:
joint_prediction = {}
for item in union:
    if item in deep_prediction and item in cherry_prediction_species and item in vhmnet_prediction:
        votes = [deep_prediction[item], cherry_prediction_species[item], vhmnet_prediction[item]]
        joint_prediction[item] = max(set(votes), key=votes.count)

In [68]:
total = 0
correct = 0
for acc in label:
    try:
        pred = joint_prediction[acc]
        if pred == '':
            continue
        elif pred.replace('_', ' ') == label[acc]:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass

print("Joint + majority vote accuracy:", correct / total)
print("Joint + majority vote prediction rate:", total / len(label))

Joint + majority vote accuracy: 0.7030363697030364
Joint + majority vote prediction rate: 0.6379310344827587


## Joint + consensus

In [69]:
joint_prediction = {}
for item in union:
    if item in deep_prediction and item in cherry_prediction_species and item in vhmnet_prediction:
        if deep_prediction[item] == cherry_prediction_species[item] and  deep_prediction[item] == vhmnet_prediction[item]:
            joint_prediction[item] = deep_prediction[item]

In [70]:
total = 0
correct = 0
for acc in label:
    try:
        pred = joint_prediction[acc]
        if pred == '':
            continue
        elif pred.replace('_', ' ') == label[acc]:
            correct += 1
            total += 1
        else:
            total += 1
    except:
        pass

print("Joint + consensus accuracy:", correct / total)
print("Joint + consensus prediction rate:", total / len(label))

Joint + consensus accuracy: 0.9974424552429667
Joint + consensus prediction rate: 0.08322690506598553
