In [1]:
!pip install torch==2.4.0
!pip install torch-geometric==2.4.0
!pip install biopython
!pip install obonet
!pip install networkx
!pip install pandas
!pip install numpy
!pip install matplotlib
!pip install seaborn
!pip install scipy
!pip install scikit-learn
!pip install fair-esm
import torch
torch_version = str(torch.__version__)
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
!pip install torch-scatter -f $scatter_src
!pip install torch-sparse -f $sparse_src

Collecting torch==2.4.0
  Downloading torch-2.4.0-cp312-cp312-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.4.0)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.4.0)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.4.0)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.4.0)
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.4.0)
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-many

In [16]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # avoid fragmentation
import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import Bio
from Bio import SeqIO
import obonet
import gc
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import random
import esm


In [7]:
obo_path = 'go-basic.obo'
fasta_path = 'train_sequences.fasta'
term_path = 'train_terms.tsv'
taxonomy_path = 'train_taxonomy.tsv'

In [8]:
LARGEST_FASTA_SEQ_LEN = 8922
ESM_EMBEDDING_DIM = 320
PCA_TARGET_DIM = 128

In [17]:
term_df = pd.read_csv(term_path, sep='\t')
term_df.head(), len(term_df["term"].unique())

(  EntryID        term aspect
 0  Q5W0B1  GO:0000785      C
 1  Q5W0B1  GO:0004842      F
 2  Q5W0B1  GO:0051865      P
 3  Q5W0B1  GO:0006275      P
 4  Q5W0B1  GO:0006513      P,
 26125)

In [18]:
taxonomy_df = pd.read_csv(taxonomy_path, sep='\t', names=['EntryID', 'taxonomyID'])
taxonomy_df.head()

Unnamed: 0,EntryID,taxonomyID
0,A0A0C5B5G6,9606
1,A0JNW5,9606
2,A0JP26,9606
3,A0PK11,9606
4,A1A4S6,9606


In [19]:
# def get_processed_fasta_df(fasta_data, term_df):
#     fasta_dict_list = []
#     term_set = set(term_df.tolist())
#     for fasta_seq in fasta_data:
#         entry = fasta_seq.id.split('|')[1] if '|' in fasta_seq.id else fasta_seq.id
#         if entry in term_set:
#             fasta_dict_list.append({
#                 "EntryID": entry,
#                 "fasta_sequence": str(fasta_seq.seq)
#             })

#     return pd.DataFrame(fasta_dict_list)


def sample_tsv(tsv_df, sample_frac=0.05, random_state=42):
    """
    Read a TSV file and sample based on unique EntryID.
    Pulls in all associated rows for sampled EntryIDs.
    """
    df = tsv_df
    unique_ids = df['EntryID'].unique()
    sample_size = max(1, int(len(unique_ids) * sample_frac))
    sampled_ids = random.sample(list(unique_ids), sample_size)
    sampled_df = df[df['EntryID'].isin(sampled_ids)]
    print(f"Sampled {len(sampled_df)} rows from {len(unique_ids)} unique EntryIDs")
    return sampled_df


def get_processed_fasta_df(fasta_data):
    fasta_dict_list = []
    for fasta_seq in fasta_data:
        entry = fasta_seq.id.split('|')[1] if '|' in fasta_seq.id else fasta_seq.id
        fasta_dict_list.append({
                "EntryID": entry,
                "fasta_sequence": str(fasta_seq.seq)
            })

    return pd.DataFrame(fasta_dict_list)





# Get Merged DF FULL (includes batching to fit embeddings into GPU and offloads to CPU)

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
model = model.to(device)
model = model.half()
model.eval()
batch_converter = alphabet.get_batch_converter()

def generate_protein_embeddings_esm_batch(seq_df, seq_col='Sequence', entryid_col='EntryID',
                                              target_dim=16, batch_size=1, use_fp16=True):
    """
    Memory-optimized ESM embedding generation for proteins.
    Processes small batches and moves embeddings to CPU immediately.
    """
    sequences = seq_df[seq_col].tolist()
    entry_ids = seq_df[entryid_col].tolist()
    print(f"sequences: {len(sequences)} entry_ids {len(entry_ids)}")

    all_embeddings = []
    # print(len(sequences))
    dtype = torch.float16 if use_fp16 else torch.float32
    for i in range(0, len(sequences), batch_size):
        try:
            batch_seqs = sequences[i:i+batch_size]
            batch_labels = entry_ids[i:i+batch_size]
            if len(batch_seqs[0]) > LARGEST_FASTA_SEQ_LEN:
                print(f"length of batch seqs is: {len(batch_seqs[0])}")
                # Split the sequence to be max length LARGEST_FASTA_SEQ_LEN (will always be len 1 batch), this helps avoid the loss of information
                curr_seq = batch_seqs[0]
                embeddings_list = []
                for i in range(0, len( batch_seqs[0]), LARGEST_FASTA_SEQ_LEN):
                  curr_seq = batch_seqs[0][i:i+LARGEST_FASTA_SEQ_LEN]
                  curr_seq_embedding = get_sequence_embedding_esm(batch_labels, [curr_seq], dtype)
                  embeddings_list.append(curr_seq_embedding)
                # Average the embeddings_list since the long sequences were broken down into smaller chunks
                seq_embeddings = np.mean(embeddings_list, axis=0)

            else:
              seq_embeddings = get_sequence_embedding_esm(batch_labels, batch_seqs, dtype)

            all_embeddings.append(seq_embeddings)
            del seq_embeddings
        except Exception as e:
            print(e)
            print(len(batch_seqs[0]))
            print(batch_seqs)


    raw_embeddings = np.vstack(all_embeddings)
    # Mitigate memory constraints
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.reset_peak_memory_stats()

    col_names = [f"prot_emb_{i}" for i in range(320)]
    emb_df = pd.DataFrame(raw_embeddings, index=entry_ids, columns=col_names)
    emb_df.index.name = entryid_col

    return emb_df


def get_sequence_embedding_esm(batch_labels, batch_seqs, dtype):
    batch_data = [(label, seq) for label, seq in zip(batch_labels, batch_seqs)]
    _, _, batch_tokens = batch_converter(batch_data)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
      with torch.autocast(device_type="cuda", dtype=dtype):
          results = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)
          token_embeddings = results["representations"][model.num_layers]  # (B, L, D)
          # Mean pool over sequence length
          attention_mask = batch_tokens != alphabet.padding_idx
          masked_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
          seq_lengths = attention_mask.sum(dim=1).unsqueeze(-1)
          seq_embeddings = (masked_embeddings.sum(dim=1) / seq_lengths).cpu().float().numpy()
          del batch_tokens, token_embeddings, masked_embeddings, results, attention_mask
          return seq_embeddings


def apply_pca_to_esm_embeddings(esm_embeddings_df:pd.DataFrame, target_dim=16):
    '''
    Helper method
    '''
    emb_cols = [c for c in esm_embeddings_df.columns if c.startswith("prot_emb")]

    tensor_list = [
        torch.tensor(row[emb_cols].values.astype("float16"), dtype=torch.float16)
        for _, row in esm_embeddings_df.iterrows()
    ]

    pca = PCA(n_components=PCA_TARGET_DIM)
    tensor_list_transformed = pca.fit_transform(tensor_list)

    # Intialize dictionary where we use EntryID for joining with all the other data
    embeddings_dict = { "EntryID": [] }
    # Initialize empty PCA embeddings for the pd.DataFrame
    for i in range(PCA_TARGET_DIM):
        embeddings_dict[f"emb_{i}"] =  []

    # Append EntryID's and embeddings to the dictionary for the dataframe
    for i, r in enumerate(tensor_list_transformed):
        curr_entry_id = esm_embeddings_df["EntryID"][i]
        embeddings_dict["EntryID"].append(curr_entry_id)
        for j in range(PCA_TARGET_DIM):
            embeddings_dict[f"emb_{j}"].append(r[j])

    pca_embeddings_df = pd.DataFrame(embeddings_dict)
    return pca_embeddings_df




In [21]:
file_name = "fasta_embeddings_final (1).csv"
embeddings_processed = True

def get_merged_df_full(file_name, batch_size=250, embeddings_processed=False):
    """
    Merge term.tsv, fasta data, taxonomy data as well as nodes in the obo graph.
    """
    term_df = pd.read_csv(term_path, sep='\t')
    term_df = term_df
    taxonomy_df = pd.read_csv(taxonomy_path, sep='\t', names=['EntryID', 'taxonomyID'])
    fasta_data = list(SeqIO.parse(fasta_path, "fasta"))

    entry_ids = list(term_df['EntryID'])
    all_batches = []
    if not embeddings_processed:
      for i in range(0, len(entry_ids), batch_size):
          total_processed = i
          print(f"Total processed: {total_processed}, {i//batch_size} batch")

          # Batch the EntryIDs
          entry_batch = list(set(entry_ids[i:i+batch_size]))

          curr_term_df = term_df[i:i+batch_size]
          fasta_df_batch = get_processed_fasta_df(fasta_data, curr_term_df['EntryID'])
          fasta_emb_df_batch = generate_protein_embeddings_esm_batch(
              fasta_df_batch,
              "fasta_sequence"
          )


          all_batches.append(fasta_emb_df_batch)

          full_df = pd.concat(all_batches, ignore_index=False)
          full_df.to_csv("fasta_embeddings.csv", index=True)

    else:
        full_df = pd.read_csv(file_name)
        print(full_df.head())

    fasta_emb_df = apply_pca_to_esm_embeddings(full_df)

    # TODO add embeddings in getting merged_df
    merged_df = pd.merge(term_df, fasta_emb_df, on="EntryID", how='left')
    merged_df = pd.merge(merged_df, taxonomy_df, on="EntryID", how="left")
    graph = obonet.read_obo(obo_path)
    edges_list = []
    for node_id, data in graph.nodes(data=True):
        for parent_id in data.get("is_a", []):
            edges_list.append({
                    "term": node_id,
                    "parent": parent_id,
                    "relation": data.get('relation', 'is_a'),
                    "name": data["name"],
                    "namespace": data["namespace"],
                    "def": data["def"],
                    "synonym": data.get("synonym", [])

                })
    edges_df = pd.DataFrame(edges_list)
    return merged_df, edges_df

