In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Packages

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
!pip install biopython obonet --quiet
!pip install torch torchvision torchaudio --quiet
!pip install transformers biopython --quiet

# üß© CAFA6 Data Loading & Parsing (Step 1)

In [None]:
%%time
# =========================================================
# CAFA6 - Step 1: Load & Explore Data
# =========================================================
import pandas as pd

from Bio import SeqIO
import obonet

# ---------------------------------------------------------
# File paths (update if needed)
# ---------------------------------------------------------
BASE_PATH = "/kaggle/input/cafa-6-protein-function-prediction"
TRAIN_PATH = f"{BASE_PATH}/Train"
TEST_PATH = f"{BASE_PATH}/Test"

# ---------------------------------------------------------
# 1. Load TSV files
# ---------------------------------------------------------
train_terms = pd.read_csv(
    f"{TRAIN_PATH}/train_terms.tsv",
    sep="\t",
    names=["protein_id", "go_term", "ontology"]
)

train_taxonomy = pd.read_csv(
    f"{TRAIN_PATH}/train_taxonomy.tsv",
    sep="\t",
    names=["protein_id", "taxon_id"]
)

ia = pd.read_csv(
    f"{BASE_PATH}/IA.tsv",
    sep="\t",
    names=["go_term", "weight"]
)

# ---------------------------------------------------------
# 2. Load ontology graph (go-basic.obo)
# ---------------------------------------------------------
go_graph = obonet.read_obo(f"{TRAIN_PATH}/go-basic.obo")
print(f"Ontology terms loaded: {len(go_graph)}")

# Example: extract parent relationships
root_nodes = {
    'BP': 'GO:0008150',
    'CC': 'GO:0005575',
    'MF': 'GO:0003674'
}

# ---------------------------------------------------------
# 3. Load FASTA sequences (train & test superset)
# ---------------------------------------------------------
def load_fasta_sequences(file_path):
    """Return dict {protein_id: sequence}"""
    sequences = {}
    for record in SeqIO.parse(file_path, "fasta"):
        prot_id = record.id.split("|")[1] if "|" in record.id else record.id
        sequences[prot_id] = str(record.seq)
    return sequences

train_sequences = load_fasta_sequences(f"{TRAIN_PATH}/train_sequences.fasta")
test_sequences = load_fasta_sequences(f"{TEST_PATH}/testsuperset.fasta")

print(f"Train sequences: {len(train_sequences)}")
print(f"Test superset sequences: {len(test_sequences)}")

# ---------------------------------------------------------
# 4. Merge annotations and taxonomy
# ---------------------------------------------------------
train_df = train_terms.merge(train_taxonomy, on="protein_id", how="left")
print(f"Train dataframe: {train_df.shape}")

# ---------------------------------------------------------
# 5. Quick data summary
# ---------------------------------------------------------
print("\nData Summary:")
print(f"- train_terms: {train_terms.shape}")
print(f"- train_taxonomy: {train_taxonomy.shape}")
print(f"- ia: {ia.shape}")
print(f"- train_sequences: {len(train_sequences)}")
print(f"- test_sequences: {len(test_sequences)}")
print(f"- Ontology terms (GO): {len(go_graph)}")

# Preview few examples
print("\nSample train_terms:")
display(train_terms.head())

print("\nSample train_taxonomy:")
display(train_taxonomy.head())

print("\nSample IA:")
display(ia.head())

# Example check of GO hierarchy
example_term = list(go_graph.nodes())[0]
print(f"\nExample GO term: {example_term}")
print("Parents:", list(go_graph.predecessors(example_term)))
print("Children:", list(go_graph.successors(example_term)))


# üß≠ Step 2: Exploratory Data Analysis (EDA)

Below is a well-structured EDA section to add right after dataset loading.

It covers:

Overview of GO ontologies

Distribution of GO terms & taxonomies

Sequence length statistics

Relationship between ontology, taxonomy, and terms

Visualization of embeddings (optional)

In [None]:
%%time
# =========================================================
# Step 2: Exploratory Data Analysis (EDA)
# =========================================================
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Style setup
plt.style.use("seaborn-v0_8-muted")
sns.set_palette("coolwarm")

# ---------------------------------------------------------
# 1. Ontology distribution
# ---------------------------------------------------------
plt.figure(figsize=(6,4))
sns.countplot(x="ontology", data=train_terms)
plt.title("Distribution of Ontology Types")
plt.xlabel("Ontology (Aspect)")
plt.ylabel("Count")
plt.tight_layout()
plt.show()

# ---------------------------------------------------------
# 2. GO term frequency
# ---------------------------------------------------------
go_counts = train_terms["go_term"].value_counts().head(20)
plt.figure(figsize=(8,4))
sns.barplot(y=go_counts.index, x=go_counts.values)
plt.title("Top 20 Most Frequent GO Terms")
plt.xlabel("Count")
plt.ylabel("GO Term")
plt.tight_layout()
plt.show()

# ---------------------------------------------------------
# 3. Taxonomy overview
# ---------------------------------------------------------
tax_counts = train_taxonomy["taxon_id"].value_counts().head(10)
plt.figure(figsize=(8,4))
sns.barplot(x=tax_counts.index.astype(str), y=tax_counts.values)
plt.title("Top 10 Taxon IDs (Species)")
plt.xlabel("Taxon ID")
plt.ylabel("Protein Count")
plt.tight_layout()
plt.show()

# ---------------------------------------------------------
# 4. Sequence length distribution
# ---------------------------------------------------------
seq_lengths = [len(seq) for seq in train_sequences.values()]
plt.figure(figsize=(8,4))
sns.histplot(seq_lengths, bins=50, kde=True)
plt.title("Distribution of Protein Sequence Lengths")
plt.xlabel("Sequence Length (aa)")
plt.ylabel("Count")
plt.tight_layout()
plt.show()

