In [221]:
import pandas as pd
import numpy as np
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import neighbors
from sklearn.metrics import accuracy_score, precision_score, f1_score, classification_report
from sklearn.model_selection import train_test_split

In [222]:
db = 'urinary_max20_min5'
region = 'V3V4' # 'dna_seq' or 'V3V4'. Will only apply to training sequences

In [None]:
# File paths
train_file = "../datasets/train_sets/" + db + ".csv"


In [224]:
# Load data
data = pd.read_csv(train_file)

taxonomy_levels = ['domain', 'phylum', 'class', 'order', 'family', 'genus', 'species']


In [225]:
for level in taxonomy_levels:
    data[level] = data[level].fillna('Unknown')
    

In [226]:
display(data.head())

Unnamed: 0,txid,seq_id,dna_seq,domain,phylum,class,order,family,genus,species,V3V4
0,33007,>PQ788148.1 Winkia neuii strain som 201 16S ri...,CGAACGCTGGCGGCGTGCTTAACACATGCAAGTCGAACGGGATCCA...,Bacteria,Actinomycetota,Actinomycetes,Actinomycetales,Actinomycetaceae,Winkia,Winkia neuii,CCTACGGGAGGCAGCAGTGGGGGATATTGCACAATGGACGGAAGTC...
1,33007,>OR999579.1 Winkia neuii strain CNSY1 16S ribo...,GGCCTGCGGCGTGCTTACCATGCAAGTCGAACGGGATCCATTAGCG...,Bacteria,Actinomycetota,Actinomycetes,Actinomycetales,Actinomycetaceae,Winkia,Winkia neuii,CCTACGGGAGGCAGCAGTGGGGGATATTGCACAATGGACGAAAGTC...
2,33007,>OR260435.1 Winkia neuii strain 19 16S ribosom...,AACGGGTGAGTAACACGTGAGTAACCTGCCCTTTTCTTTGGGATAA...,Bacteria,Actinomycetota,Actinomycetes,Actinomycetales,Actinomycetaceae,Winkia,Winkia neuii,CCTACGGGAGGCAGCAGTGGGGGATATTGCACAATGGACGNAAGTC...
3,33007,>NR_042428.1 Winkia neuii strain DSM 8576 16S ...,CGGCGTGCTTAACACATGCAAGTCGAACGGGATCCATTGGTGCTTG...,Bacteria,Actinomycetota,Actinomycetes,Actinomycetales,Actinomycetaceae,Winkia,Winkia neuii,CCTACGGGAGGCAGCAGTGGGGGATATTGCACAATGGACGCAAGTC...
4,33007,>MZ452128.1 Winkia neuii strain 14a71 16S ribo...,ATGCAGTCGACGGGATCCATTAGCGCTTTTGTGTTTTTGGTGAGAG...,Bacteria,Actinomycetota,Actinomycetes,Actinomycetales,Actinomycetaceae,Winkia,Winkia neuii,CCTACGGGAGGCAGCAGTGGGGGATATTGCACAATGGACGGAAGTC...


In [227]:
# Validate DNA sequences
def clean_sequence(sequence):
    return re.sub(r'[^ATGC]', '', sequence.upper())

data['dna_seq'] = data['dna_seq'].astype(str).apply(clean_sequence)
data['V3V4'] = data['V3V4'].astype(str).apply(clean_sequence)


In [228]:
def generate_kmers(sequence, k):
    return [sequence[i:i+k] for i in range(len(sequence) - k + 1)]

def encode_sequences(sequences, k):
    kmer_list = [' '.join(generate_kmers(seq, k)) for seq in sequences]
    vectorizer = TfidfVectorizer()
    kmer_vectors = vectorizer.fit_transform(kmer_list)
    return kmer_vectors, vectorizer

def find_closest_taxonomy(X_train, X_test, y_train, y_test, level, k=7, threshold=0.8, n_neighbors=3):
    # Encode sequences
    train_vectors, vectorizer = encode_sequences(X_train, k)
    test_vectors = vectorizer.transform([' '.join(generate_kmers(seq, k)) for seq in X_test])
    
    # Use Nearest Neighbors for fast lookup
    knn = neighbors.KNeighborsClassifier(n_neighbors=n_neighbors, metric='cosine')
    knn.fit(train_vectors, y_train)
    
    y_pred = knn.predict(test_vectors)
    
    return y_test, y_pred
    

