In [145]:
import pandas as pd
import numpy as np
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score, precision_score, f1_score

In [146]:
db = 'urinary_max20_min5'
region = 'V3V4' # 'dna_seq' or 'V3V4'. Will only apply to reference sequences.

In [147]:
print(db)
print(region)

urinary_max20_min5
V3V4


In [None]:
# File paths
ref_database = "../datasets/train_sets/" + db + ".csv"
test_database = "..datasets/test_sets/test_set_from_refseq_v2.csv"


if region == 'dna_seq':
    output_file = "../preds/cosine/cosine_" + db + "_fullseq.csv"
elif region == 'V3V4':
    output_file = "../preds/cosine/cosine_" + db + "_V3V4.csv"

In [149]:
# Load data
ref_set = pd.read_csv(ref_database, sep=',')
test_set = pd.read_csv(test_database, sep=',')

taxonomy_levels = ['domain', 'phylum', 'class', 'order', 'family', 'genus', 'species']

ref_set['taxonomy'] = ref_set[taxonomy_levels].fillna("Unknown").agg('_'.join, axis=1)

In [150]:
#######################
# keep only if needed #
#######################

# remove in ref_data sequences from test_set
for id in test_set['seq_id']:
    ref_set = ref_set.loc[(ref_set['seq_id'] != id)]

In [151]:
# Validate DNA sequences
def clean_sequence(sequence):
    return re.sub(r'[^ATGC]', '', sequence.upper())

ref_set['dna_seq'] = ref_set['dna_seq'].astype(str).apply(clean_sequence)
ref_set['V3V4'] = ref_set['V3V4'].astype(str).apply(clean_sequence)
test_set['dna_seq'] = test_set['dna_seq'].astype(str).apply(clean_sequence)

In [152]:
def generate_kmers(sequence, k):
    return [sequence[i:i+k] for i in range(len(sequence) - k + 1)]

def encode_sequences(sequences, k):
    kmer_list = [' '.join(generate_kmers(seq, k)) for seq in sequences]
    vectorizer = TfidfVectorizer()
    kmer_vectors = vectorizer.fit_transform(kmer_list)
    return kmer_vectors, vectorizer

def find_closest_taxonomy(ref_set, test_set, k=7, threshold=0.8, n_neighbors=3):
    known_sequences = ref_set[region].tolist()
    unknown_sequences = test_set["dna_seq"].tolist()
    
    # Encode sequences
    known_vectors, vectorizer = encode_sequences(known_sequences, k)
    unknown_vectors = vectorizer.transform([' '.join(generate_kmers(seq, k)) for seq in unknown_sequences])
    
    # Use Nearest Neighbors for fast lookup
    nn = NearestNeighbors(n_neighbors=n_neighbors, metric='cosine')
    nn.fit(known_vectors)
    
    distances, indices = nn.kneighbors(unknown_vectors)
    
    results = []
    for i, (dist, idx) in enumerate(zip(distances, indices)):
        similarity = 1 - dist[0]  # Convert distance to similarity
        closest_taxonomy = ref_set.iloc[idx[0]]['taxonomy'] if similarity >= threshold else "Unclassified_Unclassified_Unclassified_Unclassified_Unclassified_Unclassified_Unclassified"
        
        results.append({
            "test_seq_id": test_set.iloc[i]["seq_id"],
            "closest_taxonomy": closest_taxonomy,
            "similarity": similarity
        })
    
    return pd.DataFrame(results)



In [153]:
# Run classification
result_df = find_closest_taxonomy(ref_set, test_set, k=7, threshold=0.1, n_neighbors=3)
display(result_df)


Unnamed: 0,test_seq_id,closest_taxonomy,similarity
0,>OK036813.1 Shewanella putrefaciens strain NMC...,Bacteria_Pseudomonadota_Gammaproteobacteria_Al...,1.000000
1,>MW398078.1 Faecalibacterium prausnitzii strai...,Bacteria_Bacillota_Clostridia_Eubacteriales_Os...,0.928726
2,>EU086786.1 Arthrobacter albus strain 1366 16S...,Bacteria_Actinomycetota_Actinomycetes_Micrococ...,0.929986
3,>PQ340191.1 Comamonas testosteroni strain YX-D...,Bacteria_Pseudomonadota_Betaproteobacteria_Bur...,0.962616
4,>OR501373.1 Streptococcus pneumoniae strain AT...,Bacteria_Bacillota_Bacilli_Lactobacillales_Str...,1.000000
...,...,...,...
838,>OP389295.1 Glutamicibacter creatinolyticus st...,Bacteria_Actinomycetota_Actinomycetes_Micrococ...,1.000000
839,>PQ236872.1 Arcanobacterium urinimassiliense s...,Bacteria_Actinomycetota_Actinomycetes_Actinomy...,1.000000
840,>PQ839297.1 Microbacterium testaceum strain _Z...,Bacteria_Actinomycetota_Actinomycetes_Micrococ...,0.974192
841,>MW049073.1 Cupriavidus pauculus strain RSCup0...,Bacteria_Pseudomonadota_Betaproteobacteria_Bur...,0.977086