print(f"üß¨ Average sequence length: {np.mean(seq_lengths):.1f} aa")
print(f"Longest sequence: {np.max(seq_lengths)} aa")
print(f"Shortest sequence: {np.min(seq_lengths)} aa")

# ---------------------------------------------------------
# 5. Ontology √ó Taxonomy heatmap (co-occurrence)
# ---------------------------------------------------------
merged = train_terms.merge(train_taxonomy, on="protein_id", how="left")
ont_tax_counts = (
    merged.groupby(["ontology", "taxon_id"])
    .size()
    .reset_index(name="count")
    .pivot(index="ontology", columns="taxon_id", values="count")
    .fillna(0)
)
plt.figure(figsize=(10,4))
sns.heatmap(np.log1p(ont_tax_counts), cmap="viridis")
plt.title("Ontology √ó Taxonomy Co-occurrence (log scale)")
plt.xlabel("Taxon ID")
plt.ylabel("Ontology")
plt.tight_layout()
plt.show()

# ---------------------------------------------------------
# 6. IA weight distribution (Information Content)
# ---------------------------------------------------------
plt.figure(figsize=(8,4))
sns.histplot(ia["weight"], bins=50, kde=True)
plt.title("GO Term Information Content (IA Weights)")
plt.xlabel("Weight")
plt.ylabel("Frequency")
plt.tight_layout()
plt.show()

print(f"üß† IA weights: mean={ia['weight'].mean():.3f}, max={ia['weight'].max():.3f}")

# üß¨ Step 2 ‚Äì Protein Embedding Extraction (ESM2 baseline)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
from Bio import SeqIO
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

# ---------------------------------------------------------
# Load model + tokenizer
# ---------------------------------------------------------
MODEL_NAME = "facebook/esm2_t6_8M_UR50D"  # small and fast baseline
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()

print(f"‚úÖ Loaded {MODEL_NAME} on {device}")

# ---------------------------------------------------------
# Function to compute mean-pooled embeddings
# ---------------------------------------------------------
def get_protein_embedding(sequence: str):
    """Return mean pooled embedding for one protein sequence."""
    inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1022)
    with torch.no_grad():
        outputs = model(**{k: v.to(device) for k, v in inputs.items()})
    # Mean-pool across tokens (excluding [CLS], [EOS])
    emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return emb

# ---------------------------------------------------------
# Compute embeddings for all test sequences
# ---------------------------------------------------------
test_fasta_path = '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta'  # Path to the test FASTA file
test_sequences = {record.id: str(record.seq) for record in SeqIO.parse(test_fasta_path, "fasta")}

# You can limit to first N for quick testing
N = 1000  # change to None or len(test_sequences) for full run

seq_items = list(test_sequences.items())[:N]
test_embeddings = []
test_protein_ids = []

for prot_id, seq in tqdm(seq_items, desc="Embedding test proteins"):
    try:
        emb = get_protein_embedding(seq)
        test_embeddings.append(emb)
        test_protein_ids.append(prot_id)
    except Exception as e:
        print(f"‚ö†Ô∏è Skipped {prot_id}: {e}")

# ---------------------------------------------------------
# Save test embeddings
# ---------------------------------------------------------
test_emb_df = pd.DataFrame(test_embeddings)
test_emb_df.insert(0, "protein_id", test_protein_ids)
test_emb_df.to_parquet("test_esm2_embeddings.parquet", index=False)

print(f"\n‚úÖ Saved {len(test_emb_df)} test protein embeddings with shape {test_emb_df.shape}")
display(test_emb_df.head())

In [None]:
%%time
# =========================================================
# Step 2: Extract Protein Embeddings using ESM2 (facebook/esm2_t6_8M_UR50D)
# =========================================================


import torch
from transformers import AutoTokenizer, AutoModel
from Bio import SeqIO
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

# ---------------------------------------------------------
# Load model + tokenizer
# ---------------------------------------------------------
MODEL_NAME = "facebook/esm2_t6_8M_UR50D"  # small and fast baseline
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()

print(f"‚úÖ Loaded {MODEL_NAME} on {device}")

# ---------------------------------------------------------
# Function to compute mean-pooled embeddings
# ---------------------------------------------------------
def get_protein_embedding(sequence: str):
    """Return mean pooled embedding for one protein sequence."""
    inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1022)
    with torch.no_grad():
        outputs = model(**{k: v.to(device) for k, v in inputs.items()})
    # Mean-pool across tokens (excluding [CLS], [EOS])
    emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return emb

# ---------------------------------------------------------
# Compute embeddings for all training sequences
# ---------------------------------------------------------
# You can limit to first N for quick testing

seq_items = list(train_sequences.items())
embeddings = []
protein_ids = []

for prot_id, seq in tqdm(seq_items, desc="Embedding proteins"):
    try:
        emb = get_protein_embedding(seq)
        embeddings.append(emb)
        protein_ids.append(prot_id)
    except Exception as e:
        print(f"‚ö†Ô∏è Skipped {prot_id}: {e}")

# ---------------------------------------------------------
# Save embeddings
# ---------------------------------------------------------
emb_df = pd.DataFrame(embeddings)
emb_df.insert(0, "protein_id", protein_ids)
emb_df.to_parquet("train_esm2_embeddings.parquet", index=False)

print(f"\n‚úÖ Saved {len(emb_df)} protein embeddings with shape {emb_df.shape}")
display(emb_df.head())
