In [1]:
!pip install torch==2.7.0
!pip install torch-geometric
!pip install biopython
!pip install obonet
!pip install networkx
!pip install pandas
!pip install numpy
!pip install matplotlib
!pip install seaborn
!pip install scipy
!pip install scikit-learn
!pip install fair-esm

Collecting torch==2.7.0
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch==2.7.0)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch==2.7.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch==2.7.0)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch==2.7.0)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch==2.7.0)
  Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.6.4.1 (from torch==2.7.0)
  Downloading nvidia_cublas_cu12-12.6.4.1-py3-no

In [2]:
import os
import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import Bio
from Bio import SeqIO
import obonet
import gc
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import random
import esm


In [3]:
obo_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo'
fasta_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
term_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv'
taxonomy_path = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv'

In [4]:
term_df = pd.read_csv(term_path, sep='\t')
term_df.head()

Unnamed: 0,EntryID,term,aspect
0,Q5W0B1,GO:0000785,C
1,Q5W0B1,GO:0004842,F
2,Q5W0B1,GO:0051865,P
3,Q5W0B1,GO:0006275,P
4,Q5W0B1,GO:0006513,P


In [5]:
taxonomy_df = pd.read_csv(taxonomy_path, sep='\t', names=['EntryID', 'taxonomyID'])
taxonomy_df.head()

Unnamed: 0,EntryID,taxonomyID
0,A0A0C5B5G6,9606
1,A0JNW5,9606
2,A0JP26,9606
3,A0PK11,9606
4,A1A4S6,9606


In [62]:
def get_processed_fasta_df(fasta_data, term_df):
    fasta_dict_list = []
    term_set = set(term_df.tolist())
    for fasta_seq in fasta_data:
        entry = fasta_seq.id.split('|')[1] if '|' in fasta_seq.id else fasta_seq.id
        if entry in term_set:
            fasta_dict_list.append({
                "EntryID": entry, 
                "fasta_sequence": str(fasta_seq.seq)
            })
        
    return pd.DataFrame(fasta_dict_list)


# def get_processed_fasta_df(fasta_data, entry_ids):
#     records = []
#     entry_ids = set(entry_ids)
#     for record in fasta_data:
#         # Typical Uniprot headers: ">sp|P12345|PROT_HUMAN ..."
#         label = record.id.split('|')[1] if '|' in record.id else record.id
#         if label in entry_ids:
#             records.append({
#                 'EntryID': label,
#                 'fasta_sequence': str(record.seq)
#             })
#     seq_df = pd.DataFrame(records)
#     print(f"Retrieved {len(seq_df)} sequences matching EntryIDs")
#     return seq_df

def sample_tsv(tsv_df, sample_frac=0.05, random_state=42):
    """
    Read a TSV file and sample based on unique EntryID.
    Pulls in all associated rows for sampled EntryIDs.
    """
    df = tsv_df
    unique_ids = df['EntryID'].unique()
    sample_size = max(1, int(len(unique_ids) * sample_frac))
    sampled_ids = random.sample(list(unique_ids), sample_size)
    sampled_df = df[df['EntryID'].isin(sampled_ids)]
    print(f"Sampled {len(sampled_df)} rows from {len(unique_ids)} unique EntryIDs")
    return sampled_df





