# π-Proteoformer PTM sites functional association prediction 

This notebook demonstrates how to use the pre-trained $\pi$-Proteoformer model for predicting PTM sites functional association within the protein.

## 1. Setup and Imports

In [11]:
import sys
import os
import re
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

# Add the proteoformer package to path
PROJECT_ROOT = Path("your/project/root")
sys.path.insert(0, str(PROJECT_ROOT  / "pi-Proteoformer" / "src"))

# Import Proteoformer components
from proteoformer.net import ProteoformerForEmbedding
from proteoformer.tokenization import ProteoformerTokenizer
from proteoformer.models import PTMFucntionalClassifier

print("All imports successful!")

All imports successful!


## 2. Configuration

In [2]:
# Paths configuration
PRETRAINED_MODEL_PATH = 'pretrain_stage2_balance_1126/checkpoint-197000'

# Path to the PTM sites functional association prediction checkpoint
PTM_CHECKPOINT_PATH = 'PTM_sites_functional_association_prediction.pt'

# Example test data path (optional)
TEST_DATA_PATH = '/fold1/test.json'

# Window size for extracting PTM embeddings (number of amino acids on each side of PTM)
WINDOW_SIZE = 10

# Maximum sequence length for tokenization
MAX_LENGTH = 1024

# Hidden dimension for the classifier
HIDDEN_DIM = 256

# Dropout rate for the classifier
DROPOUT = 0.3

# Batch size for inference
BATCH_SIZE = 16

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


## 3. Model Definition

In [3]:
# The PTMFucntionalClassifier is imported from proteoformer.models
# 
# Architecture:
#   - Takes precomputed PTM embeddings (h1 and h2)
#   - Creates enhanced pair-wise features: [h1, h2, |h1 - h2|, h1 ⊙ h2]
#   - Passes through fully connected layers for binary classification

## 4. Usage

In [4]:
def parse_ptm_position(proteoform: str, ptm_marker: str) -> int:
    """
    Parse PTM position from proteoform sequence.
    
    The proteoform sequence contains PTM markers in the format [PFMOD:xxx].
    This function finds the position of the residue that is modified.
    
    Args:
        proteoform: Sequence with PTM marked (e.g., "MMNKLYIGN...K[PFMOD:535]SGY...")
        ptm_marker: PTM marker to find (e.g., "PFMOD:535")
    
    Returns:
        Position of the PTM residue (0-indexed)
    
    Example:
        >>> proteoform = "MMNKLYIGNLSPAVTADDLRQLFGDRKLPLAGQVLLK[PFMOD:535]SGYAFVD..."
        >>> parse_ptm_position(proteoform, "PFMOD:535")
        36  # K at position 37 (1-indexed), 36 (0-indexed)
    """
    ptm_pattern = re.escape(ptm_marker)
    match = re.search(rf'\[{ptm_pattern}\]', proteoform)
    
    if match:
        # Get the sequence before the PTM marker
        seq_before = proteoform[:match.start()]
        # Remove all PTM markers from the sequence before to get clean position
        clean_seq_before = re.sub(r'\[PFMOD:\d+\]', '', seq_before)
        return len(clean_seq_before) - 1  # -1 because the residue is the last character
    else:
        raise ValueError(f"Could not find PTM marker {ptm_marker} in proteoform")


def extract_ptm_embedding(
    hidden_states: torch.Tensor,
    ptm_position: int,
    attention_mask: torch.Tensor,
    window_size: int
) -> torch.Tensor:
    """
    Extract embedding for a PTM site by mean pooling over the PTM residue and its flanking regions.
    
    This creates a context-aware embedding that captures information from the 
    local sequence environment around the PTM site.
    
    Args:
        hidden_states: Hidden states from encoder (seq_len, hidden_dim)
        ptm_position: Position of the PTM residue (0-indexed)
        attention_mask: Attention mask indicating valid positions (seq_len,)
        window_size: Number of amino acids on each side of PTM to include
    
    Returns:
        PTM embedding vector (hidden_dim,)
    """
    seq_len = hidden_states.shape[0]
    
    # Define the window around the PTM site
    start_pos = max(0, ptm_position - window_size)
    end_pos = min(seq_len, ptm_position + window_size + 2)  # +2 to include the PTM residue
    
    # Extract embeddings in the window
    window_embeddings = hidden_states[start_pos:end_pos]  # (window_len, hidden_dim)
    window_mask = attention_mask[start_pos:end_pos]        # (window_len,)
    
    # Mean pooling (only over valid positions)
    if window_mask.sum() > 0:
        masked_embeddings = window_embeddings * window_mask.unsqueeze(-1)
        ptm_embedding = masked_embeddings.sum(dim=0) / window_mask.sum()
    else:
        # Fallback: use the PTM position embedding directly
        ptm_embedding = hidden_states[ptm_position]
    
    return ptm_embedding

