In [1]:
!pip install torch==2.7.0
!pip install torch-geometric
!pip install biopython
!pip install obonet
!pip install networkx
!pip install pandas
!pip install numpy
!pip install matplotlib
!pip install seaborn
!pip install scipy
!pip install scikit-learn
!pip install fair-esm

Collecting torch==2.7.0
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch==2.7.0)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch==2.7.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch==2.7.0)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch==2.7.0)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch==2.7.0)
  Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.6.4.1 (from torch==2.7.0)
  Downloading nvidia_cublas_cu12-12.6.4.1-py3-no

In [2]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # avoid fragmentation
import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import Bio
from Bio import SeqIO
import obonet
import gc
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import random
import esm


In [3]:
obo_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo'
fasta_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
term_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv'
taxonomy_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv'

In [4]:
LARGEST_FASTA_SEQ_LEN = 8922
ESM_EMBEDDING_DIM = 320
PCA_TARGET_DIM = 16

In [5]:
term_df = pd.read_csv(term_path, sep='\t')
term_df.head(), len(term_df["term"].unique())

(  EntryID        term aspect
 0  Q5W0B1  GO:0000785      C
 1  Q5W0B1  GO:0004842      F
 2  Q5W0B1  GO:0051865      P
 3  Q5W0B1  GO:0006275      P
 4  Q5W0B1  GO:0006513      P,
 26125)

In [6]:
taxonomy_df = pd.read_csv(taxonomy_path, sep='\t', names=['EntryID', 'taxonomyID'])
taxonomy_df.head()

Unnamed: 0,EntryID,taxonomyID
0,A0A0C5B5G6,9606
1,A0JNW5,9606
2,A0JP26,9606
3,A0PK11,9606
4,A1A4S6,9606


In [7]:
# def get_processed_fasta_df(fasta_data, term_df):
#     fasta_dict_list = []
#     term_set = set(term_df.tolist())
#     for fasta_seq in fasta_data:
#         entry = fasta_seq.id.split('|')[1] if '|' in fasta_seq.id else fasta_seq.id
#         if entry in term_set:
#             fasta_dict_list.append({
#                 "EntryID": entry,
#                 "fasta_sequence": str(fasta_seq.seq)
#             })

#     return pd.DataFrame(fasta_dict_list)


def sample_tsv(tsv_df, sample_frac=0.05, random_state=42):
    """
    Read a TSV file and sample based on unique EntryID.
    Pulls in all associated rows for sampled EntryIDs.
    """
    df = tsv_df
    unique_ids = df['EntryID'].unique()
    sample_size = max(1, int(len(unique_ids) * sample_frac))
    sampled_ids = random.sample(list(unique_ids), sample_size)
    sampled_df = df[df['EntryID'].isin(sampled_ids)]
    print(f"Sampled {len(sampled_df)} rows from {len(unique_ids)} unique EntryIDs")
    return sampled_df


def get_processed_fasta_df(fasta_data):
    fasta_dict_list = []
    for fasta_seq in fasta_data:
        entry = fasta_seq.id.split('|')[1] if '|' in fasta_seq.id else fasta_seq.id
        fasta_dict_list.append({
                "EntryID": entry,
                "fasta_sequence": str(fasta_seq.seq)
            })

    return pd.DataFrame(fasta_dict_list)





# Get Merged DF FULL (includes batching to fit embeddings into GPU and offloads to CPU)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
model = model.to(device)
model = model.half()
model.eval()
batch_converter = alphabet.get_batch_converter()

def generate_protein_embeddings_esm_batch(seq_df, seq_col='Sequence', entryid_col='EntryID',
                                              target_dim=16, batch_size=1, use_fp16=True):
    """
    Memory-optimized ESM embedding generation for proteins.
    Processes small batches and moves embeddings to CPU immediately.
    """
    sequences = seq_df[seq_col].tolist()
    entry_ids = seq_df[entryid_col].tolist()
    print(f"sequences: {len(sequences)} entry_ids {len(entry_ids)}")

    all_embeddings = []
    # print(len(sequences))
    dtype = torch.float16 if use_fp16 else torch.float32
    for i in range(0, len(sequences), batch_size):
        try:
            batch_seqs = sequences[i:i+batch_size]
            batch_labels = entry_ids[i:i+batch_size]
            if len(batch_seqs[0]) > LARGEST_FASTA_SEQ_LEN:
                print(f"length of batch seqs is: {len(batch_seqs[0])}")
                # Split the sequence to be max length LARGEST_FASTA_SEQ_LEN (will always be len 1 batch), this helps avoid the loss of information
                curr_seq = batch_seqs[0]
                embeddings_list = []
                for i in range(0, len( batch_seqs[0]), LARGEST_FASTA_SEQ_LEN):
                  curr_seq = batch_seqs[0][i:i+LARGEST_FASTA_SEQ_LEN]
                  curr_seq_embedding = get_sequence_embedding_esm(batch_labels, [curr_seq], dtype)
                  embeddings_list.append(curr_seq_embedding)
                # Average the embeddings_list since the long sequences were broken down into smaller chunks
                seq_embeddings = np.mean(embeddings_list, axis=0)

            else:
              seq_embeddings = get_sequence_embedding_esm(batch_labels, batch_seqs, dtype)

            all_embeddings.append(seq_embeddings)
            del seq_embeddings
        except Exception as e:
            print(e)
            print(len(batch_seqs[0]))
            print(batch_seqs)


    raw_embeddings = np.vstack(all_embeddings)
    # Mitigate memory constraints
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.reset_peak_memory_stats()

    col_names = [f"prot_emb_{i}" for i in range(320)]
    emb_df = pd.DataFrame(raw_embeddings, index=entry_ids, columns=col_names)
    emb_df.index.name = entryid_col

    return emb_df