In [68]:
def generate_protein_embeddings_esm_optimized(seq_df, seq_col='Sequence', entryid_col='EntryID',
                                              target_dim=16, batch_size=1, use_fp16=True):
    """
    Memory-optimized ESM embedding generation for proteins.
    Processes small batches and moves embeddings to CPU immediately.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Load pretrained ESM model
    # model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
    model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    model = model.to(device)
    model.eval()

    batch_converter = alphabet.get_batch_converter()

    sequences = seq_df[seq_col].tolist()
    entry_ids = seq_df[entryid_col].tolist()

    all_embeddings = []
    print(len(sequences))
    for i in range(0, len(sequences), batch_size):
        if not i % 20:
            print(i)
        batch_seqs = sequences[i:i+batch_size]
        batch_labels = entry_ids[i:i+batch_size]

        batch_data = [(label, seq) for label, seq in zip(batch_labels, batch_seqs)]
        _, _, batch_tokens = batch_converter(batch_data)
        batch_tokens = batch_tokens.to(device)

        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)
            token_embeddings = results["representations"][model.num_layers]  # (B, L, D)

            # Mean pool over sequence length
            attention_mask = batch_tokens != alphabet.padding_idx
            masked_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
            seq_lengths = attention_mask.sum(dim=1).unsqueeze(-1)
            seq_embeddings = (masked_embeddings.sum(dim=1) / seq_lengths).cpu().float().numpy()

        all_embeddings.append(seq_embeddings)
        del batch_tokens, token_embeddings, masked_embeddings, seq_embeddings
        torch.cuda.empty_cache()

    raw_embeddings = np.vstack(all_embeddings)
    print("Raw embeddings shape:", raw_embeddings.shape)

    if raw_embeddings.shape[1] > target_dim:
        pca = PCA(n_components=target_dim)
        reduced = pca.fit_transform(raw_embeddings)
    else:
        reduced = np.zeros((raw_embeddings.shape[0], target_dim), dtype=np.float32)
        reduced[:, :raw_embeddings.shape[1]] = raw_embeddings

    col_names = [f"prot_emb_{i}" for i in range(target_dim)]
    emb_df = pd.DataFrame(reduced, index=entry_ids, columns=col_names)
    emb_df.index.name = entryid_col
    print("Reduced embeddings shape:", emb_df.shape)
    return emb_df





In [80]:
def get_merged_df(frac=1):
    """
    Merge term.tsv, fasta data, taxonomy data as well as nodes in the obo graph.
    """
    term_df = pd.read_csv(term_path, sep='\t')
    term_df = sample_tsv(term_df, frac)
    term_df = term_df
    taxonomy_df = pd.read_csv(taxonomy_path, sep='\t', names=['EntryID', 'taxonomyID'])
    fasta_data = list(SeqIO.parse(fasta_path, "fasta"))
    fasta_df = get_processed_fasta_df(fasta_data, term_df['EntryID'])
    fasta_emb_df = generate_protein_embeddings_esm_optimized(fasta_df, "fasta_sequence")
    # TODO add embeddings in getting merged_df
    merged_df = pd.merge(term_df, fasta_emb_df, on="EntryID", how='left')
    merged_df = pd.merge(merged_df, taxonomy_df, on="EntryID", how="left")
    graph = obonet.read_obo(obo_path)
    edges_list = []
    for node_id, data in graph.nodes(data=True):
        for parent_id in data.get("is_a", []):
            edges_list.append({
                    "term": node_id,
                    "parent": parent_id,
                    "name": data["name"],
                    "namespace": data["namespace"],
                    "def": data["def"],
                    "synonym": data.get("synonym", [])
                    
                })
    edges_df = pd.DataFrame(edges_list)
    merged_df = merged_df.merge(edges_df, on="term", how="left")
    return merged_df

train_df = get_merged_df(0.005)
len(train_df)
    

Sampled 2987 rows from 82404 unique EntryIDs
Using device: cuda
412
0
20
40
60
80
100
120
140
160
180
200
220
240
260
280
300
320
340
360
380
400
Raw embeddings shape: (412, 320)
Reduced embeddings shape: (412, 16)


4207

In [81]:
train_df.head()

Unnamed: 0,EntryID,term,aspect,prot_emb_0,prot_emb_1,prot_emb_2,prot_emb_3,prot_emb_4,prot_emb_5,prot_emb_6,...,prot_emb_12,prot_emb_13,prot_emb_14,prot_emb_15,taxonomyID,parent,name,namespace,def,synonym
0,Q9H0R8,GO:0005543,F,-0.797306,-0.609369,0.576919,-0.244198,-0.001294,0.519213,0.186692,...,-0.096609,-0.175182,-0.089086,0.416648,9606,GO:0008289,phospholipid binding,molecular_function,"""Binding to a phospholipid, a class of lipids ...",[]
1,Q9H0R8,GO:0005739,C,-0.797306,-0.609369,0.576919,-0.244198,-0.001294,0.519213,0.186692,...,-0.096609,-0.175182,-0.089086,0.416648,9606,GO:0043231,mitochondrion,cellular_component,"""A semiautonomous, self replicating organelle ...","[""mitochondria"" EXACT []]"
2,Q9H0R8,GO:0005515,F,-0.797306,-0.609369,0.576919,-0.244198,-0.001294,0.519213,0.186692,...,-0.096609,-0.175182,-0.089086,0.416648,9606,GO:0005488,protein binding,molecular_function,"""Binding to a protein."" [GOC:go_curators]","[""glycoprotein binding"" NARROW [], ""protein am..."
3,Q9H0R8,GO:0005776,C,-0.797306,-0.609369,0.576919,-0.244198,-0.001294,0.519213,0.186692,...,-0.096609,-0.175182,-0.089086,0.416648,9606,GO:0005773,autophagosome,cellular_component,"""A double-membrane-bounded compartment that en...","[""autophagic vacuole"" EXACT [NIF_Subcellular:s..."
4,Q9H0R8,GO:0030957,F,-0.797306,-0.609369,0.576919,-0.244198,-0.001294,0.519213,0.186692,...,-0.096609,-0.175182,-0.089086,0.416648,9606,GO:0061629,Tat protein binding,molecular_function,"""Binding to Tat, a viral transactivating regul...",[]


In [78]:
gc.collect()                     
torch.cuda.empty_cache()
torch.cuda.ipc_collect()   

In [79]:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

### Create set of embedding from the graph edges using GCN

In [22]:
def create_go_embeddings_optimized(obo_path, go_terms, embed_dim=16, hidden_dim=32, out_dim=16, epochs=50):

    print(" Loading Gene Ontology...")
    graph = obonet.read_obo(obo_path)


    edges = pd.DataFrame([
        {'source': u, 'target': v, 'relation': data.get('relation', 'is_a')}
        for u, v, data in graph.edges(data=True)
    ])

    relevant_edges = edges[
        edges['source'].isin(go_terms) | edges['target'].isin(go_terms)
    ].reset_index(drop=True)

    nodes = pd.DataFrame({'id': list(set(relevant_edges['source']).union(relevant_edges['target']))})
    nodes['node_idx'] = range(len(nodes))
    node2idx = dict(zip(nodes['id'], nodes['node_idx']))

    edge_index = torch.tensor([
        [node2idx[s] for s in relevant_edges['source']],
        [node2idx[t] for t in relevant_edges['target']]
    ], dtype=torch.long)

    num_nodes = len(nodes)
    print(f"Using {num_nodes} GO terms and {len(relevant_edges)} edges")


    x = torch.randn((num_nodes, embed_dim), dtype=torch.float32)


    class SimpleGCN(nn.Module):
        def __init__(self, in_dim, hidden_dim, out_dim):
            super(SimpleGCN, self).__init__()
            self.conv1 = GCNConv(in_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, out_dim)

        def forward(self, x, edge_index):
            x = self.conv1(x, edge_index)
            x = F.relu(x)
            x = self.conv2(x, edge_index)
            return x

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleGCN(embed_dim, hidden_dim, out_dim).to(device)


    x = x.to(device)
    edge_index = edge_index.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    data = Data(x=x, edge_index=edge_index)

    print(f"Training on device: {device}")
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        embeddings = model(data.x, data.edge_index)
        # Inner product decoder
        recon = torch.sigmoid(torch.matmul(embeddings, embeddings.T))
        adj_true = torch.zeros_like(recon)
        adj_true[data.edge_index[0], data.edge_index[1]] = 1.0

        loss = F.binary_cross_entropy(recon, adj_true)
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch:03d} | Loss: {loss.item():.4f}")


    with torch.no_grad():
        node_embeddings = model(data.x, data.edge_index).cpu().numpy()

    del model, x, data, recon, adj_true
    torch.cuda.empty_cache()
    gc.collect()

    col_names = [f"go_emb_{i}" for i in range(node_embeddings.shape[1])]

    emb_df = pd.DataFrame(node_embeddings, index=nodes['id'], columns=col_names)
    print(f"Created embeddings for {len(emb_df)} GO terms")
    return emb_df





