In [207]:
import pandas as pd
import numpy as np
import re
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score, precision_score, f1_score

In [208]:
db = "urinary_max20_min5"
region = 'V3V4' # 'dna_seq' or 'V3V4'. Will only apply to reference sequences.

In [None]:
# File paths
input_database = "../datasets/train_sets/" + db + ".csv"



if region == 'dna_seq':
    output_file = "../preds/cosine/cosine_oneset_" + db + "_fullseq.csv"
elif region == 'V3V4':
    output_file = "../preds/cosine/cosine_oneset" + db + "_V3V4.csv"

In [210]:
# Load data
database = pd.read_csv(input_database, sep=',')

taxonomy_levels = ['domain', 'phylum', 'class', 'order', 'family', 'genus', 'species']

database['taxonomy'] = database[taxonomy_levels].fillna("Unknown").agg('_'.join, axis=1)

In [211]:
# Validate DNA sequences
def clean_sequence(sequence):
    return re.sub(r'[^ATGC]', '', sequence.upper())

database['dna_seq'] = database['dna_seq'].astype(str).apply(clean_sequence)
database['V3V4'] = database['V3V4'].astype(str).apply(clean_sequence)

In [212]:
if region == 'V3V4':
    database = database.drop_duplicates('V3V4').reset_index()
    
ref_set, test_set = train_test_split(database, test_size = 0.2, random_state = 1234)

In [213]:
test_set.info()

