In [1]:
from pathlib import Path

from tqdm import tqdm

import pandas as pd
import numpy as np

from sklearn.model_selection import KFold
from pynndescent import NNDescent

## Prepare path for the results

**Notebook needs to be run for each embeddings separately**

In [2]:
### Deep
# embeddings_path = '../data/reduced_deep_embeddings.h5'
# results_path_pattern = 'results_deep_embeddings/repeat_{}/{}.csv'

### 3-mers
# embeddings_path = '../data/reduced_3mers_embeddings.h5'
# results_path_pattern = 'results_3mers_embeddings/repeat_{}/{}.csv'

### 3-mers-tfidf
# embeddings_path = '../data/reduced_3mers-tfidf_embeddings.h5'
# results_path_pattern = 'results_3mers-tfidf_embeddings/repeat_{}/{}.csv'

### AA frequencies
embeddings_path = '../data/reduced_aafreq_embeddings.h5'
results_path_pattern = 'results_aafreq_embeddings/repeat_{}/{}.csv'

## Load data

In [3]:
swiss_df = pd.read_hdf('../data/bacterial_swissprot.h5')
swiss_df.head()

Unnamed: 0_level_0,accessions,sequence_length,sequence,description,InterPro,GO,KO,Gene3D,Pfam,KEGG,...,Superkingdom,Kingdom,Phylum,Class,Order,Family,Subfamily,Genus,Species,Transmembrane
entry_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
12AH_CLOS4,P21215,29.0,MIFDGKVAIITGGGKAKSIGYGIAVAYAK,RecName: Full=12-alpha-hydroxysteroid dehydrog...,IPR036291,GO:0047013||GO:0030573||GO:0016042,,,,,...,Bacteria,,Firmicutes,Clostridia,Clostridiales,Clostridiaceae,,Clostridium,,0.0
12KD_MYCSM,P80438,24.0,MFHVLTLTYLCPLDVVXQTRPAHV,RecName: Full=12 kDa protein; Flags: Fragment;,,,,,,,...,Bacteria,,Actinobacteria,Actinobacteria,Corynebacteriales,Mycobacteriaceae,,Mycolicibacterium,,0.0
12OLP_LISIN,Q92AT0,1086.0,MTMLKEIKKADLSAAFYPSGELAWLKLKDIMLNQVIQNPLENRLSQ...,"RecName: Full=1,2-beta-oligoglucan phosphoryla...",IPR008928||IPR012341||IPR033432,GO:0016740,K21298,1.50.10.10,PF17167,lin:lin1839,...,Bacteria,,Firmicutes,Bacilli,Bacillales,Listeriaceae,,Listeria,,
12S_PROFR,Q8GBW6||Q05617,611.0,MAENNNLKLASTMEGRVEQLAEQRQVIEAGGGERRVEKQHSQGKQT...,RecName: Full=Methylmalonyl-CoA carboxyltransf...,IPR034733||IPR000438||IPR029045||IPR011763||IP...,GO:0009317||GO:0003989||GO:0047154||GO:0006633,,,PF01039,,...,Bacteria,,Actinobacteria,Actinobacteria,Propionibacteriales,Propionibacteriaceae,,Propionibacterium,,0.0
14KD_MYCBO,P0A5B8||A0A1R3Y251||P30223||X2BJK6,144.0,MATTLPVQRHPRSLFPEFSELFAAFPSFAGLRPTFDTRLMRLEDEM...,RecName: Full=14 kDa antigen; AltName: Full=16...,IPR002068||IPR008978,GO:0005618||GO:0005576,,2.60.40.790,PF00011,,...,Bacteria,,Actinobacteria,Actinobacteria,Corynebacteriales,Mycobacteriaceae,,Mycobacterium,,0.0


In [4]:
embed_df = pd.read_hdf(embeddings_path)
embedding_columns = embed_df.columns

embed_df.head()

Unnamed: 0_level_0,M,A,F,S,E,D,V,L,K,Y,R,P,N,W,Q,C,G,I,H,T
entry_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
001R_FRG3G,0.023438,0.050781,0.03125,0.050781,0.058594,0.066406,0.082031,0.097656,0.113281,0.054688,0.058594,0.042969,0.03125,0.015625,0.035156,0.015625,0.058594,0.046875,0.035156,0.03125
002L_FRG3G,0.021875,0.08125,0.028125,0.06875,0.01875,0.075,0.071875,0.04375,0.053125,0.034375,0.05,0.115625,0.028125,0.01875,0.040625,0.05625,0.071875,0.03125,0.009375,0.08125
002R_IIV3,0.024017,0.032751,0.048035,0.080786,0.098253,0.10262,0.045852,0.074236,0.045852,0.045852,0.048035,0.054585,0.048035,0.032751,0.034934,0.0131,0.034934,0.052402,0.021834,0.061135
003L_IIV3,0.019231,0.064103,0.019231,0.096154,0.025641,0.032051,0.025641,0.076923,0.025641,0.051282,0.057692,0.128205,0.057692,0.012821,0.025641,0.051282,0.057692,0.038462,0.019231,0.115385
003R_FRG3G,0.061644,0.06621,0.045662,0.084475,0.031963,0.06621,0.121005,0.107306,0.043379,0.015982,0.047945,0.045662,0.038813,0.011416,0.02968,0.009132,0.059361,0.027397,0.022831,0.063927


In [5]:
swiss_df = swiss_df.join(embed_df)

