# Training machine learning models on pairs of substrates in individual organisms

## Imports

In [1]:
from subpred.util import load_df
from subpred.graph import preprocess_data, get_substrate_matrix
from subpred.pssm import calculate_pssm_feature
from subpred.compositions import calculate_aac, calculate_paac
import pandas as pd
from subpred.cdhit import cd_hit
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.feature_selection import SelectPercentile
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import classification_report
from subpred.custom_transformers import FeatureCombinator, get_feature_type_combinations

from subpred.util import load_df
import networkx as nx


## Functions

### Dataset

In [2]:
def get_classification_task(
    organism_ids: set,
    labels: set,
    clustering_threshold: int = None,
    dataset_folder_path: str = "../data/datasets",
) -> pd.DataFrame:
    # TODO handling for multi-substrate
    # TODO ability to use go terms or chebi terms (compare sample count, performance)

    (
        df_uniprot,
        df_uniprot_goa,
        graph_go_filtered,
        graph_chebi_filtered,
    ) = preprocess_data(
        organism_ids=organism_ids, datasets_folder_path=dataset_folder_path
    )
    # TODO go through method code
    df_substrate_overlaps, dict_chebi_to_uniprot = get_substrate_matrix(
        datasets_folder_path=dataset_folder_path,
        graph_chebi=graph_chebi_filtered,
        graph_go=graph_go_filtered,
        df_uniprot_goa=df_uniprot_goa,
        min_overlap=0,
        max_overlap=int(1e6),
    )
    assert df_substrate_overlaps.shape[0] == len(dict_chebi_to_uniprot.keys())
    chebi_name_to_term = {
        name: term for term, name in graph_chebi_filtered.nodes(data="name")
    }
    chebi_term_to_name = {
        term: name for term, name in graph_chebi_filtered.nodes(data="name")
    }
    molecule_counts = {
        chebi_term_to_name[term]: len(proteins)
        for term, proteins in dict_chebi_to_uniprot.items()
    }
    print(sorted(molecule_counts.items(), key=lambda item: item[1], reverse=True))

    protein_to_label = list()
    for label in labels:
        label_proteins = dict_chebi_to_uniprot[chebi_name_to_term[label]]
        for protein in label_proteins:
            protein_to_label.append([protein, label])

    df_labels = pd.DataFrame.from_records(
        protein_to_label, columns=["Uniprot", "label"], index="Uniprot"
    )

    df_labels = df_labels[~df_labels.index.duplicated()]  # TODO series?
    print(df_labels.label.value_counts())
    df_sequences = df_uniprot.loc[df_labels.index].sequence.to_frame()
    print("number of sequences", df_sequences.shape[0])
    if clustering_threshold:
        cluster_representatives = cd_hit(
            df_sequences.sequence, identity_threshold=clustering_threshold
        )
        print(cluster_representatives)
        df_sequences = df_sequences.loc[cluster_representatives]
        df_labels = df_labels.loc[cluster_representatives]
    return pd.concat([df_sequences, df_labels], axis=1)

### Features

In [3]:
def get_features(series_sequences:pd.Series):
    df_aac = calculate_aac(series_sequences)
    df_paac = calculate_paac(series_sequences)
    df_pssm_50_1 = calculate_pssm_feature(
        series_sequences,
        tmp_folder="../data/intermediate/blast/pssm_uniref50_1it",
        blast_db="../data/raw/uniref/uniref50/uniref50.fasta",
        iterations=1,
        psiblast_threads=-1,
        verbose=False,
        feature_name="PSSM_50_1"
    )
    df_pssm_50_3 = calculate_pssm_feature(
        series_sequences,
        tmp_folder="../data/intermediate/blast/pssm_uniref50_3it",
        blast_db="../data/raw/uniref/uniref50/uniref50.fasta",
        iterations=3,
        psiblast_threads=-1,
        verbose=False,
        feature_name="PSSM_50_3"
    )
    df_pssm_90_1 = calculate_pssm_feature(
        series_sequences,
        tmp_folder="../data/intermediate/blast/pssm_uniref90_3it",
        blast_db="../data/raw/uniref/uniref90/uniref90.fasta",
        iterations=1,
        psiblast_threads=-1,
        verbose=False,
        feature_name="PSSM_90_1"
    )
    df_pssm_90_3 = calculate_pssm_feature(
        series_sequences,
        tmp_folder="../data/intermediate/blast/pssm_uniref90_3it",
        blast_db="../data/raw/uniref/uniref90/uniref90.fasta",
        iterations=3,
        psiblast_threads=-1,
        verbose=False,
        feature_name="PSSM_90_3"
    )
    df_features = pd.concat(
        [
            df_aac,
            df_paac,
            df_pssm_50_1,
            df_pssm_50_3,
            df_pssm_90_1,
            df_pssm_90_3,
        ], axis=1
    )
    return df_features