def get_sequence_embedding_esm(batch_labels, batch_seqs, dtype):
    batch_data = [(label, seq) for label, seq in zip(batch_labels, batch_seqs)]
    _, _, batch_tokens = batch_converter(batch_data)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
      with torch.autocast(device_type="cuda", dtype=dtype):
          results = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)
          token_embeddings = results["representations"][model.num_layers]  # (B, L, D)
          # Mean pool over sequence length
          attention_mask = batch_tokens != alphabet.padding_idx
          masked_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
          seq_lengths = attention_mask.sum(dim=1).unsqueeze(-1)
          seq_embeddings = (masked_embeddings.sum(dim=1) / seq_lengths).cpu().float().numpy()
          del batch_tokens, token_embeddings, masked_embeddings, results, attention_mask
          return seq_embeddings


def apply_pca_to_esm_embeddings(esm_embeddings_df:pd.DataFrame, target_dim=16):
    '''
    Helper method 
    '''
    emb_cols = [c for c in esm_embeddings_df.columns if c.startswith("prot_emb")]
    
    tensor_list = [ 
        torch.tensor(row[emb_cols].values.astype("float16"), dtype=torch.float16)
        for _, row in esm_embeddings_df.iterrows()
    ]
    
    pca = PCA(n_components=PCA_TARGET_DIM)
    tensor_list_transformed = pca.fit_transform(tensor_list)
    
    # Intialize dictionary where we use EntryID for joining with all the other data
    embeddings_dict = { "EntryID": [] }
    # Initialize empty PCA embeddings for the pd.DataFrame
    for i in range(PCA_TARGET_DIM):
        embeddings_dict[f"emb_{i}"] =  []
        
    # Append EntryID's and embeddings to the dictionary for the dataframe
    for i, r in enumerate(tensor_list_transformed):
        curr_entry_id = esm_embeddings_df["EntryID"][i]
        embeddings_dict["EntryID"].append(curr_entry_id)
        for j in range(PCA_TARGET_DIM):
            embeddings_dict[f"emb_{j}"].append(r[j])

    pca_embeddings_df = pd.DataFrame(embeddings_dict)
    return pca_embeddings_df




Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t6_8M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t6_8M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t6_8M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t6_8M_UR50D-contact-regression.pt


In [11]:
file_name = "/kaggle/input/fasta-embeddings-final/fasta_embeddings_final.csv"
embeddings_processed = True

def get_merged_df_full(file_name, batch_size=250, embeddings_processed=False):
    """
    Merge term.tsv, fasta data, taxonomy data as well as nodes in the obo graph.
    """
    term_df = pd.read_csv(term_path, sep='\t')
    term_df = term_df
    taxonomy_df = pd.read_csv(taxonomy_path, sep='\t', names=['EntryID', 'taxonomyID'])
    fasta_data = list(SeqIO.parse(fasta_path, "fasta"))

    entry_ids = list(term_df['EntryID'])
    all_batches = []
    if not embeddings_processed:
      for i in range(0, len(entry_ids), batch_size):
          total_processed = i
          print(f"Total processed: {total_processed}, {i//batch_size} batch")

          # Batch the EntryIDs
          entry_batch = list(set(entry_ids[i:i+batch_size]))

          curr_term_df = term_df[i:i+batch_size]
          fasta_df_batch = get_processed_fasta_df(fasta_data, curr_term_df['EntryID'])
          fasta_emb_df_batch = generate_protein_embeddings_esm_batch(
              fasta_df_batch,
              "fasta_sequence"
          )


          all_batches.append(fasta_emb_df_batch)

          full_df = pd.concat(all_batches, ignore_index=False)
          full_df.to_csv("fasta_embeddings.csv", index=True)

    else:
        full_df = pd.read_csv(file_name)
        print(full_df.head())
        
    fasta_emb_df = apply_pca_to_esm_embeddings(full_df)
    
    # TODO add embeddings in getting merged_df
    merged_df = pd.merge(term_df, fasta_emb_df, on="EntryID", how='left')
    merged_df = pd.merge(merged_df, taxonomy_df, on="EntryID", how="left")
    graph = obonet.read_obo(obo_path)
    edges_list = []
    for node_id, data in graph.nodes(data=True):
        for parent_id in data.get("is_a", []):
            edges_list.append({
                    "term": node_id,
                    "parent": parent_id,
                    "relation": data.get('relation', 'is_a'),
                    "name": data["name"],
                    "namespace": data["namespace"],
                    "def": data["def"],
                    "synonym": data.get("synonym", [])
                    
                })
    edges_df = pd.DataFrame(edges_list)
    return merged_df, edges_df

protein_function_df, graph_df = get_merged_df_full(file_name, embeddings_processed=embeddings_processed)
len(protein_function_df)

  EntryID  prot_emb_0  prot_emb_1  prot_emb_2  prot_emb_3  prot_emb_4  \
