## Generate labels



In [49]:
import json
import pickle
import os
import pandas as pd
# DeepFri
with open("model_outputs/DeepFri/bp/_BP_pred_scores.json", 'r') as f:
    deepfri_bp_dict = json.load(f)
deepfri_bp_classes = deepfri_bp_dict['goterms']
deepfri_proteins = deepfri_bp_dict['pdb_chains']

with open("model_outputs/DeepFri/mf/_MF_pred_scores.json", 'r') as f:
    deepfri_mf_dict = json.load(f)
deepfri_mf_classes = deepfri_mf_dict['goterms']

with open("model_outputs/DeepFri/cc/_CC_pred_scores.json", 'r') as f:
    deepfri_cc_dict = json.load(f)
deepfri_cc_classes = deepfri_cc_dict['goterms']


# HEAL
heal_bp_classes = list(pickle.load(open("model_outputs/HEAL/bp/1A0P-A.pkl", "rb")).keys())
heal_mf_classes = list(pickle.load(open("model_outputs/HEAL/mf/1A0P-A.pkl", "rb")).keys())
heal_cc_classes = list(pickle.load(open("model_outputs/HEAL/cc/1A0P-A.pkl", "rb")).keys())

heal_proteins = [i.split(".")[0] for i in os.listdir('model_outputs/HEAL/bp')]

# PFresGO
pfresgo_bp_classes = list(pickle.load(open("model_outputs/PFresGO/BP_PFresGO_results.pckl", "rb"))['goterms'])
pfresgo_mf_classes = list(pickle.load(open("model_outputs/PFresGO/MF_PFresGO_results.pckl", "rb"))['goterms'])
pfresgo_cc_classes = list(pickle.load(open("model_outputs/PFresGO/CC_PFresGO_results.pckl", "rb"))['goterms'])

pfresgo_proteins = list(pickle.load(open("model_outputs/PFresGO/BP_PFresGO_results.pckl", "rb"))['proteins'])


valid_bp_classes = list(set(deepfri_bp_classes) & set(heal_bp_classes) & set(pfresgo_bp_classes))
valid_mf_classes = list(set(deepfri_mf_classes) & set(heal_mf_classes) & set(pfresgo_mf_classes))
valid_cc_classes = list(set(deepfri_cc_classes) & set(heal_cc_classes) & set(pfresgo_cc_classes))

#GoBERT
gobert_proteins = list(pickle.load(open("model_outputs/GoBERT/processed/GoBERT_BP_logits.pkl", "rb")).keys())
# Find overlapping proteins
valid_proteins = list(set(deepfri_proteins) & set(heal_proteins) & set(pfresgo_proteins) & set(gobert_proteins))

# Summary of results
print(f"Number of valid BP classes: {len(valid_bp_classes)}")
print(f"Number of valid MF classes: {len(valid_mf_classes)}")
print(f"Number of valid CC classes: {len(valid_cc_classes)}")
print(f"Number of valid proteins: {len(valid_proteins)}")

valid_dict = {
    "valid_proteins": valid_proteins,
    "valid_bp_classes": valid_bp_classes,
    "valid_mf_classes": valid_mf_classes,
    "valid_cc_classes": valid_cc_classes
}

output_path = "PDB_test_set/valid_classes_and_proteins.pkl"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as pkl_file:
    pickle.dump(valid_dict, pkl_file)

print(f"Valid entities saved to: {output_path}")

data_path = "PDB_test_set/nrPDB-GO_2019.06.18_annot.tsv"
output_dir = "PDB_test_set"

# Initialize dictionaries with all zeros
bp_labels = {protein: {go: 0 for go in valid_bp_classes} for protein in valid_proteins}
cc_labels = {protein: {go: 0 for go in valid_cc_classes} for protein in valid_proteins}
mf_labels = {protein: {go: 0 for go in valid_mf_classes} for protein in valid_proteins}

# Load the TSV annotation file
annotations = pd.read_csv(data_path, sep='\t', header=None,
                           names=["PDB-chain", "GO-terms (molecular_function)", 
                                  "GO-terms (biological_process)", "GO-terms (cellular_component)"])

