In [1]:
from datetime import datetime
from pathlib import Path

import pandas as pd
from joblib import Parallel, delayed
from pynndescent import NNDescent
from sklearn.model_selection import KFold
from tqdm import tqdm

## Load data

In [3]:
swiss_df = pd.read_hdf('../data/bacterial_swissprot.h5')
swiss_df.head()

Unnamed: 0_level_0,accessions,sequence_length,sequence,description,InterPro,GO,KO,Gene3D,Pfam,KEGG,...,Superkingdom,Kingdom,Phylum,Class,Order,Family,Subfamily,Genus,Species,Transmembrane
entry_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
12AH_CLOS4,P21215,29.0,MIFDGKVAIITGGGKAKSIGYGIAVAYAK,RecName: Full=12-alpha-hydroxysteroid dehydrog...,IPR036291,GO:0047013||GO:0030573||GO:0016042,,,,,...,Bacteria,,Firmicutes,Clostridia,Clostridiales,Clostridiaceae,,Clostridium,,0.0
12KD_MYCSM,P80438,24.0,MFHVLTLTYLCPLDVVXQTRPAHV,RecName: Full=12 kDa protein; Flags: Fragment;,,,,,,,...,Bacteria,,Actinobacteria,Actinobacteria,Corynebacteriales,Mycobacteriaceae,,Mycolicibacterium,,0.0
12OLP_LISIN,Q92AT0,1086.0,MTMLKEIKKADLSAAFYPSGELAWLKLKDIMLNQVIQNPLENRLSQ...,"RecName: Full=1,2-beta-oligoglucan phosphoryla...",IPR008928||IPR012341||IPR033432,GO:0016740,K21298,1.50.10.10,PF17167,lin:lin1839,...,Bacteria,,Firmicutes,Bacilli,Bacillales,Listeriaceae,,Listeria,,
12S_PROFR,Q8GBW6||Q05617,611.0,MAENNNLKLASTMEGRVEQLAEQRQVIEAGGGERRVEKQHSQGKQT...,RecName: Full=Methylmalonyl-CoA carboxyltransf...,IPR034733||IPR000438||IPR029045||IPR011763||IP...,GO:0009317||GO:0003989||GO:0047154||GO:0006633,,,PF01039,,...,Bacteria,,Actinobacteria,Actinobacteria,Propionibacteriales,Propionibacteriaceae,,Propionibacterium,,0.0
14KD_MYCBO,P0A5B8||A0A1R3Y251||P30223||X2BJK6,144.0,MATTLPVQRHPRSLFPEFSELFAAFPSFAGLRPTFDTRLMRLEDEM...,RecName: Full=14 kDa antigen; AltName: Full=16...,IPR002068||IPR008978,GO:0005618||GO:0005576,,2.60.40.790,PF00011,,...,Bacteria,,Actinobacteria,Actinobacteria,Corynebacteriales,Mycobacteriaceae,,Mycobacterium,,0.0


In [4]:
embeddings_paths = {
    'deep_embeddings': '../data/reduced_deep_embeddings.h5',
    '3mers': '../data/reduced_3mers_embeddings.h5',
    '3mers-tfidf': '../data/reduced_3mers-tfidf_embeddings.h5',
    'aafreq': '../data/reduced_aafreq_embeddings.h5'
}

In [5]:
%%time
embeddings = {
    k: pd.read_hdf(ep)
    for k, ep in embeddings_paths.items()
}

CPU times: user 1.21 s, sys: 411 ms, total: 1.62 s
Wall time: 1.65 s


## Helpers

In [6]:
from collections import defaultdict


def calculate_metric(neighbors, train_y, test_y, metric_function):
    # neighbor_labels = train_y.values[neighbors]
    
    # neighbor_labels = train_y.loc[neighbors]
    train_dict = defaultdict(lambda: ['Unknown'], train_y.to_dict())
    neighbor_labels = neighbors.applymap(lambda x: train_dict[x]).values
     
    k = neighbor_labels.shape[1]

    labels_df = (
        pd.DataFrame(neighbor_labels, index=test_y.index)
            .sum(axis=1)
            .rename('neighbors_labels')
            .to_frame()
    )
    
    labels_df = labels_df.merge(test_y, left_index=True, right_index=True)
    
    return labels_df.apply(lambda x: metric_function(x[1], x[0], k), axis=1)