## Helpers

In [6]:
def calculate_iou(ground_truth, predictions):
    predictions_counts = pd.Series(predictions).value_counts()
    len_ground_truth = len(ground_truth)
    
    # take all values that are equally popular as the last value - we break ties in this manner
    most_popular_values = predictions_counts.loc[
        predictions_counts >= predictions_counts.iloc[
            min(len_ground_truth-1, len(predictions_counts)-1)
        ]
    ].index
    
    intersection = set(most_popular_values).intersection(set(ground_truth))
    union = set(most_popular_values).union(set(ground_truth))
    return len(intersection) / len(union)

def calculate_metric(neighbors, train_y, test_y, metric_function):
    neighbor_labels = train_y.values[neighbors]

    labels_df = (
        pd.DataFrame(neighbor_labels, index=test_y.index)
            .sum(axis=1)
            .rename('neighbors_labels')
            .to_frame()
    )
    
    labels_df = labels_df.merge(test_y, left_index=True, right_index=True)
    
    return labels_df.apply(lambda x: metric_function(x[1], x[0]), axis=1)

## Cross-validation for all annotations 

In [7]:
ks = [3, 15, 51]
n_repeats = 10
n_folds = 5
random_state_seed=0

annotations = [
    'Phylum',
    'Order',
    'Family',
    'Genus',
    'SUPFAM',
    'Gene3D',
    'InterPro',
    'KO',
    'GO',
    'eggNOG',
    'Pfam',
    'EC number'
]
swiss_df[annotations].head()

Unnamed: 0_level_0,Phylum,Order,Family,Genus,SUPFAM,Gene3D,InterPro,KO,GO,eggNOG,Pfam,EC number
entry_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
12AH_CLOS4,Firmicutes,Clostridiales,Clostridiaceae,Clostridium,SSF51735,,IPR036291,,GO:0047013||GO:0030573||GO:0016042,,,1.1.1.176
12KD_MYCSM,Actinobacteria,Corynebacteriales,Mycobacteriaceae,Mycolicibacterium,,,,,,,,
12OLP_LISIN,Firmicutes,Bacillales,Listeriaceae,Listeria,SSF48208,1.50.10.10,IPR008928||IPR012341||IPR033432,K21298,GO:0016740,ENOG4107T40||COG3459,PF17167,2.4.1.333
12S_PROFR,Actinobacteria,Propionibacteriales,Propionibacteriaceae,Propionibacterium,SSF52096,,IPR034733||IPR000438||IPR029045||IPR011763||IP...,,GO:0009317||GO:0003989||GO:0047154||GO:0006633,ENOG4107QX3||COG4799,PF01039,2.1.3.1
14KD_MYCBO,Actinobacteria,Corynebacteriales,Mycobacteriaceae,Mycobacterium,SSF49764,2.60.40.790,IPR002068||IPR008978,,GO:0005618||GO:0005576,,PF00011,


In [None]:
for annotation in tqdm(annotations): 
    print(annotation)
    random_state = np.random.RandomState(random_state_seed)

    # Not null annotation column and embedding columns
    annot_df = swiss_df[
        swiss_df[annotation].notnull() & swiss_df[embedding_columns].notnull().all(axis=1)
    ]#.sample(1000) # For testing
   
    metrics_df = pd.DataFrame(
        data=np.nan,
        index=annot_df.index,
        columns=[f'k={k}' for k in ks]
    )
    
    for repeat in range(n_repeats):
        kfold = KFold(n_splits=n_folds, random_state=random_state, shuffle=True)
        for train_ids, test_ids in tqdm(kfold.split(annot_df), total=kfold.n_splits):

            # Train-test split
            train_X = annot_df.iloc[train_ids][embedding_columns]
            test_X = annot_df.iloc[test_ids][embedding_columns]
            train_y = annot_df.iloc[train_ids][annotation]
            test_y = annot_df.iloc[test_ids][annotation]

            # Expand labels to lists
            train_y, test_y = train_y.str.split(pat=r'\|\|'), test_y.str.split(pat=r'\|\|')

            # Build & query NN graph
            nn_graph = NNDescent(train_X, n_neighbors=max(ks), n_jobs=4)

            neighbors, distances = nn_graph.query(test_X, k=max(ks))

            for k in ks:
                k_nearest_neighbors = neighbors[:, :k]
                metric_values = calculate_metric(k_nearest_neighbors, train_y, test_y, calculate_iou)

                metrics_df.loc[metric_values.index, f'k={k}'] = metric_values.values            
                print(f'{annotation} for k={k}: {metric_values.mean():.4f}')
            
            results_path = Path(results_path_pattern.format(repeat, annotation))
            results_path.parent.mkdir(parents=True, exist_ok=True)
            metrics_df.to_csv(str(results_path))

  0%|                                                                            | 0/12 [00:00<?, ?it/s]

Phylum



  0%|                                                                             | 0/5 [00:00<?, ?it/s][A

Phylum for k=3: 0.7626
Phylum for k=15: 0.7343
Phylum for k=51: 0.6880



 20%|█████████████▌                                                      | 1/5 [02:33<10:12, 153.09s/it][A

Phylum for k=3: 0.7624
Phylum for k=15: 0.7319
Phylum for k=51: 0.6861



 40%|███████████████████████████▏                                        | 2/5 [04:59<07:26, 148.97s/it][A

Phylum for k=3: 0.7649
