In [1]:
!pip install torch==2.7.0
!pip install torch-geometric
!pip install biopython
!pip install obonet
!pip install networkx
!pip install pandas
!pip install numpy
!pip install matplotlib
!pip install seaborn
!pip install scipy
!pip install scikit-learn
!pip install fair-esm

Collecting torch==2.7.0
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch==2.7.0)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch==2.7.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch==2.7.0)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch==2.7.0)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch==2.7.0)
  Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.6.4.1 (from torch==2.7.0)
  Downloading nvidia_cublas_cu12-12.6.4.1-py3-no

In [4]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # avoid fragmentation
import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import Bio
from Bio import SeqIO
import obonet
import gc
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import random
import esm


In [5]:
obo_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo'
fasta_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
term_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv'
taxonomy_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv'

In [6]:
LARGEST_FASTA_SEQ_LEN = 8922
ESM_EMBEDDING_DIM = 320
PCA_TARGET_DIM = 16

In [7]:
term_df = pd.read_csv(term_path, sep='\t')
term_df.head(), len(term_df["term"].unique())

(  EntryID        term aspect
 0  Q5W0B1  GO:0000785      C
 1  Q5W0B1  GO:0004842      F
 2  Q5W0B1  GO:0051865      P
 3  Q5W0B1  GO:0006275      P
 4  Q5W0B1  GO:0006513      P,
 26125)

In [8]:
taxonomy_df = pd.read_csv(taxonomy_path, sep='\t', names=['EntryID', 'taxonomyID'])
taxonomy_df.head()

Unnamed: 0,EntryID,taxonomyID
0,A0A0C5B5G6,9606
1,A0JNW5,9606
2,A0JP26,9606
3,A0PK11,9606
4,A1A4S6,9606


In [9]:
# def get_processed_fasta_df(fasta_data, term_df):
#     fasta_dict_list = []
#     term_set = set(term_df.tolist())
#     for fasta_seq in fasta_data:
#         entry = fasta_seq.id.split('|')[1] if '|' in fasta_seq.id else fasta_seq.id
#         if entry in term_set:
#             fasta_dict_list.append({
#                 "EntryID": entry,
#                 "fasta_sequence": str(fasta_seq.seq)
#             })

#     return pd.DataFrame(fasta_dict_list)


def sample_tsv(tsv_df, sample_frac=0.05, random_state=42):
    """
    Read a TSV file and sample based on unique EntryID.
    Pulls in all associated rows for sampled EntryIDs.
    """
    df = tsv_df
    unique_ids = df['EntryID'].unique()
    sample_size = max(1, int(len(unique_ids) * sample_frac))
    sampled_ids = random.sample(list(unique_ids), sample_size)
    sampled_df = df[df['EntryID'].isin(sampled_ids)]
    print(f"Sampled {len(sampled_df)} rows from {len(unique_ids)} unique EntryIDs")
    return sampled_df


def get_processed_fasta_df(fasta_data):
    fasta_dict_list = []
    for fasta_seq in fasta_data:
        entry = fasta_seq.id.split('|')[1] if '|' in fasta_seq.id else fasta_seq.id
        fasta_dict_list.append({
                "EntryID": entry,
                "fasta_sequence": str(fasta_seq.seq)
            })

    return pd.DataFrame(fasta_dict_list)





# Get Merged DF FULL (includes batching to fit embeddings into GPU and offloads to CPU)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
model = model.to(device)
model = model.half()
model.eval()
batch_converter = alphabet.get_batch_converter()

def generate_protein_embeddings_esm_batch(seq_df, seq_col='Sequence', entryid_col='EntryID',
                                              target_dim=16, batch_size=1, use_fp16=True):
    """
    Memory-optimized ESM embedding generation for proteins.
    Processes small batches and moves embeddings to CPU immediately.
    """
    sequences = seq_df[seq_col].tolist()
    entry_ids = seq_df[entryid_col].tolist()
    print(f"sequences: {len(sequences)} entry_ids {len(entry_ids)}")

    all_embeddings = []
    # print(len(sequences))
    dtype = torch.float16 if use_fp16 else torch.float32
    for i in range(0, len(sequences), batch_size):
        try:
            batch_seqs = sequences[i:i+batch_size]
            batch_labels = entry_ids[i:i+batch_size]
            if len(batch_seqs[0]) > LARGEST_FASTA_SEQ_LEN:
                print(f"length of batch seqs is: {len(batch_seqs[0])}")
                # Split the sequence to be max length LARGEST_FASTA_SEQ_LEN (will always be len 1 batch), this helps avoid the loss of information
                curr_seq = batch_seqs[0]
                embeddings_list = []
                for i in range(0, len( batch_seqs[0]), LARGEST_FASTA_SEQ_LEN):
                  curr_seq = batch_seqs[0][i:i+LARGEST_FASTA_SEQ_LEN]
                  curr_seq_embedding = get_sequence_embedding_esm(batch_labels, [curr_seq], dtype)
                  embeddings_list.append(curr_seq_embedding)
                # Average the embeddings_list since the long sequences were broken down into smaller chunks
                seq_embeddings = np.mean(embeddings_list, axis=0)

            else:
              seq_embeddings = get_sequence_embedding_esm(batch_labels, batch_seqs, dtype)

            all_embeddings.append(seq_embeddings)
            del seq_embeddings
        except Exception as e:
            print(e)
            print(len(batch_seqs[0]))
            print(batch_seqs)


    raw_embeddings = np.vstack(all_embeddings)
    # Mitigate memory constraints
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.reset_peak_memory_stats()

    col_names = [f"prot_emb_{i}" for i in range(320)]
    emb_df = pd.DataFrame(raw_embeddings, index=entry_ids, columns=col_names)
    emb_df.index.name = entryid_col

    return emb_df


def get_sequence_embedding_esm(batch_labels, batch_seqs, dtype):
    batch_data = [(label, seq) for label, seq in zip(batch_labels, batch_seqs)]
    _, _, batch_tokens = batch_converter(batch_data)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
      with torch.autocast(device_type="cuda", dtype=dtype):
          results = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)
          token_embeddings = results["representations"][model.num_layers]  # (B, L, D)
          # Mean pool over sequence length
          attention_mask = batch_tokens != alphabet.padding_idx
          masked_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
          seq_lengths = attention_mask.sum(dim=1).unsqueeze(-1)
          seq_embeddings = (masked_embeddings.sum(dim=1) / seq_lengths).cpu().float().numpy()
          del batch_tokens, token_embeddings, masked_embeddings, results, attention_mask
          return seq_embeddings