def calculate_iou(ground_truth, neighbor_labels, k):
    """
    ground_truth - List of ground truth labels for given protein (list of strings)
    predictions - List of predictions labels for given protein (list of strings)
    """
    predictions = deduplicate_predictions(neighbor_labels, k)
    # predictions = neighbor_labels
    
    intersection = set(predictions).intersection(set(ground_truth))
    union = set(predictions).union(set(ground_truth))
    return len(intersection) / len(union)

def calculate_precision(ground_truth, neighbor_labels, k):
    """
    ground_truth - List of ground truth labels for given protein (list of strings)
    predictions - List of predictions labels for given protein (list of strings)
    """
    predictions = deduplicate_predictions(neighbor_labels, k)
    if len(predictions) == 0:
        return 0
    # predictions = neighbor_labels

    intersection = set(predictions).intersection(set(ground_truth))

    return len(intersection) / len(set(predictions))


def calculate_recall(ground_truth, neighbor_labels, k):
    """
    ground_truth - List of ground truth labels for given protein (list of strings)
    predictions - List of predictions labels for given protein (list of strings)
    """
    predictions = deduplicate_predictions(neighbor_labels, k)
    # predictions = neighbor_labels
    
    intersection = set(predictions).intersection(set(ground_truth))
    return len(intersection) / len(set(ground_truth))


def deduplicate_predictions(neighbor_labels, k):
    # It is slowish
    labels_counts = pd.Series(neighbor_labels).value_counts()

    # take all values that are equally popular as the last value - we break ties in this manner
    # most_popular_predictions = labels_counts.loc[
    #     labels_counts >= labels_counts.iloc[
    #         min(len(ground_truth)-1, len(labels_counts)-1)
    #     ]
    # ].index.tolist()
    
    # take all values that are equally popular as the last value - we break ties in this manner
    most_popular_predictions = labels_counts.loc[
        labels_counts > k/2
    ].index.tolist()
    
    return most_popular_predictions

## Cross-validation for all annotations 

In [7]:
ks = [1, 3, 15, 51, 101, 201]
# ks = [3, 15]
n_repeats = 1
n_folds = 5
random_state_seed=0

annotations = [
    'Phylum',
    'Order',
    'Family',
    'Genus',
    'SUPFAM',
    'Gene3D',
    'InterPro',
    'KO',
    'GO',
    'eggNOG',
    'Pfam',
    'EC number'
]
swiss_df[annotations].head()

Unnamed: 0_level_0,Phylum,Order,Family,Genus,SUPFAM,Gene3D,InterPro,KO,GO,eggNOG,Pfam,EC number
entry_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
12AH_CLOS4,Firmicutes,Clostridiales,Clostridiaceae,Clostridium,SSF51735,,IPR036291,,GO:0047013||GO:0030573||GO:0016042,,,1.1.1.176
12KD_MYCSM,Actinobacteria,Corynebacteriales,Mycobacteriaceae,Mycolicibacterium,,,,,,,,
12OLP_LISIN,Firmicutes,Bacillales,Listeriaceae,Listeria,SSF48208,1.50.10.10,IPR008928||IPR012341||IPR033432,K21298,GO:0016740,ENOG4107T40||COG3459,PF17167,2.4.1.333
12S_PROFR,Actinobacteria,Propionibacteriales,Propionibacteriaceae,Propionibacterium,SSF52096,,IPR034733||IPR000438||IPR029045||IPR011763||IP...,,GO:0009317||GO:0003989||GO:0047154||GO:0006633,ENOG4107QX3||COG4799,PF01039,2.1.3.1
14KD_MYCBO,Actinobacteria,Corynebacteriales,Mycobacteriaceae,Mycobacterium,SSF49764,2.60.40.790,IPR002068||IPR008978,,GO:0005618||GO:0005576,,PF00011,


In [8]:
import os
import tempfile
from itertools import cycle, islice

import pastas as pst


def mmseqs2_knn(train_df, test_df, verbose=2, threads=64):
    import pastas as pst

    # temp files
    train_fasta = tempfile.mktemp(suffix='.fasta')
    test_fasta = tempfile.mktemp(suffix='.fasta')
    results_file = tempfile.mktemp(suffix='.tsv')
    
    verbose = str(verbose)
    threads = str(threads)
    
    # save DataFrames to Fastas
    train_df.bio.to_fasta(train_fasta)
    test_df.bio.to_fasta(test_fasta)
    
    # mmseq2 search
    base_dir = os.getcwd()
    try:
        with tempfile.TemporaryDirectory() as tmp_dir:
            os.chdir(tmp_dir)
            os.system(f"mmseqs createdb {train_fasta} targetDB -v {verbose}")
            os.system(f"mmseqs createdb {test_fasta} queryDB -v {verbose}")
            
            os.system(f"mmseqs createindex targetDB tmp -k 5 -v {verbose}")
            
            os.system(f"mmseqs search queryDB targetDB resultDB tmp -e inf -s 9.0 --threads {threads} -v {verbose}")
            
            os.system(f"mmseqs createtsv queryDB targetDB resultDB {results_file} -v {verbose}")
            