## 5. Model Loading Functions

In [5]:
def load_models(
    proteoformer_checkpoint: str,
    ptm_func_checkpoint: str,
    device: torch.device,
    hidden_dim: int = 256,
    dropout: float = 0.3
) -> Tuple[ProteoformerForEmbedding, PTMFucntionalClassifier, ProteoformerTokenizer]:
    """
    Load the complete PTM functional association prediction model.
    
    This loads:
    1. Proteoformer2 model for encoding protein sequences
    2. PTM classifier for predicting functional association
    3. Tokenizer for sequence tokenization
    
    Args:
        proteoformer_checkpoint: Path to Proteoformer2 model checkpoint
        ptm_func_checkpoint: Path to fine-tuned PTM classifier checkpoint
        device: Torch device to load models on
        hidden_dim: Hidden dimension for the classifier
        dropout: Dropout rate for the classifier
    
    Returns:
        Tuple of (proteoformer_model, ptm_classifier, tokenizer)
    """
    # print(f"Loading Proteoformer2 from {proteoformer_checkpoint}...")
    
    # Load tokenizer
    tokenizer = ProteoformerTokenizer.from_pretrained(proteoformer_checkpoint)
    
    # Load Proteoformer2 backbone model
    proteoformer_model = ProteoformerForEmbedding.from_pretrained(proteoformer_checkpoint)
    proteoformer_model = proteoformer_model.to(device)
    proteoformer_model.eval()
    
    # Get embedding dimension from model config
    embedding_dim = proteoformer_model.config.hidden_size
    print(f"  - Model hidden size: {embedding_dim}")
    print(f"  - Number of layers: {proteoformer_model.config.num_hidden_layers}")
    
    # Create the PTM classifier using PTMFucntionalClassifier from proteoformer.models
    # print(f"Loading PTM classifier from {ptm_func_checkpoint}...")
    ptm_classifier = PTMFucntionalClassifier(
        embedding_dim=embedding_dim,
        hidden_dim=hidden_dim,
        dropout=dropout
    )
    
    # Load the fine-tuned checkpoint
    checkpoint = torch.load(ptm_func_checkpoint, map_location=device)
    ptm_classifier.load_state_dict(checkpoint['model_state_dict'])
    ptm_classifier = ptm_classifier.to(device)
    ptm_classifier.eval()
    
    # Print checkpoint information
    if 'metrics' in checkpoint:
        metrics = checkpoint['metrics']
        print(f"Checkpoint metrics:")
        print(f"  - Accuracy: {metrics.get('accuracy', 'N/A'):.4f}")
        print(f"  - Precision: {metrics.get('precision', 'N/A'):.4f}")
        print(f"  - Recall: {metrics.get('recall', 'N/A'):.4f}")
        print(f"  - F1 Score: {metrics.get('f1', 'N/A'):.4f}")
        print(f"  - AUC: {metrics.get('auc', 'N/A'):.4f}")
    
    print("Models loaded successfully!")
    return proteoformer_model, ptm_classifier, tokenizer

## 6. Prediction Functions