# Update the labels dictionaries based on the annotation file
for _, row in annotations.iterrows():
    protein = row["PDB-chain"]

    if protein in valid_proteins:
        # Update MF labels
        if isinstance(row["GO-terms (molecular_function)"], str):
            mf_terms = set(row["GO-terms (molecular_function)"].split(','))
            for go in mf_terms.intersection(valid_mf_classes):
                mf_labels[protein][go] = 1

        # Update BP labels
        if isinstance(row["GO-terms (biological_process)"], str):
            bp_terms = set(row["GO-terms (biological_process)"].split(','))
            for go in bp_terms.intersection(valid_bp_classes):
                bp_labels[protein][go] = 1

        # Update CC labels
        if isinstance(row["GO-terms (cellular_component)"], str):
            cc_terms = set(row["GO-terms (cellular_component)"].split(','))
            for go in cc_terms.intersection(valid_cc_classes):
                cc_labels[protein][go] = 1

# Save the label dictionaries
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "BP_labels.pkl"), 'wb') as bp_file:
    pickle.dump(bp_labels, bp_file)

with open(os.path.join(output_dir, "CC_labels.pkl"), 'wb') as cc_file:
    pickle.dump(cc_labels, cc_file)

with open(os.path.join(output_dir, "MF_labels.pkl"), 'wb') as mf_file:
    pickle.dump(mf_labels, mf_file)

print(f"Label dictionaries saved to {output_dir}.")




Number of valid BP classes: 1927
Number of valid MF classes: 478
Number of valid CC classes: 316
Number of valid proteins: 3409
Valid entities saved to: PDB_test_set/valid_classes_and_proteins.pkl
Label dictionaries saved to PDB_test_set.


## Preprocess DeepFri output


In [50]:
import json
import pickle
import os
from tqdm import tqdm 

def process_deepfri_scores(input_json_path, output_pkl_path, valid_proteins, valid_classes):
    """
    Processes the DeepFri JSON file and saves the results as a dictionary in a .pkl file.
    Filters based on valid proteins and classes.
    
    Args:
        input_json_path (str): Path to the input JSON file.
        output_pkl_path (str): Path to save the processed .pkl file.
        valid_proteins (list): List of valid proteins to include.
        valid_classes (list): List of valid GO classes to include.
    """
    with open(input_json_path, 'r') as f:
        deepfri_dict = json.load(f)
    
    # Extract relevant keys
    pdb_chains = deepfri_dict['pdb_chains']
    y_hat = deepfri_dict['Y_hat']
    goterms = deepfri_dict['goterms']

    # Filter proteins and classes
    filtered_dict = {
        chain: {goterm: y_hat[idx][gidx] for gidx, goterm in enumerate(goterms) if goterm in valid_classes}
        for idx, chain in tqdm(enumerate(pdb_chains), total=len(pdb_chains)) if chain in valid_proteins
    }

    # Save the processed dictionary as a .pkl file
    os.makedirs(os.path.dirname(output_pkl_path), exist_ok=True)
    with open(output_pkl_path, 'wb') as pkl_file:
        pickle.dump(filtered_dict, pkl_file)
    
    print(f"Processed data saved to: {output_pkl_path} with {len(filtered_dict)} proteins and {len(list(filtered_dict.values())[0])} GO classes.")

# Paths for DeepFri files
file_paths = {
    "bp": {
        "input": "model_outputs/DeepFri/bp/_BP_pred_scores.json",
        "output": "model_outputs/DeepFri/processed/DeepFri_BP_logits.pkl"
    },
    "cc": {
        "input": "model_outputs/DeepFri/cc/_CC_pred_scores.json",
        "output": "model_outputs/DeepFri/processed/DeepFri_CC_logits.pkl"
    },
    "mf": {
        "input": "model_outputs/DeepFri/mf/_MF_pred_scores.json",
        "output": "model_outputs/DeepFri/processed/DeepFri_MF_logits.pkl"
    }
}

# Load valid classes and proteins
valid_dict = pickle.load(open("PDB_test_set/valid_classes_and_proteins.pkl", "rb"))
valid_proteins = valid_dict["valid_proteins"]
valid_bp_classes = valid_dict["valid_bp_classes"]
valid_mf_classes = valid_dict["valid_mf_classes"]
valid_cc_classes = valid_dict["valid_cc_classes"]

# Process each file with constraints
process_deepfri_scores(file_paths['bp']['input'], file_paths['bp']['output'], valid_proteins, valid_bp_classes)
process_deepfri_scores(file_paths['cc']['input'], file_paths['cc']['output'], valid_proteins, valid_cc_classes)
process_deepfri_scores(file_paths['mf']['input'], file_paths['mf']['output'], valid_proteins, valid_mf_classes)

print("Processing with constraints completed.")


100%|██████████| 3410/3410 [01:33<00:00, 36.57it/s]


Processed data saved to: model_outputs/DeepFri/processed/DeepFri_BP_logits.pkl with 3409 proteins and 1927 GO classes.