0  P86164    0.063601    0.090173    0.272553    0.045446   -0.024562   
1  P84910    0.164426   -0.128475    0.244240    0.068856   -0.031558   
2  P83012    0.126772   -0.093400    0.216583    0.056300   -0.058426   
3  P83246    0.039001   -0.228608    0.207255    0.138270    0.081464   
4  P86133   -0.063519    0.141610    0.168906   -0.003960   -0.064855   

   prot_emb_5  prot_emb_6  prot_emb_7  prot_emb_8  ...  prot_emb_310  \
0   -0.087723   -0.097689   -0.002640   -0.085523  ...      0.121198   
1   -0.102609   -0.108406   -0.071888   -0.022300  ...      0.094619   
2    0.006500   -0.052936   -0.012787   -0.090959  ...      0.024688   
3   -0.122376   -0.169814   -0.090188   -0.169230  ...     -0.015823   
4   -0.166769    0.018469    0.077215    0.058494  ...      0.147194   

   prot_emb_311  prot_emb_312  prot_emb_313  prot_emb_314  prot_emb_315  \
0      0.150969      0.038082     -0.035188      0.10

537027

In [12]:
UNIQUE_TERMS = list(protein_function_df["term"].unique())
len(UNIQUE_TERMS) 

26125

In [13]:
 graph = obonet.read_obo(obo_path)
edges_list = []
for node_id, data in graph.nodes(data=True):
        for parent_id in data.get("is_a", []):
            edges_list.append({
                    "term": node_id,
                    "parent": parent_id,
                    "name": data["name"],
                    "namespace": data["namespace"],
                    "def": data["def"],
                    "synonym": data.get("synonym", [])
                })
edges_df = pd.DataFrame(edges_list)

In [14]:
ALL_SUBONTOLOGIES = graph_df["namespace"].unique()
graph_df.head(
)