#             !mmseqs createdb $train_fasta targetDB -v $verbose 
#             !mmseqs createdb $test_fasta queryDB -v $verbose 

#             !mmseqs createindex targetDB tmp -k 5 -v $verbose 

# #             !mmseqs search queryDB targetDB resultDB tmp --threads $threads -v $verbose 
#             !mmseqs search queryDB targetDB resultDB tmp -e inf -s 9.0 --threads $threads -v $verbose

#             !mmseqs createtsv queryDB targetDB resultDB $results_file -v $verbose 
    finally:
        os.chdir(base_dir)
        
    # parse results
    headers = [
        'Query',
        'Target',
        'Score',
        'Seq.Id.',
        'E-value',
        'qStartPos',
        'qEndPos',
        'qLen',
        'tStartPos',
        'tEndPos',
        'tLen'
    ]
    results = pd.read_csv(results_file, sep='\t', names=headers)
    # assert False

    # Limit to first 51
    results = results.groupby('Query').head(201)
    
    # Reformat
    results['n'] = results.groupby('Query').cumcount()
    results = results.set_index(['Query', 'n'])['Target'].unstack(-1)
    
    return results.reindex(test_df.index)

In [9]:
# annot_df = swiss_df[
#     swiss_df['GO'].notnull()
# ]

In [10]:
# neighbors = mmseqs2_knn(annot_df.iloc[:-5000], swiss_df[-5000:])
# neighbors.isnull().sum()

In [11]:
def one_cv_repeat(embeddings_name, repeat, annotation, output_path):
    print(f"[{datetime.now()}] Started {annotation}, repeat {repeat}, embedding: {embeddings_name}")
    
    # Different random states in each repeat
    random_state = repeat

    # Read embeddings
    # embeddings_df = embeddings[embeddings_name]
    # annot_df = swiss_df.join(embeddings_df, how='inner')
    annot_df = swiss_df
    
    # Not null annotation column and embedding columns
    annot_df = annot_df[
        annot_df[annotation].notnull()
    ]#.sample(1000) # For testing

    metric_dfs = []
    kfold = KFold(n_splits=n_folds, random_state=random_state, shuffle=True)
    # for train_ids, test_ids in tqdm(kfold.split(annot_df), total=kfold.n_splits):
    for train_ids, test_ids in kfold.split(annot_df):

        # Train-test split
        train_X = annot_df.iloc[train_ids]#[embeddings_df.columns]
        test_X = annot_df.iloc[test_ids]#[embeddings_df.columns]
        train_y = annot_df.iloc[train_ids][annotation]
        test_y = annot_df.iloc[test_ids][annotation]

        # Expand labels to lists
        train_y, test_y = train_y.str.split(pat=r'\|\|'), test_y.str.split(pat=r'\|\|')

        # Build & query NN graph
        # nn_graph = NNDescent(train_X, n_neighbors=max(ks), n_jobs=8)

        # neighbors, distances = nn_graph.query(test_X, k=max(ks))
        neighbors = mmseqs2_knn(train_X, test_X)
        # assert False
    
        # neighbors = neighbors.values

        for k in ks:
            k_nearest_neighbors = neighbors.iloc[:, :k]
            # k_nearest_neighbors = neighbors[:, :k]

            metric_values = pd.concat({
                'IoU': calculate_metric(k_nearest_neighbors, train_y, test_y, calculate_iou),
                'Precision': calculate_metric(k_nearest_neighbors, train_y, test_y, calculate_precision),
                'Recall': calculate_metric(k_nearest_neighbors, train_y, test_y, calculate_recall),
            }, axis=1)

            metric_values = metric_values.reset_index().melt(
                id_vars='entry_name',
                var_name='Metric',
                value_name='Value'
            )

            metric_values['k'] = k
            metric_dfs.append(metric_values)

    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    pd.concat(metric_dfs).to_pickle(str(output_path))
    print(f"[{datetime.now()}] Finished {annotation}, repeat {repeat}, embedding: {embeddings_name}")