<class 'pandas.core.frame.DataFrame'>
Index: 555 entries, 2553 to 2137
Data columns (total 13 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   index     555 non-null    int64 
 1   txid      555 non-null    int64 
 2   seq_id    555 non-null    object
 3   dna_seq   555 non-null    object
 4   domain    555 non-null    object
 5   phylum    555 non-null    object
 6   class     555 non-null    object
 7   order     555 non-null    object
 8   family    555 non-null    object
 9   genus     555 non-null    object
 10  species   555 non-null    object
 11  V3V4      555 non-null    object
 12  taxonomy  555 non-null    object
dtypes: int64(2), object(11)
memory usage: 60.7+ KB


In [214]:
len(test_set[region].unique())

555

In [215]:
def generate_kmers(sequence, k):
    return [sequence[i:i+k] for i in range(len(sequence) - k + 1)]

def encode_sequences(sequences, k):
    kmer_list = [' '.join(generate_kmers(seq, k)) for seq in sequences]
    vectorizer = TfidfVectorizer()
    kmer_vectors = vectorizer.fit_transform(kmer_list)
    return kmer_vectors, vectorizer

def find_closest_taxonomy(ref_set, test_set, k=7, threshold=0.8, n_neighbors=3):
    ref_sequences = ref_set[region].tolist()
    test_sequences = test_set["V3V4"].tolist()

    
    # Encode sequences
    known_vectors, vectorizer = encode_sequences(ref_sequences, k)
    unknown_vectors = vectorizer.transform([' '.join(generate_kmers(seq, k)) for seq in test_sequences])
    
    # Use Nearest Neighbors for fast lookup
    nn = NearestNeighbors(n_neighbors=n_neighbors, metric='cosine')
    nn.fit(known_vectors)
    
    distances, indices = nn.kneighbors(unknown_vectors)
    
    results = []
    for i, (dist, idx) in enumerate(zip(distances, indices)):
        similarity = 1 - dist[0]  # Convert distance to similarity
        closest_taxonomy = ref_set.iloc[idx[0]]['taxonomy'] if similarity >= threshold else "Unclassified_Unclassified_Unclassified_Unclassified_Unclassified_Unclassified_Unclassified"
        
        results.append({
            "test_seq_id": test_set.iloc[i]["seq_id"],
            "closest_taxonomy": closest_taxonomy,
            "similarity": similarity
        })
    
    return pd.DataFrame(results)



In [216]:
# Run classification
result_df = find_closest_taxonomy(ref_set, test_set, k=7, threshold=0.1, n_neighbors=3)
display(result_df)


Unnamed: 0,test_seq_id,closest_taxonomy,similarity
0,>PX410416.1 Vibrio vulnificus strain V4WW 16S ...,Bacteria_Pseudomonadota_Gammaproteobacteria_Vi...,0.990304
1,>KC201361.1 Cedecea davisae strain HME8588 16S...,Bacteria_Pseudomonadota_Gammaproteobacteria_En...,0.942034
2,>PQ345526.1 Myroides odoratimimus strain 44RR ...,Bacteria_Bacteroidota_Flavobacteriia_Flavobact...,0.937237
3,>MH972167.1 Pseudescherichia vulneris strain D...,Bacteria_Pseudomonadota_Gammaproteobacteria_En...,0.878942
4,>KC764965.1 Methylobacterium zatmanii strain T...,Bacteria_Pseudomonadota_Alphaproteobacteria_Hy...,0.961541
...,...,...,...
550,>OK355362.1 Cutibacterium avidum strain HMCB2 ...,Bacteria_Actinomycetota_Actinomycetes_Propioni...,0.964104
551,>ON864055.1 Fusobacterium nucleatum strain SSF...,Bacteria_Fusobacteriota_Fusobacteriia_Fusobact...,0.929809
552,>MW398086.1 Megasphaera elsdenii strain PC439 ...,Bacteria_Bacillota_Negativicutes_Veillonellale...,0.872961
553,>PV960334.1 Bacteroides fragilis strain 27831 ...,Bacteria_Bacteroidota_Bacteroidia_Bacteroidale...,0.961635


In [217]:
def tax_splitter(tax):
    taxons = re.split(r"_", tax)
    return(taxons)



result_df['closest_taxonomy'] = result_df['closest_taxonomy'].apply(lambda x : tax_splitter(x))
display(result_df)


Unnamed: 0,test_seq_id,closest_taxonomy,similarity
0,>PX410416.1 Vibrio vulnificus strain V4WW 16S ...,"[Bacteria, Pseudomonadota, Gammaproteobacteria...",0.990304
1,>KC201361.1 Cedecea davisae strain HME8588 16S...,"[Bacteria, Pseudomonadota, Gammaproteobacteria...",0.942034
2,>PQ345526.1 Myroides odoratimimus strain 44RR ...,"[Bacteria, Bacteroidota, Flavobacteriia, Flavo...",0.937237
3,>MH972167.1 Pseudescherichia vulneris strain D...,"[Bacteria, Pseudomonadota, Gammaproteobacteria...",0.878942
4,>KC764965.1 Methylobacterium zatmanii strain T...,"[Bacteria, Pseudomonadota, Alphaproteobacteria...",0.961541
...,...,...,...
550,>OK355362.1 Cutibacterium avidum strain HMCB2 ...,"[Bacteria, Actinomycetota, Actinomycetes, Prop...",0.964104
551,>ON864055.1 Fusobacterium nucleatum strain SSF...,"[Bacteria, Fusobacteriota, Fusobacteriia, Fuso...",0.929809
552,>MW398086.1 Megasphaera elsdenii strain PC439 ...,"[Bacteria, Bacillota, Negativicutes, Veillonel...",0.872961
553,>PV960334.1 Bacteroides fragilis strain 27831 ...,"[Bacteria, Bacteroidota, Bacteroidia, Bacteroi...",0.961635


In [218]:
for k, level in enumerate(taxonomy_levels):
    result_df[level] = result_df['closest_taxonomy'].apply(lambda x : x[k])
    

result_df = result_df.drop(['closest_taxonomy'], axis = 1)
display(result_df)

Unnamed: 0,test_seq_id,similarity,domain,phylum,class,order,family,genus,species
0,>PX410416.1 Vibrio vulnificus strain V4WW 16S ...,0.990304,Bacteria,Pseudomonadota,Gammaproteobacteria,Vibrionales,Vibrionaceae,Vibrio,Vibrio vulnificus
1,>KC201361.1 Cedecea davisae strain HME8588 16S...,0.942034,Bacteria,Pseudomonadota,Gammaproteobacteria,Enterobacterales,Enterobacteriaceae,Raoultella,Raoultella ornithinolytica
2,>PQ345526.1 Myroides odoratimimus strain 44RR ...,0.937237,Bacteria,Bacteroidota,Flavobacteriia,Flavobacteriales,Flavobacteriaceae,Myroides,Myroides odoratimimus
3,>MH972167.1 Pseudescherichia vulneris strain D...,0.878942,Bacteria,Pseudomonadota,Gammaproteobacteria,Enterobacterales,Enterobacteriaceae,Salmonella,Salmonella enterica
4,>KC764965.1 Methylobacterium zatmanii strain T...,0.961541,Bacteria,Pseudomonadota,Alphaproteobacteria,Hyphomicrobiales,Methylobacteriaceae,Methylorubrum,Methylorubrum zatmanii
...,...,...,...,...,...,...,...,...,...
550,>OK355362.1 Cutibacterium avidum strain HMCB2 ...,0.964104,Bacteria,Actinomycetota,Actinomycetes,Propionibacteriales,Propionibacteriaceae,Cutibacterium,Cutibacterium avidum
551,>ON864055.1 Fusobacterium nucleatum strain SSF...,0.929809,Bacteria,Fusobacteriota,Fusobacteriia,Fusobacteriales,Fusobacteriaceae,Fusobacterium,Fusobacterium nucleatum
552,>MW398086.1 Megasphaera elsdenii strain PC439 ...,0.872961,Bacteria,Bacillota,Negativicutes,Veillonellales,Veillonellaceae,Megasphaera,Megasphaera elsdenii
553,>PV960334.1 Bacteroides fragilis strain 27831 ...,0.961635,Bacteria,Bacteroidota,Bacteroidia,Bacteroidales,Bacteroidaceae,Bacteroides,Bacteroides fragilis


In [219]:
def print_scores(true_labels, pred_labels):
    print('Accuracy:', accuracy_score(true_labels, pred_labels))
    print('Precision:', precision_score(true_labels, pred_labels, average = 'weighted', zero_division = np.nan))
    print('f1_score:', f1_score(true_labels, pred_labels, average = 'weighted', zero_division = np.nan))
    print('\n\n')



print(db)
print(region)
print('\n\n')

for level in taxonomy_levels:
    true_labels = test_set[level].to_list()
    pred_labels = result_df[level].to_list()
    print("scores au niveau ", level)
    print_scores(true_labels, pred_labels)




urinary_max20_min5
V3V4



scores au niveau  domain
Accuracy: 1.0
Precision: 1.0
f1_score: 1.0



scores au niveau  phylum
Accuracy: 0.9927927927927928
Precision: 0.9928369173307772
f1_score: 0.9900966338175641



scores au niveau  class
Accuracy: 0.990990990990991
Precision: 0.9915305322484957
f1_score: 0.9884171395376983



scores au niveau  order
Accuracy: 0.9891891891891892
Precision: 0.9908422091037814
f1_score: 0.9869643659184651



scores au niveau  family
Accuracy: 0.9855855855855856
Precision: 0.9877359154533066
f1_score: 0.9833614917516798



scores au niveau  genus
Accuracy: 0.9171171171171171
Precision: 0.9399767001601864
f1_score: 0.9127105177105178



scores au niveau  species
Accuracy: 0.7711711711711712
Precision: 0.8748251413241354
f1_score: 0.7586888186888187





  type_true = type_of_target(y_true, input_name="y_true")
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)


In [220]:
result_df.to_csv(output_file, index=False)


In [None]:
score_file = '../scores/recap_scores.csv'

with open(score_file, mode = 'a') as file:
    for level in taxonomy_levels[-3:]:
        true_labels = test_set[level].to_list()
        pred_labels = result_df[level].to_list()
        print(level)
        accuracy = accuracy_score(true_labels, pred_labels)
        print(accuracy)
        precision = precision_score(true_labels, pred_labels, average = 'weighted', zero_division = np.nan)
        print(precision)
        f1 = f1_score(true_labels, pred_labels, average = 'weighted', zero_division = np.nan)
        print(f1)
        
        file.write('cosine oneset, {}, {}, {}, {}, {}, {}\n'.format(db, region, level, accuracy, precision, f1))


family
0.9855855855855856
0.9877359154533066
0.9833614917516798
genus
0.9171171171171171
0.9399767001601864
0.9127105177105178
species
0.7711711711711712
0.8748251413241354
0.7586888186888187


  type_true = type_of_target(y_true, input_name="y_true")
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
