# Protein function prediction from embeddings

We will start with installing all necessary dependencies from sklearn etc.

In [1]:
import json
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, hamming_loss
from sklearn.multioutput import MultiOutputClassifier

We want training labels that include not only the GO term itself but also its hierarchy/ parents which is why we `extract_go_parents`

In [2]:
def get_parents(term_id, hierarchy):
    parents_of_term = set()
    if hierarchy:
        for element in hierarchy:
            if element:
                if element["id"] == term_id:
                    for parent in element.get("parents", []):
                        if parent:
                            parent_id = parent.get("id")
                            parents_of_term.add(parent_id)
                else:
                    parent_ids = get_parents(term_id, element["parents"])
                    parents_of_term.update(parent_ids)
    return parents_of_term

def extract_go_parents_recursive(go_hierarchies, term_id, collected_parents=None):
    if collected_parents is None:
        collected_parents = set()

        # Add the current term to the collected parents
    collected_parents.add(term_id)

    # Get parent ids of term_id
    parent_ids = get_parents(term_id, go_hierarchies)
    
    for parent_id in parent_ids:
        if parent_id and parent_id not in collected_parents:
            extract_go_parents_recursive(go_hierarchies, parent_id, collected_parents)
            
    return collected_parents

We now take the embeddings with the annotations and create an input and expected output value

In [3]:
with open("./../embeddings/protein_data_with_embeddings_and_hierarchy.json", "r") as infile:
#with open("./../embeddings/temp.json", "r") as infile:

    data = json.load(infile)


X = np.array([entry["embedding"] for entry in data])  # Embeddings
y_raw = []  # Multi-label targets

In [4]:
print(X)

[[ 0.09006091  0.03170803 -0.00533702 ... -0.04390722 -0.02989083
   0.0003081 ]
 [ 0.02427036  0.07759918  0.01994916 ...  0.06509287 -0.02460103
   0.0864621 ]
 [-0.00220569  0.00289085  0.01496955 ...  0.00796726  0.01389519
   0.01796642]
 ...
 [-0.02916173 -0.00513361 -0.05528961 ... -0.04723859 -0.00036292
   0.00650729]
 [ 0.04925685  0.00979528  0.00251721 ... -0.04029406 -0.03207516
   0.03731419]
 [ 0.08649088  0.02324473  0.02209109 ... -0.0716107  -0.07024817
  -0.01152768]]


Now add all go-terms from the hierarchy to the output labels

In [5]:
for entry in data:
    go_terms = set(go["id"] for go in entry.get("go_annotations_with_ids", []))


    go_hierarchies = entry.get("go_hierarchies", [])
    # print(go_hierarchies)
    all_terms = set()
    for term_id in go_terms:
        # print(term_id)
        temp = extract_go_parents_recursive(go_hierarchies, term_id)
        #print(temp)
        all_terms.update(temp)
        #print(all_terms)

    y_raw.append(all_terms)
    # print(y_raw)

In [6]:
len(y_raw[0])

17

In [7]:
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(y_raw)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

base_model = RandomForestClassifier(n_estimators=100, random_state=42)
classifier = MultiOutputClassifier(base_model)

In [9]:
print(f"Number of GO terms (labels): {len(mlb.classes_)}")

Number of GO terms (labels): 4975


The code below needs to be executed on the cluster:

In [None]:
print("Training the classifier...")
classifier.fit(X_train, y_train)

print("Evaluating the classifier...")
y_pred = classifier.predict(X_test)

print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=mlb.classes_))

print("Hamming Loss:", hamming_loss(y_test, y_pred))