100%|██████████| 3410/3410 [00:02<00:00, 1321.25it/s]


Processed data saved to: model_outputs/DeepFri/processed/DeepFri_CC_logits.pkl with 3409 proteins and 316 GO classes.


100%|██████████| 3410/3410 [00:06<00:00, 504.96it/s]


Processed data saved to: model_outputs/DeepFri/processed/DeepFri_MF_logits.pkl with 3409 proteins and 478 GO classes.
Processing with constraints completed.


## Preprocess HEAL output

In [51]:
import os
import pickle
from tqdm import tqdm
def process_heal_outputs(input_folder, output_pkl_path, valid_proteins, valid_classes):
    """
    Processes the HEAL outputs stored as separate .pkl files and combines them
    into a dictionary where the key is the pdb_chain and the value is a nested
    dictionary with GO terms and logits. Filters based on valid proteins and classes.

    Args:
        input_folder (str): Path to the folder containing HEAL .pkl files.
        output_pkl_path (str): Path to save the combined dictionary as a .pkl file.
        valid_proteins (list): List of valid proteins to include.
        valid_classes (list): List of valid GO classes to include.
    """
    combined_dict = {}

    # Iterate over all .pkl files in the input folder
    for filename in tqdm(os.listdir(input_folder)):
        if filename.endswith('.pkl'):
            pdb_chain = filename.split('.')[0]  # Extract pdb_chain (e.g., '1A0P-A')
            if pdb_chain not in valid_proteins:
                continue

            file_path = os.path.join(input_folder, filename)

            # Load the .pkl file
            data = pickle.load(open(file_path, "rb"))

            # Filter data for valid classes
            filtered_data = {goterm: logit for goterm, logit in data.items() if goterm in valid_classes}

            if filtered_data:  # Add only if there are valid classes
                combined_dict[pdb_chain] = filtered_data

    # Save the combined dictionary as a .pkl file
    os.makedirs(os.path.dirname(output_pkl_path), exist_ok=True)
    with open(output_pkl_path, 'wb') as pkl_file:
        pickle.dump(combined_dict, pkl_file)

    print(f"Processed HEAL data saved to: {output_pkl_path} with {len(combined_dict)} proteins and {len(list(combined_dict.values())[0])} GO classes.")

# Define input folders and output paths
heal_paths = {
    "bp": ("model_outputs/HEAL/bp", "model_outputs/HEAL/processed/HEAL_BP_logits.pkl"),
    "cc": ("model_outputs/HEAL/cc", "model_outputs/HEAL/processed/HEAL_CC_logits.pkl"),
    "mf": ("model_outputs/HEAL/mf", "model_outputs/HEAL/processed/HEAL_MF_logits.pkl")
}

# Load valid classes and proteins
valid_dict = pickle.load(open("PDB_test_set/valid_classes_and_proteins.pkl", "rb"))
valid_proteins = valid_dict["valid_proteins"]
valid_bp_classes = valid_dict["valid_bp_classes"]
valid_mf_classes = valid_dict["valid_mf_classes"]
valid_cc_classes = valid_dict["valid_cc_classes"]

# Process HEAL outputs for each category
print("Processing HEAL BP outputs...")
process_heal_outputs(heal_paths["bp"][0], heal_paths["bp"][1], valid_proteins, valid_bp_classes)

print("Processing HEAL CC outputs...")
process_heal_outputs(heal_paths["cc"][0], heal_paths["cc"][1], valid_proteins, valid_cc_classes)

print("Processing HEAL MF outputs...")
process_heal_outputs(heal_paths["mf"][0], heal_paths["mf"][1], valid_proteins, valid_mf_classes)



Processing HEAL BP outputs...


100%|██████████| 3410/3410 [00:38<00:00, 89.73it/s]


Processed HEAL data saved to: model_outputs/HEAL/processed/HEAL_BP_logits.pkl with 3409 proteins and 1927 GO classes.
Processing HEAL CC outputs...


100%|██████████| 3410/3410 [00:01<00:00, 2751.84it/s]


Processed HEAL data saved to: model_outputs/HEAL/processed/HEAL_CC_logits.pkl with 3409 proteins and 316 GO classes.
Processing HEAL MF outputs...


100%|██████████| 3410/3410 [00:02<00:00, 1270.51it/s]


Processed HEAL data saved to: model_outputs/HEAL/processed/HEAL_MF_logits.pkl with 3409 proteins and 478 GO classes.


## Preprocess PFresGO output

In [52]:
import os
import pickle

