In [2]:
!pip install torch==2.7.0
!pip install torch-geometric
!pip install biopython
!pip install obonet
!pip install networkx
!pip install pandas
!pip install numpy
!pip install matplotlib
!pip install seaborn
!pip install scipy
!pip install scikit-learn
!pip install fair-esm

Collecting torch==2.7.0
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch==2.7.0)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch==2.7.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch==2.7.0)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch==2.7.0)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch==2.7.0)
  Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.6.4.1 (from torch==2.7.0)
  Downloading nvidia_cublas_cu12-12.6.4.1-py3-no

In [3]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # avoid fragmentation
import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import Bio
from Bio import SeqIO
import obonet
import gc
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import random
import esm


In [4]:
obo_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo'
fasta_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
term_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv'
taxonomy_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv'

In [5]:
LARGEST_FASTA_SEQ_LEN = 8922
ESM_EMBEDDING_DIM = 320
PCA

sklearn.decomposition._pca.PCA

In [6]:
term_df = pd.read_csv(term_path, sep='\t')
term_df.head(), len(term_df["term"].unique())

(  EntryID        term aspect
 0  Q5W0B1  GO:0000785      C
 1  Q5W0B1  GO:0004842      F
 2  Q5W0B1  GO:0051865      P
 3  Q5W0B1  GO:0006275      P
 4  Q5W0B1  GO:0006513      P,
 26125)

In [7]:
taxonomy_df = pd.read_csv(taxonomy_path, sep='\t', names=['EntryID', 'taxonomyID'])
taxonomy_df.head()

Unnamed: 0,EntryID,taxonomyID
0,A0A0C5B5G6,9606
1,A0JNW5,9606
2,A0JP26,9606
3,A0PK11,9606
4,A1A4S6,9606


In [10]:
# def get_processed_fasta_df(fasta_data, term_df):
#     fasta_dict_list = []
#     term_set = set(term_df.tolist())
#     for fasta_seq in fasta_data:
#         entry = fasta_seq.id.split('|')[1] if '|' in fasta_seq.id else fasta_seq.id
#         if entry in term_set:
#             fasta_dict_list.append({
#                 "EntryID": entry,
#                 "fasta_sequence": str(fasta_seq.seq)
#             })

#     return pd.DataFrame(fasta_dict_list)


def sample_tsv(tsv_df, sample_frac=0.05, random_state=42):
    """
    Read a TSV file and sample based on unique EntryID.
    Pulls in all associated rows for sampled EntryIDs.
    """
    df = tsv_df
    unique_ids = df['EntryID'].unique()
    sample_size = max(1, int(len(unique_ids) * sample_frac))
    sampled_ids = random.sample(list(unique_ids), sample_size)
    sampled_df = df[df['EntryID'].isin(sampled_ids)]
    print(f"Sampled {len(sampled_df)} rows from {len(unique_ids)} unique EntryIDs")
    return sampled_df


def get_processed_fasta_df(fasta_data):
    fasta_dict_list = []
    for fasta_seq in fasta_data:
        entry = fasta_seq.id.split('|')[1] if '|' in fasta_seq.id else fasta_seq.id
        fasta_dict_list.append({
                "EntryID": entry,
                "fasta_sequence": str(fasta_seq.seq)
            })

    return pd.DataFrame(fasta_dict_list)





# Get Merged DF FULL (includes batching to fit embeddings into GPU and offloads to CPU)

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
model = model.to(device)
model = model.half()
model.eval()
batch_converter = alphabet.get_batch_converter()