### Eval

In [4]:
# TODO try removing worst sample according to percentages
# TODO feature selection, regularization
# TODO cd-hit
# TODO determinism
# TODO also comparative analysis of features?
# TODO compare to protein embeddings and BLAST
# TODO parameter for using featurecombinator
# TODO separate functions

def evaluate(df_dataset, df_features):

    # converting data to numpy
    label_encoder = LabelEncoder()
    label_encoder.fit(sorted(df_dataset.label.unique()))
    sample_names = df_features.index.values
    feature_names = df_features.columns.values
    X = df_features.values
    y = label_encoder.transform(df_dataset.label)
    # train test eval split
    (
        X_train,
        X_eval,
        y_train,
        y_eval,
        sample_names_train,
        sample_names_eval,
    ) = train_test_split(X, y, sample_names, test_size=0.2, random_state=1, stratify=y)

    feature_type_combinations = get_feature_type_combinations(feature_names=feature_names)
    feature_combinator = FeatureCombinator(feature_names=df_features.columns)
    model = make_pipeline(
        StandardScaler(), feature_combinator, SVC(random_state=1, probability=True)
    )
    param_grid = {
        "svc__C": [0.1, 1, 10],
        # "svc__gamma": ["scale", "auto"],
        # "svc__class_weight": ["balanced", None],
        "featurecombinator__feature_types": feature_type_combinations,
        # "selectpercentile__percentile": list(range(1, 101, 5)),
    }

    # hyperparam optim & crossval
    gridsearch = GridSearchCV(
        estimator=model,
        param_grid=param_grid,
        scoring="f1",
        cv=5,
        n_jobs=-1,
        return_train_score=True,
        # verbose=20
    )
    gridsearch.fit(X_train, y_train)
    print("Best train score:", gridsearch.best_score_)
    print("Best train params", gridsearch.best_params_)
    model_optim = gridsearch.best_estimator_

    # eval
    y_pred = model_optim.predict(X_eval)
    print(classification_report(y_true=y_eval, y_pred=y_pred))
    print(model_optim.predict_proba(X_eval))  # TODO compare with actual labels


## New dataset creation

In [None]:
def get_id_update_dict(graph, field="alt_id"):
    dict_update_id = dict()
    for node, alt_ids in graph.nodes(data=field):
        if not alt_ids:
            continue
        for alt_id in alt_ids:
            dict_update_id[alt_id] = node

    return dict_update_id


def get_go_subgraph(root_term_name: str, edge_keys: set = {"is_a"}):
    # TODO filtering by aspect
    graph_go = load_df("go_obo")
    go_name_to_id = {name: node for node, name in list(graph_go.nodes(data="name"))}
    root_term_id = go_name_to_id[root_term_name]

    root_term_descendants = nx.ancestors(graph_go, root_term_id)
    graph_go_sub = graph_go.subgraph(root_term_descendants)
    graph_go_sub_isa = graph_go_sub.edge_subgraph(
        [
            (node1, node2, key)
            for node1, node2, key in graph_go.edges(keys=True)
            if key in edge_keys
        ]
    )
    return graph_go_sub_isa.copy()