def apply_pca_to_esm_embeddings(esm_embeddings_df:pd.DataFrame, target_dim=16):
    '''
    Helper method 
    '''
    emb_cols = [c for c in esm_embeddings_df.columns if c.startswith("prot_emb")]
    
    tensor_list = [ 
        torch.tensor(row[emb_cols].values.astype("float16"), dtype=torch.float16)
        for _, row in esm_embeddings_df.iterrows()
    ]
    
    pca = PCA(n_components=PCA_TARGET_DIM)
    tensor_list_transformed = pca.fit_transform(tensor_list)
    
    # Intialize dictionary where we use EntryID for joining with all the other data
    embeddings_dict = { "EntryID": [] }
    # Initialize empty PCA embeddings for the pd.DataFrame
    for i in range(PCA_TARGET_DIM):
        embeddings_dict[f"emb_{i}"] =  []
        
    # Append EntryID's and embeddings to the dictionary for the dataframe
    for i, r in enumerate(tensor_list_transformed):
        curr_entry_id = esm_embeddings_df["EntryID"][i]
        embeddings_dict["EntryID"].append(curr_entry_id)
        for j in range(PCA_TARGET_DIM):
            embeddings_dict[f"emb_{j}"].append(r[j])

    pca_embeddings_df = pd.DataFrame(embeddings_dict)
    return pca_embeddings_df




Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t6_8M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t6_8M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t6_8M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t6_8M_UR50D-contact-regression.pt


In [90]:
file_name = "/kaggle/input/fasta-embeddings-final/fasta_embeddings_final.csv"
embeddings_processed = True

def get_merged_df_full(file_name, batch_size=250, embeddings_processed=False):
    """
    Merge term.tsv, fasta data, taxonomy data as well as nodes in the obo graph.
    """
    term_df = pd.read_csv(term_path, sep='\t')
    term_df = term_df
    taxonomy_df = pd.read_csv(taxonomy_path, sep='\t', names=['EntryID', 'taxonomyID'])
    fasta_data = list(SeqIO.parse(fasta_path, "fasta"))

    entry_ids = list(term_df['EntryID'])
    all_batches = []
    if not embeddings_processed:
      for i in range(0, len(entry_ids), batch_size):
          total_processed = i
          print(f"Total processed: {total_processed}, {i//batch_size} batch")

          # Batch the EntryIDs
          entry_batch = list(set(entry_ids[i:i+batch_size]))

          curr_term_df = term_df[i:i+batch_size]
          fasta_df_batch = get_processed_fasta_df(fasta_data, curr_term_df['EntryID'])
          fasta_emb_df_batch = generate_protein_embeddings_esm_batch(
              fasta_df_batch,
              "fasta_sequence"
          )


          all_batches.append(fasta_emb_df_batch)

          full_df = pd.concat(all_batches, ignore_index=False)
          full_df.to_csv("fasta_embeddings.csv", index=True)

    else:
        full_df = pd.read_csv(file_name)
        print(full_df.head())
        
    fasta_emb_df = apply_pca_to_esm_embeddings(full_df)
    
    # TODO add embeddings in getting merged_df
    merged_df = pd.merge(term_df, fasta_emb_df, on="EntryID", how='left')
    merged_df = pd.merge(merged_df, taxonomy_df, on="EntryID", how="left")
    graph = obonet.read_obo(obo_path)
    edges_list = []
    for node_id, data in graph.nodes(data=True):
        for parent_id in data.get("is_a", []):
            edges_list.append({
                    "term": node_id,
                    "parent": parent_id,
                    "relation": data.get('relation', 'is_a'),
                    "name": data["name"],
                    "namespace": data["namespace"],
                    "def": data["def"],
                    "synonym": data.get("synonym", [])
                    
                })
    edges_df = pd.DataFrame(edges_list)
    return merged_df, edges_df

protein_function_df, graph_df = get_merged_df_full(file_name, embeddings_processed=embeddings_processed)
len(protein_function_df)

  EntryID  prot_emb_0  prot_emb_1  prot_emb_2  prot_emb_3  prot_emb_4  \
0  P86164    0.063601    0.090173    0.272553    0.045446   -0.024562   
1  P84910    0.164426   -0.128475    0.244240    0.068856   -0.031558   
2  P83012    0.126772   -0.093400    0.216583    0.056300   -0.058426   
3  P83246    0.039001   -0.228608    0.207255    0.138270    0.081464   
4  P86133   -0.063519    0.141610    0.168906   -0.003960   -0.064855   

   prot_emb_5  prot_emb_6  prot_emb_7  prot_emb_8  ...  prot_emb_310  \
0   -0.087723   -0.097689   -0.002640   -0.085523  ...      0.121198   
1   -0.102609   -0.108406   -0.071888   -0.022300  ...      0.094619   
2    0.006500   -0.052936   -0.012787   -0.090959  ...      0.024688   
3   -0.122376   -0.169814   -0.090188   -0.169230  ...     -0.015823   
4   -0.166769    0.018469    0.077215    0.058494  ...      0.147194   

   prot_emb_311  prot_emb_312  prot_emb_313  prot_emb_314  prot_emb_315  \
0      0.150969      0.038082     -0.035188      0.10

537027

In [91]:
graph = obonet.read_obo(obo_path)
edges_list = []
for node_id, data in graph.nodes(data=True):
        for parent_id in data.get("is_a", []):
            edges_list.append({
                    "term": node_id,
                    "parent": parent_id,
                    "name": data["name"],
                    "namespace": data["namespace"],
                    "def": data["def"],
                    "synonym": data.get("synonym", [])
                })
edges_df = pd.DataFrame(edges_list)

In [92]:
ALL_SUBONTOLOGIES = graph_df["namespace"].unique()
graph_df.head(
)