def generate_protein_embeddings_esm_batch(seq_df, seq_col='Sequence', entryid_col='EntryID',
                                              target_dim=16, batch_size=1, use_fp16=True):
    """
    Memory-optimized ESM embedding generation for proteins.
    Processes small batches and moves embeddings to CPU immediately.
    """
    sequences = seq_df[seq_col].tolist()
    entry_ids = seq_df[entryid_col].tolist()
    print(f"sequences: {len(sequences)} entry_ids {len(entry_ids)}")

    all_embeddings = []
    # print(len(sequences))
    dtype = torch.float16 if use_fp16 else torch.float32
    for i in range(0, len(sequences), batch_size):
        try:
            batch_seqs = sequences[i:i+batch_size]
            batch_labels = entry_ids[i:i+batch_size]
            if len(batch_seqs[0]) > LARGEST_FASTA_SEQ_LEN:
                print(f"length of batch seqs is: {len(batch_seqs[0])}")
                # Split the sequence to be max length LARGEST_FASTA_SEQ_LEN (will always be len 1 batch), this helps avoid the loss of information
                curr_seq = batch_seqs[0]
                embeddings_list = []
                for i in range(0, len( batch_seqs[0]), LARGEST_FASTA_SEQ_LEN):
                  curr_seq = batch_seqs[0][i:i+LARGEST_FASTA_SEQ_LEN]
                  curr_seq_embedding = get_sequence_embedding_esm(batch_labels, [curr_seq], dtype)
                  embeddings_list.append(curr_seq_embedding)
                # Average the embeddings_list since the long sequences were broken down into smaller chunks
                seq_embeddings = np.mean(embeddings_list, axis=0)

            else:
              seq_embeddings = get_sequence_embedding_esm(batch_labels, batch_seqs, dtype)

            all_embeddings.append(seq_embeddings)
            del seq_embeddings
        except Exception as e:
            print(e)
            print(len(batch_seqs[0]))
            print(batch_seqs)


    raw_embeddings = np.vstack(all_embeddings)
    # Mitigate memory constraints
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.reset_peak_memory_stats()

    col_names = [f"prot_emb_{i}" for i in range(320)]
    emb_df = pd.DataFrame(raw_embeddings, index=entry_ids, columns=col_names)
    emb_df.index.name = entryid_col

    return emb_df


def get_sequence_embedding_esm(batch_labels, batch_seqs, dtype):
    batch_data = [(label, seq) for label, seq in zip(batch_labels, batch_seqs)]
    _, _, batch_tokens = batch_converter(batch_data)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
      with torch.autocast(device_type="cuda", dtype=dtype):
          results = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)
          token_embeddings = results["representations"][model.num_layers]  # (B, L, D)
          # Mean pool over sequence length
          attention_mask = batch_tokens != alphabet.padding_idx
          masked_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
          seq_lengths = attention_mask.sum(dim=1).unsqueeze(-1)
          seq_embeddings = (masked_embeddings.sum(dim=1) / seq_lengths).cpu().float().numpy()
          del batch_tokens, token_embeddings, masked_embeddings, results, attention_mask
          return seq_embeddings


def apply_pca_to_esm_embeddings(esm_embeddings_df:pd.DataFrame, target_dim=16):
    '''
    Helper method 
    '''
    emb_cols = [c for c in esm_embeddings_df.columns if c.startswith("prot_emb")]
    
    tensor_list = [ 
        torch.tensor(row[emb_cols].values.astype("float16"), dtype=torch.float16)
        for _, row in esm_embeddings_df.iterrows()
    ]
    
    pca = PCA(n_components=PCA_TARGET_DIM)
    tensor_list_transformed = pca.fit_transform(tensor_list)
    
    # Intialize dictionary where we use EntryID for joining with all the other data
    embeddings_dict = { "EntryID": [] }
    # Initialize empty PCA embeddings for the pd.DataFrame
    for i in range(PCA_TARGET_DIM):
        embeddings_dict[f"emb_{i}"] =  []
        
    # Append EntryID's and embeddings to the dictionary for the dataframe
    for i, r in enumerate(tensor_list_transformed):
        curr_entry_id = esm_embeddings_df["EntryID"][i]
        embeddings_dict["EntryID"].append(curr_entry_id)
        for j in range(PCA_TARGET_DIM):
            embeddings_dict[f"emb_{j}"].append(r[j])

    pca_embeddings_df = pd.DataFrame(embeddings_dict)
    return pca_embeddings_df




In [1]:
file_name = "/kaggle/input/fasta-embeddings-final/fasta_embeddings_final.csv"
embeddings_processed = True