In [154]:
def tax_splitter(tax):
    taxons = re.split(r"_", tax)
    return(taxons)



result_df['closest_taxonomy'] = result_df['closest_taxonomy'].apply(lambda x : tax_splitter(x))
display(result_df)


Unnamed: 0,test_seq_id,closest_taxonomy,similarity
0,>OK036813.1 Shewanella putrefaciens strain NMC...,"[Bacteria, Pseudomonadota, Gammaproteobacteria...",1.000000
1,>MW398078.1 Faecalibacterium prausnitzii strai...,"[Bacteria, Bacillota, Clostridia, Eubacteriale...",0.928726
2,>EU086786.1 Arthrobacter albus strain 1366 16S...,"[Bacteria, Actinomycetota, Actinomycetes, Micr...",0.929986
3,>PQ340191.1 Comamonas testosteroni strain YX-D...,"[Bacteria, Pseudomonadota, Betaproteobacteria,...",0.962616
4,>OR501373.1 Streptococcus pneumoniae strain AT...,"[Bacteria, Bacillota, Bacilli, Lactobacillales...",1.000000
...,...,...,...
838,>OP389295.1 Glutamicibacter creatinolyticus st...,"[Bacteria, Actinomycetota, Actinomycetes, Micr...",1.000000
839,>PQ236872.1 Arcanobacterium urinimassiliense s...,"[Bacteria, Actinomycetota, Actinomycetes, Acti...",1.000000
840,>PQ839297.1 Microbacterium testaceum strain _Z...,"[Bacteria, Actinomycetota, Actinomycetes, Micr...",0.974192
841,>MW049073.1 Cupriavidus pauculus strain RSCup0...,"[Bacteria, Pseudomonadota, Betaproteobacteria,...",0.977086


In [155]:
for k, level in enumerate(taxonomy_levels):
    result_df[level] = result_df['closest_taxonomy'].apply(lambda x : x[k])
    

result_df = result_df.drop(['closest_taxonomy'], axis = 1)


In [156]:
result_df

Unnamed: 0,test_seq_id,similarity,domain,phylum,class,order,family,genus,species
0,>OK036813.1 Shewanella putrefaciens strain NMC...,1.000000,Bacteria,Pseudomonadota,Gammaproteobacteria,Alteromonadales,Shewanellaceae,Shewanella,Shewanella putrefaciens
1,>MW398078.1 Faecalibacterium prausnitzii strai...,0.928726,Bacteria,Bacillota,Clostridia,Eubacteriales,Oscillospiraceae,Faecalibacterium,Faecalibacterium prausnitzii
2,>EU086786.1 Arthrobacter albus strain 1366 16S...,0.929986,Bacteria,Actinomycetota,Actinomycetes,Micrococcales,Micrococcaceae,Pseudoglutamicibacter,Pseudoglutamicibacter cumminsii
3,>PQ340191.1 Comamonas testosteroni strain YX-D...,0.962616,Bacteria,Pseudomonadota,Betaproteobacteria,Burkholderiales,Comamonadaceae,Comamonas,Comamonas testosteroni
4,>OR501373.1 Streptococcus pneumoniae strain AT...,1.000000,Bacteria,Bacillota,Bacilli,Lactobacillales,Streptococcaceae,Streptococcus,Streptococcus pneumoniae
...,...,...,...,...,...,...,...,...,...
838,>OP389295.1 Glutamicibacter creatinolyticus st...,1.000000,Bacteria,Actinomycetota,Actinomycetes,Micrococcales,Micrococcaceae,Glutamicibacter,Glutamicibacter creatinolyticus
839,>PQ236872.1 Arcanobacterium urinimassiliense s...,1.000000,Bacteria,Actinomycetota,Actinomycetes,Actinomycetales,Actinomycetaceae,Arcanobacterium,Arcanobacterium urinimassiliense
840,>PQ839297.1 Microbacterium testaceum strain _Z...,0.974192,Bacteria,Actinomycetota,Actinomycetes,Micrococcales,Microbacteriaceae,Microbacterium,Microbacterium testaceum
841,>MW049073.1 Cupriavidus pauculus strain RSCup0...,0.977086,Bacteria,Pseudomonadota,Betaproteobacteria,Burkholderiales,Burkholderiaceae,Cupriavidus,Cupriavidus pauculus


In [157]:
test_set