In [6]:
def predict_single_pair(
    proteoformer_model: ProteoformerForEmbedding,
    ptm_classifier: PTMFucntionalClassifier,
    tokenizer: ProteoformerTokenizer,
    proteoform1: str,
    ptm1_marker: str,
    proteoform2: str,
    ptm2_marker: str,
    device: torch.device,
    window_size: int = 10,
    max_length: int = 1024,
    threshold: float = 0.5
) -> Dict[str, any]:
    """
    Predict functional association for a single PTM pair.
    
    This function:
    1. Parses PTM positions from proteoform sequences
    2. Encodes sequences using Proteoformer
    3. Extracts PTM site embeddings
    4. Predicts functional association using the classifier
    
    Args:
        proteoformer_model: Loaded Proteoformer model
        ptm_classifier: Loaded PTM classifier
        tokenizer: Loaded tokenizer
        proteoform1: First proteoform sequence with PTM marker
        ptm1_marker: PTM marker for first site (e.g., "PFMOD:535")
        proteoform2: Second proteoform sequence with PTM marker
        ptm2_marker: PTM marker for second site (e.g., "PFMOD:21")
        device: Torch device
        window_size: Window size for embedding extraction
        max_length: Maximum sequence length
        threshold: Classification threshold (default 0.5)
    
    Returns:
        Dictionary containing:
        - probability: Association probability (0-1)
        - prediction: Binary prediction (0 or 1)
        - label: Human-readable prediction label
    """
    proteoformer_model.eval()
    ptm_classifier.eval()
    
    with torch.no_grad():
        # Parse PTM positions
        ptm1_position = parse_ptm_position(proteoform1, ptm1_marker)
        ptm2_position = parse_ptm_position(proteoform2, ptm2_marker)
        
        # Tokenize proteoforms
        encoded1 = tokenizer(
            [proteoform1],
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
            add_special_tokens=False
        )
        
        encoded2 = tokenizer(
            [proteoform2],
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
            add_special_tokens=False
        )
        
        # Move to device
        input_ids1 = encoded1['input_ids'].to(device)
        attention_mask1 = encoded1['attention_mask'].to(device)
        input_ids2 = encoded2['input_ids'].to(device)
        attention_mask2 = encoded2['attention_mask'].to(device)
        
        # Encode proteoforms using ProteoformerForEmbedding
        outputs1 = proteoformer_model(
            input_ids=input_ids1,
            attention_mask=attention_mask1
        )
        hidden_states1 = outputs1.last_hidden_state[0]  # (seq_len, hidden_dim)
        
        outputs2 = proteoformer_model(
            input_ids=input_ids2,
            attention_mask=attention_mask2
        )
        hidden_states2 = outputs2.last_hidden_state[0]  # (seq_len, hidden_dim)
        
        # Extract PTM embeddings
        ptm1_embedding = extract_ptm_embedding(
            hidden_states1, ptm1_position, attention_mask1[0], window_size
        ).unsqueeze(0)  # (1, hidden_dim)
        
        ptm2_embedding = extract_ptm_embedding(
            hidden_states2, ptm2_position, attention_mask2[0], window_size
        ).unsqueeze(0)  # (1, hidden_dim)
        
        # Predict functional association
        logits = ptm_classifier(ptm1_embedding, ptm2_embedding)
        probability = torch.sigmoid(logits).item()
        prediction = 1 if probability >= threshold else 0
    
    return {
        'probability': probability,
        'prediction': prediction,
        'label': 'Functionally Associated' if prediction == 1 else 'Not Associated'
    }

## 7. Example Usage

In [7]:
print("=" * 80)
print("π-Proteoformer PTM Sites Functional Association Prediction")
print("=" * 80)

print("\n[Step 1] Loading models...")

proteoformer_model, ptm_classifier, tokenizer = load_models(
    proteoformer_checkpoint=PRETRAINED_MODEL_PATH,
    ptm_func_checkpoint=PTM_CHECKPOINT_PATH,
    device=DEVICE,
    hidden_dim=HIDDEN_DIM,
    dropout=DROPOUT
)


You are using a model of type proteoformer2 to instantiate a model of type proteoformer. This is not supported for all configurations of models and can yield errors.


π-Proteoformer PTM Sites Functional Association Prediction

[Step 1] Loading models...
  - Model hidden size: 1280
  - Number of layers: 24
Checkpoint metrics:
  - Accuracy: 0.7179
  - Precision: 0.7618
  - Recall: 0.6443
  - F1 Score: 0.6981
  - AUC: 0.7957
Models loaded successfully!


In [8]:
print("\n[Step 2] Single pair prediction example...")
print("-" * 80)

