# Protein function prediction from embeddings

We will start with installing all necessary dependencies from sklearn etc.

In [71]:
import json
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, hamming_loss
from sklearn.multioutput import MultiOutputClassifier

We want training labels that include not only the GO term itself but also its hierarchy/ parents which is why we `extract_go_parents`

In [72]:
def extract_go_parents_recursive(go_hierarchies, term_id, collected_parents=None):
    if collected_parents is None:
        collected_parents = set()

        # Add the current term to the collected parents
    collected_parents.add(term_id)

    # Find the term in the hierarchy by its ID
    term_hierarchy = next((term for term in go_hierarchies if term["id"] == term_id), None)
    if term_hierarchy:
        parents = term_hierarchy.get("parents", [])

        # Recurse into each parent
        for parent in parents:
            parent_id = parent.get("id")
            if parent_id and parent_id not in collected_parents:
                extract_go_parents_recursive(go_hierarchies, parent_id, collected_parents)

    #print(collected_parents)
    return collected_parents

We now take the embeddings with the annotations and create an input and expected output value

In [73]:
#with open("./../embeddings/protein_data_with_embeddings_and_hierarchy.json", "r") as infile:
with open("./../embeddings/temp.json", "r") as infile:

    data = json.load(infile)


X = np.array([entry["embedding"] for entry in data])  # Embeddings
y_raw = []  # Multi-label targets

In [74]:
print(X)

[[ 0.09006091  0.03170803 -0.00533702 -0.07386952  0.0730283  -0.00972234
   0.01985308  0.05804221  0.00130264 -0.04566936]]


Now add all go-terms from the hierarchy to the output labels

In [75]:
for entry in data:
    go_terms = set(go["id"] for go in entry.get("go_annotations_with_ids", []))


    go_hierarchies = entry.get("go_hierarchies", [])
    print(go_hierarchies)
    all_terms = set()
    for term_id in go_terms:
        print(term_id)
        temp = extract_go_parents_recursive(go_hierarchies, term_id)
        #print(temp)
        all_terms.update(temp)
        #print(all_terms)

    y_raw.append(all_terms)
    print(y_raw)

[{'id': 'GO:0070330', 'name': 'aromatase activity', 'namespace': 'molecular_function', 'parents': [{'id': 'GO:0016712', 'name': 'oxidoreductase activity, acting on paired donors, with incorporation or reduction of molecular oxygen, reduced flavin or flavoprotein as one donor, and incorporation of one atom of oxygen', 'namespace': 'molecular_function', 'parents': [{'id': 'GO:0004497', 'name': 'monooxygenase activity', 'namespace': 'molecular_function', 'parents': [{'id': 'GO:0016491', 'name': 'oxidoreductase activity', 'namespace': 'molecular_function', 'parents': [{'id': 'GO:0003824', 'name': 'catalytic activity', 'namespace': 'molecular_function', 'parents': [{'id': 'GO:0003674', 'name': 'molecular_function', 'namespace': 'molecular_function', 'parents': []}]}]}]}, {'id': 'GO:0016705', 'name': 'oxidoreductase activity, acting on paired donors, with incorporation or reduction of molecular oxygen', 'namespace': 'molecular_function', 'parents': [None]}]}]}, {'id': 'GO:0020037', 'name': '

In [76]:
len(y_raw[0])

8