In [12]:
Parallel(n_jobs=1)(
    delayed(one_cv_repeat)(emb_name, r, annotation, f'half_{emb_name}/repeat_{r}/{annotation}.pkl')
    for r in range(n_repeats)
    for annotation in annotations
    for emb_name in ['mmseqs2']
    # for emb_name, emb_df in embeddings.items()
    
)

[2022-04-23 10:33:42.414086] Started Phylum, repeat 0, embedding: mmseqs2
[2022-04-23 12:36:40.270744] Finished Phylum, repeat 0, embedding: mmseqs2
[2022-04-23 12:36:41.376315] Started Order, repeat 0, embedding: mmseqs2
[2022-04-23 14:45:59.053868] Finished Order, repeat 0, embedding: mmseqs2
[2022-04-23 14:45:59.868335] Started Family, repeat 0, embedding: mmseqs2
[2022-04-23 16:55:13.928953] Finished Family, repeat 0, embedding: mmseqs2
[2022-04-23 16:55:14.552026] Started Genus, repeat 0, embedding: mmseqs2
[2022-04-23 19:10:13.957344] Finished Genus, repeat 0, embedding: mmseqs2
[2022-04-23 19:10:14.897479] Started SUPFAM, repeat 0, embedding: mmseqs2
[2022-04-23 20:29:31.176854] Finished SUPFAM, repeat 0, embedding: mmseqs2
[2022-04-23 20:29:31.667280] Started Gene3D, repeat 0, embedding: mmseqs2
[2022-04-23 21:30:12.131884] Finished Gene3D, repeat 0, embedding: mmseqs2
[2022-04-23 21:30:12.544060] Started InterPro, repeat 0, embedding: mmseqs2
[2022-04-23 23:37:05.598965] Finis

[None, None, None, None, None, None, None, None, None, None, None, None]

### Merge all results

In [13]:
def read_results(results_path):
    results_path = Path(results_path)

    return pd.concat({
        (repeat, path.stem): pd.read_pickle(path) 
        for repeat in range(n_repeats)
        for path in results_path.glob(f'repeat_{repeat}/*.pkl')
    }, names=['Repeat', 'Label', 'entry_name']).reset_index(level=0).groupby(['Repeat', 'Label', 'k', 'Metric']).mean()

In [14]:
# df = read_results('fixed_deep_embeddings/')

In [15]:
# def bootstrap(group, p=0.01, n=10):
#     return pd.Series({
#         f'bootstrap_{i}': group.sample(frac=p)['Value'].mean()
#         for i in range(n)
#     })

# df.query('k == 3 and Label == "KO" and Metric == "Precision"').groupby(['Repeat', 'Label', 'k', 'Metric']).apply(bootstrap)

In [17]:
variant = "half2"
results = pd.concat({
    'Deep Embeddings': read_results(f'{variant}_deep_embeddings/'),
    '3-mers': read_results(f'{variant}_3mers/'),
    '3-mers TF-IDF': read_results(f'{variant}_3mers-tfidf/'),
    'AA freq.': read_results(f'{variant}_aafreq/'),
    'mmseqs2': read_results(f'{variant}_mmseqs2/'),
}, names=['Representation']).reset_index()

results.head()

Unnamed: 0,Representation,Repeat,Label,k,Metric,Value
0,Deep Embeddings,0,EC number,1,IoU,0.932389
1,Deep Embeddings,0,EC number,1,Precision,0.933024
2,Deep Embeddings,0,EC number,1,Recall,0.932958
3,Deep Embeddings,0,EC number,3,IoU,0.909968
4,Deep Embeddings,0,EC number,3,Precision,0.910627


In [18]:
def f1_score(group):
    metrics = group.set_index('Metric')['Value']
    f1 = 2 * metrics['Precision'] * metrics['Recall'] / (metrics['Precision'] + metrics['Recall'])
    return f1

f1_scores = results.groupby(['Representation', 'Repeat', 'Label', 'k']).apply(f1_score)
f1_scores.name = 'Value'
f1_scores = f1_scores.reset_index()
f1_scores['Metric'] = 'F1 Score'

results = pd.concat([results, f1_scores])
results.head()

  This is separate from the ipykernel package so we can avoid doing imports until


Unnamed: 0,Representation,Repeat,Label,k,Metric,Value
0,Deep Embeddings,0,EC number,1,IoU,0.932389
1,Deep Embeddings,0,EC number,1,Precision,0.933024
2,Deep Embeddings,0,EC number,1,Recall,0.932958
3,Deep Embeddings,0,EC number,3,IoU,0.909968
4,Deep Embeddings,0,EC number,3,Precision,0.910627


In [19]:
results.to_hdf('results_half_mmseq2.h5', 'results')