In [1]:
from pathlib import Path

from tqdm import tqdm

import pandas as pd
import numpy as np

from sklearn.model_selection import RepeatedKFold
from pynndescent import NNDescent

## Prepare path for the results

In [2]:
results_path_pattern = 'results_ArdiMiPE/{}.csv'
# results_path_pattern = 'results_3mers/{}.csv'

In [3]:
Path(results_path_pattern).parent.mkdir(exist_ok=True)

## Load data

In [4]:
swiss_df = pd.read_hdf('../data/bacterial_swissprot.h5')
swiss_df.head()

Unnamed: 0_level_0,accessions,sequence_length,sequence,description,InterPro,GO,KO,Gene3D,Pfam,KEGG,...,EC number,Superkingdom,Kingdom,Phylum,Class,Order,Family,Subfamily,Genus,Species
entry_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
12AH_CLOS4,P21215,29.0,MIFDGKVAIITGGGKAKSIGYGIAVAYAK,RecName: Full=12-alpha-hydroxysteroid dehydrog...,IPR036291,GO:0047013||GO:0030573||GO:0016042,,,,,...,1.1.1.176,Bacteria,,Firmicutes,Clostridia,Clostridiales,Clostridiaceae,,Clostridium,
12KD_MYCSM,P80438,24.0,MFHVLTLTYLCPLDVVXQTRPAHV,RecName: Full=12 kDa protein; Flags: Fragment;,,,,,,,...,,Bacteria,,Actinobacteria,Actinobacteria,Corynebacteriales,Mycobacteriaceae,,Mycolicibacterium,
12OLP_LISIN,Q92AT0,1086.0,MTMLKEIKKADLSAAFYPSGELAWLKLKDIMLNQVIQNPLENRLSQ...,"RecName: Full=1,2-beta-oligoglucan phosphoryla...",IPR008928||IPR012341||IPR033432,GO:0016740,K21298,1.50.10.10,PF17167,lin:lin1839,...,2.4.1.333,Bacteria,,Firmicutes,Bacilli,Bacillales,Listeriaceae,,Listeria,
12S_PROFR,Q8GBW6||Q05617,611.0,MAENNNLKLASTMEGRVEQLAEQRQVIEAGGGERRVEKQHSQGKQT...,RecName: Full=Methylmalonyl-CoA carboxyltransf...,IPR034733||IPR000438||IPR029045||IPR011763||IP...,GO:0009317||GO:0003989||GO:0047154||GO:0006633,,,PF01039,,...,2.1.3.1,Bacteria,,Actinobacteria,Actinobacteria,Propionibacteriales,Propionibacteriaceae,,Propionibacterium,
14KD_MYCBO,P0A5B8||A0A1R3Y251||P30223||X2BJK6,144.0,MATTLPVQRHPRSLFPEFSELFAAFPSFAGLRPTFDTRLMRLEDEM...,RecName: Full=14 kDa antigen; AltName: Full=16...,IPR002068||IPR008978,GO:0005618||GO:0005576,,2.60.40.790,PF00011,,...,,Bacteria,,Actinobacteria,Actinobacteria,Corynebacteriales,Mycobacteriaceae,,Mycobacterium,


In [5]:
embed_df = pd.read_hdf('../data/reduced_deep_embeddings.h5')
# embed_df = pd.read_hdf('../data/reduced_3mers_embeddings.h5')
embedding_columns = embed_df.columns

embed_df.head()