def get_merged_df_full(file_name, batch_size=250, embeddings_processed=False):
    """
    Merge term.tsv, fasta data, taxonomy data as well as nodes in the obo graph.
    """
    term_df = pd.read_csv(term_path, sep='\t')
    term_df = term_df
    taxonomy_df = pd.read_csv(taxonomy_path, sep='\t', names=['EntryID', 'taxonomyID'])
    fasta_data = list(SeqIO.parse(fasta_path, "fasta"))

    entry_ids = list(term_df['EntryID'])
    all_batches = []
    if not embeddings_processed:
      for i in range(0, len(entry_ids), batch_size):
          total_processed = i
          print(f"Total processed: {total_processed}, {i//batch_size} batch")

          # Batch the EntryIDs
          entry_batch = list(set(entry_ids[i:i+batch_size]))

          curr_term_df = term_df[i:i+batch_size]
          fasta_df_batch = get_processed_fasta_df(fasta_data, curr_term_df['EntryID'])
          fasta_emb_df_batch = generate_protein_embeddings_esm_batch(
              fasta_df_batch,
              "fasta_sequence"
          )


          all_batches.append(fasta_emb_df_batch)

          full_df = pd.concat(all_batches, ignore_index=False)
          full_df.to_csv("fasta_embeddings.csv", index=True)

    else:
        full_df = pd.read_csv(file_name)
        print(full_df.head())
        
    fasta_emb_df = apply_pca_to_esm_embeddings(full_df)
    
    # TODO add embeddings in getting merged_df
    merged_df = pd.merge(term_df, fasta_emb_df, on="EntryID", how='left')
    merged_df = pd.merge(merged_df, taxonomy_df, on="EntryID", how="left")
    graph = obonet.read_obo(obo_path)
    edges_list = []
    for node_id, data in graph.nodes(data=True):
        for parent_id in data.get("is_a", []):
            edges_list.append({
                    "term": node_id,
                    "parent": parent_id,
                    "name": data["name"],
                    "namespace": data["namespace"],
                    "def": data["def"],
                    "synonym": data.get("synonym", [])
                })
    edges_df = pd.DataFrame(edges_list)
    merged_df = merged_df.merge(edges_df, on="term", how="left")
    return merged_df

protein_function_df = get_merged_df_full(file_name, embeddings_processed=embeddings_processed)
len(protein_function_df)

NameError: name 'pd' is not defined

In [22]:
protein_function_df.head()