In [25]:
sampled_data = sample_tsv(term_path, sample_frac=0.05)
go_terms = sampled_data['term'].unique()
embeddings_df = create_go_embeddings_optimized(obo_path, go_terms)
seq_df = extract_sequences(fasta_path, sampled_data['EntryID'])

Sampled 26576 rows from 82404 unique EntryIDs
 Loading Gene Ontology...
Using 25961 GO terms and 35352 edges
Training on device: cuda
Epoch 000 | Loss: 8.0363
Epoch 010 | Loss: 0.9643
Epoch 020 | Loss: 0.8855
Epoch 030 | Loss: 0.7805
Epoch 040 | Loss: 0.7235
Epoch 049 | Loss: 0.7165
Created embeddings for 25961 GO terms
Retrieved 4120 sequences matching EntryIDs


### Combine GO embedding and PLM embedding into one dataset

In [None]:
def combine_go_protein_embeddings(sampled_data, go_emb_df, prot_emb_df):

    combined = sampled_data.merge(go_emb_df, how='left', left_on='term', right_index=True)

    combined = combined.merge(prot_emb_df, how='left', left_on='EntryID', right_index=True)

    return combined


multimodal_df = combine_go_protein_embeddings(sampled_data, embeddings_df, prot_emb_df)

print("Multimodal feature dataframe shape:", multimodal_df.shape)
print(multimodal_df.head())