# Example from the test data
# This example shows K37 with acetylation (PFMOD:535) and Y40 with phosphorylation (PFMOD:21)
example = {
    "proteoform1": "MMNKLYIGNLSPAVTADDLRQLFGDRKLPLAGQVLLK[PFMOD:535]SGYAFVDYPDQNWAIRAIETLSGKVELHGKIMEVDYSVSKKLRSRKIQIRNIPPHLQWEVLDGLLAQYGTVENVEQVNTDTETAVVNVTYATREEAKIAMEKLSGHQFENYSFKISYIPDEEVSSPSPPQRAQRGDHSSREQGHAPGGTSQARQIDFPLRILVPTQFVGAIIGKEGLTIKNITKQTQSRVDIHRKENSGAAEKPVTIHATPEGTSEACRMILEIMQKEADETKLAEEIPLKILAHNGLVGRLIGKEGRNLKKIEHETGTKITISSLQDLSIYNPERTITVKGTVEACASAEIEIMKKLREAFENDMLAVNQQANLIPGLNLSALGIFSTGLSVLSPPAGPRGAPPAAPYHPFTTHSGYFSSLYPHHQFGPFPHHHSYPEQEIVNLFIPTQAVGAIIGKKGAHIKQLARFAGASIKIAPAEGPDVSERMVIITGPPEAQFKAQGRIFGKLKEENFFNPKEEVKLEAHIRVPSSTAGRVIGKGGKTVNELQNLTSAEVIVPRDQTPDENEEVIVRIIGHFFASQTAQRKIREIVQQVKQQEQKYPQGVASQRSK",
    "Residue1": "K37",
    "PTM1": "PFMOD:535",
    "proteoform2": "MMNKLYIGNLSPAVTADDLRQLFGDRKLPLAGQVLLKSGY[PFMOD:21]AFVDYPDQNWAIRAIETLSGKVELHGKIMEVDYSVSKKLRSRKIQIRNIPPHLQWEVLDGLLAQYGTVENVEQVNTDTETAVVNVTYATREEAKIAMEKLSGHQFENYSFKISYIPDEEVSSPSPPQRAQRGDHSSREQGHAPGGTSQARQIDFPLRILVPTQFVGAIIGKEGLTIKNITKQTQSRVDIHRKENSGAAEKPVTIHATPEGTSEACRMILEIMQKEADETKLAEEIPLKILAHNGLVGRLIGKEGRNLKKIEHETGTKITISSLQDLSIYNPERTITVKGTVEACASAEIEIMKKLREAFENDMLAVNQQANLIPGLNLSALGIFSTGLSVLSPPAGPRGAPPAAPYHPFTTHSGYFSSLYPHHQFGPFPHHHSYPEQEIVNLFIPTQAVGAIIGKKGAHIKQLARFAGASIKIAPAEGPDVSERMVIITGPPEAQFKAQGRIFGKLKEENFFNPKEEVKLEAHIRVPSSTAGRVIGKGGKTVNELQNLTSAEVIVPRDQTPDENEEVIVRIIGHFFASQTAQRKIREIVQQVKQQEQKYPQGVASQRSK",
    "Residue2": "Y40",
    "PTM2": "PFMOD:21",
    "label": 1  # Ground truth: these PTM sites are functionally associated
}


[Step 2] Single pair prediction example...
--------------------------------------------------------------------------------


In [9]:
print(f"PTM Site 1: {example['Residue1']} with modification {example['PTM1']}")
print(f"PTM Site 2: {example['Residue2']} with modification {example['PTM2']}")
print(f"Ground Truth: {'Functionally Associated' if example['label'] == 1 else 'Not Associated'}")

PTM Site 1: K37 with modification PFMOD:535
PTM Site 2: Y40 with modification PFMOD:21
Ground Truth: Functionally Associated



In [10]:
# Make prediction
result = predict_single_pair(
    proteoformer_model=proteoformer_model,
    ptm_classifier=ptm_classifier,
    tokenizer=tokenizer,
    proteoform1=example['proteoform1'],
    ptm1_marker=example['PTM1'],
    proteoform2=example['proteoform2'],
    ptm2_marker=example['PTM2'],
    device=DEVICE,
    window_size=WINDOW_SIZE,
    max_length=MAX_LENGTH
)

print(f"Prediction Results:")
print(f"  - Probability: {result['probability']:.4f}")
print(f"  - Prediction: {result['label']}")
print(f"  - Correct: {'✓' if result['prediction'] == example['label'] else '✗'}")

Prediction Results:
  - Probability: 0.6270
  - Prediction: Functionally Associated
  - Correct: ✓
