## π-Proteoformer Variant PTM Effect Prediction 

This notebook demonstrates how to use the pre-trained π-Proteoformer model and fine-tuned classifier for predicting PTM effect changes due to variants.

## 1. Import Libraries and Setup Path

In [2]:
import sys
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm import tqdm

# Add the proteoformer package to path
PROJECT_ROOT = Path("your/project/root")
sys.path.insert(0, str(PROJECT_ROOT / "proteoformer_codebase" / "pi-Proteoformer" / "src"))

# Import Proteoformer components
from proteoformer.net import ProteoformerForEmbedding
from proteoformer.tokenization import ProteoformerTokenizer
from proteoformer.models import VariantPTMClassifier
print("Libraries imported successfully!")

Libraries imported successfully!


## 2. Configuration

In [3]:
# Paths configuration
PRETRAINED_MODEL_PATH = 'checkpoint-197000'

# Path to the fine-tuned Variant PTM classifier checkpoint
VARIANT_PTM_CHECKPOINT_PATH = 'indirect_model.pt'


# Maximum sequence length for tokenization
MAX_LENGTH = 1024

# Hidden dimension for the classifier (must match training)
HIDDEN_DIM = 512

# Dropout rate for the classifier
DROPOUT = 0.3

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## 3. Define Helper Functions

In [4]:
def mean_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    """
    Apply mean pooling over sequence length (excluding padding).
    
    Args:
        hidden_states: Hidden states from model (batch_size, seq_len, hidden_size)
        attention_mask: Attention mask (batch_size, seq_len)
    
    Returns:
        Pooled embeddings (batch_size, hidden_size)
    """
    mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
    sum_embeddings = torch.sum(hidden_states * mask_expanded, dim=1)
    sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
    return sum_embeddings / sum_mask

def extract_embeddings(
    model: ProteoformerForEmbedding,
    tokenizer: ProteoformerTokenizer,
    peptides: List[str],
    device: torch.device,
    batch_size: int = 8,
    max_length: int = 1024
) -> torch.Tensor:
    """
    Extract embeddings for a list of peptide sequences.
    
    Args:
        model: Pretrained ProteoformerForEmbedding model
        tokenizer: ProteoformerTokenizer
        peptides: List of peptide sequences
        device: Device to run inference on
        batch_size: Batch size for inference
        max_length: Maximum sequence length for model input
    
    Returns:
        Tensor of embeddings (num_peptides, hidden_size)
    """
    all_embeddings = []
    
    with torch.no_grad():
        for i in range(0, len(peptides), batch_size):
            batch_peptides = peptides[i:i + batch_size]
            
            # Tokenize batch
            encoded = tokenizer(
                batch_peptides,
                padding=True,
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
                add_special_tokens=False
            )
            
            # Move to device
            input_ids = encoded['input_ids'].to(device)
            attention_mask = encoded['attention_mask'].to(device)
            
            # Get embeddings from model
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            # Extract last hidden state
            hidden_states = outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)
            
            # Apply mean pooling
            batch_embeddings = mean_pooling(hidden_states, attention_mask)
            all_embeddings.append(batch_embeddings.cpu())
    
    # Concatenate all embeddings
    embeddings = torch.cat(all_embeddings, dim=0)
    
    return embeddings

## 4. Define Model Loading Function

In [5]:
def load_models(
    proteoformer_checkpoint: str,
    variant_ptm_checkpoint: str,
    device: torch.device,
    hidden_dim: int = 512,
    dropout: float = 0.3
) -> Tuple[ProteoformerForEmbedding, VariantPTMClassifier, ProteoformerTokenizer]:
    """
    Load the complete Variant PTM effect prediction model.
    
    This loads:
    1. Proteoformer model for encoding protein sequences
    2. VariantPTMClassifier for predicting PTM effect change
    3. Tokenizer for sequence tokenization
    
    Args:
        proteoformer_checkpoint: Path to Proteoformer model checkpoint
        variant_ptm_checkpoint: Path to fine-tuned VariantPTM classifier checkpoint
        device: Torch device to load models on
        hidden_dim: Hidden dimension for the classifier
        dropout: Dropout rate for the classifier
    
    Returns:
        Tuple of (proteoformer_model, variant_ptm_classifier, tokenizer)
    """
    # print(f"Loading Proteoformer from {proteoformer_checkpoint}...")
    
    # Load tokenizer
    tokenizer = ProteoformerTokenizer.from_pretrained(proteoformer_checkpoint)
    
    # Load Proteoformer backbone model
    proteoformer_model = ProteoformerForEmbedding.from_pretrained(proteoformer_checkpoint)
    proteoformer_model = proteoformer_model.to(device)
    proteoformer_model.eval()
    
    # Get embedding dimension from model config
    embedding_dim = proteoformer_model.config.hidden_size
    print(f"  - Model hidden size: {embedding_dim}")
    print(f"  - Number of layers: {proteoformer_model.config.num_hidden_layers}")
    
    # Create the Variant PTM classifier
    # print(f"Loading VariantPTM classifier from {variant_ptm_checkpoint}...")
    variant_ptm_classifier = VariantPTMClassifier(
        embedding_dim=embedding_dim,
        hidden_dim=hidden_dim,
        dropout=dropout
    )
    
    # Load the fine-tuned checkpoint
    checkpoint = torch.load(variant_ptm_checkpoint, map_location=device)
    variant_ptm_classifier.load_state_dict(checkpoint['model_state_dict'])
    variant_ptm_classifier = variant_ptm_classifier.to(device)
    variant_ptm_classifier.eval()
    
    print("Models loaded successfully!")
    return proteoformer_model, variant_ptm_classifier, tokenizer