Unnamed: 0_level_0,pca_0,pca_1,pca_2,pca_3,pca_4,pca_5,pca_6,pca_7,pca_8,pca_9,...,pca_40,pca_41,pca_42,pca_43,pca_44,pca_45,pca_46,pca_47,pca_48,pca_49
entry_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
12AH_CLOS4,-1.423302,-0.736753,-0.874685,-0.819151,-0.142118,-0.816148,0.386158,0.061937,0.687497,-0.297038,...,0.118606,-0.178684,0.080721,0.000936,-0.040333,0.036337,-0.045772,0.040853,-0.182722,0.036003
12KD_MYCSM,-2.985445,0.512436,-0.804671,-0.881089,0.023167,-0.840529,0.681208,-0.202686,0.34303,0.302511,...,0.063164,0.068167,-0.055077,0.054813,0.115826,-0.138364,-0.137902,-0.151101,-0.352669,-0.107964
12OLP_LISIN,-3.489646,0.009758,0.592749,-0.249326,-0.399583,0.004498,-0.739661,0.568125,-0.138422,-0.558067,...,0.078285,-0.139291,0.163776,-0.086977,0.097379,-0.009465,-0.043559,-0.039196,-0.102595,-0.001293
12S_PROFR,0.62239,-0.471249,0.868454,0.555548,-0.403828,-0.031662,-0.045285,-0.540816,0.314709,-0.14073,...,-0.212403,0.026798,-0.185727,-0.201333,-0.117253,-0.084229,-0.140449,-0.099275,0.174085,0.006138
14KD_MYCBO,-0.739342,0.083417,-0.164653,-0.817445,-0.376513,-0.484122,-0.727681,0.142234,0.0605,0.264555,...,-0.006761,-0.284457,0.022018,0.025971,0.017194,-0.175638,0.00239,-0.232771,-0.178814,0.0329


In [6]:
swiss_df = swiss_df.join(embed_df)

## Helpers

In [7]:
def calculate_iou(ground_truth, predictions):
    predictions_counts = pd.Series(predictions).value_counts()
    len_ground_truth = len(ground_truth)
    
    # take all values that are equally popular as the last value - we break ties in this manner
    most_popular_values = predictions_counts.loc[
        predictions_counts >= predictions_counts.iloc[
            min(len_ground_truth-1, len(predictions_counts)-1)
        ]
    ].index
    
    intersection = set(most_popular_values).intersection(set(ground_truth))
    union = set(most_popular_values).union(set(ground_truth))
    return len(intersection) / len(union)

def calculate_metric(neighbors, train_y, test_y, metric_function):
    neighbor_labels = train_y.values[neighbors]

    labels_df = (
        pd.DataFrame(neighbor_labels, index=test_y.index)
            .sum(axis=1)
            .rename('neighbors_labels')
            .to_frame()
    )
    
    labels_df = labels_df.merge(test_y, left_index=True, right_index=True)
    
    return labels_df.apply(lambda x: metric_function(x[1], x[0]), axis=1)

## Cross-validation for all annotations 

In [8]:
ks = [3, 15, 51]
n_repeats = 10
n_folds = 5
random_state=0

annotations = [
    'Phylum',
    'Order',
    'Family',
    'Genus',
    'SUPFAM',
    'Gene3D',
    'InterPro',
    'KO',
    'GO',
    'eggNOG',
    'Pfam',
    'EC number'
]
swiss_df[annotations].head()

Unnamed: 0_level_0,Phylum,Order,Family,Genus,SUPFAM,Gene3D,InterPro,KO,GO,eggNOG,Pfam,EC number
entry_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
12AH_CLOS4,Firmicutes,Clostridiales,Clostridiaceae,Clostridium,SSF51735,,IPR036291,,GO:0047013||GO:0030573||GO:0016042,,,1.1.1.176
12KD_MYCSM,Actinobacteria,Corynebacteriales,Mycobacteriaceae,Mycolicibacterium,,,,,,,,
12OLP_LISIN,Firmicutes,Bacillales,Listeriaceae,Listeria,SSF48208,1.50.10.10,IPR008928||IPR012341||IPR033432,K21298,GO:0016740,ENOG4107T40||COG3459,PF17167,2.4.1.333
12S_PROFR,Actinobacteria,Propionibacteriales,Propionibacteriaceae,Propionibacterium,SSF52096,,IPR034733||IPR000438||IPR029045||IPR011763||IP...,,GO:0009317||GO:0003989||GO:0047154||GO:0006633,ENOG4107QX3||COG4799,PF01039,2.1.3.1
14KD_MYCBO,Actinobacteria,Corynebacteriales,Mycobacteriaceae,Mycobacterium,SSF49764,2.60.40.790,IPR002068||IPR008978,,GO:0005618||GO:0005576,,PF00011,