def get_goa_subset(
    root_term_name: str,
    proteins_filter: set = None,
    qualifiers_filter: set = None,
    aspects_filter: set = None,
    evidence_codes_exclude: set = None,
):
    # creates subset of go annotations below a root node, and filtered by a set of proteins.
    graph_go = load_df("go_obo")

    go_name_to_id = {name: node for node, name in list(graph_go.nodes(data="name"))}
    go_id_to_name = {node: name for node, name in list(graph_go.nodes(data="name"))}
    root_term_id = go_name_to_id[root_term_name]

    # use is_a key for retrieving ancestors and descendants, to avoid crossing into different aspect
    graph_go_isa = graph_go.edge_subgraph(
        [
            (node1, node2, key)
            for node1, node2, key in graph_go.edges(keys=True)
            if key == "is_a"
        ]
    )
    root_term_descendants = nx.ancestors(graph_go_isa, root_term_id)

    df_goa = load_df("go")

    # update go terms
    dict_update_ids = get_id_update_dict(graph_go, field="alt_id")
    df_goa.go_id.apply(
        lambda id: dict_update_ids[id] if id in dict_update_ids.keys() else id
    )
    # filter annotations
    df_goa = df_goa[df_goa.go_id.isin(root_term_descendants)]
    if proteins_filter:
        df_goa = df_goa[df_goa.Uniprot.isin(proteins_filter)]
    if qualifiers_filter:
        df_goa = df_goa[df_goa.qualifier.isin(qualifiers_filter)]
    if aspects_filter:
        df_goa = df_goa[df_goa.aspect.isin(aspects_filter)]
    if evidence_codes_exclude:
        df_goa = df_goa[~df_goa.evidence_code.isin(evidence_codes_exclude)]

    df_goa = df_goa.drop_duplicates().reset_index(drop=True)
    df_goa["go_term"] = df_goa.go_id.map(go_id_to_name)
    df_goa["ancestors"] = df_goa.go_id.apply(
        lambda x: set(nx.descendants(graph_go_isa, x) & root_term_descendants)
    )
    df_goa = df_goa.explode("ancestors").reset_index(drop=True)
    df_goa["ancestor_term"] = df_goa.ancestors.map(go_id_to_name)

    return df_goa

## Main

Unnamed: 0,Uniprot,qualifier,go_id,evidence_code,aspect,go_term,ancestors,ancestor_term
0,B4E2Q0,enables,GO:0005388,IEA,F,P-type calcium transporter activity,GO:0015662,P-type ion transporter activity
1,B4E2Q0,enables,GO:0005388,IEA,F,P-type calcium transporter activity,GO:0140358,P-type transmembrane transporter activity
2,B4E2Q0,enables,GO:0005388,IEA,F,P-type calcium transporter activity,GO:0022890,inorganic cation transmembrane transporter act...
3,B4E2Q0,enables,GO:0005388,IEA,F,P-type calcium transporter activity,GO:0015085,calcium ion transmembrane transporter activity
4,B4E2Q0,enables,GO:0005388,IEA,F,P-type calcium transporter activity,GO:0008324,monoatomic cation transmembrane transporter ac...
...,...,...,...,...,...,...,...,...
1124,Q9P0L9,enables,GO:0015269,IDA,F,calcium-activated potassium channel activity,GO:0015075,monoatomic ion transmembrane transporter activity
1125,Q9P0L9,enables,GO:0015269,IDA,F,calcium-activated potassium channel activity,GO:0022839,monoatomic ion gated channel activity
1126,Q9P0L9,enables,GO:0015269,IDA,F,calcium-activated potassium channel activity,GO:0046873,metal ion transmembrane transporter activity
1127,Q9P0L9,enables,GO:0015269,IDA,F,calcium-activated potassium channel activity,GO:1901702,salt transmembrane transporter activity


In [6]:
def main(organism_ids:set, labels: set):
    # TODO get rid of unnecessary prints
    df_dataset = get_classification_task(
        organism_ids=organism_ids,
        labels=labels,
        clustering_threshold=70,
    )

    # TODO this is a quickfix, redo pipeline
    tmtp_proteins = get_goa_subset(
        root_term_name="transmembrane transporter activity",
        qualifiers_filter=["enables"],
        aspects_filter=["F"],
        proteins_filter=df_dataset.index.tolist()
    ).Uniprot.unique()
    df_dataset = df_dataset[df_dataset.index.isin(set(tmtp_proteins))]
    # TODO quickfix end

    df_features = get_features(df_dataset.sequence)

    df_features = df_features.loc[df_features.index.sort_values()]
    df_dataset = df_dataset.loc[df_features.index]
    df_features = df_features.loc[df_dataset.index]

    print(df_dataset.shape, df_features.shape)

    evaluate(df_dataset=df_dataset, df_features=df_features)


In [7]:
labels = {"potassium(1+)", "calcium(2+)"}