Unnamed: 0,term,parent,relation,name,namespace,def,synonym
0,GO:0000001,GO:0048308,is_a,mitochondrion inheritance,biological_process,"""The distribution of mitochondria, including t...","[""mitochondrial inheritance"" EXACT []]"
1,GO:0000001,GO:0048311,is_a,mitochondrion inheritance,biological_process,"""The distribution of mitochondria, including t...","[""mitochondrial inheritance"" EXACT []]"
2,GO:0000002,GO:0007005,is_a,mitochondrial genome maintenance,biological_process,"""The maintenance of the structure and integrit...",[]
3,GO:0000006,GO:0005385,is_a,high-affinity zinc transmembrane transporter a...,molecular_function,"""Enables the transfer of zinc ions (Zn2+) from...","[""high affinity zinc uptake transmembrane tran..."
4,GO:0000007,GO:0005385,is_a,low-affinity zinc ion transmembrane transporte...,molecular_function,"""Enables the transfer of a solute or solutes f...",[]


In [15]:
protein_function_df.head()

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
0,Q5W0B1,GO:0000785,C,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606
1,Q5W0B1,GO:0004842,F,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606
2,Q5W0B1,GO:0051865,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606
3,Q5W0B1,GO:0006275,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606
4,Q5W0B1,GO:0006513,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606


In [16]:
N_LARGEST = 10
top_terms = protein_function_df['term'].value_counts().nlargest(N_LARGEST).index
protein_function_top_terms_df = protein_function_df[protein_function_df['term'].isin(top_terms)]

In [17]:
len(protein_function_top_terms_df), len(protein_function_df), len(graph_df)

(100851, 537027, 62410)

### Create set of embedding from the graph edges using GCN

In [18]:
class SimpleGCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(SimpleGCN, self).__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x
  
def init_subontology_GCNs(embed_dim=16, hidden_dim=32, out_dim=16):
    graph_subontology_dict = {
       subontology: SimpleGCN(embed_dim, hidden_dim, out_dim) for subontology in ALL_SUBONTOLOGIES
    }
    return graph_subontology_dict

def get_subontology_graph_dfs(graph_df):
    subontology_graph_dfs = {
       subontology: graph_df[graph_df["namespace"]==subontology] for subontology in ALL_SUBONTOLOGIES
    }
    return subontology_graph_dfs
  
subontology_GCNs, subontology_graph_dfs = init_subontology_GCNs(), get_subontology_graph_dfs(graph_df)

In [19]:
protein_function_df.head()

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
0,Q5W0B1,GO:0000785,C,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606
1,Q5W0B1,GO:0004842,F,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606
2,Q5W0B1,GO:0051865,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606
3,Q5W0B1,GO:0006275,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606
4,Q5W0B1,GO:0006513,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606


In [20]:
def group_terms_and_aspects(protein_function_df):
    protein_function_grouped_df = (
        protein_function_df
            .groupby("EntryID")
            .agg({
                "term": list,                     
                "aspect": list,                   
                **{c: "first" for c in protein_function_top_terms_df.columns 
                   if c.startswith("emb_")},      
                "taxonomyID": "first"            
            })
            .rename(columns={"term": "output_terms"})
            .reset_index()
    )
    return protein_function_grouped_df

protein_function_grouped_df = group_terms_and_aspects(protein_function_df)

In [21]:
protein_function_grouped_df.head()

Unnamed: 0,EntryID,output_terms,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
0,A0A023FBW4,[GO:0019958],[F],-0.733679,-0.278833,-0.73913,-0.129298,0.640625,0.608326,0.276888,0.059629,0.456924,-0.223931,0.269701,-0.287853,-0.389044,-0.08632,0.318142,0.208756,34607
1,A0A023FBW7,[GO:0019957],[F],-0.677593,-0.401437,-0.250645,0.016632,0.365204,1.118803,0.574566,-0.098997,0.70253,-0.231559,0.155171,-0.124747,0.052744,-0.021441,0.258956,0.208436,34607
2,A0A023FDY8,[GO:0019957],[F],-0.652475,-0.402101,-0.241203,0.076115,0.370869,1.122666,0.586032,-0.069793,0.735825,-0.261511,0.230517,-0.101076,0.027873,-0.013094,0.264486,0.208656,34607
3,A0A023FF81,[GO:0019958],[F],-0.550702,-0.327903,-0.61398,-0.264536,0.529165,0.625794,0.384368,-0.124693,0.377998,-0.150738,0.158927,-0.375993,-0.408175,-0.077548,0.325447,0.130539,34607
4,A0A023FFB5,[GO:0019957],[F],-0.633638,-0.376555,-0.319808,-0.28403,0.645544,1.141829,0.553369,-0.213995,0.667949,-0.399849,0.135954,-0.303759,-0.021841,-0.024281,0.311488,0.099055,34607


In [22]:
protein_function_grouped_df[protein_function_grouped_df["EntryID"] == "Q5W0B1"]

Unnamed: 0,EntryID,output_terms,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
48432,Q5W0B1,"[GO:0000785, GO:0004842, GO:0051865, GO:000627...","[C, F, P, P, P, F, F]",-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220611,0.036556,-0.221058,0.19175,0.372038,0.575382,-0.200636,-0.117404,-0.2487,-0.039927,0.179385,9606


In [23]:
protein_function_grouped_df = protein_function_grouped_df.sample(frac=1, random_state=42).reset_index(drop=True)

In [24]:
protein_function_grouped_df.head()

Unnamed: 0,EntryID,output_terms,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
0,P04268,"[GO:0005515, GO:0005886, GO:0008092]","[F, C, F]",0.828323,0.703457,1.160203,-0.054023,-0.939162,-1.08581,1.780389,0.795923,0.36177,0.317409,-0.095173,-0.350046,-0.512594,-0.313681,0.097373,0.833666,9031
1,Q8LAP6,"[GO:0005634, GO:0009507, GO:0009534, GO:000953...","[C, C, C, C, C]",-0.793946,-0.121416,0.264455,-0.091416,0.171004,0.079607,0.149902,-0.03493,-0.256049,0.294488,0.015121,-0.285039,-0.168807,0.164656,-0.089438,-0.23967,3702
2,Q03489,[GO:0005515],[F],-0.825974,0.277098,-0.046827,-0.155165,-0.192515,-0.353548,0.237032,0.520949,0.298848,-0.230937,-0.166167,0.181591,-0.110926,-0.025534,0.032579,-0.221545,4102
3,Q96PQ6,"[GO:0005515, GO:0005634]","[F, C]",1.056239,1.366326,-0.094833,-0.112145,-0.218979,0.265883,0.452629,0.542162,0.305502,0.255235,-0.147662,0.492993,0.423559,-0.004679,-0.289698,0.479508,9606
4,Q59RK9,"[GO:0008233, GO:0005739, GO:0042775, GO:004401...","[F, C, P, P, P]",-0.249164,-0.473922,0.326476,-0.376346,-0.044497,-0.19228,-0.052922,-0.022,-0.123031,-0.181675,0.263037,0.056715,-0.134725,0.34887,0.097481,-0.156761,237561


In [48]:
protein_function_grouped_df.columns

Index(['EntryID', 'output_terms', 'aspect', 'emb_0', 'emb_1', 'emb_2', 'emb_3',
       'emb_4', 'emb_5', 'emb_6', 'emb_7', 'emb_8', 'emb_9', 'emb_10',
       'emb_11', 'emb_12', 'emb_13', 'emb_14', 'emb_15', 'taxonomyID'],
      dtype='object')

In [207]:
from sklearn.model_selection import train_test_split
# Columns for Training and Columns for Testing
PREDICTORS = [f"emb_{i}" for i in range(PCA_TARGET_DIM)]
PREDICTORS.append("EntryID")
OUTPUTS = ['output_terms']

X, y = protein_function_grouped_df[PREDICTORS].values, protein_function_grouped_df[OUTPUTS].iloc[:, 0].tolist()

X_train_all, X_test, y_train_all, y_test = train_test_split(X, y, test_size=0.05, random_state=42)
# Perform a second split for validation set for finetuning our model
X_train, X_val, y_train, y_val = train_test_split(X_train_all, y_train_all, test_size=0.1, random_state=42)


In [208]:
X_train_entry_ids = [data_row[-1] for data_row in X_train]
X_val_entry_ids = [data_row[-1] for data_row in X_val]
X_test_entry_ids = [data_row[-1] for data_row in X_test]
X_train = np.array([data_row[:-1] for data_row in X_train])
X_val = np.array([data_row[:-1] for data_row in X_val])
X_test = np.array([data_row[:-1] for data_row in X_test])

In [209]:
y_train[:3]

[['GO:0030514',
  'GO:0045668',
  'GO:0005886',
  'GO:0005515',
  'GO:0005737',
  'GO:0016477',
  'GO:0030512',
  'GO:0045893',
  'GO:0090263',
  'GO:0005114',
  'GO:0010718',
  'GO:0008284',
  'GO:0008360',
  'GO:0005109'],
 ['GO:0005829'],
 ['GO:0005829',
  'GO:0005515',
  'GO:1905821',
  'GO:0000796',
  'GO:0007076',
  'GO:0016020']]

In [210]:
from sklearn.preprocessing import MultiLabelBinarizer
term_to_index = {term: i for i, term in enumerate(UNIQUE_TERMS)}

mlb = MultiLabelBinarizer(classes=UNIQUE_TERMS)
y_train_transformed = mlb.fit_transform(y_train)
y_val_transformed = mlb.transform(y_val)

from sklearn.impute import SimpleImputer
imputer = SimpleImputer(strategy="constant", fill_value=0)
X_train_imputed = imputer.fit_transform(X_train)
X_test_imputed = imputer.transform(X_test)
X_val_imputed = imputer.transform(X_val)

In [110]:
import xgboost as xgb
import numpy as np
from sklearn.linear_model import LogisticRegression

num_models = 100

models_dict = {}
def train_xgb_models():
    models = []
    for i in range(num_models):
        print(f"\nTraining model for label {i}...")
    
        y_i = y_train_transformed[:, i]
    
        pos = np.sum(y_i == 1)
        neg = np.sum(y_i == 0)
        scale_pos_weight = neg / pos if pos > 0 else 1.0
    
        print(f"Label {i} imbalance pos={pos}, neg={neg}, scaling weight={scale_pos_weight}")
    
        dtrain = xgb.DMatrix(X_train_imputed, label=y_i)
    
        params = {
            "objective": "binary:logistic",
            "eval_metric": "aucpr",
            "tree_method": "hist",
            "max_depth": 7,
            "eta": 0.05,
            "lambda": 1.5,
            "alpha": 0.8,
            "min_child_weight": 2,
            "subsample": 0.8,
            "colsample_bytree": 0.8,
            "scale_pos_weight": scale_pos_weight
        }
    
        model = xgb.train(
            params=params,
            dtrain=dtrain,
            num_boost_round=300,
            evals=[(dtrain, "train")],
            early_stopping_rounds=25,
            verbose_eval=50
        )
    
        models.append(model)
    return models

def train_logistic_regression_models(i):

    print(f"\nTraining model for label {i}...")
    
    y_i = y_train_transformed[:, i]
    # Sample not good
    if len(np.unique(y_i)) < 2:
        models_dict[f"lr_{i}"] = None
        return

    pos = np.sum(y_i == 1)
    neg = np.sum(y_i == 0)
    scale_pos_weight = neg / pos if pos > 0 else 1.0
    
    print(f"Label {i}")
    
    model = LogisticRegression(
        penalty="l2",
        solver="liblinear",
        class_weight={0: 1.0, 1: scale_pos_weight},
        max_iter=200,
    )
    
    model.fit(X_train_imputed, y_i)
    models_dict[f"lr_{i}"] = model



In [80]:
dval = xgb.DMatrix(X_val)


xgb_pred_list = []
for model in xgboost_models:
    print(f"model: {model}")
    xgb_pred_list.append(model.predict(dval))

xgb_preds = np.column_stack(xgb_pred_list)

lr_pred_list = []
for model in logistic_regression_models:
    print(f"model: {model}")
    lr_pred_list.append(model.predict(X_val_imputed))

lr_preds = np.column_stack(lr_pred_list)

model_predictions_dict = {
    "xgb_preds": xgb_preds,
    "lr_preds": lr_preds
}


model: <xgboost.core.Booster object at 0x7f5b9a954b90>
model: <xgboost.core.Booster object at 0x7f5b972cc550>
model: <xgboost.core.Booster object at 0x7f5b9a9cd910>
model: <xgboost.core.Booster object at 0x7f5b971b6610>
model: <xgboost.core.Booster object at 0x7f5bbae9b710>
model: <xgboost.core.Booster object at 0x7f5b972cf790>
model: <xgboost.core.Booster object at 0x7f5b9a948850>
model: <xgboost.core.Booster object at 0x7f5b981adf90>
model: <xgboost.core.Booster object at 0x7f5b974a0dd0>
model: <xgboost.core.Booster object at 0x7f5bbac25e10>
model: <xgboost.core.Booster object at 0x7f5bb9d09450>
model: <xgboost.core.Booster object at 0x7f5bec32dfd0>
model: <xgboost.core.Booster object at 0x7f5b972cf9d0>
model: <xgboost.core.Booster object at 0x7f5b972d7610>
model: <xgboost.core.Booster object at 0x7f5b972cc510>
model: <xgboost.core.Booster object at 0x7f5b9a9cd790>
model: <xgboost.core.Booster object at 0x7f5bec34bb10>
model: <xgboost.core.Booster object at 0x7f5b98cc1990>
model: <xg

In [91]:
from sklearn.metrics import accuracy_score, f1_score

for k, v in model_predictions_dict.items():
    pred_binary = (v >= 0.5).astype(int)
    
    y_true = y_val_transformed[:, :num_models]
    
    accuracy_per_label = []
    f1_per_label = []

    for i in range(len(pred_binary[0])):
        acc = accuracy_score(y_true[:, i], pred_binary[:, i])
        f1  = f1_score(y_true[:, i], pred_binary[:, i], zero_division=0)
    
        accuracy_per_label.append(acc)
        f1_per_label.append(f1)
    
        print(f"Model: {k} Label {i}:  Accuracy = {acc:.4f},   F1 = {f1:.4f}")

Model: xgb_preds Label 0:  Accuracy = 0.9599,   F1 = 0.0977
Model: xgb_preds Label 1:  Accuracy = 0.9894,   F1 = 0.3025
Model: xgb_preds Label 2:  Accuracy = 0.9954,   F1 = 0.0000
Model: xgb_preds Label 3:  Accuracy = 0.9986,   F1 = 0.0000
Model: xgb_preds Label 4:  Accuracy = 0.9987,   F1 = 0.1667
Model: xgb_preds Label 5:  Accuracy = 0.9670,   F1 = 0.1400
Model: xgb_preds Label 6:  Accuracy = 0.6647,   F1 = 0.6331
Model: xgb_preds Label 7:  Accuracy = 0.9981,   F1 = 0.1176
Model: xgb_preds Label 8:  Accuracy = 0.9157,   F1 = 0.1771
Model: xgb_preds Label 9:  Accuracy = 0.9991,   F1 = 0.4615
Model: xgb_preds Label 10:  Accuracy = 0.9990,   F1 = 0.6000
Model: xgb_preds Label 11:  Accuracy = 1.0000,   F1 = 0.0000
Model: xgb_preds Label 12:  Accuracy = 0.9891,   F1 = 0.2478
Model: xgb_preds Label 13:  Accuracy = 0.9997,   F1 = 0.5000
Model: xgb_preds Label 14:  Accuracy = 0.9940,   F1 = 0.0784
Model: xgb_preds Label 15:  Accuracy = 0.9997,   F1 = 0.0000
Model: xgb_preds Label 16:  Accura

In [102]:
from joblib import Parallel, delayed
import psutil
n_cores = psutil.cpu_count()
print(f"Available cores: {n_cores}")

Available cores: 4


In [111]:
results = Parallel(n_jobs=n_cores, verbose=10)(
    delayed(train_logistic_regression_models)(i) for i in range(200)
)

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   5 tasks      | elapsed:   10.5s
[Parallel(n_jobs=4)]: Done  10 tasks      | elapsed:   11.5s
[Parallel(n_jobs=4)]: Done  17 tasks      | elapsed:   12.8s
[Parallel(n_jobs=4)]: Done  24 tasks      | elapsed:   14.7s
[Parallel(n_jobs=4)]: Done  33 tasks      | elapsed:   16.5s
[Parallel(n_jobs=4)]: Done  42 tasks      | elapsed:   18.7s
[Parallel(n_jobs=4)]: Done  53 tasks      | elapsed:   20.9s
[Parallel(n_jobs=4)]: Done  64 tasks      | elapsed:   23.3s
[Parallel(n_jobs=4)]: Done  77 tasks      | elapsed:   25.8s
[Parallel(n_jobs=4)]: Done  90 tasks      | elapsed:   29.0s
[Parallel(n_jobs=4)]: Done 105 tasks      | elapsed:   32.1s
[Parallel(n_jobs=4)]: Done 120 tasks      | elapsed:   35.3s
[Parallel(n_jobs=4)]: Done 137 tasks      | elapsed:   39.5s
[Parallel(n_jobs=4)]: Done 154 tasks      | elapsed:   42.9s
[Parallel(n_jobs=4)]: Done 173 tasks      | elapsed:   46.8s
[Parallel(

In [113]:
# b

In [114]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

output_dim = 200

class WeightedMultiLabelNN(nn.Module):
    def __init__(self, input_dim=16, output_dim=output_dim, hidden_dims=[64, 128, 256], pos_weights=None):
        super().__init__()
        self.pos_weights = pos_weights
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.1))
            prev_dim = hidden_dim
        self.features = nn.Sequential(*layers)
        self.classifier = nn.Linear(prev_dim, output_dim)
    
    def forward(self, x):
        features = self.features(x)
        logits = self.classifier(features)
        return logits


class PositionalWeightedLoss(nn.Module):
    def __init__(self, pos_weights=None, label_weights=None, reduction='mean'):
        super().__init__()
        self.pos_weights = pos_weights
        self.label_weights = label_weights
        self.reduction = reduction
        
    def forward(self, logits, targets):
        loss = F.binary_cross_entropy_with_logits(
            logits,
            targets,
            reduction='none'
        )
        
        if self.pos_weights is not None:
            pos_mask = (targets > 0.5).float()
            pos_weight_factor = 1 + (self.pos_weights.unsqueeze(0) - 1) * pos_mask
            loss = loss * pos_weight_factor
            
        if self.label_weights is not None:
            loss = loss * self.label_weights.unsqueeze(0)
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

In [136]:
class ProteinEmbeddingsDataset(Dataset):
    def __init__(self, features, labels, total_labels=output_dim, label_metadata=None):
        self.features = features.astype(np.float32)
        self.labels = labels[:, :total_labels].astype(np.float32)
        self.label_metadata = label_metadata
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        x = torch.FloatTensor(self.features[idx])
        y = torch.FloatTensor(self.labels[idx])
        return x, y

In [158]:
protein_embeddings_dataset_train = ProteinEmbeddingsDataset(X_train_imputed, y_train_transformed, total_labels=output_dim)
protein_embeddings_dataset_val = ProteinEmbeddingsDataset(X_val_imputed, y_val_transformed, total_labels=output_dim)

In [167]:
from sklearn.metrics import f1_score

def evaluate_f1(model, val_loader, device, threshold=0.5):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for features, targets in val_loader:
            features = features.to(device)
            targets = targets.to(device)

            logits = model(features)
            probs = torch.sigmoid(logits)

            preds = (probs > threshold).float()

            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    num_labels = all_targets.shape[1]
    f1_scores = []

    for i in range(num_labels):
        try:
            f1 = f1_score(all_targets[:, i], all_preds[:, i], zero_division=0   )
        except ValueError:
            f1 = 0.0 
        f1_scores.append(f1)

    f1_scores = np.array(f1_scores)
    best_idx = np.argsort(f1_scores)[-3:]
    worst_idx = np.argsort(f1_scores)[:3]

    return f1_scores, best_idx, worst_idx
    

In [168]:
def obtain_label_metadata(curr_y):
    label_metadata = {
    "pos_weights": []
    }
    print(len(curr_y[0]))
    for i in range(len(curr_y[0])):
        if not i % 1000: print(i)
        y_i = curr_y[:, i]
        pos = np.sum(y_i == 1)
        neg = np.sum(y_i == 0)
        scale_pos_weight = neg / pos if pos > 0 else 1.0
        label_metadata["pos_weights"].append(scale_pos_weight)
    return label_metadata


def train_model(model, train_loader, pos_weights, num_epochs=50):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    pos_weights = pos_weights.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for features, targets in train_loader:
            features = features.to(device)
            targets = targets.to(device)
            
            logits = model(features)
            loss = F.binary_cross_entropy_with_logits(
                logits,
                targets,
                pos_weight=pos_weights,
                reduction='mean'
            )
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"current eppoch {epoch+1}/{num_epochs} current loss: {total_loss/len(train_loader):.4f}")
    
    return model

In [172]:
import warnings
warnings.filterwarnings("ignore")

def train_model(model, train_loader, val_loader, pos_weights, num_epochs=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    pos_weights = pos_weights.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for features, targets in train_loader:
            features = features.to(device)
            targets = targets.to(device)

            logits = model(features)

            loss = F.binary_cross_entropy_with_logits(
                logits,
                targets,
                pos_weight=pos_weights,
                reduction="mean"
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}")

        # --- VALIDATION ---
        f1_scores, best_idx, worst_idx = evaluate_f1(model, val_loader, device)

        print("Top 3 labels (best F1):")
        for idx in reversed(best_idx):
            print(f"  Label {idx}: F1 = {f1_scores[idx]:.4f}")

        print("Bottom 3 labels (worst F1):")
        for idx in worst_idx:
            print(f"  Label {idx}: F1 = {f1_scores[idx]:.4f}")

        print(f"Average f1 score: {np.mean(f1_scores)}")

    return model

In [126]:
label_metadata = obtain_label_metadata(y_train_transformed)

26125
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000


In [173]:
from torch.utils.data import Dataset, DataLoader
curr_model = WeightedMultiLabelNN()
data_loader = DataLoader(protein_embeddings_dataset_train, batch_size=32)
data_loader_val = DataLoader(protein_embeddings_dataset_val)



In [174]:
trained_model_nn = train_model(curr_model, data_loader, data_loader_val, torch.tensor(label_metadata["pos_weights"][:output_dim]))

[Epoch 1/50] Loss: 2.0301
Top 3 labels (best F1):
  Label 6: F1 = 0.6066
  Label 53: F1 = 0.3716
  Label 37: F1 = 0.3145
Bottom 3 labels (worst F1):
  Label 194: F1 = 0.0000
  Label 179: F1 = 0.0000
  Label 109: F1 = 0.0000
Average f1 score: 0.02036426189012476
[Epoch 2/50] Loss: 1.4014
Top 3 labels (best F1):
  Label 6: F1 = 0.6068
  Label 53: F1 = 0.3652
  Label 37: F1 = 0.3132
Bottom 3 labels (worst F1):
  Label 17: F1 = 0.0000
  Label 80: F1 = 0.0000
  Label 81: F1 = 0.0000
Average f1 score: 0.01969134649147814
[Epoch 3/50] Loss: 1.2975
Top 3 labels (best F1):
  Label 6: F1 = 0.6132
  Label 53: F1 = 0.3634
  Label 37: F1 = 0.3266
Bottom 3 labels (worst F1):
  Label 194: F1 = 0.0000
  Label 17: F1 = 0.0000
  Label 114: F1 = 0.0000
Average f1 score: 0.020471092338956373
[Epoch 4/50] Loss: 1.2272
Top 3 labels (best F1):
  Label 6: F1 = 0.6038
  Label 53: F1 = 0.3725
  Label 37: F1 = 0.3223
Bottom 3 labels (worst F1):
  Label 11: F1 = 0.0000
  Label 194: F1 = 0.0000
  Label 182: F1 = 0

In [112]:
subontology_GCNs, subontology_graph_dfs

({'biological_process': SimpleGCN(
    (conv1): GCNConv(16, 32)
    (conv2): GCNConv(32, 16)
  ),
  'molecular_function': SimpleGCN(
    (conv1): GCNConv(16, 32)
    (conv2): GCNConv(32, 16)
  ),
  'cellular_component': SimpleGCN(
    (conv1): GCNConv(16, 32)
    (conv2): GCNConv(32, 16)
  )},
 {'biological_process':              term      parent relation                              name  \
  0      GO:0000001  GO:0048308     is_a         mitochondrion inheritance   
  1      GO:0000001  GO:0048311     is_a         mitochondrion inheritance   
  2      GO:0000002  GO:0007005     is_a  mitochondrial genome maintenance   
  7      GO:0000011  GO:0007033     is_a               vacuole inheritance   
  8      GO:0000011  GO:0048308     is_a               vacuole inheritance   
  ...           ...         ...      ...                               ...   
  62403  GO:2001316  GO:0120254     is_a      kojic acid metabolic process   
  62404  GO:2001317  GO:0034309     is_a   kojic acid biosy


Training model for label 3...
Label 3

Training model for label 6...
Label 6

Training model for label 8...
Label 8

Training model for label 12...
Label 12

Training model for label 16...
Label 16

Training model for label 19...
Label 19

Training model for label 21...
Label 21

Training model for label 26...
Label 26

Training model for label 30...
Label 30

Training model for label 34...
Label 34

Training model for label 38...
Label 38

Training model for label 42...
Label 42

Training model for label 47...
Label 47

Training model for label 51...
Label 51

Training model for label 56...
Label 56

Training model for label 60...
Label 60

Training model for label 64...
Label 64

Training model for label 68...
Label 68

Training model for label 73...
Label 73

Training model for label 77...
Label 77

Training model for label 81...
Label 81

Training model for label 85...
Label 85

Training model for label 89...
Label 89

Training model for label 93...
Label 93

Training model for la

In [None]:
def create_go_embeddings_optimized(obo_path, go_terms, embed_dim=16, hidden_dim=32, out_dim=16, epochs=50):

    print(" Loading Gene Ontology...")
    graph = obonet.read_obo(obo_path)


    edges = pd.DataFrame([
        {'source': u, 'target': v, 'relation': data.get('relation', 'is_a')}
        for u, v, data in graph.edges(data=True)
    ])

    relevant_edges = edges[
        edges['source'].isin(go_terms) | edges['target'].isin(go_terms)
    ].reset_index(drop=True)

    nodes = pd.DataFrame({'id': list(set(relevant_edges['source']).union(relevant_edges['target']))})
    nodes['node_idx'] = range(len(nodes))
    node2idx = dict(zip(nodes['id'], nodes['node_idx']))

    edge_index = torch.tensor([
        [node2idx[s] for s in relevant_edges['source']],
        [node2idx[t] for t in relevant_edges['target']]
    ], dtype=torch.long)

    num_nodes = len(nodes)
    print(f"Using {num_nodes} GO terms and {len(relevant_edges)} edges")


    x = torch.randn((num_nodes, embed_dim), dtype=torch.float32)


    class SimpleGCN(nn.Module):
        def __init__(self, in_dim, hidden_dim, out_dim):
            super(SimpleGCN, self).__init__()
            self.conv1 = GCNConv(in_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, out_dim)

        def forward(self, x, edge_index):
            x = self.conv1(x, edge_index)
            x = F.relu(x)
            x = self.conv2(x, edge_index)
            return x

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleGCN(embed_dim, hidden_dim, out_dim).to(device)


    x = x.to(device)
    edge_index = edge_index.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    data = Data(x=x, edge_index=edge_index)

    print(f"Training on device: {device}")
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        embeddings = model(data.x, data.edge_index)
        # Inner product decoder
        recon = torch.sigmoid(torch.matmul(embeddings, embeddings.T))
        adj_true = torch.zeros_like(recon)
        adj_true[data.edge_index[0], data.edge_index[1]] = 1.0

        loss = F.binary_cross_entropy(recon, adj_true)
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch:03d} | Loss: {loss.item():.4f}")


    with torch.no_grad():
        node_embeddings = model(data.x, data.edge_index).cpu().numpy()

    del model, x, data, recon, adj_true
    torch.cuda.empty_cache()
    gc.collect()

    col_names = [f"go_emb_{i}" for i in range(node_embeddings.shape[1])]

    emb_df = pd.DataFrame(node_embeddings, index=nodes['id'], columns=col_names)
    print(f"Created embeddings for {len(emb_df)} GO terms")
    return emb_df





In [None]:
sampled_data = sample_tsv(term_path, sample_frac=0.05)
go_terms = sampled_data['term'].unique()
embeddings_df = create_go_embeddings_optimized(obo_path, go_terms)
seq_df = extract_sequences(fasta_path, sampled_data['EntryID'])

### Combine GO embedding and PLM embedding into one dataset

In [None]:
def combine_go_protein_embeddings(sampled_data, go_emb_df, prot_emb_df):

    combined = sampled_data.merge(go_emb_df, how='left', left_on='term', right_index=True)

    combined = combined.merge(prot_emb_df, how='left', left_on='EntryID', right_index=True)

    return combined


multimodal_df = combine_go_protein_embeddings(sampled_data, embeddings_df, prot_emb_df)

print("Multimodal feature dataframe shape:", multimodal_df.shape)
print(multimodal_df.head())