In [None]:
data = data.drop_duplicates([region])
print(data.info())

# train test split
X = data[['dna_seq', 'V3V4']]
y = data[taxonomy_levels]
    
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 1234)

X_train = X_train[region].tolist()
X_test = X_test['V3V4'].tolist()

# Run classification
y_true = pd.DataFrame()
y_pred = pd.DataFrame()

for level in taxonomy_levels:
    level_true, level_pred = find_closest_taxonomy(X_train, X_test, y_train[level], y_test[level], level, k=7, threshold=0.8, n_neighbors=3)
    y_true[level] = level_true
    y_pred[level] = level_pred


y_pred.to_csv("../preds/knn/knn_oneset_" + db + "_" + region + ".csv", index=False)


<class 'pandas.core.frame.DataFrame'>
Index: 2773 entries, 0 to 7062
Data columns (total 11 columns):
 #   Column   Non-Null Count  Dtype 
---  ------   --------------  ----- 
 0   txid     2773 non-null   int64 
 1   seq_id   2773 non-null   object
 2   dna_seq  2773 non-null   object
 3   domain   2773 non-null   object
 4   phylum   2773 non-null   object
 5   class    2773 non-null   object
 6   order    2773 non-null   object
 7   family   2773 non-null   object
 8   genus    2773 non-null   object
 9   species  2773 non-null   object
 10  V3V4     2773 non-null   object
dtypes: int64(1), object(10)
memory usage: 260.0+ KB
None


In [230]:
def print_scores(y_true, y_pred, levels_list):
    for level in levels_list:
        print(level)
        print('-' * len(level))
        print('   accuracy : ', accuracy_score(y_true[level], y_pred[level]))
        print('   precision :', precision_score(y_true[level], y_pred[level], average = 'weighted', zero_division = np.nan))
        print('   score f1 :', f1_score(y_true[level], y_pred[level], average = 'weighted', zero_division = np.nan))
        print('\n')
        print(classification_report(y_true[level], y_pred[level], zero_division = np.nan))

print_scores(y_true, y_pred, taxonomy_levels)

domain
------
   accuracy :  1.0
   precision : 1.0
   score f1 : 1.0


              precision    recall  f1-score   support

    Bacteria       1.00      1.00      1.00       555

    accuracy                           1.00       555
   macro avg       1.00      1.00      1.00       555
weighted avg       1.00      1.00      1.00       555

phylum
------
   accuracy :  0.9891891891891892
   precision : 0.9893714244173805
   score f1 : 0.9860825061967876


                         precision    recall  f1-score   support

         Actinomycetota       0.99      1.00      1.00       112
              Bacillota       0.97      1.00      0.98       129
           Bacteroidota       1.00      0.97      0.98        32
       Campylobacterota       1.00      1.00      1.00         6
            Chlamydiota       1.00      1.00      1.00         3
         Fusobacteriota       1.00      1.00      1.00         7
         Pseudomonadota       1.00      1.00      1.00       257
          Spiroch

  type_true = type_of_target(y_true, input_name="y_true")
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) f

In [231]:
score_file = '/home/marthe/Documents/DS/projet/local/frogsdays/scores/recap_scores.csv'

with open(score_file, mode = 'a') as file:
    for level in taxonomy_levels[-3:]:
        print(level)
        accuracy = accuracy_score(y_true[level], y_pred[level])
        print(accuracy)
        precision = precision_score(y_true[level], y_pred[level], average = 'weighted', zero_division = np.nan)
        print(precision)
        f1 = f1_score(y_true[level], y_pred[level], average = 'weighted', zero_division = np.nan)
        print(f1)
        
        file.write('knn oneset, {}, {}, {}, {}, {}, {}\n'.format(db, region, level, accuracy, precision, f1))


family
0.9783783783783784
0.9833276617948151
0.9734784665622379
genus
0.8972972972972973
0.9397605145860379
0.8888646711701788
species
0.6936936936936937
0.8453587161736942
0.670992160992161


  type_true = type_of_target(y_true, input_name="y_true")
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
  type_true = type_of_target(y_true, input_name="y_true")
  ys_types = set(type_of_target(x) for x in ys)