def process_pfresgo_scores(input_pkl_path, output_pkl_path, valid_proteins, valid_classes):
    """
    Processes the PFresGO .pckl file and saves the results as a dictionary in a .pkl file.
    Filters based on valid proteins and classes.

    Args:
        input_pkl_path (str): Path to the input .pckl file.
        output_pkl_path (str): Path to save the processed .pkl file.
        valid_proteins (list): List of valid proteins to include.
        valid_classes (list): List of valid GO classes to include.
    """
    # Load the PFresGO data
    data = pickle.load(open(input_pkl_path, "rb"))

    # Extract relevant keys
    proteins = data['proteins']
    y_hat = data['Y_pred']
    goterms = data['goterms']

    # Create the nested dictionary
    processed_dict = {}
    for idx, protein in tqdm(enumerate(proteins), total=len(proteins)):
        if protein not in valid_proteins:
            continue

        filtered_data = {goterm: y_hat[idx][gidx] for gidx, goterm in enumerate(goterms) if goterm in valid_classes}

        if filtered_data:  # Add only if there are valid classes
            processed_dict[protein] = filtered_data

    # Save the processed dictionary as a .pkl file
    os.makedirs(os.path.dirname(output_pkl_path), exist_ok=True)
    with open(output_pkl_path, 'wb') as pkl_file:
        pickle.dump(processed_dict, pkl_file)

    print(f"Processed data saved to: {output_pkl_path} with {len(processed_dict)} proteins and {len(list(processed_dict.values())[0])} GO classes.")

# Paths for PFresGO files
file_paths = {
    "bp": {
        "input": "model_outputs/PFresGO/BP_PFresGO_results.pckl",
        "output": "model_outputs/PFresGO/processed/PFresGO_BP_logits.pkl"
    },
    "cc": {
        "input": "model_outputs/PFresGO/CC_PFresGO_results.pckl",
        "output": "model_outputs/PFresGO/processed/PFresGO_CC_logits.pkl"
    },
    "mf": {
        "input": "model_outputs/PFresGO/MF_PFresGO_results.pckl",
        "output": "model_outputs/PFresGO/processed/PFresGO_MF_logits.pkl"
    }
}

# Load valid classes and proteins
valid_dict = pickle.load(open("PDB_test_set/valid_classes_and_proteins.pkl", "rb"))
valid_proteins = valid_dict["valid_proteins"]
valid_bp_classes = valid_dict["valid_bp_classes"]
valid_mf_classes = valid_dict["valid_mf_classes"]
valid_cc_classes = valid_dict["valid_cc_classes"]

# Process each file
print("Processing PFresGO BP outputs...")
process_pfresgo_scores(file_paths["bp"]["input"], file_paths["bp"]["output"], valid_proteins, valid_bp_classes)

print("Processing PFresGO CC outputs...")
process_pfresgo_scores(file_paths["cc"]["input"], file_paths["cc"]["output"], valid_proteins, valid_cc_classes)

print("Processing PFresGO MF outputs...")
process_pfresgo_scores(file_paths["mf"]["input"], file_paths["mf"]["output"], valid_proteins, valid_mf_classes)

print("Processing with constraints completed.")


Processing PFresGO BP outputs...


100%|██████████| 3416/3416 [00:37<00:00, 90.86it/s]


Processed data saved to: model_outputs/PFresGO/processed/PFresGO_BP_logits.pkl with 3409 proteins and 1927 GO classes.
Processing PFresGO CC outputs...


100%|██████████| 3416/3416 [00:01<00:00, 2275.96it/s]


Processed data saved to: model_outputs/PFresGO/processed/PFresGO_CC_logits.pkl with 3409 proteins and 316 GO classes.
Processing PFresGO MF outputs...


100%|██████████| 3416/3416 [00:02<00:00, 1175.84it/s]


Processed data saved to: model_outputs/PFresGO/processed/PFresGO_MF_logits.pkl with 3409 proteins and 478 GO classes.
Processing with constraints completed.


## Check GoBERT output

In [53]:
for i in ['BP', 'MF', 'CC']:
    data = pickle.load(open(f"model_outputs/GoBERT/processed/GoBERT_{i}_logits.pkl", "rb"))
    print(f"Checking GoBERT logits for {i}, there are {len(data)} proteins and {len(list(data.values())[0])} GO classes")


Checking GoBERT logits for BP, there are 3409 proteins and 193 GO classes
Checking GoBERT logits for MF, there are 3409 proteins and 48 GO classes
Checking GoBERT logits for CC, there are 3409 proteins and 33 GO classes