protein_function_df, graph_df = get_merged_df_full(file_name, embeddings_processed=embeddings_processed)
len(protein_function_df)

  EntryID  prot_emb_0  prot_emb_1  prot_emb_2  prot_emb_3  prot_emb_4  \
0  P86164    0.063601    0.090173    0.272553    0.045446   -0.024562   
1  P84910    0.164426   -0.128475    0.244240    0.068856   -0.031558   
2  P83012    0.126772   -0.093400    0.216583    0.056300   -0.058426   
3  P83246    0.039001   -0.228608    0.207255    0.138270    0.081464   
4  P86133   -0.063519    0.141610    0.168906   -0.003960   -0.064855   

   prot_emb_5  prot_emb_6  prot_emb_7  prot_emb_8  ...  prot_emb_310  \
0   -0.087723   -0.097689   -0.002640   -0.085523  ...      0.121198   
1   -0.102609   -0.108406   -0.071888   -0.022300  ...      0.094619   
2    0.006500   -0.052936   -0.012787   -0.090959  ...      0.024688   
3   -0.122376   -0.169814   -0.090188   -0.169230  ...     -0.015823   
4   -0.166769    0.018469    0.077215    0.058494  ...      0.147194   

   prot_emb_311  prot_emb_312  prot_emb_313  prot_emb_314  prot_emb_315  \
0      0.150969      0.038082     -0.035188      0.10

537027

In [22]:
graph = obonet.read_obo(obo_path)
edges_list = []
for node_id, data in graph.nodes(data=True):
        for parent_id in data.get("is_a", []):
            edges_list.append({
                    "term": node_id,
                    "parent": parent_id,
                    "name": data["name"],
                    "namespace": data["namespace"],
                    "def": data["def"],
                    "synonym": data.get("synonym", [])
                })
edges_df = pd.DataFrame(edges_list)

In [23]:
ALL_SUBONTOLOGIES = graph_df["namespace"].unique()
graph_df.head(
)