Unnamed: 0,EntryID,term,aspect,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,...,emb_12,emb_13,emb_14,emb_15,taxonomyID,parent,name,namespace,def,synonym
0,Q5W0B1,GO:0000785,C,-0.270628,1.031287,-0.162849,-0.90439,-0.347544,-0.220612,0.036554,...,-0.117032,-0.249009,-0.037782,0.177799,9606,GO:0110165,chromatin,cellular_component,"""The ordered and organized complex of DNA, pro...","[""chromosome scaffold"" RELATED [], ""cytoplasmi..."
1,Q5W0B1,GO:0004842,F,-0.270628,1.031287,-0.162849,-0.90439,-0.347544,-0.220612,0.036554,...,-0.117032,-0.249009,-0.037782,0.177799,9606,GO:0019787,ubiquitin-protein transferase activity,molecular_function,"""Catalysis of the transfer of ubiquitin from o...","[""E2"" BROAD [], ""E3"" BROAD [], ""ubiquitin conj..."
2,Q5W0B1,GO:0051865,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347544,-0.220612,0.036554,...,-0.117032,-0.249009,-0.037782,0.177799,9606,GO:0016567,protein autoubiquitination,biological_process,"""The ubiquitination by a protein of one or mor...","[""protein auto-ubiquitination"" EXACT [], ""prot..."
3,Q5W0B1,GO:0006275,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347544,-0.220612,0.036554,...,-0.117032,-0.249009,-0.037782,0.177799,9606,GO:0051052,regulation of DNA replication,biological_process,"""Any process that modulates the frequency, rat...",[]
4,Q5W0B1,GO:0006513,P,-0.270628,1.031287,-0.162849,-0.90439,-0.347544,-0.220612,0.036554,...,-0.117032,-0.249009,-0.037782,0.177799,9606,GO:0016567,protein monoubiquitination,biological_process,"""Addition of a single ubiquitin group to a pro...","[""protein monoubiquitinylation"" EXACT [], ""pro..."


In [1]:
N_LARGEST = 30
top_terms = protein_function_df['term'].value_counts().nlargest(N_LARGEST).index
protein_function_top_terms_df = protein_function_df[protein_function_df['term'].isin(top_terms)]

NameError: name 'protein_function_df' is not defined

### Create set of embedding from the graph edges using GCN

In [20]:
def create_go_embeddings_optimized(obo_path, go_terms, embed_dim=16, hidden_dim=32, out_dim=16, epochs=50):

    print(" Loading Gene Ontology...")
    graph = obonet.read_obo(obo_path)


    edges = pd.DataFrame([
        {'source': u, 'target': v, 'relation': data.get('relation', 'is_a')}
        for u, v, data in graph.edges(data=True)
    ])

    relevant_edges = edges[
        edges['source'].isin(go_terms) | edges['target'].isin(go_terms)
    ].reset_index(drop=True)

    nodes = pd.DataFrame({'id': list(set(relevant_edges['source']).union(relevant_edges['target']))})
    nodes['node_idx'] = range(len(nodes))
    node2idx = dict(zip(nodes['id'], nodes['node_idx']))

    edge_index = torch.tensor([
        [node2idx[s] for s in relevant_edges['source']],
        [node2idx[t] for t in relevant_edges['target']]
    ], dtype=torch.long)

    num_nodes = len(nodes)
    print(f"Using {num_nodes} GO terms and {len(relevant_edges)} edges")


    x = torch.randn((num_nodes, embed_dim), dtype=torch.float32)


    class SimpleGCN(nn.Module):
        def __init__(self, in_dim, hidden_dim, out_dim):
            super(SimpleGCN, self).__init__()
            self.conv1 = GCNConv(in_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, out_dim)

        def forward(self, x, edge_index):
            x = self.conv1(x, edge_index)
            x = F.relu(x)
            x = self.conv2(x, edge_index)
            return x

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleGCN(embed_dim, hidden_dim, out_dim).to(device)


    x = x.to(device)
    edge_index = edge_index.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    data = Data(x=x, edge_index=edge_index)

    print(f"Training on device: {device}")
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        embeddings = model(data.x, data.edge_index)
        # Inner product decoder
        recon = torch.sigmoid(torch.matmul(embeddings, embeddings.T))
        adj_true = torch.zeros_like(recon)
        adj_true[data.edge_index[0], data.edge_index[1]] = 1.0

        loss = F.binary_cross_entropy(recon, adj_true)
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch:03d} | Loss: {loss.item():.4f}")


    with torch.no_grad():
        node_embeddings = model(data.x, data.edge_index).cpu().numpy()

    del model, x, data, recon, adj_true
    torch.cuda.empty_cache()
    gc.collect()

    col_names = [f"go_emb_{i}" for i in range(node_embeddings.shape[1])]

    emb_df = pd.DataFrame(node_embeddings, index=nodes['id'], columns=col_names)
    print(f"Created embeddings for {len(emb_df)} GO terms")
    return emb_df





In [None]:
sampled_data = sample_tsv(term_path, sample_frac=0.05)
go_terms = sampled_data['term'].unique()
embeddings_df = create_go_embeddings_optimized(obo_path, go_terms)
seq_df = extract_sequences(fasta_path, sampled_data['EntryID'])

### Combine GO embedding and PLM embedding into one dataset

In [None]:
def combine_go_protein_embeddings(sampled_data, go_emb_df, prot_emb_df):

    combined = sampled_data.merge(go_emb_df, how='left', left_on='term', right_index=True)

    combined = combined.merge(prot_emb_df, how='left', left_on='EntryID', right_index=True)

    return combined


multimodal_df = combine_go_protein_embeddings(sampled_data, embeddings_df, prot_emb_df)

print("Multimodal feature dataframe shape:", multimodal_df.shape)
print(multimodal_df.head())