In [None]:
for annotation in tqdm(annotations): 
    print(annotation)

    # Not null annotation column and embedding columns
    annot_df = swiss_df[
        swiss_df[annotation].notnull() & swiss_df[embedding_columns].notnull().all(axis=1)
    ]#.sample(1000) # For testing
   
    metrics_df = pd.DataFrame(
        data=np.nan,
        index=annot_df.index,
        columns=[f'k={k}' for k in ks]
    )
    
    kfold = RepeatedKFold(n_splits=n_folds, random_state=random_state, n_repeats=n_repeats)
    for train_ids, test_ids in tqdm(kfold.split(annot_df), total=kfold.get_n_splits()):
        
        # Train-test split
        train_X = annot_df.iloc[train_ids][embedding_columns]
        test_X = annot_df.iloc[test_ids][embedding_columns]
        train_y = annot_df.iloc[train_ids][annotation]
        test_y = annot_df.iloc[test_ids][annotation]
        
        # Expand labels to lists
        train_y, test_y = train_y.str.split(pat=r'\|\|'), test_y.str.split(pat=r'\|\|')

        # Build & query NN graph
        nn_graph = NNDescent(train_X, n_neighbors=max(ks), n_jobs=8)
        
        neighbors, distances = nn_graph.query(test_X, k=max(ks))

        for k in ks:
            k_nearest_neighbors = neighbors[:, :k]
            metric_values = calculate_metric(k_nearest_neighbors, train_y, test_y, calculate_iou)
            
            metrics_df.loc[metric_values.index, f'k={k}'] = metric_values.values            
            print(f'{annotation} for k={k}: {metric_values.mean():.4f}')

        metrics_df.to_csv(results_path_pattern.format(annotation))

  0%|          | 0/12 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s][A

Phylum
Phylum for k=3: 0.8215
Phylum for k=15: 0.7574
Phylum for k=51: 0.6823



  2%|▏         | 1/50 [02:40<2:10:55, 160.31s/it][A

Phylum for k=3: 0.8210
Phylum for k=15: 0.7562
Phylum for k=51: 0.6807



  4%|▍         | 2/50 [05:05<2:04:42, 155.88s/it][A

Phylum for k=3: 0.8211
Phylum for k=15: 0.7574
Phylum for k=51: 0.6783



  6%|▌         | 3/50 [07:30<1:59:23, 152.42s/it][A

Phylum for k=3: 0.8197
Phylum for k=15: 0.7564
Phylum for k=51: 0.6798



  8%|▊         | 4/50 [09:51<1:54:22, 149.18s/it][A

Phylum for k=3: 0.8231
Phylum for k=15: 0.7574
Phylum for k=51: 0.6789



 10%|█         | 5/50 [12:16<1:50:48, 147.74s/it][A

Phylum for k=3: 0.8209
Phylum for k=15: 0.7588
Phylum for k=51: 0.6821



 12%|█▏        | 6/50 [14:43<1:48:09, 147.49s/it][A

Phylum for k=3: 0.8225
Phylum for k=15: 0.7539
Phylum for k=51: 0.6770



 14%|█▍        | 7/50 [17:07<1:45:05, 146.64s/it][A

Phylum for k=3: 0.8197
Phylum for k=15: 0.7558
Phylum for k=51: 0.6810



 16%|█▌        | 8/50 [19:28<1:41:28, 144.97s/it][A

Phylum for k=3: 0.8231
Phylum for k=15: 0.7585
Phylum for k=51: 0.6802



 18%|█▊        | 9/50 [21:51<1:38:36, 144.29s/it][A

Phylum for k=3: 0.8197
Phylum for k=15: 0.7581
Phylum for k=51: 0.6789



 20%|██        | 10/50 [24:16<1:36:17, 144.43s/it][A

Phylum for k=3: 0.8187
Phylum for k=15: 0.7537
Phylum for k=51: 0.6756



 22%|██▏       | 11/50 [26:41<1:33:59, 144.61s/it][A

Phylum for k=3: 0.8224
Phylum for k=15: 0.7557
Phylum for k=51: 0.6814



 24%|██▍       | 12/50 [29:03<1:31:03, 143.77s/it][A

Phylum for k=3: 0.8210
Phylum for k=15: 0.7594
Phylum for k=51: 0.6825



 26%|██▌       | 13/50 [31:24<1:28:08, 142.92s/it][A

Phylum for k=3: 0.8203
Phylum for k=15: 0.7569
Phylum for k=51: 0.6794



 28%|██▊       | 14/50 [33:47<1:25:47, 143.00s/it][A

Phylum for k=3: 0.8232
Phylum for k=15: 0.7584
Phylum for k=51: 0.6809



 30%|███       | 15/50 [36:09<1:23:14, 142.71s/it][A

Phylum for k=3: 0.8226
Phylum for k=15: 0.7588
Phylum for k=51: 0.6816



 32%|███▏      | 16/50 [38:31<1:20:48, 142.60s/it][A

Phylum for k=3: 0.8205
Phylum for k=15: 0.7570
Phylum for k=51: 0.6803



 34%|███▍      | 17/50 [40:55<1:18:36, 142.94s/it][A

Phylum for k=3: 0.8210
Phylum for k=15: 0.7582
Phylum for k=51: 0.6802



 36%|███▌      | 18/50 [43:16<1:15:55, 142.36s/it][A

Phylum for k=3: 0.8187
Phylum for k=15: 0.7546
Phylum for k=51: 0.6781



 38%|███▊      | 19/50 [45:38<1:13:34, 142.41s/it][A

Phylum for k=3: 0.8214
Phylum for k=15: 0.7563
Phylum for k=51: 0.6799



 40%|████      | 20/50 [48:01<1:11:15, 142.53s/it][A

Phylum for k=3: 0.8201
Phylum for k=15: 0.7589
Phylum for k=51: 0.6839



 42%|████▏     | 21/50 [50:24<1:08:59, 142.74s/it][A

Phylum for k=3: 0.8224
Phylum for k=15: 0.7575
Phylum for k=51: 0.6787



 44%|████▍     | 22/50 [52:44<1:06:08, 141.72s/it][A

Phylum for k=3: 0.8258
Phylum for k=15: 0.7606
Phylum for k=51: 0.6826



 46%|████▌     | 23/50 [55:02<1:03:21, 140.79s/it][A

Phylum for k=3: 0.8209
Phylum for k=15: 0.7569
Phylum for k=51: 0.6804



 48%|████▊     | 24/50 [57:22<1:00:50, 140.39s/it][A

Phylum for k=3: 0.8191
Phylum for k=15: 0.7528
Phylum for k=51: 0.6732



 50%|█████     | 25/50 [59:40<58:14, 139.79s/it]  [A

Phylum for k=3: 0.8204
Phylum for k=15: 0.7568
Phylum for k=51: 0.6800



 52%|█████▏    | 26/50 [1:02:00<55:54, 139.76s/it][A

Phylum for k=3: 0.8200
Phylum for k=15: 0.7551
Phylum for k=51: 0.6777



 54%|█████▍    | 27/50 [1:04:17<53:18, 139.08s/it][A

Phylum for k=3: 0.8239
Phylum for k=15: 0.7604
Phylum for k=51: 0.6838



 56%|█████▌    | 28/50 [1:06:36<50:55, 138.89s/it][A

Phylum for k=3: 0.8191
Phylum for k=15: 0.7593
Phylum for k=51: 0.6795



 58%|█████▊    | 29/50 [1:08:55<48:37, 138.94s/it][A

Phylum for k=3: 0.8193
Phylum for k=15: 0.7541
Phylum for k=51: 0.6776



 60%|██████    | 30/50 [1:11:14<46:20, 139.02s/it][A

Phylum for k=3: 0.8234
Phylum for k=15: 0.7581
Phylum for k=51: 0.6811



 62%|██████▏   | 31/50 [1:13:31<43:50, 138.45s/it][A

Phylum for k=3: 0.8217
Phylum for k=15: 0.7592
Phylum for k=51: 0.6800



 64%|██████▍   | 32/50 [1:15:50<41:31, 138.43s/it][A

Phylum for k=3: 0.8186
Phylum for k=15: 0.7549
Phylum for k=51: 0.6766



 66%|██████▌   | 33/50 [1:18:09<39:19, 138.77s/it][A

Phylum for k=3: 0.8213
Phylum for k=15: 0.7582
Phylum for k=51: 0.6824



 68%|██████▊   | 34/50 [1:20:30<37:09, 139.35s/it][A

Phylum for k=3: 0.8212
Phylum for k=15: 0.7554
Phylum for k=51: 0.6791



 70%|███████   | 35/50 [1:22:50<34:52, 139.49s/it][A

Phylum for k=3: 0.8209
Phylum for k=15: 0.7547
Phylum for k=51: 0.6772



 72%|███████▏  | 36/50 [1:25:09<32:32, 139.45s/it][A

Phylum for k=3: 0.8197
Phylum for k=15: 0.7541
Phylum for k=51: 0.6765



 74%|███████▍  | 37/50 [1:27:29<30:16, 139.70s/it][A

Phylum for k=3: 0.8207
Phylum for k=15: 0.7590
Phylum for k=51: 0.6813



 76%|███████▌  | 38/50 [1:29:52<28:07, 140.66s/it][A

Phylum for k=3: 0.8208
Phylum for k=15: 0.7580
Phylum for k=51: 0.6810



 78%|███████▊  | 39/50 [1:32:15<25:52, 141.13s/it][A

Phylum for k=3: 0.8226
Phylum for k=15: 0.7603
Phylum for k=51: 0.6814



 80%|████████  | 40/50 [1:34:37<23:36, 141.66s/it][A

Phylum for k=3: 0.8182
Phylum for k=15: 0.7572
Phylum for k=51: 0.6837



 82%|████████▏ | 41/50 [1:37:06<21:33, 143.69s/it][A

Phylum for k=3: 0.8226
Phylum for k=15: 0.7578
Phylum for k=51: 0.6797



 84%|████████▍ | 42/50 [1:39:24<18:55, 141.91s/it][A

Phylum for k=3: 0.8188
Phylum for k=15: 0.7558
Phylum for k=51: 0.6759



 86%|████████▌ | 43/50 [1:41:42<16:26, 140.96s/it][A

Phylum for k=3: 0.8195
Phylum for k=15: 0.7537
Phylum for k=51: 0.6773



 88%|████████▊ | 44/50 [1:44:03<14:04, 140.74s/it][A

Phylum for k=3: 0.8225
Phylum for k=15: 0.7609
Phylum for k=51: 0.6806



 90%|█████████ | 45/50 [1:46:22<11:41, 140.34s/it][A

Phylum for k=3: 0.8210
Phylum for k=15: 0.7573
Phylum for k=51: 0.6807



 92%|█████████▏| 46/50 [1:48:41<09:20, 140.02s/it][A

Phylum for k=3: 0.8219
Phylum for k=15: 0.7546
Phylum for k=51: 0.6805



 94%|█████████▍| 47/50 [1:51:00<06:58, 139.55s/it][A

Phylum for k=3: 0.8209
Phylum for k=15: 0.7568
Phylum for k=51: 0.6801



 96%|█████████▌| 48/50 [1:53:18<04:38, 139.17s/it][A

Phylum for k=3: 0.8215
Phylum for k=15: 0.7590
Phylum for k=51: 0.6800



 98%|█████████▊| 49/50 [1:55:37<02:18, 138.99s/it][A

Phylum for k=3: 0.8195
Phylum for k=15: 0.7557
Phylum for k=51: 0.6788



100%|██████████| 50/50 [1:57:53<00:00, 141.47s/it][A
  8%|▊         | 1/12 [1:57:53<21:36:51, 7073.74s/it]

Order



  0%|          | 0/50 [00:00<?, ?it/s][A

Order for k=3: 0.5626
Order for k=15: 0.4469
Order for k=51: 0.3131



  2%|▏         | 1/50 [02:17<1:52:11, 137.38s/it][A

Order for k=3: 0.5624
Order for k=15: 0.4455
Order for k=51: 0.3091



  4%|▍         | 2/50 [04:36<1:50:26, 138.05s/it][A

Order for k=3: 0.5658
Order for k=15: 0.4460
Order for k=51: 0.3100



  6%|▌         | 3/50 [06:55<1:48:11, 138.11s/it][A

Order for k=3: 0.5640
Order for k=15: 0.4468
Order for k=51: 0.3118



  8%|▊         | 4/50 [09:14<1:46:11, 138.51s/it][A

Order for k=3: 0.5642
Order for k=15: 0.4467
Order for k=51: 0.3142



 10%|█         | 5/50 [11:38<1:45:00, 140.02s/it][A

Order for k=3: 0.5611
Order for k=15: 0.4429
Order for k=51: 0.3129



 12%|█▏        | 6/50 [13:57<1:42:29, 139.77s/it][A

Order for k=3: 0.5690