## 5. Define Prediction Functions

In [6]:
def predict_single_pair(
    proteoformer_model: ProteoformerForEmbedding,
    variant_ptm_classifier: VariantPTMClassifier,
    tokenizer: ProteoformerTokenizer,
    wt_peptide: str,
    mt_peptide: str,
    device: torch.device,
    max_length: int = 1024,
    threshold: float = 0.5
) -> Dict[str, any]:
    """
    Predict PTM effect change for a single WT/MT peptide pair.
    
    This function:
    1. Encodes WT and MT sequences using Proteoformer
    2. Computes delta embeddings: ΔE = E_MT - E_WT
    3. Predicts PTM effect change using the classifier
    
    Args:
        proteoformer_model: Loaded Proteoformer model
        variant_ptm_classifier: Loaded VariantPTM classifier
        tokenizer: Loaded tokenizer
        wt_peptide: Wild-type peptide sequence
        mt_peptide: Mutant peptide sequence
        device: Torch device
        max_length: Maximum sequence length
        threshold: Classification threshold (default 0.5)
    
    Returns:
        Dictionary containing:
        - probability: Probability of Decrease effect (0-1)
        - prediction: Binary prediction (0=Increase, 1=Decrease)
        - label: Human-readable prediction label
    """
    proteoformer_model.eval()
    variant_ptm_classifier.eval()
    
    with torch.no_grad():
        # Extract WT embedding
        wt_embedding = extract_embeddings(
            model=proteoformer_model,
            tokenizer=tokenizer,
            peptides=[wt_peptide],
            device=device,
            batch_size=1,
            max_length=max_length
        )  # (1, hidden_dim)
        
        # Extract MT embedding
        mt_embedding = extract_embeddings(
            model=proteoformer_model,
            tokenizer=tokenizer,
            peptides=[mt_peptide],
            device=device,
            batch_size=1,
            max_length=max_length
        )  # (1, hidden_dim)
        
        # Compute delta embedding: ΔE = E_MT - E_WT
        delta_embedding = mt_embedding - wt_embedding
        delta_embedding = delta_embedding.to(device)
        
        # Predict PTM effect change
        logits = variant_ptm_classifier(delta_embedding)
        probability = torch.sigmoid(logits).item()
        prediction = 1 if probability >= threshold else 0
    
    return {
        'probability': probability,
        'prediction': prediction,
        'label': 'Decrease' if prediction == 1 else 'Increase'
    }

## 6. Load Models

In [7]:
print("=" * 80)
print("Loading π-Proteoformer and VariantPTM Classifier")
print("=" * 80)

proteoformer_model, variant_ptm_classifier, tokenizer = load_models(
    proteoformer_checkpoint=PRETRAINED_MODEL_PATH,
    variant_ptm_checkpoint=VARIANT_PTM_CHECKPOINT_PATH,
    device=device,
    hidden_dim=HIDDEN_DIM,
    dropout=DROPOUT
)

Loading π-Proteoformer and VariantPTM Classifier


You are using a model of type proteoformer2 to instantiate a model of type proteoformer. This is not supported for all configurations of models and can yield errors.


  - Model hidden size: 1280
  - Number of layers: 24
Models loaded successfully!


## 7. Single Pair Prediction Example

In [8]:
example = {
    "PTM": "Sumoylation",
    "Gene_symbol": "AKT1",
    "UniProt_ID": "P31749",
    "AA_Ref": "E",
    "AA_Pos": 278,
    "AA_Var": "Q",
    "Affected_site_pos": 276,
    "diff_pos": 2,
    "Effect": "Decrease",  # Ground truth
    "WT_pep": "NVVYRDLKLENLMLD",
    "MT_pep": "NVVYRDLKLQNLMLD",
    "Impact_Type": "Indirect"
}

In [10]:
# Make prediction
result = predict_single_pair(
    proteoformer_model=proteoformer_model,
    variant_ptm_classifier=variant_ptm_classifier,
    tokenizer=tokenizer,
    wt_peptide=example['WT_pep'],
    mt_peptide=example['MT_pep'],
    device=device,
    max_length=MAX_LENGTH
)

print(f"\nPrediction Results:")
print(f"  - Predicted Effect: {result['label']}")
print(f"  - Correct: {'✓' if result['label'] == example['Effect'] else '✗'}")


Prediction Results:
  - Predicted Effect: Decrease
  - Correct: ✓


: 