dataset_name_to_organism_ids = {
    "human": {9606},
    "athaliana": {3702},
    "ecoli": {83333},
    "yeast": {559292},
}
dataset_name_to_organism_ids["all"] = {
    list(s)[0] for s in dataset_name_to_organism_ids.values() if len(s) == 1
}


test_cases = [
    ("athaliana", "potassium(1+)", "calcium(2+)"),
    ("athaliana", "inorganic anion", "inorganic cation"),
    ("athaliana", "carboxylic acid anion", "inorganic anion"),
    ("ecoli", "carbohydrate derivate", "monosaccharide"),
    ("ecoli", "monocarboxylic acid", "amino acid"),
    ("human", "calcium(2+)", "sodium(1+)"),
    ("human", "calcium(2+)", "potassium(1+)"),
    ("human", "sodium(1+)", "potassium(1+)"),
    ("human", "inorganic anion", "inorganic cation"),
    ("yeast", "amide", "amino acid derivative"),
]

# - athaliana
#   - Ca2+ und K+
#   - inorganic anion/cation
#   - carboxylic acid anion/inorganic anion
# - ecoli
#   - carbohydrate derivate / monosaccharide
#   - minicarboxylic acid / amino acid
# - human
#   - Ca2+ / Na1+
#   - Ca2+ / K+
#   - Na+ / K+
#   - inorganic anion/cation
# - hefe
#   - amid / amino acid derivative

# TODO P05556 protein does not belong in there?

for dataset_name, organism_ids in dataset_name_to_organism_ids.items():
    print(dataset_name)
    main(organism_ids=organism_ids, labels=labels)


human


43248
164519
60547
1995
477
474
[('monoatomic ion', 1455), ('monoatomic cation', 1138), ('inorganic cation', 1061), ('metal cation', 852), ('calcium(2+)', 356), ('potassium(1+)', 289), ('organic anion', 266), ('sodium(1+)', 239), ('monoatomic anion', 238), ('inorganic anion', 234), ('proton', 197), ('chloride', 187), ('organic acid', 186), ('carboxylic acid anion', 183), ('chemical entity', 144), ('amino acid', 111), ('L-alpha-amino acid zwitterion', 86), ('carbohydrate derivative', 78), ('sulfur molecular entity', 66), ('ion', 58), ('monocarboxylic acid', 58), ('amide', 55), ('transition element cation', 49), ('carbohydrate', 40), ('organic cation', 40), ('amino acid derivative', 38), ('organic phosphate', 35), ('nucleotide-sugar', 32), ('biomacromolecule', 32), ('nucleotide', 31), ('pyrimidine nucleotide-sugar', 30), ('dicarboxylic acid', 28), ('purine nucleotide', 27), ('monosaccharide', 27), ('glucose', 25), ('hexose', 25), ('hydrogencarbonate', 24), ('adenyl nucleotide', 24), ('pu

KeyboardInterrupt: 

## PSSM generation script

In [None]:
# # TODO delete

# dataset_name_to_organism_ids = {
#     "human": {9606},
#     "athaliana": {3702},
#     "ecoli": {83333},
#     "yeast": {559292},
# }

# test_cases = [
#     ("athaliana", "potassium(1+)", "calcium(2+)"),
#     ("athaliana", "inorganic anion", "inorganic cation"),
#     ("athaliana", "carboxylic acid anion", "inorganic anion"),
#     ("ecoli", "carbohydrate derivate", "monosaccharide"),
#     ("ecoli", "monocarboxylic acid", "amino acid"),
#     ("human", "calcium(2+)", "sodium(1+)"),
#     ("human", "calcium(2+)", "potassium(1+)"),
#     ("human", "sodium(1+)", "potassium(1+)"),
#     ("human", "inorganic anion", "inorganic cation"),
#     ("yeast", "amide", "amino acid derivative"),
# ]

# dataset_name_to_organism_ids["all"] = {
#     list(s)[0] for s in dataset_name_to_organism_ids.values() if len(s) == 1
# }

# for dataset_name, substrate1, substrate2 in test_cases:
#     organism_ids = dataset_name_to_organism_ids[dataset_name]
#     df_dataset = get_classification_task(
#         organism_ids=organism_ids,
#         labels={substrate1, substrate2},
#         clustering_threshold=70,
#     )

# df_features = get_features(df_dataset.sequence)

## Comparisons

Compare training results with: 

- Average sequence similarity
- GO term similarity
  - How many protein in common?
  - Semantic similarity?