Unnamed: 0.1,Unnamed: 0,txid,seq_id,domain,phylum,class,order,family,genus,species,dna_seq
0,3365,24,>OK036813.1 Shewanella putrefaciens strain NMC...,Bacteria,Pseudomonadota,Gammaproteobacteria,Alteromonadales,Shewanellaceae,Shewanella,Shewanella putrefaciens,CCTACGGGAGGCAGCAGTGGGGAATATTGCACAATGGGGGAAACCC...
1,1874,853,>MW398078.1 Faecalibacterium prausnitzii strai...,Bacteria,Bacillota,Clostridia,Eubacteriales,Oscillospiraceae,Faecalibacterium,Faecalibacterium prausnitzii,CCTACGGGAGGCAGCAGTGGGGAATATTGCACAATGGGGGAAACCC...
2,811,98671,>EU086786.1 Arthrobacter albus strain 1366 16S...,Bacteria,Actinomycetota,Actinomycetes,Micrococcales,Micrococcaceae,Pseudoglutamicibacter,Pseudoglutamicibacter albus,CCTACGGGAGGCAGCAGTGGGGAATATTGCACAATGGGCGCAAGCC...
3,1380,285,>PQ340191.1 Comamonas testosteroni strain YX-D...,Bacteria,Pseudomonadota,Betaproteobacteria,Burkholderiales,Comamonadaceae,Comamonas,Comamonas testosteroni,CCTACGGGAGGCAGCAGTGGGGAATTTTGGACAATGGGCGAAAGCC...
4,3736,1313,>OR501373.1 Streptococcus pneumoniae strain AT...,Bacteria,Bacillota,Bacilli,Lactobacillales,Streptococcaceae,Streptococcus,Streptococcus pneumoniae,CCTACGGGAGGCAGCAGTAGGGAATCTTCGGCAATGGACGGAAGTC...
...,...,...,...,...,...,...,...,...,...,...,...
838,825,162496,>OP389295.1 Glutamicibacter creatinolyticus st...,Bacteria,Actinomycetota,Actinomycetes,Micrococcales,Micrococcaceae,Glutamicibacter,Glutamicibacter creatinolyticus,CCTACGGGAGGCAGCAGTGGGGAATATTGCACAATGGGCGAAAGCC...
839,291,1871014,>PQ236872.1 Arcanobacterium urinimassiliense s...,Bacteria,Actinomycetota,Actinomycetes,Actinomycetales,Actinomycetaceae,Arcanobacterium,Arcanobacterium urinimassiliense,CCTACGGGAGGCAGCAGTGGGGAATATTGCACAATGGACGGAAGTC...
840,2349,2033,>PQ839297.1 Microbacterium testaceum strain _Z...,Bacteria,Actinomycetota,Actinomycetes,Micrococcales,Microbacteriaceae,Microbacterium,Microbacterium testaceum,CCTACGGGAGGCAGCAGTGGGGAATATTGCACAATGGGCGAAAGCC...
841,4192,82633,>MW049073.1 Cupriavidus pauculus strain RSCup0...,Bacteria,Pseudomonadota,Betaproteobacteria,Burkholderiales,Burkholderiaceae,Cupriavidus,Cupriavidus pauculus,CCTACGGGAGGCAGCAGTGGGGAATTTTGGACAATGGGGGCAACCC...


In [158]:
def print_scores(true_labels, pred_labels):
    print('Accuracy:', accuracy_score(true_labels, pred_labels))
    print('Precision:', precision_score(true_labels, pred_labels, average = 'weighted'))
    print('f1_score:', f1_score(true_labels, pred_labels, average = 'weighted'))
    print('\n\n')



print(db)
print(region)
print('\n\n')

for level in taxonomy_levels:
    true_labels = test_set[level]
    pred_labels = result_df[level]
    print("scores au niveau ", level)
    print_scores(true_labels, pred_labels)




urinary_max20_min5
V3V4



scores au niveau  domain
Accuracy: 1.0
Precision: 1.0
f1_score: 1.0



scores au niveau  phylum
Accuracy: 0.99644128113879
Precision: 0.9929241037428861
f1_score: 0.994672368045351



scores au niveau  class
Accuracy: 0.9952550415183867
Precision: 0.9919451840213213
f1_score: 0.9935402492367068



scores au niveau  order
Accuracy: 0.9916963226571768
Precision: 0.987504524611829
f1_score: 0.9892879313228192



scores au niveau  family
Accuracy: 0.9857651245551602
Precision: 0.9778865229932846
f1_score: 0.9806343313327752



scores au niveau  genus
Accuracy: 0.9252669039145908
Precision: 0.9187714481924033
f1_score: 0.9113538139404976



scores au niveau  species
Accuracy: 0.7627520759193357
Precision: 0.7426445440680317
f1_score: 0.7332871359561751





  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [159]:
result_df.to_csv(output_file, index=False)


In [None]:
score_file = '../scores/recap_scores.csv'

with open(score_file, mode = 'a') as file:
    for level in taxonomy_levels[-3:]:
        print(level)
        accuracy = accuracy_score(test_set[level], result_df[level])
        print(accuracy)
        precision = precision_score(test_set[level], result_df[level], average = 'weighted', zero_division = np.nan)
        print(precision)
        f1 = f1_score(test_set[level], result_df[level], average = 'weighted', zero_division = np.nan)
        print(f1)
        
        file.write('cosine, {}, {}, {}, {}, {}, {}\n'.format(db, region, level, accuracy, precision, f1))


family
0.9857651245551602
0.988439255255802
0.9806343313327752
genus
0.9252669039145908
0.9480101968496891
0.9113538139404976
species
0.7627520759193357
0.8564286602590295
0.7332871359561751