Unnamed: 0,term,parent,relation,name,namespace,def,synonym
0,GO:0000001,GO:0048308,is_a,mitochondrion inheritance,biological_process,"""The distribution of mitochondria, including t...","[""mitochondrial inheritance"" EXACT []]"
1,GO:0000001,GO:0048311,is_a,mitochondrion inheritance,biological_process,"""The distribution of mitochondria, including t...","[""mitochondrial inheritance"" EXACT []]"
2,GO:0000002,GO:0007005,is_a,mitochondrial genome maintenance,biological_process,"""The maintenance of the structure and integrit...",[]
3,GO:0000006,GO:0005385,is_a,high-affinity zinc transmembrane transporter a...,molecular_function,"""Enables the transfer of zinc ions (Zn2+) from...","[""high affinity zinc uptake transmembrane tran..."
4,GO:0000007,GO:0005385,is_a,low-affinity zinc ion transmembrane transporte...,molecular_function,"""Enables the transfer of a solute or solutes f...",[]


In [152]:
protein_function_df.head()

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
0,Q5W0B1,GO:0000785,C,-0.270628,1.031287,-0.162849,-0.904390,-0.347545,-0.220612,0.036552,-0.221079,0.191750,0.372088,0.575290,-0.20096,-0.117028,-0.248705,-0.036299,0.176752,9606
1,Q5W0B1,GO:0004842,F,-0.270628,1.031287,-0.162849,-0.904390,-0.347545,-0.220612,0.036552,-0.221079,0.191750,0.372088,0.575290,-0.20096,-0.117028,-0.248705,-0.036299,0.176752,9606
2,Q5W0B1,GO:0051865,P,-0.270628,1.031287,-0.162849,-0.904390,-0.347545,-0.220612,0.036552,-0.221079,0.191750,0.372088,0.575290,-0.20096,-0.117028,-0.248705,-0.036299,0.176752,9606
3,Q5W0B1,GO:0006275,P,-0.270628,1.031287,-0.162849,-0.904390,-0.347545,-0.220612,0.036552,-0.221079,0.191750,0.372088,0.575290,-0.20096,-0.117028,-0.248705,-0.036299,0.176752,9606
4,Q5W0B1,GO:0006513,P,-0.270628,1.031287,-0.162849,-0.904390,-0.347545,-0.220612,0.036552,-0.221079,0.191750,0.372088,0.575290,-0.20096,-0.117028,-0.248705,-0.036299,0.176752,9606
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
537022,Q06667,GO:0070481,P,-1.083548,-0.679863,0.137228,-1.287429,0.274794,0.360896,-0.161794,0.540822,0.211911,0.154250,0.227489,0.19708,-0.123450,0.196853,0.115528,-0.117391,559292
537023,B1NF19,GO:0033075,P,1.666123,-0.441586,0.090926,-0.512658,0.176430,-0.082721,-0.082259,-0.454386,0.093465,-0.149323,-0.205518,0.09290,0.123291,0.451366,0.273499,0.302869,54796
537024,B1NF19,GO:0047052,F,1.666123,-0.441586,0.090926,-0.512658,0.176430,-0.082721,-0.082259,-0.454386,0.093465,-0.149323,-0.205518,0.09290,0.123291,0.451366,0.273499,0.302869,54796
537025,B1NF19,GO:0047056,F,1.666123,-0.441586,0.090926,-0.512658,0.176430,-0.082721,-0.082259,-0.454386,0.093465,-0.149323,-0.205518,0.09290,0.123291,0.451366,0.273499,0.302869,54796


In [153]:
protein_function_df[protein_function_df["EntryID"]=="Q63871"]

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
124839,Q63871,GO:0006366,P,-0.59942,-0.373074,0.409588,0.142399,-0.107085,0.563983,0.066211,0.061919,0.225195,-0.134774,-0.471171,0.503097,0.021821,-0.148401,-0.141495,0.121833,10090
124840,Q63871,GO:0005654,C,-0.59942,-0.373074,0.409588,0.142399,-0.107085,0.563983,0.066211,0.061919,0.225195,-0.134774,-0.471171,0.503097,0.021821,-0.148401,-0.141495,0.121833,10090


In [113]:
protein_function_subontology_dict = {
    'C': protein_function_df[protein_function_df["aspect"]=='C'],
    'F': protein_function_df[protein_function_df["aspect"]=='F'],
    'P': protein_function_df[protein_function_df["aspect"]=='P']
}

protein_function_unique_terms_subontology_dict = {
   k: v["term"].unique() for k,v in protein_function_subontology_dict.items()
}



In [114]:
protein_function_subontology_dict['C'].head()

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
0,Q5W0B1,GO:0000785,C,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220612,0.036552,-0.221079,0.19175,0.372088,0.57529,-0.20096,-0.117028,-0.248705,-0.036299,0.176752,9606
7,Q3EC77,GO:0000138,C,-0.124728,-0.263022,-0.456481,-0.18172,-0.070037,0.144072,-0.371444,-0.067311,0.075183,0.052979,-0.05539,-0.271132,-0.08985,0.252874,0.024556,0.242461,3702
8,Q3EC77,GO:0005794,C,-0.124728,-0.263022,-0.456481,-0.18172,-0.070037,0.144072,-0.371444,-0.067311,0.075183,0.052979,-0.05539,-0.271132,-0.08985,0.252874,0.024556,0.242461,3702
13,Q8R2Z3,GO:0016323,C,0.394743,-0.708571,-0.151442,-0.196877,0.396844,-0.341863,0.126681,-0.512221,0.131078,0.023888,0.307004,-0.242611,0.396452,-0.326244,0.078056,-0.312577,10090
21,Q8R2Z3,GO:0016020,C,0.394743,-0.708571,-0.151442,-0.196877,0.396844,-0.341863,0.126681,-0.512221,0.131078,0.023888,0.307004,-0.242611,0.396452,-0.326244,0.078056,-0.312577,10090


In [17]:
N_LARGEST = 10
top_terms = protein_function_df['term'].value_counts().nlargest(N_LARGEST).index
protein_function_top_terms_df = protein_function_df[protein_function_df['term'].isin(top_terms)]

In [18]:
len(protein_function_top_terms_df), len(protein_function_df), len(graph_df)

(100851, 537027, 62410)

### Create set of embedding from the graph edges using GCN

In [19]:
class SimpleGCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(SimpleGCN, self).__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x
  
def init_subontology_GCNs(embed_dim=16, hidden_dim=32, out_dim=16):
    graph_subontology_dict = {
       subontology: SimpleGCN(embed_dim, hidden_dim, out_dim) for subontology in ALL_SUBONTOLOGIES
    }
    return graph_subontology_dict

def get_subontology_graph_dfs(graph_df):
    subontology_graph_dfs = {
       subontology: graph_df[graph_df["namespace"]==subontology] for subontology in ALL_SUBONTOLOGIES
    }
    return subontology_graph_dfs
  
subontology_GCNs, subontology_graph_dfs = init_subontology_GCNs(), get_subontology_graph_dfs(graph_df)

In [20]:
protein_function_df.head()

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
0,Q5W0B1,GO:0000785,C,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220612,0.036551,-0.221058,0.191759,0.3721,0.575317,-0.200968,-0.117173,-0.248961,-0.038595,0.183119,9606
1,Q5W0B1,GO:0004842,F,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220612,0.036551,-0.221058,0.191759,0.3721,0.575317,-0.200968,-0.117173,-0.248961,-0.038595,0.183119,9606
2,Q5W0B1,GO:0051865,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220612,0.036551,-0.221058,0.191759,0.3721,0.575317,-0.200968,-0.117173,-0.248961,-0.038595,0.183119,9606
3,Q5W0B1,GO:0006275,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220612,0.036551,-0.221058,0.191759,0.3721,0.575317,-0.200968,-0.117173,-0.248961,-0.038595,0.183119,9606
4,Q5W0B1,GO:0006513,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347545,-0.220612,0.036551,-0.221058,0.191759,0.3721,0.575317,-0.200968,-0.117173,-0.248961,-0.038595,0.183119,9606


In [145]:
def group_terms_and_aspects(protein_function_df):
    protein_function_grouped_df = (
        protein_function_df
            .groupby("EntryID")
            .agg({
                "term": list,                     
                "aspect": list,                   
                **{c: "first" for c in protein_function_top_terms_df.columns 
                   if c.startswith("emb_")},      
                "taxonomyID": "first"            
            })
            .rename(columns={"term": "output_terms"})
            .reset_index()
    )
    return protein_function_grouped_df

protein_function_grouped_subontology_dict = {
   k: group_terms_and_aspects(v) for k,v in protein_function_subontology_dict.items()
}
# protein_function_grouped_df = group_terms_and_aspects(protein_function_df)

In [146]:
protein_function_grouped_subontology_dict["C"].head()

Unnamed: 0,EntryID,output_terms,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
0,A0A023PZB3,[GO:0005739],[C],-0.892717,0.08732,-0.119613,-0.446892,0.295609,0.060304,0.074364,-0.377235,0.089221,0.112587,-0.141796,0.050659,-0.415592,-0.085689,-0.113286,0.196419,559292
1,A0A024RBG1,[GO:0005829],[C],0.975543,-0.338469,-0.000139,0.019703,-0.655573,0.512959,-0.308134,0.00241,0.023895,0.04314,0.041124,0.353098,-0.16968,0.171263,0.017887,0.020749,9606
2,A0A059TC02,[GO:0005737],[C],1.805741,-0.368449,0.109841,-0.028463,-0.152263,0.315593,-0.052549,-0.193249,0.150809,-0.054232,0.24263,-0.021289,0.184021,0.292629,0.054679,0.356749,4102
3,A0A060A682,[GO:0005911],[C],-0.724893,0.047931,0.15265,-0.801221,1.21119,0.523572,0.394047,0.33649,0.116497,0.064541,0.881014,0.12348,-0.108504,0.265298,0.341598,-0.73927,5911
4,A0A060L102,[GO:0012511],[C],-0.435772,-0.80578,0.243521,1.009033,0.856131,-0.61473,0.340572,-0.0711,0.007405,0.376591,-0.233788,0.249008,0.170998,-0.003208,0.182627,0.214605,88730


In [147]:
protein_function_grouped_subontology_dict["F"].head()

Unnamed: 0,EntryID,output_terms,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
0,A0A023FBW4,[GO:0019958],[F],-0.733679,-0.278833,-0.73913,-0.129297,0.640625,0.608324,0.276895,0.059644,0.456897,-0.223989,0.269743,-0.28764,-0.389554,-0.086319,0.312944,0.207338,34607
1,A0A023FBW7,[GO:0019957],[F],-0.677593,-0.401437,-0.250645,0.016632,0.365204,1.118802,0.57457,-0.099005,0.702556,-0.231644,0.155429,-0.124291,0.05237,-0.020351,0.254775,0.206514,34607
2,A0A023FDY8,[GO:0019957],[F],-0.652475,-0.402101,-0.241203,0.076115,0.370869,1.122665,0.586035,-0.069801,0.735847,-0.26158,0.230738,-0.100794,0.027561,-0.012186,0.261414,0.206666,34607
3,A0A023FF81,[GO:0019958],[F],-0.550702,-0.327903,-0.61398,-0.264536,0.529165,0.62579,0.384382,-0.124661,0.377992,-0.150903,0.159281,-0.375155,-0.409352,-0.076471,0.313823,0.129724,34607
4,A0A023FFB5,[GO:0019957],[F],-0.633638,-0.376555,-0.319808,-0.28403,0.645544,1.141826,0.553378,-0.213993,0.667963,-0.399956,0.136289,-0.303193,-0.022595,-0.022912,0.3043,0.095832,34607


In [148]:
protein_function_grouped_subontology_dict["P"].head()

Unnamed: 0,EntryID,output_terms,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,taxonomyID
0,A0A023FFD0,[GO:1900137],[P],-0.5557,-0.180274,-0.647727,0.090899,0.79779,0.881821,0.393482,0.06978,0.352916,-0.534025,0.192397,-0.107591,-0.289423,0.016901,0.436988,0.2546,34607
1,A0A023I7E1,[GO:0000272],[P],-0.692904,-0.624924,0.328031,0.124243,0.410377,-0.050196,-0.319525,0.269531,-0.792961,-0.184239,0.055391,0.291116,-0.307184,0.485986,0.071023,0.294006,4839
2,A0A023PXP4,[GO:0006974],[P],-0.88985,-0.035833,-1.037647,-0.215612,-0.237554,0.068305,-0.006168,0.121818,0.380766,0.819004,-0.118286,0.190439,-0.441868,0.359809,0.07112,0.231049,559292
3,A0A026W182,"[GO:0035176, GO:0042048, GO:0043695, GO:001923...","[P, P, P, P, P, P]",-0.08961,-0.761104,-0.370754,-0.268787,0.556936,-0.538747,0.367187,0.01839,0.276713,-0.357124,-0.317369,-0.068336,0.079929,0.000447,0.050457,0.044899,2015173
4,A0A044RE18,"[GO:0031638, GO:0090472]","[P, P]",0.33807,-0.25307,-0.210247,-0.09247,0.242525,0.463352,-0.359205,0.274327,-0.349316,-0.035748,-0.13508,-0.055286,-0.351821,0.129811,-0.065735,-0.037453,6282


In [149]:
for k, df in protein_function_grouped_subontology_dict.items():
    protein_function_grouped_subontology_dict[k] = df.sample(frac=1, random_state=42).reset_index(drop=True)

In [150]:
from sklearn.model_selection import train_test_split
# Columns for Training and Columns for Testing
PREDICTORS = [f"emb_{i}" for i in range(PCA_TARGET_DIM)]
PREDICTORS.append("EntryID")
PREDICTORS.append("aspect")
OUTPUTS = ['output_terms']
subontology_train_val_test_dic = {}
for k, df in protein_function_grouped_subontology_dict.items():
    X, y = df[PREDICTORS].values, df[OUTPUTS].iloc[:, 0].tolist()
    X_train_all, X_test, y_train_all, y_test = train_test_split(X, y, test_size=0.05, random_state=42)
    # Perform a second split for validation set for finetuning our model
    X_train, X_val, y_train, y_val = train_test_split(X_train_all, y_train_all, test_size=0.1, random_state=42)
    protein_function_grouped_subontology_dict[k] = X_train, X_val, y_train, y_val, X_test, y_test


In [151]:
protein_function_grouped_subontology_dict["C"]

(array([[-0.5994196447925982, -0.3730742624220832, 0.409588462246922, ...,
         0.12183256085556757, 'Q63871', list(['C'])],
        [-1.0215088474568375, -0.519317542514266, -0.16285983766189538,
         ..., -0.07147352870546446, 'Q9VYS5', list(['C', 'C'])],
        [0.05422578592291629, -0.06722377976756616, 0.15800841652212663,
         ..., -0.20549474833914394, 'F4I4E1', list(['C', 'C'])],
        ...,
        [-0.6270777543235305, -0.2293104002049112, -0.9908981174642683,
         ..., 0.16843782647846015, 'P15328',
         list(['C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C'])],
        [0.8468180444792296, -0.38749349651039294, -0.04081622706635616,
         ..., -0.03283858334029386, 'Q9CY64', list(['C', 'C', 'C', 'C'])],
        [-0.7875850291273915, 0.24632715070875835, -0.9907900251341156,
         ..., -0.17760960567759884, 'P45646', list(['C'])]], dtype=object),
 array([[0.5005845095277185, 0.13534886713438044, -0.35312733426680776,
         ..., -0.131

In [121]:
protein_function_metadata_dict = {
    "C": None,
    "F": None,
    "P": None
}
for k,(X_train, X_val, y_train, y_val, X_test, y_test) in protein_function_grouped_subontology_dict.items():
    X_train_entry_ids = [data_row[-1] for data_row in X_train]
    X_val_entry_ids = [data_row[-1] for data_row in X_val]
    X_test_entry_ids = [data_row[-1] for data_row in X_test]
    X_train = np.array([data_row[:-1] for data_row in X_train])
    X_val = np.array([data_row[:-1] for data_row in X_val])
    X_test = np.array([data_row[:-1] for data_row in X_test])
    protein_function_metadata_dict[k] = {
    "X_train": X_train,
    "X_val": X_val,
    "y_train": y_train,
    "y_val": y_val,
    "X_test": X_test,
    "y_test": y_test,
    "X_train_entry_ids": X_train_entry_ids,
    "X_val_entry_ids": X_val_entry_ids,
    "X_test_entry_ids": X_test_entry_ids
}

In [122]:
protein_function_metadata_dict['C'].keys()

dict_keys(['X_train', 'X_val', 'y_train', 'y_val', 'X_test', 'y_test', 'X_train_entry_ids', 'X_val_entry_ids', 'X_test_entry_ids'])

In [123]:
from sklearn.preprocessing import MultiLabelBinarizer
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    subontology_metadata_dict["unique_terms"] = protein_function_unique_terms_subontology_dict[k]
    term_to_index = {term: i for i, term in enumerate(subontology_metadata_dict["unique_terms"])}
    
    mlb = MultiLabelBinarizer(classes=subontology_metadata_dict["unique_terms"])
    y_train_transformed = mlb.fit_transform(subontology_metadata_dict["y_train"])
    y_val_transformed = mlb.transform(subontology_metadata_dict["y_val"])
    
    from sklearn.impute import SimpleImputer
    imputer = SimpleImputer(strategy="constant", fill_value=0)
    X_train_imputed = imputer.fit_transform(subontology_metadata_dict["X_train"])
    X_test_imputed = imputer.transform(subontology_metadata_dict["X_test"])
    X_val_imputed = imputer.transform(subontology_metadata_dict["X_val"])
    subontology_metadata_dict["y_train_transformed"] = y_train_transformed
    subontology_metadata_dict["y_val_transformed"] = y_val_transformed
    subontology_metadata_dict["X_train_imputed"] = X_train_imputed
    subontology_metadata_dict["X_val_imputed"] = X_val_imputed
    subontology_metadata_dict["mlb"] = mlb
    

In [211]:
import xgboost as xgb
import numpy as np
from sklearn.linear_model import LogisticRegression

num_models = 100

models_dict = {}
def train_xgb_models():
    models = []
    for i in range(num_models):
        print(f"\nTraining model for label {i}...")
    
        y_i = y_train_transformed[:, i]
    
        pos = np.sum(y_i == 1)
        neg = np.sum(y_i == 0)
        scale_pos_weight = neg / pos if pos > 0 else 1.0
    
        print(f"Label {i} imbalance pos={pos}, neg={neg}, scaling weight={scale_pos_weight}")
    
        dtrain = xgb.DMatrix(X_train_imputed, label=y_i)
    
        params = {
            "objective": "binary:logistic",
            "eval_metric": "aucpr",
            "tree_method": "hist",
            "max_depth": 7,
            "eta": 0.05,
            "lambda": 1.5,
            "alpha": 0.8,
            "min_child_weight": 2,
            "subsample": 0.8,
            "colsample_bytree": 0.8,
            "scale_pos_weight": scale_pos_weight
        }
    
        model = xgb.train(
            params=params,
            dtrain=dtrain,
            num_boost_round=300,
            evals=[(dtrain, "train")],
            early_stopping_rounds=25,
            verbose_eval=50
        )
    
        models.append(model)
    return models

def train_logistic_regression_models(i):

    print(f"\nTraining model for label {i}...")
    
    y_i = y_train_transformed[:, i]
    # Sample not good
    if len(np.unique(y_i)) < 2:
        models_dict[f"lr_{i}"] = None
        return

    pos = np.sum(y_i == 1)
    neg = np.sum(y_i == 0)
    scale_pos_weight = neg / pos if pos > 0 else 1.0
    
    print(f"Label {i}")
    
    model = LogisticRegression(
        penalty="l2",
        solver="liblinear",
        class_weight={0: 1.0, 1: scale_pos_weight},
        max_iter=200,
    )
    
    model.fit(X_train_imputed, y_i)
    models_dict[f"lr_{i}"] = model



In [212]:
# dval = xgb.DMatrix(X_val)


# xgb_pred_list = []
# for model in xgboost_models:
#     print(f"model: {model}")
#     xgb_pred_list.append(model.predict(dval))

# xgb_preds = np.column_stack(xgb_pred_list)

# lr_pred_list = []
# for model in logistic_regression_models:
#     print(f"model: {model}")
#     lr_pred_list.append(model.predict(X_val_imputed))

# lr_preds = np.column_stack(lr_pred_list)

# model_predictions_dict = {
#     "xgb_preds": xgb_preds,
#     "lr_preds": lr_preds
# }


In [91]:
from sklearn.metrics import accuracy_score, f1_score

for k, v in model_predictions_dict.items():
    pred_binary = (v >= 0.5).astype(int)
    
    y_true = y_val_transformed[:, :num_models]
    
    accuracy_per_label = []
    f1_per_label = []

    for i in range(len(pred_binary[0])):
        acc = accuracy_score(y_true[:, i], pred_binary[:, i])
        f1  = f1_score(y_true[:, i], pred_binary[:, i], zero_division=0)
    
        accuracy_per_label.append(acc)
        f1_per_label.append(f1)
    
        print(f"Model: {k} Label {i}:  Accuracy = {acc:.4f},   F1 = {f1:.4f}")

Model: xgb_preds Label 0:  Accuracy = 0.9599,   F1 = 0.0977
Model: xgb_preds Label 1:  Accuracy = 0.9894,   F1 = 0.3025
Model: xgb_preds Label 2:  Accuracy = 0.9954,   F1 = 0.0000
Model: xgb_preds Label 3:  Accuracy = 0.9986,   F1 = 0.0000
Model: xgb_preds Label 4:  Accuracy = 0.9987,   F1 = 0.1667
Model: xgb_preds Label 5:  Accuracy = 0.9670,   F1 = 0.1400
Model: xgb_preds Label 6:  Accuracy = 0.6647,   F1 = 0.6331
Model: xgb_preds Label 7:  Accuracy = 0.9981,   F1 = 0.1176
Model: xgb_preds Label 8:  Accuracy = 0.9157,   F1 = 0.1771
Model: xgb_preds Label 9:  Accuracy = 0.9991,   F1 = 0.4615
Model: xgb_preds Label 10:  Accuracy = 0.9990,   F1 = 0.6000
Model: xgb_preds Label 11:  Accuracy = 1.0000,   F1 = 0.0000
Model: xgb_preds Label 12:  Accuracy = 0.9891,   F1 = 0.2478
Model: xgb_preds Label 13:  Accuracy = 0.9997,   F1 = 0.5000
Model: xgb_preds Label 14:  Accuracy = 0.9940,   F1 = 0.0784
Model: xgb_preds Label 15:  Accuracy = 0.9997,   F1 = 0.0000
Model: xgb_preds Label 16:  Accura

In [213]:
from joblib import Parallel, delayed
import psutil
n_cores = psutil.cpu_count()
print(f"Available cores: {n_cores}")

Available cores: 4


In [214]:
results = Parallel(n_jobs=n_cores, verbose=10)(
    delayed(train_logistic_regression_models)(i) for i in range(200)
)

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   5 tasks      | elapsed:   16.1s
[Parallel(n_jobs=4)]: Done  10 tasks      | elapsed:   18.2s
[Parallel(n_jobs=4)]: Done  17 tasks      | elapsed:   20.2s
[Parallel(n_jobs=4)]: Done  24 tasks      | elapsed:   23.2s
[Parallel(n_jobs=4)]: Done  33 tasks      | elapsed:   26.1s
[Parallel(n_jobs=4)]: Done  42 tasks      | elapsed:   29.7s
[Parallel(n_jobs=4)]: Done  53 tasks      | elapsed:   33.1s
[Parallel(n_jobs=4)]: Done  64 tasks      | elapsed:   37.0s
[Parallel(n_jobs=4)]: Done  77 tasks      | elapsed:   41.6s
[Parallel(n_jobs=4)]: Done  90 tasks      | elapsed:   46.3s
[Parallel(n_jobs=4)]: Done 105 tasks      | elapsed:   51.2s
[Parallel(n_jobs=4)]: Done 120 tasks      | elapsed:   56.7s
[Parallel(n_jobs=4)]: Done 137 tasks      | elapsed:  1.0min
[Parallel(n_jobs=4)]: Done 154 tasks      | elapsed:  1.1min
[Parallel(n_jobs=4)]: Done 173 tasks      | elapsed:  1.3min
[Parallel(

In [215]:
# b

In [126]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

output_dim = 100

class WeightedMultiLabelNN(nn.Module):
    def __init__(self, input_dim=16, output_dim=output_dim, hidden_dims=[64, 128, 256], pos_weights=None):
        super().__init__()
        self.pos_weights = pos_weights
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.1))
            prev_dim = hidden_dim
        self.features = nn.Sequential(*layers)
        self.classifier = nn.Linear(prev_dim, output_dim)
    
    def forward(self, x):
        features = self.features(x)
        logits = self.classifier(features)
        return logits


class PositionalWeightedLoss(nn.Module):
    def __init__(self, pos_weights=None, label_weights=None, reduction='mean'):
        super().__init__()
        self.pos_weights = pos_weights
        self.label_weights = label_weights
        self.reduction = reduction
        
    def forward(self, logits, targets):
        loss = F.binary_cross_entropy_with_logits(
            logits,
            targets,
            reduction='none'
        )
        
        if self.pos_weights is not None:
            pos_mask = (targets > 0.5).float()
            pos_weight_factor = 1 + (self.pos_weights.unsqueeze(0) - 1) * pos_mask
            loss = loss * pos_weight_factor
            
        if self.label_weights is not None:
            loss = loss * self.label_weights.unsqueeze(0)
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

In [127]:
class ProteinEmbeddingsDataset(Dataset):
    def __init__(self, features, labels, total_labels=output_dim, label_metadata=None):
        self.features = features.astype(np.float32)
        self.labels = labels[:, :total_labels].astype(np.float32)
        self.label_metadata = label_metadata
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        x = torch.FloatTensor(self.features[idx])
        y = torch.FloatTensor(self.labels[idx])
        return x, y

In [128]:
for k, subontology_metadat_dict in protein_function_metadata_dict.items():
    protein_embeddings_dataset_train = ProteinEmbeddingsDataset(subontology_metadat_dict["X_train_imputed"],subontology_metadat_dict["y_train_transformed"], total_labels=output_dim)
    protein_embeddings_dataset_val = ProteinEmbeddingsDataset(subontology_metadat_dict["X_val_imputed"], subontology_metadat_dict["y_val_transformed"], total_labels=output_dim)
    subontology_metadat_dict["protein_embeddings_dataset_train"] = protein_embeddings_dataset_train
    subontology_metadat_dict["protein_embeddings_dataset_val"] = protein_embeddings_dataset_val

In [129]:
from sklearn.metrics import f1_score

def evaluate_f1(model, val_loader, device, threshold=0.5):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for features, targets in val_loader:
            features = features.to(device)
            targets = targets.to(device)

            logits = model(features)
            probs = torch.sigmoid(logits)

            preds = (probs > threshold).float()

            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    num_labels = all_targets.shape[1]
    f1_scores = []

    for i in range(num_labels):
        try:
            f1 = f1_score(all_targets[:, i], all_preds[:, i], zero_division=0   )
        except ValueError:
            f1 = 0.0 
        f1_scores.append(f1)

    f1_scores = np.array(f1_scores)
    best_idx = np.argsort(f1_scores)[-3:]
    worst_idx = np.argsort(f1_scores)[:3]

    return f1_scores, best_idx, worst_idx
    

In [165]:
def obtain_label_metadata(curr_y, output_dim):
    label_metadata = {
    "pos_weights": []
    }
    print(len(curr_y[0]))
    for i in range(output_dim):
        if not i % 1000: print(i)
        y_i = curr_y[:, i]
        pos = np.sum(y_i == 1)
        neg = np.sum(y_i == 0)
        scale_pos_weight = neg / pos if pos > 0 else 1.0
        label_metadata["pos_weights"].append(scale_pos_weight)
    return label_metadata


def train_model(model, train_loader, pos_weights, num_epochs=50):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    pos_weights = pos_weights.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for features, targets in train_loader:
            features = features.to(device)
            targets = targets.to(device)
            
            logits = model(features)
            loss = F.binary_cross_entropy_with_logits(
                logits,
                targets,
                pos_weight=pos_weights,
                reduction='mean'
            )
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"current eppoch {epoch+1}/{num_epochs} current loss: {total_loss/len(train_loader):.4f}")
    
    return model

In [166]:
import warnings
warnings.filterwarnings("ignore")

def train_model(model, train_loader, val_loader, pos_weights, num_epochs=25):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    pos_weights = pos_weights.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for features, targets in train_loader:
            features = features.to(device)
            targets = targets.to(device)

            logits = model(features)

            loss = F.binary_cross_entropy_with_logits(
                logits,
                targets,
                pos_weight=pos_weights,
                reduction="mean"
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
       

        # Validate using the validation dataloader
        f1_scores, best_idx, worst_idx = evaluate_f1(model, val_loader, device)
        print(f"curr epoch {epoch+1}/{num_epochs}] loss: {avg_loss:.4f}")
        print("Validation op 3 labels (best F1):")
        for idx in reversed(best_idx):
            print(f"Label {idx}: F1 = {f1_scores[idx]:.4f}")

        print("Validation Bottom 3 labels (worst F1):")
        for idx in worst_idx:
            print(f"Label {idx}: F1 = {f1_scores[idx]}")

        print(f"Validation average f1 score: {np.mean(f1_scores)}")

    return model

In [None]:
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    label_metadata = obtain_label_metadata(subontology_metadata_dict["y_train_transformed"], output_dim)
    subontology_metadata_dict["label_metadata"] = label_metadata
    curr_model = WeightedMultiLabelNN()
    data_loader = DataLoader(subontology_metadata_dict["protein_embeddings_dataset_train"], batch_size=32)
    data_loader_val = DataLoader(subontology_metadata_dict["protein_embeddings_dataset_val"])
    subontology_metadata_dict["curr_model"] = curr_model
    subontology_metadata_dict["data_loader"] = data_loader
    subontology_metadata_dict["data_loader_val"] = data_loader_val
    print(f"\n\nTRAINING: {k} model \n\n")
    trained_model_nn = train_model(curr_model, data_loader, data_loader_val, torch.tensor(label_metadata["pos_weights"][:output_dim]))
    subontology_metadata_dict["trained_model_nn"] = trained_model_nn


2651
0


TRAINING: C model 


curr epoch 1/25] loss: 1.4975
Validation op 3 labels (best F1):
Label 24: F1 = 0.4921
Label 9: F1 = 0.4478
Label 13: F1 = 0.4321
Validation Bottom 3 labels (worst F1):
Label 99: F1 = 0.0
Label 10: F1 = 0.0
Label 11: F1 = 0.0
Validation average f1 score: 0.05367597656411689
curr epoch 2/25] loss: 1.1665
Validation op 3 labels (best F1):
Label 24: F1 = 0.4999
Label 9: F1 = 0.4736
Label 13: F1 = 0.4469
Validation Bottom 3 labels (worst F1):
Label 99: F1 = 0.0
Label 10: F1 = 0.0
Label 11: F1 = 0.0
Validation average f1 score: 0.05718110232149605
curr epoch 3/25] loss: 1.0841
Validation op 3 labels (best F1):
Label 24: F1 = 0.5118
Label 9: F1 = 0.4730
Label 13: F1 = 0.4676
Validation Bottom 3 labels (worst F1):
Label 99: F1 = 0.0
Label 10: F1 = 0.0
Label 11: F1 = 0.0
Validation average f1 score: 0.061444455351773335
curr epoch 4/25] loss: 1.0221
Validation op 3 labels (best F1):
Label 24: F1 = 0.5040
Label 9: F1 = 0.4927
Label 13: F1 = 0.4684
Validation Bottom 

In [None]:
def predict(model, dataloader, device):
    model.eval()
    preds = []
    labels = []

    with torch.no_grad():
        for batch in dataloader:
            X, y = batch
            X = X.to(device)
            y = y.to(device)

            logits = model(X)
            out = torch.sigmoid(logits)
            preds.append(out.cpu())
            labels.append(y.cpu())

    preds = torch.cat(preds)
    labels = torch.cat(labels)
    return preds, labels

In [None]:
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    
    y_preds, y_labels = predict(subontology_metadata_dict["trained_model_nn"], subontology_metadata_dict["data_loader_val"], 'cuda')
    subontology_metadata_dict["y_preds"] = y_preds
    subontology_metadata_dict["y_labels"] = y_labels


In [None]:
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    
    pred_entry_ids_df = pd.DataFrame({"EntryID": subontology_metadata_dict["X_val_entry_ids"], "prediction": [row.tolist() for row in subontology_metadata_dict["y_preds"]]})
    label_names = subontology_metadata_dict["mlb"].classes_[:output_dim]
    num_labels = len(label_names)
    
    pred_matrix = np.vstack(pred_entry_ids_df["prediction"].values)
    
    pred_df = pd.DataFrame(pred_matrix, columns=label_names)
    
    prediction_df = pd.concat([pred_entry_ids_df[["EntryID"]], pred_df], axis=1)
    go_cols = [c for c in prediction_df.columns if c != "EntryID"]
    to_score_df = prediction_df.melt(id_vars="EntryID", value_vars=go_cols, var_name="term", value_name="score")
    
    to_score_df = to_score_df.sort_values(["EntryID", "term"]).reset_index(drop=True)
    to_score_df["score"] = to_score_df["score"].round(3)
    to_score_df = to_score_df[to_score_df["score"] != 0]
    subontology_metadata_dict["to_score_df"] = to_score_df
    to_score_df.to_csv(f"to_score_{k}.tsv", header=False, index=False, sep="\t")
    


In [None]:
# per insturctions we must submit a file with the EntryID, GO and score while droping 0 scores after rounding to 3 decimal point


In [None]:
for k, curr_protein_function_df in protein_function_subontology_dict.items():
    ground_truth_score_df = curr_protein_function_df[curr_protein_function_df["EntryID"].isin(X_val_entry_ids)][["EntryID", "term"]]
    protein_function_metadata_dict[k]["ground_truth_score_df"] = ground_truth_score_df

In [None]:
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    subontology_metadata_dict["ground_truth_score_df"].to_csv(f"ground_truth_score_{k}.tsv", header=False, index=False, sep="\t")

In [None]:
ground_truth_score_df_merged = pd.concat([
    protein_function_metadata_dict["C"]["ground_truth_score_df"],
    protein_function_metadata_dict["F"]["ground_truth_score_df"],
    protein_function_metadata_dict["P"]["ground_truth_score_df"]
], ignore_index=True)

score_df_merged = pd.concat([
    protein_function_metadata_dict["C"]["to_score_df"],
    protein_function_metadata_dict["F"]["to_score_df"],
    protein_function_metadata_dict["P"]["to_score_df"]
], ignore_index=True)


In [None]:
score_df_merged.size, ground_truth_score_df_merged.size

In [None]:
ground_truth_score_df_merged.to_csv("ground_truth_score.tsv", header=False, index=False, sep="\t")

In [None]:
score_df_merged.to_csv("to_score.tsv", header=False, index=False, sep="\t")

In [None]:
score_df_merged.head(50)

In [None]:
protein_function_df.merge(protein_function_metadata_dict["C"]["to_score_df"], on=["EntryID", "term"]).head()

In [None]:
protein_function_df.merge(protein_function_metadata_dict["C"]["to_score_df"], on=["EntryID", "term"])

In [None]:
merged_df = protein_function_df.merge(score_df_merged, on="EntryID", how="right")

In [None]:
"A0A068FIK2" in protein_function_df["EntryID"].unique()

In [None]:
protein_function_metadata_dict["P"].keys()

In [108]:
merged_df.to_csv("test_csv.csv", index=False, header=True)

KeyboardInterrupt: 

In [83]:
protein_function_df.head(), score_df_merged.head()

(   EntryID        term aspect     emb_0     emb_1     emb_2     emb_3  \
 2   Q5W0B1  GO:0051865      P -0.270628  1.031287 -0.162849 -0.904390   
 3   Q5W0B1  GO:0006275      P -0.270628  1.031287 -0.162849 -0.904390   
 4   Q5W0B1  GO:0006513      P -0.270628  1.031287 -0.162849 -0.904390   
 12  Q8R2Z3  GO:0035429      P  0.394743 -0.708571 -0.151442 -0.196877   
 14  Q8R2Z3  GO:1902358      P  0.394743 -0.708571 -0.151442 -0.196877   
 
        emb_4     emb_5     emb_6     emb_7     emb_8     emb_9    emb_10  \
 2  -0.347545 -0.220612  0.036551 -0.221058  0.191759  0.372100  0.575317   
 3  -0.347545 -0.220612  0.036551 -0.221058  0.191759  0.372100  0.575317   
 4  -0.347545 -0.220612  0.036551 -0.221058  0.191759  0.372100  0.575317   
 12  0.396844 -0.341860  0.126673 -0.512127  0.131130  0.023955  0.306699   
 14  0.396844 -0.341860  0.126673 -0.512127  0.131130  0.023955  0.306699   
 
       emb_11    emb_12    emb_13    emb_14    emb_15  taxonomyID  
 2  -0.200968 -0.11717

In [None]:
def create_go_embeddings_optimized(obo_path, go_terms, embed_dim=16, hidden_dim=32, out_dim=16, epochs=50):

    print(" Loading Gene Ontology...")
    graph = obonet.read_obo(obo_path)


    edges = pd.DataFrame([
        {'source': u, 'target': v, 'relation': data.get('relation', 'is_a')}
        for u, v, data in graph.edges(data=True)
    ])

    relevant_edges = edges[
        edges['source'].isin(go_terms) | edges['target'].isin(go_terms)
    ].reset_index(drop=True)

    nodes = pd.DataFrame({'id': list(set(relevant_edges['source']).union(relevant_edges['target']))})
    nodes['node_idx'] = range(len(nodes))
    node2idx = dict(zip(nodes['id'], nodes['node_idx']))

    edge_index = torch.tensor([
        [node2idx[s] for s in relevant_edges['source']],
        [node2idx[t] for t in relevant_edges['target']]
    ], dtype=torch.long)

    num_nodes = len(nodes)
    print(f"Using {num_nodes} GO terms and {len(relevant_edges)} edges")


    x = torch.randn((num_nodes, embed_dim), dtype=torch.float32)


    class SimpleGCN(nn.Module):
        def __init__(self, in_dim, hidden_dim, out_dim):
            super(SimpleGCN, self).__init__()
            self.conv1 = GCNConv(in_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, out_dim)

        def forward(self, x, edge_index):
            x = self.conv1(x, edge_index)
            x = F.relu(x)
            x = self.conv2(x, edge_index)
            return x

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleGCN(embed_dim, hidden_dim, out_dim).to(device)


    x = x.to(device)
    edge_index = edge_index.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    data = Data(x=x, edge_index=edge_index)

    print(f"Training on device: {device}")
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        embeddings = model(data.x, data.edge_index)
        # Inner product decoder
        recon = torch.sigmoid(torch.matmul(embeddings, embeddings.T))
        adj_true = torch.zeros_like(recon)
        adj_true[data.edge_index[0], data.edge_index[1]] = 1.0

        loss = F.binary_cross_entropy(recon, adj_true)
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch:03d} | Loss: {loss.item():.4f}")


    with torch.no_grad():
        node_embeddings = model(data.x, data.edge_index).cpu().numpy()

    del model, x, data, recon, adj_true
    torch.cuda.empty_cache()
    gc.collect()

    col_names = [f"go_emb_{i}" for i in range(node_embeddings.shape[1])]

    emb_df = pd.DataFrame(node_embeddings, index=nodes['id'], columns=col_names)
    print(f"Created embeddings for {len(emb_df)} GO terms")
    return emb_df





In [None]:
sampled_data = sample_tsv(term_path, sample_frac=0.05)
go_terms = sampled_data['term'].unique()
embeddings_df = create_go_embeddings_optimized(obo_path, go_terms)
seq_df = extract_sequences(fasta_path, sampled_data['EntryID'])

### Combine GO embedding and PLM embedding into one dataset

In [None]:
def combine_go_protein_embeddings(sampled_data, go_emb_df, prot_emb_df):

    combined = sampled_data.merge(go_emb_df, how='left', left_on='term', right_index=True)

    combined = combined.merge(prot_emb_df, how='left', left_on='EntryID', right_index=True)

    return combined


multimodal_df = combine_go_protein_embeddings(sampled_data, embeddings_df, prot_emb_df)

print("Multimodal feature dataframe shape:", multimodal_df.shape)
print(multimodal_df.head())