Unnamed: 0,term,parent,relation,name,namespace,def,synonym
0,GO:0000001,GO:0048308,is_a,mitochondrion inheritance,biological_process,"""The distribution of mitochondria, including t...","[""mitochondrial inheritance"" EXACT []]"
1,GO:0000001,GO:0048311,is_a,mitochondrion inheritance,biological_process,"""The distribution of mitochondria, including t...","[""mitochondrial inheritance"" EXACT []]"
2,GO:0000002,GO:0007005,is_a,mitochondrial genome maintenance,biological_process,"""The maintenance of the structure and integrit...",[]
3,GO:0000006,GO:0005385,is_a,high-affinity zinc transmembrane transporter a...,molecular_function,"""Enables the transfer of zinc ions (Zn2+) from...","[""high affinity zinc uptake transmembrane tran..."
4,GO:0000007,GO:0005385,is_a,low-affinity zinc ion transmembrane transporte...,molecular_function,"""Enables the transfer of a solute or solutes f...",[]


In [24]:
protein_function_df.head()

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,...,emb_119,emb_120,emb_121,emb_122,emb_123,emb_124,emb_125,emb_126,emb_127,taxonomyID
0,Q5W0B1,GO:0000785,C,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606
1,Q5W0B1,GO:0004842,F,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606
2,Q5W0B1,GO:0051865,P,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606
3,Q5W0B1,GO:0006275,P,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606
4,Q5W0B1,GO:0006513,P,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606


In [25]:
protein_function_df[protein_function_df["EntryID"]=="Q63871"]

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,...,emb_119,emb_120,emb_121,emb_122,emb_123,emb_124,emb_125,emb_126,emb_127,taxonomyID
124839,Q63871,GO:0006366,P,-0.59942,-0.373074,-0.409588,0.142399,0.107085,-0.563983,0.066212,...,0.027711,0.069348,0.054291,-0.025748,0.043843,-0.032489,-0.005283,0.030139,0.021205,10090
124840,Q63871,GO:0005654,C,-0.59942,-0.373074,-0.409588,0.142399,0.107085,-0.563983,0.066212,...,0.027711,0.069348,0.054291,-0.025748,0.043843,-0.032489,-0.005283,0.030139,0.021205,10090


In [26]:
protein_function_subontology_dict = {
    'C': protein_function_df[protein_function_df["aspect"]=='C'],
    'F': protein_function_df[protein_function_df["aspect"]=='F'],
    'P': protein_function_df[protein_function_df["aspect"]=='P']
}

protein_function_unique_terms_subontology_dict = {
   k: v["term"].unique() for k,v in protein_function_subontology_dict.items()
}



In [27]:
protein_function_subontology_dict['C'].head()

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,...,emb_119,emb_120,emb_121,emb_122,emb_123,emb_124,emb_125,emb_126,emb_127,taxonomyID
0,Q5W0B1,GO:0000785,C,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606
7,Q3EC77,GO:0000138,C,-0.124728,-0.263022,0.456481,-0.181721,0.070037,-0.144072,-0.371443,...,0.003247,0.080996,-0.016898,-0.023454,-0.023529,0.002127,-0.002366,0.018679,0.05698,3702
8,Q3EC77,GO:0005794,C,-0.124728,-0.263022,0.456481,-0.181721,0.070037,-0.144072,-0.371443,...,0.003247,0.080996,-0.016898,-0.023454,-0.023529,0.002127,-0.002366,0.018679,0.05698,3702
13,Q8R2Z3,GO:0016323,C,0.394743,-0.708571,0.151442,-0.196877,-0.396845,0.341863,0.126681,...,0.123511,0.04763,-0.025337,0.018063,-0.028023,0.112842,6.1e-05,-0.036657,-0.047793,10090
21,Q8R2Z3,GO:0016020,C,0.394743,-0.708571,0.151442,-0.196877,-0.396845,0.341863,0.126681,...,0.123511,0.04763,-0.025337,0.018063,-0.028023,0.112842,6.1e-05,-0.036657,-0.047793,10090


### Create set of embedding from the graph edges using GCN

In [28]:
class SimpleGCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(SimpleGCN, self).__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

def init_subontology_GCNs(embed_dim=16, hidden_dim=32, out_dim=16):
    graph_subontology_dict = {
       subontology: SimpleGCN(embed_dim, hidden_dim, out_dim) for subontology in ALL_SUBONTOLOGIES
    }
    return graph_subontology_dict

def get_subontology_graph_dfs(graph_df):
    subontology_graph_dfs = {
       subontology: graph_df[graph_df["namespace"]==subontology] for subontology in ALL_SUBONTOLOGIES
    }
    return subontology_graph_dfs

subontology_GCNs, subontology_graph_dfs = init_subontology_GCNs(), get_subontology_graph_dfs(graph_df)

In [29]:
protein_function_df.head()

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,...,emb_119,emb_120,emb_121,emb_122,emb_123,emb_124,emb_125,emb_126,emb_127,taxonomyID
0,Q5W0B1,GO:0000785,C,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606
1,Q5W0B1,GO:0004842,F,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606
2,Q5W0B1,GO:0051865,P,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606
3,Q5W0B1,GO:0006275,P,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606
4,Q5W0B1,GO:0006513,P,-0.270628,1.031287,0.162849,-0.90439,0.347545,0.220612,0.036554,...,-0.012309,-0.024481,-0.027564,0.020984,0.019852,-0.015266,0.007087,-0.021933,-0.012099,9606


In [30]:
for k, subontology_df in protein_function_subontology_dict.items():
  print(f"Subontology: {k} {subontology_df.size}")

Subontology: C 20825640
Subontology: F 16955664
Subontology: P 33106260


In [31]:
N_TERMS_TO_PREDICT = 200
protein_function_unique_metadata_dict = {

}



# Perform a biased term samplings (papers use common terms and terms that are somewhat frequent the most, we do not want to include too many terms that are infrequent)
def group_terms_and_aspects(df, N_TERMS, subontology, random_state=42):
    term_counts = df["term"].value_counts()
    total_terms = len(term_counts)
    top_cut = N_TERMS

    top_terms = term_counts.index[:top_cut]
    # mid_terms = term_counts.index[top_cut:mid_cut]
    # bottom_terms = term_counts.index[mid_cut:]
    n_top = int(N_TERMS * 1)
    print(n_top)
    # n_mid = int(N_TERMS * 0.175)
    # n_bottom = N_TERMS - n_top - n_mid

    rng = np.random.default_rng(random_state)

    sampled_top = rng.choice(top_terms, size=min(n_top, len(top_terms)), replace=False)
    # sampled_mid = rng.choice(mid_terms, size=min(n_mid, len(mid_terms)), replace=False)
    # sampled_bottom = rng.choice(bottom_terms, size=min(n_bottom, len(bottom_terms)), replace=False)
    selected_terms = set(sampled_top)

    df_filtered = df[df["term"].isin(selected_terms)]
    protein_function_unique_metadata_dict[subontology] = df_filtered["term"].unique()
    result = (
        df_filtered
            .groupby("EntryID")
            .agg({
                "term": list,
                **{c: "first" for c in df.columns if c.startswith("emb_")},
                "taxonomyID": "first"
            })
            .rename(columns={"term": "output_terms"})
            .reset_index()
    )
    return result
protein_function_grouped_subontology_dict = {
   k: group_terms_and_aspects(v, N_TERMS_TO_PREDICT, k) for k,v in protein_function_subontology_dict.items()
}

200
200
200


In [32]:
protein_function_grouped_subontology_dict["C"].head()

Unnamed: 0,EntryID,output_terms,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,...,emb_119,emb_120,emb_121,emb_122,emb_123,emb_124,emb_125,emb_126,emb_127,taxonomyID
0,A0A023PZB3,[GO:0005739],-0.892717,0.08732,0.119613,-0.446892,-0.295609,-0.060304,0.074365,0.377238,...,-0.075837,-0.070641,0.023028,-0.048751,0.032029,0.026451,0.004967,0.016969,0.019651,559292
1,A0A024RBG1,[GO:0005829],0.975543,-0.338469,0.000139,0.019703,0.655574,-0.512959,-0.308127,-0.002441,...,-0.067811,0.042249,-0.098216,-0.044835,-0.072517,-0.078007,0.059016,0.10286,-0.02172,9606
2,A0A059TC02,[GO:0005737],1.805741,-0.368449,-0.109841,-0.028463,0.152263,-0.315594,-0.052553,0.193297,...,-0.01584,-0.037659,-0.065485,-0.019915,-0.006994,0.021847,-0.019371,0.00967,-0.001809,4102
3,A0A060A682,[GO:0005911],-0.724893,0.047931,-0.15265,-0.801221,-1.21119,-0.523571,0.394052,-0.33652,...,0.039839,-0.015549,0.046384,-0.103141,0.07483,-0.005365,0.010642,0.082668,0.013116,5911
4,A0A060X6Z0,[GO:0043204],-0.878039,-0.243687,-0.036863,0.042023,0.337301,0.257103,0.026951,-0.078867,...,-0.008244,0.018263,0.008246,0.044083,0.003077,-0.003395,0.00396,-0.014217,-0.001287,8022


In [33]:
protein_function_grouped_subontology_dict["F"].head()

Unnamed: 0,EntryID,output_terms,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,...,emb_119,emb_120,emb_121,emb_122,emb_123,emb_124,emb_125,emb_126,emb_127,taxonomyID
0,A0A024RBG1,[GO:0005515],0.975543,-0.338469,0.000139,0.019703,0.655574,-0.512959,-0.308127,-0.002441,...,-0.067811,0.042249,-0.098216,-0.044835,-0.072517,-0.078007,0.059016,0.10286,-0.02172,9606
1,A0A044RE18,[GO:0004252],0.33807,-0.25307,0.210247,-0.09247,-0.242525,-0.463352,-0.359207,-0.274323,...,-0.045411,0.174499,-0.060748,-0.031365,-0.029767,0.061553,0.06213,-0.006417,0.04848,6282
2,A0A060D764,[GO:0005515],-0.790027,1.056916,0.211493,0.506878,-0.409526,0.396717,-0.750691,-0.10897,...,-0.03243,-0.027425,-0.003981,0.004307,-0.048944,-0.03985,0.018063,-0.023037,0.057817,4577
3,A0A060KY90,[GO:0043565],-0.36261,0.72849,-0.056367,-0.369524,0.031498,0.169136,-0.269536,0.076936,...,-0.028762,-0.004668,0.008622,0.000472,-0.034791,-0.004167,0.025103,-5.1e-05,0.005016,4081
4,A0A061I403,"[GO:0051087, GO:0030544]",-0.508128,-0.23469,0.115922,0.075436,0.298544,0.320581,0.107912,-0.197371,...,-0.001639,0.018303,0.018535,0.028711,-0.028406,-0.06231,-0.008782,-0.002463,0.006375,10029


In [34]:
protein_function_grouped_subontology_dict["P"].head()

Unnamed: 0,EntryID,output_terms,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,...,emb_119,emb_120,emb_121,emb_122,emb_123,emb_124,emb_125,emb_126,emb_127,taxonomyID
0,A0A023PXP4,[GO:0006974],-0.88985,-0.035833,1.037647,-0.215612,0.237554,-0.068306,-0.006166,-0.121834,...,0.09116,-0.049568,0.098255,0.116718,-0.064941,-0.047116,0.063442,0.016602,-0.066141,559292
1,A0A059TC02,[GO:0007623],1.805741,-0.368449,-0.109841,-0.028463,0.152263,-0.315594,-0.052553,0.193297,...,-0.01584,-0.037659,-0.065485,-0.019915,-0.006994,0.021847,-0.019371,0.00967,-0.001809,4102
2,A0A060D764,[GO:0045893],-0.790027,1.056916,0.211493,0.506878,-0.409526,0.396717,-0.750691,-0.10897,...,-0.03243,-0.027425,-0.003981,0.004307,-0.048944,-0.03985,0.018063,-0.023037,0.057817,4577
3,A0A060KY90,[GO:0045893],-0.36261,0.72849,-0.056367,-0.369524,0.031498,0.169136,-0.269536,0.076936,...,-0.028762,-0.004668,0.008622,0.000472,-0.034791,-0.004167,0.025103,-5.1e-05,0.005016,4081
4,A0A061ACU2,[GO:0030317],-0.8079,-0.536004,0.286739,-0.372638,-0.565446,0.436029,-0.144215,0.011856,...,0.037371,0.023713,0.106557,-0.010648,-0.018119,0.034135,-0.011705,0.024707,0.028981,6239


In [35]:
for k, df in protein_function_grouped_subontology_dict.items():
    protein_function_grouped_subontology_dict[k] = df.sample(frac=1, random_state=42).reset_index(drop=True)

In [36]:
from sklearn.model_selection import train_test_split
# Columns for Training and Columns for Testing
PREDICTORS = [f"emb_{i}" for i in range(PCA_TARGET_DIM)]
PREDICTORS.append("EntryID")

OUTPUTS = ['output_terms']
subontology_train_val_test_dic = {}
for k, df in protein_function_grouped_subontology_dict.items():
    X, y = df[PREDICTORS].values, df[OUTPUTS].iloc[:, 0].tolist()
    X_train_all, X_test, y_train_all, y_test = train_test_split(X, y, test_size=0.05, random_state=42)
    # Perform a second split for validation set for finetuning our model
    X_train, X_val, y_train, y_val = train_test_split(X_train_all, y_train_all, test_size=0.1, random_state=42)
    protein_function_grouped_subontology_dict[k] = X_train, X_val, y_train, y_val, X_test, y_test


In [37]:
protein_function_metadata_dict = {
    "C": None,
    "F": None,
    "P": None
}
for k,(X_train, X_val, y_train, y_val, X_test, y_test) in protein_function_grouped_subontology_dict.items():
    X_train_entry_ids = [data_row[-1] for data_row in X_train]
    X_val_entry_ids = [data_row[-1] for data_row in X_val]
    X_test_entry_ids = [data_row[-1] for data_row in X_test]
    X_train = np.array([data_row[:-1] for data_row in X_train])
    X_val = np.array([data_row[:-1] for data_row in X_val])
    X_test = np.array([data_row[:-1] for data_row in X_test])
    protein_function_metadata_dict[k] = {
    "X_train": X_train,
    "X_val": X_val,
    "y_train": y_train,
    "y_val": y_val,
    "X_test": X_test,
    "y_test": y_test,
    "X_train_entry_ids": X_train_entry_ids,
    "X_val_entry_ids": X_val_entry_ids,
    "X_test_entry_ids": X_test_entry_ids
}

In [38]:
protein_function_metadata_dict['C'].keys()

dict_keys(['X_train', 'X_val', 'y_train', 'y_val', 'X_test', 'y_test', 'X_train_entry_ids', 'X_val_entry_ids', 'X_test_entry_ids'])

In [39]:
from sklearn.preprocessing import MultiLabelBinarizer
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    subontology_metadata_dict["unique_terms"] = protein_function_unique_metadata_dict[k]
    print(len(protein_function_unique_metadata_dict[k]))
    term_to_index = {term: i for i, term in enumerate(subontology_metadata_dict["unique_terms"])}

    mlb = MultiLabelBinarizer(classes=subontology_metadata_dict["unique_terms"])
    y_train_transformed = mlb.fit_transform(subontology_metadata_dict["y_train"])
    y_val_transformed = mlb.transform(subontology_metadata_dict["y_val"])
    # print(mlb.classes_)
    from sklearn.impute import SimpleImputer
    imputer = SimpleImputer(strategy="constant", fill_value=0)
    X_train_imputed = imputer.fit_transform(subontology_metadata_dict["X_train"])
    X_test_imputed = imputer.transform(subontology_metadata_dict["X_test"])
    X_val_imputed = imputer.transform(subontology_metadata_dict["X_val"])
    subontology_metadata_dict["y_train_transformed"] = y_train_transformed
    subontology_metadata_dict["y_val_transformed"] = y_val_transformed
    subontology_metadata_dict["X_train_imputed"] = X_train_imputed
    subontology_metadata_dict["X_val_imputed"] = X_val_imputed
    subontology_metadata_dict["mlb"] = mlb
    # print(y_train_transformed.shape)


200
200
200


In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

output_dim = N_TERMS_TO_PREDICT
MODEL_OUTPUT_DIM = 32

class WeightedMultiLabelNN(nn.Module):
    def __init__(self, input_dim=PCA_TARGET_DIM, output_dim=output_dim, hidden_dims=[256, 512, 1024], pos_weights=None, return_embeddings=False):
        super().__init__()
        self.pos_weights = pos_weights
        layers = []
        prev_dim = input_dim
        self.return_embeddings = return_embeddings

        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.1))
            prev_dim = hidden_dim

        self.features_sequential = nn.Sequential(*layers)
        self.logits_linear = nn.Linear(prev_dim, output_dim)

    def forward(self, x):
        features = self.features_sequential(x)
        if self.return_embeddings:
            return features
        logits = self.logits_linear(features)
        return logits


class WeightedMacroSoftF1Loss(nn.Module):
    def __init__(self, class_weights=None, epsilon=1e-7):
        super().__init__()
        self.epsilon = epsilon
        if class_weights is not None:
            self.register_buffer('weights', torch.tensor(class_weights))
        else:
            self.weights = None

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)

        tp = (probs * targets).sum(dim=0)
        fp = (probs * (1 - targets)).sum(dim=0)
        fn = ((1 - probs) * targets).sum(dim=0)

        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)

        f1 = 2 * precision * recall / (precision + recall + self.epsilon)
        f1 = torch.clamp(f1, 0, 1)

        if self.weights is not None:
            weights = self.weights.to(logits.device)
            macro_f1 = (f1 * weights).sum() / weights.sum()
        else:
            macro_f1 = f1.mean()

        return 1 - macro_f1



In [41]:
class ProteinEmbeddingsDataset(Dataset):
    def __init__(self, features, labels, total_labels=output_dim, label_metadata=None):
        self.features = features.astype(np.float32)
        self.labels = labels[:, :total_labels].astype(np.float32)
        self.label_metadata = label_metadata

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        x = torch.FloatTensor(self.features[idx])
        y = torch.FloatTensor(self.labels[idx])
        return x, y

In [42]:
for k, subontology_metadat_dict in protein_function_metadata_dict.items():
    protein_embeddings_dataset_train = ProteinEmbeddingsDataset(subontology_metadat_dict["X_train_imputed"],subontology_metadat_dict["y_train_transformed"], total_labels=output_dim)
    protein_embeddings_dataset_val = ProteinEmbeddingsDataset(subontology_metadat_dict["X_val_imputed"], subontology_metadat_dict["y_val_transformed"], total_labels=output_dim)
    subontology_metadat_dict["protein_embeddings_dataset_train"] = protein_embeddings_dataset_train
    subontology_metadat_dict["protein_embeddings_dataset_val"] = protein_embeddings_dataset_val

In [43]:
from sklearn.metrics import f1_score

def evaluate_f1(model, val_loader, device, threshold=0.5):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for features, targets in val_loader:
            features = features.to(device)
            targets = targets.to(device)

            logits = model(features)
            probs = torch.sigmoid(logits)

            preds = (probs > threshold).float()

            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    num_labels = all_targets.shape[1]
    f1_scores = []

    for i in range(num_labels):
        try:
            f1 = f1_score(all_targets[:, i], all_preds[:, i], zero_division=0   )
        except ValueError:
            f1 = 0.0
        f1_scores.append(f1)

    f1_scores = np.array(f1_scores)
    best_idx = np.argsort(f1_scores)[-3:]
    worst_idx = np.argsort(f1_scores)[:3]

    return f1_scores, best_idx, worst_idx


In [44]:

class FocalLossLogits(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, logits, targets, pos_weights=None):
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='mean', pos_weight=pos_weights)
        preds = torch.sigmoid(logits)
        p_t = targets * preds + (1 - targets) * (1 - preds)
        focal_weight = (1 - p_t) **self.gamma

        if self.alpha is not None:
            if isinstance(self.alpha, torch.Tensor):
                alpha_t = self.alpha.view(1, -1).to(logits.device)
            else:
                alpha_t = torch.tensor(self.alpha, device=logits.device).view(1, -1)

            alpha_factor = targets * alpha_t + (1 - targets) * (1 - alpha_t)
            focal_weight = alpha_factor * focal_weight

        loss = focal_weight * bce

        if self.reduction == 'mean':
            return loss.mean()
        return loss

In [243]:
def obtain_label_metadata(curr_y, output_dim):
    label_metadata = {
    "pos_weights": []
    }
    print(len(curr_y[0]))
    for i in range(output_dim):
        if not i % 1000: print(i)
        y_i = curr_y[:, i]
        pos = np.sum(y_i == 1)
        neg = np.sum(y_i == 0)
        scale_pos_weight = neg / pos if pos > 0 else 1.0
        label_metadata["pos_weights"].append(scale_pos_weight)
    return label_metadata


def soft_f1_loss(y_pred, y_true, epsilon=1e-7):
    y_pred = torch.sigmoid(y_pred)
    y_pred = y_pred.view(-1)
    y_true = y_true.view(-1)
    tp = (y_pred * y_true).sum()
    fp = (y_pred * (1 - y_true)).sum()
    fn = ((1 - y_pred) * y_true).sum()

    soft_f1 = 2 * tp / (2 * tp + fp + fn + epsilon)
    return 1 - soft_f1


def combined_bce_soft_f1_loss(logits, targets, pos_weights=None, alpha=0.1):
    bce = F.binary_cross_entropy_with_logits(
        logits, targets, pos_weight=pos_weights, reduction="mean"
    )

    f1 = soft_f1_loss(logits, targets)
    # loss = alpha * bce + (1 - alpha) * f1
    f1_soft_weighted = WeightedMacroSoftF1Loss(pos_weights)
    focal_loss = FocalLossLogits()
    loss = f1_soft_weighted(logits, targets)

    return loss


In [328]:
import warnings
warnings.filterwarnings("ignore")
LR = 1e-5
EPOCHS = 25

def train_model(model, train_loader, val_loader, pos_weights, num_epochs=EPOCHS):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    pos_weights = pos_weights.to(device)

    optimizer = optim.Adam(model.parameters(), lr=LR)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for features, targets in train_loader:
            features = features.to(device)
            targets = targets.to(device)
            logits = model(features)

            loss = combined_bce_soft_f1_loss(logits, targets, pos_weights)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()



            total_loss += loss.item()

        total_val_loss = 0.0

        with torch.no_grad():
            for val_features, val_targets in val_loader:
                val_features = val_features.to(device)
                val_targets = val_targets.to(device)
                val_logits = model(val_features)
                val_loss = combined_bce_soft_f1_loss(val_logits, val_targets, pos_weights)
                total_val_loss += val_loss.item()

        avg_loss = total_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)


        # Validate using the validation dataloader
        f1_scores, best_idx, worst_idx = evaluate_f1(model, val_loader, device)
        print(f"curr epoch {epoch+1}/{num_epochs}] loss: {avg_loss} val_loss {avg_val_loss}")
        print("Validation op 3 labels (best F1):")
        for idx in reversed(best_idx):
            print(f"Label {idx}: F1 = {f1_scores[idx]}")

        print("Validation Bottom 3 labels (worst F1):")
        for idx in worst_idx:
            print(f"Label {idx}: F1 = {f1_scores[idx]}")

        print(f"Validation average f1 score: {np.mean(f1_scores)}")

    return model

In [245]:
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
      # print(len(subontology_metadata_dict["y_train_transformed"]))
      # print(subontology_metadata_dict["y_train_transformed"])
      # print(subontology_metadata_dict["y_train_transformed"].shape)
      label_metadata = obtain_label_metadata(subontology_metadata_dict["y_train_transformed"], output_dim)
      subontology_metadata_dict["label_metadata"] = label_metadata
      curr_model = WeightedMultiLabelNN()
      data_loader = DataLoader(subontology_metadata_dict["protein_embeddings_dataset_train"], batch_size=128)
      data_loader_val = DataLoader(subontology_metadata_dict["protein_embeddings_dataset_val"], batch_size=len(subontology_metadat_dict["X_val_imputed"]))
      subontology_metadata_dict["curr_model"] = curr_model
      subontology_metadata_dict["data_loader"] = data_loader
      subontology_metadata_dict["data_loader_val"] = data_loader_val
      print(f"\n\nTRAINING: {k} model \n\n")
      trained_model_nn = train_model(curr_model, data_loader, data_loader_val, torch.tensor(label_metadata["pos_weights"][:output_dim]))
      subontology_metadata_dict["trained_model_nn"] = trained_model_nn


200
0


TRAINING: C model 


curr epoch 1/25] loss: 0.9780083396743914 val_loss 0.9387632720521063
Validation op 3 labels (best F1):
Label 158: F1 = 0.6415094339622641
Label 17: F1 = 0.5546786922209695
Label 7: F1 = 0.5508707607699358
Validation Bottom 3 labels (worst F1):
Label 111: F1 = 0.0
Label 123: F1 = 0.0
Label 187: F1 = 0.0
Validation average f1 score: 0.11817551989925264
curr epoch 2/25] loss: 0.9480771757821164 val_loss 0.9094714894637064
Validation op 3 labels (best F1):
Label 158: F1 = 0.6415094339622641
Label 164: F1 = 0.6119402985074627
Label 17: F1 = 0.5711662075298439
Validation Bottom 3 labels (worst F1):
Label 111: F1 = 0.0
Label 123: F1 = 0.0
Label 182: F1 = 0.0
Validation average f1 score: 0.14470395494449115
curr epoch 3/25] loss: 0.9347204440132673 val_loss 0.8911447963109518
Validation op 3 labels (best F1):
Label 158: F1 = 0.6551724137931034
Label 164: F1 = 0.6323529411764706
Label 17: F1 = 0.5660377358490566
Validation Bottom 3 labels (worst F1):
Label 111: F1 

KeyboardInterrupt: 

In [48]:
def predict(model, dataloader, device):
    model.eval()
    preds = []
    labels = []

    with torch.no_grad():
        for batch in dataloader:
            X, y = batch
            X = X.to(device)
            y = y.to(device)

            logits = model(X)
            out = torch.sigmoid(logits)
            preds.append(out.cpu())
            labels.append(y.cpu())

    preds = torch.cat(preds)
    labels = torch.cat(labels)
    return preds, labels

In [49]:
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    try:
      y_preds, y_labels = predict(subontology_metadata_dict["trained_model_nn"], subontology_metadata_dict["data_loader_val"], 'cuda')
      subontology_metadata_dict["y_preds"] = y_preds
      subontology_metadata_dict["y_labels"] = y_labels
    except:
      print(k)
      continue


C
F
P


In [50]:
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
      pred_entry_ids_df = pd.DataFrame({"EntryID": subontology_metadata_dict["X_val_entry_ids"], "prediction": [row.tolist() for row in subontology_metadata_dict["y_preds"]]})
      label_names = subontology_metadata_dict["mlb"].classes_[:output_dim]
      num_labels = len(label_names)

      pred_matrix = np.vstack(pred_entry_ids_df["prediction"].values)

      pred_df = pd.DataFrame(pred_matrix, columns=label_names)

      prediction_df = pd.concat([pred_entry_ids_df[["EntryID"]], pred_df], axis=1)
      go_cols = [c for c in prediction_df.columns if c != "EntryID"]
      to_score_df = prediction_df.melt(id_vars="EntryID", value_vars=go_cols, var_name="term", value_name="score")

      to_score_df = to_score_df.sort_values(["EntryID", "term"]).reset_index(drop=True)
      to_score_df["score"] = to_score_df["score"].round(3)
      to_score_df = to_score_df[to_score_df["score"] != 0]
      subontology_metadata_dict["to_score_df"] = to_score_df
      to_score_df.to_csv(f"to_score_{k}.tsv", header=False, index=False, sep="\t")



KeyError: 'y_preds'

In [168]:
# per insturctions we must submit a file with the EntryID, GO and score while droping 0 scores after rounding to 3 decimal point


In [169]:
len(protein_function_subontology_dict['P'])

250805

In [170]:
len(protein_function_metadata_dict['F']["X_train_entry_ids"])

39000

In [171]:
len(protein_function_metadata_dict[k]["X_val_entry_ids"])

2872

In [172]:
for k, curr_protein_function_df in protein_function_subontology_dict.items():
    ground_truth_score_df = curr_protein_function_df[curr_protein_function_df["EntryID"].isin(protein_function_metadata_dict[k]["X_val_entry_ids"])][["EntryID", "term"]]
    print(f"{k}, {len(ground_truth_score_df)}")
    protein_function_metadata_dict[k]["ground_truth_score_df"] = ground_truth_score_df

C, 14701
F, 10640
P, 17865


In [174]:
def predict(model, dataloader, device):
    model.eval()
    preds = []
    labels = []

    with torch.no_grad():
        for batch in dataloader:
            X, y = batch
            X = X.to(device)
            y = y.to(device)

            logits = model(X)
            out = torch.sigmoid(logits)
            preds.append(out.cpu())
            labels.append(y.cpu())

    preds = torch.cat(preds)
    labels = torch.cat(labels)
    return preds, labels

for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    try:
      y_preds, y_labels = predict(subontology_metadata_dict["trained_model_nn"], subontology_metadata_dict["data_loader_val"], 'cuda')
      subontology_metadata_dict["y_preds"] = y_preds
      subontology_metadata_dict["y_labels"] = y_labels
    except:
      print(k)
      continue

for k, subontology_metadata_dict in protein_function_metadata_dict.items():
      pred_entry_ids_df = pd.DataFrame({"EntryID": subontology_metadata_dict["X_val_entry_ids"], "prediction": [row.tolist() for row in subontology_metadata_dict["y_preds"]]})
      label_names = subontology_metadata_dict["mlb"].classes_[:output_dim]
      num_labels = len(label_names)

      pred_matrix = np.vstack(pred_entry_ids_df["prediction"].values)

      pred_df = pd.DataFrame(pred_matrix, columns=label_names)

      prediction_df = pd.concat([pred_entry_ids_df[["EntryID"]], pred_df], axis=1)
      go_cols = [c for c in prediction_df.columns if c != "EntryID"]
      to_score_df = prediction_df.melt(id_vars="EntryID", value_vars=go_cols, var_name="term", value_name="score")

      to_score_df = to_score_df.sort_values(["EntryID", "term"]).reset_index(drop=True)
      to_score_df["score"] = to_score_df["score"].round(3)
      to_score_df = to_score_df[to_score_df["score"] != 0]
      subontology_metadata_dict["to_score_df"] = to_score_df
      to_score_df.to_csv(f"to_score_{k}.tsv", header=False, index=False, sep="\t")

for k, curr_protein_function_df in protein_function_subontology_dict.items():
    ground_truth_score_df = curr_protein_function_df[curr_protein_function_df["EntryID"].isin(protein_function_metadata_dict[k]["X_val_entry_ids"])][["EntryID", "term"]]
    print(f"{k}, {len(ground_truth_score_df)}")
    protein_function_metadata_dict[k]["ground_truth_score_df"] = ground_truth_score_df

for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    subontology_metadata_dict["ground_truth_score_df"].to_csv(f"ground_truth_score_{k}.tsv", header=False, index=False, sep="\t")

ground_truth_score_df_merged = pd.concat([
    protein_function_metadata_dict["C"]["ground_truth_score_df"],
    protein_function_metadata_dict["F"]["ground_truth_score_df"],
    protein_function_metadata_dict["P"]["ground_truth_score_df"]
], ignore_index=True)

score_df_merged = pd.concat([
    protein_function_metadata_dict["C"]["to_score_df"],
    protein_function_metadata_dict["F"]["to_score_df"],
    protein_function_metadata_dict["P"]["to_score_df"]
], ignore_index=True)


In [175]:
len(protein_function_metadata_dict['C']["ground_truth_score_df"])

14701

In [176]:
score_df_merged.size, ground_truth_score_df_merged.size

(863703, 86412)

In [177]:
ground_truth_score_df_merged.to_csv("ground_truth_score2.tsv", header=False, index=False, sep="\t")
score_df_merged.to_csv("to_score2.tsv", header=False, index=False, sep="\t")

In [51]:
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx
graph = obonet.read_obo(obo_path)
dataset = Planetoid(root='/tmp/cora', name='Cora')



In [52]:
for n in dataset:
    print(n)
    print(type(n))


Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
<class 'torch_geometric.data.data.Data'>


In [53]:
loader2 = torch_geometric.data.DataLoader(dataset, batch_size=64, shuffle=False)


In [54]:
for b in loader2:
    print(b)

DataBatch(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], batch=[2708], ptr=[2])


In [202]:
    G = obonet.read_obo(obo_path)
    nodes_F = [n for n in G.nodes if G.nodes[n]["namespace"] == "molecular_function"]
    nodes_C = [n for n in G.nodes if G.nodes[n]["namespace"] == "cellular_component"]
    nodes_P = [n for n in G.nodes if G.nodes[n]["namespace"] == "biological_process"]
    G.nodes.data
    G_dict = {
        "F": G.subgraph(nodes_F),
        "C": G.subgraph(nodes_C),
        "P": G.subgraph(nodes_P)
    }

In [264]:
import torch
from torch_geometric.data.data import Data
from torch_geometric.data import Dataset as GDataset
def obo_to_pyg(obo_path):
    G = obonet.read_obo(obo_path)
    nodes_F = [n for n in G.nodes if G.nodes[n]["namespace"] == "molecular_function"]
    nodes_C = [n for n in G.nodes if G.nodes[n]["namespace"] == "cellular_component"]
    nodes_P = [n for n in G.nodes if G.nodes[n]["namespace"] == "biological_process"]

    G_dict = {
        "F": G.subgraph(nodes_F),
        "C": G.subgraph(nodes_C),
        "P": G.subgraph(nodes_P)
    }
    G_data_dict = {}
    for k, curr_G in G_dict.items():
        go_list = sorted(curr_G.nodes())
        print(len(go_list))
        idx = {g:i for i,g in enumerate(go_list)}
        x = torch.arange(len(go_list), dtype=torch.long)
        e = [[idx[edge[0]], idx[edge[1]]] for edge in curr_G.edges if edge[0] in idx and edge[1] in idx]
        edge_index = torch.tensor(e, dtype=torch.long).t().contiguous()
        G_data_dict[k] = Data(x=x, edge_index=edge_index, go_ids=go_list)
    return G_data_dict

In [265]:
class GODataset(GDataset):
    def __init__(self, graph, transform=None, pre_transform=None):
        super().__init__(None, transform, pre_transform)
        self.graph = graph
        self.num_nodes = self.graph.num_nodes

    def len(self):
        return self.num_nodes

    def get(self, idx):
        return Data(
            x=self.graph.x[idx].unsqueeze(0),
            edge_index=self.graph.edge_index,
            node_idx=torch.tensor([idx]),
            go_id=self.graph.go_ids[idx]
        )

In [266]:
G_data_dict = obo_to_pyg(obo_path)

10131
4041
25950


In [249]:
G_data_dict

{'F': Data(x=[10131], edge_index=[2, 25138], go_ids=[10131]),
 'C': Data(x=[4041], edge_index=[2, 12866], go_ids=[4041]),
 'P': Data(x=[25950], edge_index=[2, 116432], go_ids=[25950])}

In [3]:
from torch_geometric.loader import DataLoader as GDataLoader
loader = GDataLoader(GODataset(G_data_dict['F']), batch_size=8, shuffle=False)

NameError: name 'GODataset' is not defined

In [250]:
for b in loader:
    print(b)
    print(b.edge_index)
    print(b.x)
    break

NameError: name 'loader' is not defined

In [251]:
G_data_dict['F'].edge_index

tensor([[    0,     1,     2,  ..., 10130, 10130, 10130],
        [ 1418,  1418,     7,  ...,  5191,  8365,  8387]])

In [252]:
len(loader)

NameError: name 'loader' is not defined

In [253]:
for b in loader:
    print(b)
    break

NameError: name 'loader' is not defined

In [254]:
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    label_metadata = obtain_label_metadata(subontology_metadata_dict["y_train_transformed"], output_dim)
    subontology_metadata_dict["label_metadata"] = label_metadata
    curr_model = WeightedMultiLabelNN()
    data_loader = DataLoader(subontology_metadata_dict["protein_embeddings_dataset_train"], batch_size=64)
    data_loader_val = DataLoader(subontology_metadata_dict["protein_embeddings_dataset_val"], batch_size=len(subontology_metadat_dict["X_val_imputed"]))
    subontology_metadata_dict["curr_model"] = curr_model
    subontology_metadata_dict["data_loader"] = data_loader
    subontology_metadata_dict["data_loader_val"] = data_loader_val
    print(f"\n\nTRAINING: {k} model \n\n")
    trained_model_nn = train_model(curr_model, data_loader, data_loader_val, torch.tensor(label_metadata["pos_weights"][:output_dim]))
    subontology_metadata_dict["trained_model_nn"] = trained_model_nn


200
0


TRAINING: C model 


curr epoch 1/25] loss: 0.9772897357706877 val_loss 0.936767575995377
Validation op 3 labels (best F1):
Label 17: F1 = 0.5617461229178633
Label 7: F1 = 0.5446888160973327
Label 158: F1 = 0.5063291139240507
Validation Bottom 3 labels (worst F1):
Label 182: F1 = 0.0
Label 141: F1 = 0.0
Label 176: F1 = 0.006309148264984227
Validation average f1 score: 0.11828255482834496
curr epoch 2/25] loss: 0.9575086110489229 val_loss 0.9161936173864584
Validation op 3 labels (best F1):
Label 17: F1 = 0.5719757421002234
Label 7: F1 = 0.5521759475900796
Label 158: F1 = 0.5507246376811594
Validation Bottom 3 labels (worst F1):
Label 111: F1 = 0.0
Label 141: F1 = 0.0
Label 182: F1 = 0.0
Validation average f1 score: 0.1332268491670524
curr epoch 3/25] loss: 0.9497656388864812 val_loss 0.9073638601884766
Validation op 3 labels (best F1):
Label 158: F1 = 0.631578947368421
Label 17: F1 = 0.5819191607528541
Label 164: F1 = 0.5660377358490566
Validation Bottom 3 labels (worst F1):
La

KeyboardInterrupt: 

In [361]:
# Graph Sage from colab 3
import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

from torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
                                    OptTensor)

from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.utils import add_self_loops

class GNNStack(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, args, n_terms, emb=False):
        # input_dim // 8
        super(GNNStack, self).__init__()
        conv_model = self.build_conv_model(args["model_type"])
        self.embeddings = nn.Embedding(n_terms, input_dim)
        self.convs = nn.ModuleList()
        self.convs.append(conv_model(input_dim, hidden_dim))
        assert (args["num_layers"] >= 1), 'Number of layers is not >=1'
        for l in range(args["num_layers"]-1):
            self.convs.append(conv_model(args["heads"] * hidden_dim, hidden_dim))

        # post-message-passing
        self.post_mp = nn.Sequential(
            nn.Linear(args["heads"] * hidden_dim, hidden_dim), nn.Dropout(args["dropout"]),
            nn.Linear(hidden_dim, output_dim))

        self.dropout = args["dropout"]
        self.num_layers = args["num_layers"]

        self.emb = emb

    def build_conv_model(self, model_type):
        if model_type == 'GraphSage':
            return GraphSage
        elif model_type == 'GAT':
            return GAT

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x_emb = self.embeddings(x)
        for i in range(self.num_layers):
            x_emb = self.convs[i](x_emb, edge_index)
            x_emb = F.relu(x_emb)
            x_emb = F.dropout(x_emb, p=self.dropout,training=self.training)

        x_emb = self.post_mp(x_emb)


        return x_emb

    def loss(self, pred, label):
        return F.nll_loss(pred, label)


class GraphSage(MessagePassing):

    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):
        super(GraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_l = None
        self.lin_r = None

        self.lin_l = nn.Linear(in_channels, out_channels, bias=bias)
        self.lin_r = nn.Linear(in_channels, out_channels, bias=bias)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

    def forward(self, x, edge_index, size = None):
        out = None
        num_nodes = x.size()[0]
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        out = self.propagate(x=x, edge_index=edge_index)

        out = self.lin_l(x) + self.lin_r(out)
        if self.normalize: out = torch.nn.functional.normalize(out)

        return out

    def message(self, x_j):
        out = None
        out = x_j

        return out

    def aggregate(self, inputs, index, dim_size = None):
        out = None
        # The axis along which to index number of nodes.
        node_dim = self.node_dim
        out = torch_scatter.scatter(inputs, index, dim=0, reduce='mean')
        return out

class GAT(MessagePassing):

    def __init__(self, in_channels, out_channels, heads = 2,
                 negative_slope = 0.2, dropout = 0., **kwargs):
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.lin_l = None
        self.lin_r = None
        self.att_l = None
        self.att_r = None

        self.lin_l, self.lin_r = nn.Linear(in_channels, out_channels * self.heads), nn.Linear(in_channels, out_channels * self.heads)
        self.att_l, self.att_r = nn.Parameter(torch.empty(self.heads, out_channels)), nn.Parameter(torch.empty(self.heads, out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_l.weight)
        nn.init.xavier_uniform_(self.lin_r.weight)
        nn.init.xavier_uniform_(self.att_l)
        nn.init.xavier_uniform_(self.att_r)

    def forward(self, x, edge_index, size = None):

        H, C = self.heads, self.out_channels
        # print(f"{H} {C}")
        # print(x.shape)
        W_l_transformed_x, W_r_transformed_x = self.lin_l(x), self.lin_r(x)
        # print(f"W_l_tr_x: {W_l_transformed_x.shape}, W_r_tr_x:{W_r_transformed_x.shape}")
        N = x.size(0)
        x_l = W_l_transformed_x.view(N, H, C)
        x_r = W_r_transformed_x.view(N, H, C)
        # print(f"x_l {x_l.shape} x_r {x_r.shape} {self.att_l.shape} {self.att_r.shape}")
        alpha_central = (x_l * self.att_l.unsqueeze(0)).sum(dim=-1)
        alpha_neighbors = (x_r * self.att_r.unsqueeze(0)).sum(dim=-1)
        num_nodes = x.size()[0]
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_central, alpha_neighbors), size=size)
        out = out.view(N, H * C)
        # print(f"forward out.shape {out.shape}")

        return out


    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
        alpha = alpha_i + alpha_j
        alpha = F.leaky_relu(alpha, negative_slope=.2)
        alpha = softmax(alpha,index,num_nodes=size_i, ptr=ptr, )
        alpha = F.dropout(alpha, training=self.training, p=self.dropout)
        out = x_j * alpha.unsqueeze(-1)

        return out


    def aggregate(self, inputs, index, dim_size = None):
        out = torch_scatter.scatter(inputs, index, dim=0)
        # print(f"out aggregate: {out.shape}")
        return out


In [366]:
import math
args_dict = {
    "heads": 1,
    "dropout": 0,
    "num_layers": 5,
    "model_type": "GraphSage"
}
class TwoTowerGNNAndNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, logits_dim, output_dim, graph_edge_dim, args_dict):
        super().__init__()
        self.NN = WeightedMultiLabelNN(input_dim=input_dim, output_dim=logits_dim)
        self.GNN = GNNStack(input_dim, hidden_dim, logits_dim, args_dict, graph_edge_dim)
        self.graph_edge_dim = graph_edge_dim
        self.W_map_graph = nn.Linear(graph_edge_dim, hidden_dim)
        self.final = nn.Linear(logits_dim, output_dim)


    def forward(self, x_protein_embeddings, x_nodes, size=None):
        x_p = self.NN(x_protein_embeddings)
        x_n = self.GNN(x_nodes)
        x_n = self.W_map_graph(x_n.T)
        x_n = x_n.mean(dim=1)
        x = x_p + x_n
        logits = self.final(x)
        return logits

class TwoTowerGNNAndNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, logits_dim, output_dim, graph_edge_dim, args_dict):
        super().__init__()
        hidden_dims=[256, 512, 1024, 2048]
        self.NN = WeightedMultiLabelNN(input_dim=input_dim, output_dim=logits_dim, hidden_dims=hidden_dims)
        self.GNN = GNNStack(input_dim, hidden_dim, logits_dim, args_dict, graph_edge_dim)
        self.final = nn.Linear(logits_dim, output_dim)

    def forward(self, x_protein_embeddings, x_nodes, size=None):
        x_p = self.NN(x_protein_embeddings)
        x_n = self.GNN(x_nodes)
        attention_scores = torch.matmul(x_p, x_n.T)
        attention_weights = F.softmax(attention_scores / math.sqrt(x_p.size(-1)), dim=-1)
        x_n_weighted = torch.matmul(attention_weights, x_n)

        x_add = x_p + x_n_weighted
        x_mult = x_p * x_n_weighted
        x_cat = torch.cat([x_p, x_n_weighted], dim=-1)
        # x = torch.cat([x_add, x_mult, x_cat], dim=-1)
        print(x_cat.shape)
        logits = self.final(x_cat)
        return logits



In [367]:
def evaluate_f1_multi_tower(model, val_loader, device, graph_data, threshold=0.5):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for features, targets in val_loader:
            features = features.to(device)
            targets = targets.to(device)

            logits = model(features, graph_data)
            probs = torch.sigmoid(logits)

            preds = (probs > threshold).float()

            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    num_labels = all_targets.shape[1]
    f1_scores = []

    for i in range(num_labels):
        try:
            f1 = f1_score(all_targets[:, i], all_preds[:, i], zero_division=0   )
        except ValueError:
            f1 = 0.0
        f1_scores.append(f1)

    f1_scores = np.array(f1_scores)
    best_idx = np.argsort(f1_scores)[-3:]
    worst_idx = np.argsort(f1_scores)[:3]

    return f1_scores, best_idx, worst_idx


In [368]:
import warnings
warnings.filterwarnings("ignore")

def train_model_two_tower(model, train_loader, val_loader, graph_data_loader, pos_weights, num_epochs=25):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    pos_weights = pos_weights.to(device)
    graph_data_loader = graph_data_loader.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    # graph_iter = iter(graph_data_loader)
    for epoch in range(num_epochs):

        model.train()
        total_loss = 0.0

        for features, targets in train_loader:
            features = features.to(device)
            targets = targets.to(device)
            logits = model(features, graph_data_loader)


            loss = combined_bce_soft_f1_loss(logits, targets, pos_weights)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()



            total_loss += loss.item()

        total_val_loss = 0.0

        with torch.no_grad():
            for val_features, val_targets in val_loader:
                val_features = val_features.to(device)
                val_targets = val_targets.to(device)
                val_logits = model(val_features, graph_data_loader)
                val_loss = combined_bce_soft_f1_loss(val_logits, val_targets, pos_weights)
                total_val_loss += val_loss.item()

        avg_loss = total_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)


        # Validate using the validation dataloader
        f1_scores, best_idx, worst_idx = evaluate_f1_multi_tower(model, val_loader, device, graph_data_loader)
        print(f"curr epoch {epoch+1}/{num_epochs}] loss: {avg_loss} val_loss {avg_val_loss}")
        print("Validation op 3 labels (best F1):")
        for idx in reversed(best_idx):
            print(f"Label {idx}: F1 = {f1_scores[idx]}")

        print("Validation Bottom 3 labels (worst F1):")
        for idx in worst_idx:
            print(f"Label {idx}: F1 = {f1_scores[idx]}")

        print(f"Validation average f1 score: {np.mean(f1_scores)}")

    return model

In [344]:
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    subontology_metadata_dict["graph_data_loader"] = G_data_dict[k]
    # C_loader = DataLoader(datasets["C"], batch_size=64, shuffle=True)
    # P_loader = DataLoader(datasets["P"], batch_size=64, shuffle=True)

In [326]:
current_val_test = protein_function_metadata_dict['C']["graph_data_loader"]
current_val_test.x.size()[0]

4041

In [365]:
# G_data_dict = obo_to_pyg(obo_path) # Get dictionaries
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    label_metadata = obtain_label_metadata(subontology_metadata_dict["y_train_transformed"], output_dim)
    subontology_metadata_dict["label_metadata"] = label_metadata

    data_loader = DataLoader(subontology_metadata_dict["protein_embeddings_dataset_train"], batch_size=128)
    data_loader_val = DataLoader(subontology_metadata_dict["protein_embeddings_dataset_val"], batch_size=len(subontology_metadat_dict["X_val_imputed"]))
    subontology_metadata_dict["curr_model_two_tower"] = curr_model
    subontology_metadata_dict["data_loader"] = data_loader
    subontology_metadata_dict["data_loader_val"] = data_loader_val
    graph_data_loader = subontology_metadata_dict["graph_data_loader"]
    print(f"\n\nTRAINING: {k} model \n\n")
    print(subontology_metadata_dict["graph_data_loader"].x.size()[0])
    curr_model = TwoTowerGNNAndNN(PCA_TARGET_DIM, 256, 512, output_dim, subontology_metadata_dict["graph_data_loader"].x.size()[0], args_dict)
    trained_model_nn = train_model_two_tower(curr_model, data_loader, data_loader_val, graph_data_loader, torch.tensor(label_metadata["pos_weights"][:output_dim]))
    subontology_metadata_dict["trained_model_two_tower"] = trained_model_nn

200
0


TRAINING: C model 


4041


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x1024 and 512x200)

In [None]:
# 0.189 0.69, 0.65, 0.55 with bce

In [329]:
def predict_two_tower(model, dataloader, graph_data, device):
    model.eval()
    preds = []
    labels = []

    with torch.no_grad():
        for batch in dataloader:
            X, y = batch
            X = X.to(device)
            y = y.to(device)

            logits = model(X, graph_data)
            out = torch.sigmoid(logits)
            preds.append(out.cpu())
            labels.append(y.cpu())

    preds = torch.cat(preds)
    labels = torch.cat(labels)
    return preds, labels

for k, subontology_metadata_dict in protein_function_metadata_dict.items():
      y_preds, y_labels = predict_two_tower(subontology_metadata_dict["trained_model_two_tower"], subontology_metadata_dict["data_loader_val"],  subontology_metadata_dict["graph_data_loader"], 'cuda')
      subontology_metadata_dict["y_preds_two_tower"] = y_preds
      subontology_metadata_dict["y_labels_two_tower"] = y_labels

for k, subontology_metadata_dict in protein_function_metadata_dict.items():
      pred_entry_ids_df = pd.DataFrame({"EntryID": subontology_metadata_dict["X_val_entry_ids"], "prediction": [row.tolist() for row in subontology_metadata_dict["y_preds_two_tower"]]})
      label_names = subontology_metadata_dict["mlb"].classes_[:output_dim]
      num_labels = len(label_names)

      pred_matrix = np.vstack(pred_entry_ids_df["prediction"].values)

      pred_df = pd.DataFrame(pred_matrix, columns=label_names)

      prediction_df = pd.concat([pred_entry_ids_df[["EntryID"]], pred_df], axis=1)
      go_cols = [c for c in prediction_df.columns if c != "EntryID"]
      to_score_df = prediction_df.melt(id_vars="EntryID", value_vars=go_cols, var_name="term", value_name="score")

      to_score_df = to_score_df.sort_values(["EntryID", "term"]).reset_index(drop=True)
      to_score_df["score"] = to_score_df["score"].round(3)
      to_score_df = to_score_df[to_score_df["score"] != 0]
      subontology_metadata_dict["to_score_df"] = to_score_df
      to_score_df.to_csv(f"to_score_{k}.tsv", header=False, index=False, sep="\t")

for k, curr_protein_function_df in protein_function_subontology_dict.items():
    ground_truth_score_df = curr_protein_function_df[curr_protein_function_df["EntryID"].isin(protein_function_metadata_dict[k]["X_val_entry_ids"])][["EntryID", "term"]]
    print(f"{k}, {len(ground_truth_score_df)}")
    protein_function_metadata_dict[k]["ground_truth_score_df"] = ground_truth_score_df

for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    subontology_metadata_dict["ground_truth_score_df"].to_csv(f"ground_truth_score_{k}.tsv", header=False, index=False, sep="\t")

ground_truth_score_df_merged = pd.concat([
    protein_function_metadata_dict["C"]["ground_truth_score_df"],
    protein_function_metadata_dict["F"]["ground_truth_score_df"],
    protein_function_metadata_dict["P"]["ground_truth_score_df"]
], ignore_index=True)

score_df_merged = pd.concat([
    protein_function_metadata_dict["C"]["to_score_df"],
    protein_function_metadata_dict["F"]["to_score_df"],
    protein_function_metadata_dict["P"]["to_score_df"]
], ignore_index=True)


C, 14701
F, 10640
P, 17865


In [330]:
ground_truth_score_df_merged.to_csv("ground_truth_score4.tsv", header=False, index=False, sep="\t")
score_df_merged.to_csv("to_score4.tsv", header=False, index=False, sep="\t")

In [350]:
args_dict = {
    "heads": 2,
    "dropout": 0,
    "num_layers": 5,
    "model_type": "GAT"
}

# G_data_dict = obo_to_pyg(obo_path) # Get dictionaries
for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    label_metadata = obtain_label_metadata(subontology_metadata_dict["y_train_transformed"], output_dim)
    subontology_metadata_dict["label_metadata"] = label_metadata

    data_loader = DataLoader(subontology_metadata_dict["protein_embeddings_dataset_train"], batch_size=128)
    data_loader_val = DataLoader(subontology_metadata_dict["protein_embeddings_dataset_val"], batch_size=len(subontology_metadat_dict["X_val_imputed"]))
    subontology_metadata_dict["curr_model_two_tower"] = curr_model
    subontology_metadata_dict["data_loader"] = data_loader
    subontology_metadata_dict["data_loader_val"] = data_loader_val
    graph_data_loader = subontology_metadata_dict["graph_data_loader"]
    print(f"\n\nTRAINING: {k} model \n\n")
    print(subontology_metadata_dict["graph_data_loader"].x.size()[0])
    curr_model = TwoTowerGNNAndNN(PCA_TARGET_DIM, 256, 512, output_dim, subontology_metadata_dict["graph_data_loader"].x.size()[0], args_dict)
    trained_model_nn = train_model_two_tower(curr_model, data_loader, data_loader_val, graph_data_loader, torch.tensor(label_metadata["pos_weights"][:output_dim]))
    subontology_metadata_dict["trained_model_two_tower"] = trained_model_nn

200
0


TRAINING: C model 


4041
curr epoch 1/25] loss: 0.9767776515595437 val_loss 0.9360862153950864
Validation op 3 labels (best F1):
Label 158: F1 = 0.6206896551724138
Label 164: F1 = 0.5925925925925926
Label 5: F1 = 0.5714285714285714
Validation Bottom 3 labels (worst F1):
Label 51: F1 = 0.0
Label 53: F1 = 0.0
Label 109: F1 = 0.0
Validation average f1 score: 0.0965192570110381
curr epoch 2/25] loss: 0.9551621003612936 val_loss 0.915682756444702
Validation op 3 labels (best F1):
Label 158: F1 = 0.6031746031746031
Label 164: F1 = 0.5806451612903226
Label 21: F1 = 0.5384615384615384
Validation Bottom 3 labels (worst F1):
Label 4: F1 = 0.0
Label 53: F1 = 0.0
Label 109: F1 = 0.0
Validation average f1 score: 0.11954365299822584
curr epoch 3/25] loss: 0.9460154279399453 val_loss 0.9034804253113475
Validation op 3 labels (best F1):
Label 158: F1 = 0.6666666666666666
Label 164: F1 = 0.6013986013986014
Label 21: F1 = 0.5882352941176471
Validation Bottom 3 labels (worst F1):
Label 53: F1 = 

In [351]:
def predict_two_tower(model, dataloader, graph_data, device):
    model.eval()
    preds = []
    labels = []

    with torch.no_grad():
        for batch in dataloader:
            X, y = batch
            X = X.to(device)
            y = y.to(device)

            logits = model(X, graph_data)
            out = torch.sigmoid(logits)
            preds.append(out.cpu())
            labels.append(y.cpu())

    preds = torch.cat(preds)
    labels = torch.cat(labels)
    return preds, labels

for k, subontology_metadata_dict in protein_function_metadata_dict.items():
      y_preds, y_labels = predict_two_tower(subontology_metadata_dict["trained_model_two_tower"], subontology_metadata_dict["data_loader_val"],  subontology_metadata_dict["graph_data_loader"], 'cuda')
      subontology_metadata_dict["y_preds_two_tower"] = y_preds
      subontology_metadata_dict["y_labels_two_tower"] = y_labels

for k, subontology_metadata_dict in protein_function_metadata_dict.items():
      pred_entry_ids_df = pd.DataFrame({"EntryID": subontology_metadata_dict["X_val_entry_ids"], "prediction": [row.tolist() for row in subontology_metadata_dict["y_preds_two_tower"]]})
      label_names = subontology_metadata_dict["mlb"].classes_[:output_dim]
      num_labels = len(label_names)

      pred_matrix = np.vstack(pred_entry_ids_df["prediction"].values)

      pred_df = pd.DataFrame(pred_matrix, columns=label_names)

      prediction_df = pd.concat([pred_entry_ids_df[["EntryID"]], pred_df], axis=1)
      go_cols = [c for c in prediction_df.columns if c != "EntryID"]
      to_score_df = prediction_df.melt(id_vars="EntryID", value_vars=go_cols, var_name="term", value_name="score")

      to_score_df = to_score_df.sort_values(["EntryID", "term"]).reset_index(drop=True)
      to_score_df["score"] = to_score_df["score"].round(3)
      to_score_df = to_score_df[to_score_df["score"] != 0]
      subontology_metadata_dict["to_score_df"] = to_score_df
      to_score_df.to_csv(f"to_score_{k}.tsv", header=False, index=False, sep="\t")

for k, curr_protein_function_df in protein_function_subontology_dict.items():
    ground_truth_score_df = curr_protein_function_df[curr_protein_function_df["EntryID"].isin(protein_function_metadata_dict[k]["X_val_entry_ids"])][["EntryID", "term"]]
    print(f"{k}, {len(ground_truth_score_df)}")
    protein_function_metadata_dict[k]["ground_truth_score_df"] = ground_truth_score_df

for k, subontology_metadata_dict in protein_function_metadata_dict.items():
    subontology_metadata_dict["ground_truth_score_df"].to_csv(f"ground_truth_score_{k}.tsv", header=False, index=False, sep="\t")

ground_truth_score_df_merged = pd.concat([
    protein_function_metadata_dict["C"]["ground_truth_score_df"],
    protein_function_metadata_dict["F"]["ground_truth_score_df"],
    protein_function_metadata_dict["P"]["ground_truth_score_df"]
], ignore_index=True)

score_df_merged = pd.concat([
    protein_function_metadata_dict["C"]["to_score_df"],
    protein_function_metadata_dict["F"]["to_score_df"],
    protein_function_metadata_dict["P"]["to_score_df"]
], ignore_index=True)


C, 14701
F, 10640
P, 17865


In [352]:
ground_truth_score_df_merged.to_csv("ground_truth_score5.tsv", header=False, index=False, sep="\t")
score_df_merged.to_csv("to_score5.tsv", header=False, index=False, sep="\t")

In [304]:
import torch_geometric as torch_g
data_loader_cora = torch_g.data.DataLoader(dataset, batch_size=64, shuffle=False)
for batch in data_loader_cora:
    print(batch)
    break

DataBatch(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], batch=[2708], ptr=[2])


In [None]:
graph = obonet.read_obo(obo_path)
graph

<networkx.classes.multidigraph.MultiDiGraph at 0x79c8fb10c610>

In [None]:
def create_go_embeddings_optimized(obo_path, go_terms, embed_dim=16, hidden_dim=32, out_dim=16, epochs=50):

    print(" Loading Gene Ontology...")
    graph = obonet.read_obo(obo_path)


    edges = pd.DataFrame([
        {'source': u, 'target': v, 'relation': data.get('relation', 'is_a')}
        for u, v, data in graph.edges(data=True)
    ])

    relevant_edges = edges[
        edges['source'].isin(go_terms) | edges['target'].isin(go_terms)
    ].reset_index(drop=True)

    nodes = pd.DataFrame({'id': list(set(relevant_edges['source']).union(relevant_edges['target']))})
    nodes['node_idx'] = range(len(nodes))
    node2idx = dict(zip(nodes['id'], nodes['node_idx']))

    edge_index = torch.tensor([
        [node2idx[s] for s in relevant_edges['source']],
        [node2idx[t] for t in relevant_edges['target']]
    ], dtype=torch.long)

    num_nodes = len(nodes)
    print(f"Using {num_nodes} GO terms and {len(relevant_edges)} edges")


    x = torch.randn((num_nodes, embed_dim), dtype=torch.float32)


    class SimpleGCN(nn.Module):
        def __init__(self, in_dim, hidden_dim, out_dim):
            super(SimpleGCN, self).__init__()
            self.conv1 = GCNConv(in_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, out_dim)

        def forward(self, x, edge_index):
            x = self.conv1(x, edge_index)
            x = F.relu(x)
            x = self.conv2(x, edge_index)
            return x

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleGCN(embed_dim, hidden_dim, out_dim).to(device)


    x = x.to(device)
    edge_index = edge_index.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    data = Data(x=x, edge_index=edge_index)

    print(f"Training on device: {device}")
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        embeddings = model(data.x, data.edge_index)
        # Inner product decoder
        recon = torch.sigmoid(torch.matmul(embeddings, embeddings.T))
        adj_true = torch.zeros_like(recon)
        adj_true[data.edge_index[0], data.edge_index[1]] = 1.0

        loss = F.binary_cross_entropy(recon, adj_true)
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch:03d} | Loss: {loss.item():.4f}")


    with torch.no_grad():
        node_embeddings = model(data.x, data.edge_index).cpu().numpy()

    del model, x, data, recon, adj_true
    torch.cuda.empty_cache()
    gc.collect()

    col_names = [f"go_emb_{i}" for i in range(node_embeddings.shape[1])]

    emb_df = pd.DataFrame(node_embeddings, index=nodes['id'], columns=col_names)
    print(f"Created embeddings for {len(emb_df)} GO terms")
    return emb_df





In [None]:
sampled_data = sample_tsv(term_path, sample_frac=0.05)
go_terms = sampled_data['term'].unique()
embeddings_df = create_go_embeddings_optimized(obo_path, go_terms)
seq_df = extract_sequences(fasta_path, sampled_data['EntryID'])

### Combine GO embedding and PLM embedding into one dataset

In [None]:
def combine_go_protein_embeddings(sampled_data, go_emb_df, prot_emb_df):

    combined = sampled_data.merge(go_emb_df, how='left', left_on='term', right_index=True)

    combined = combined.merge(prot_emb_df, how='left', left_on='EntryID', right_index=True)

    return combined


multimodal_df = combine_go_protein_embeddings(sampled_data, embeddings_df, prot_emb_df)

print("Multimodal feature dataframe shape:", multimodal_df.shape)
print(multimodal_df.head())