# π-Proteoformer Isoform Embedding Shift 

This script demonstrates how to use the pre-trained π-Proteoformer model
to compute embedding shift between protein isoforms and their canonical sequences.

The embedding shift is defined as the Cosine distance between two vectors:
    shift = 1 - cosine_similarity(isoform_embedding, canonical_embedding)

## 1: Import Libraries and Setup Path

In [2]:
import sys
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm

# Add the proteoformer package to path
PROJECT_ROOT = Path("your/project/root")
sys.path.insert(0, str(PROJECT_ROOT  / "pi-Proteoformer" / "src"))

# Import Proteoformer components
from proteoformer.net import ProteoformerForEmbedding
from proteoformer.tokenization import ProteoformerTokenizer

print("Libraries imported successfully!")

Libraries imported successfully!


## 2: Configuration

In [2]:
# Paths configuration
PRETRAINED_MODEL_PATH = 'pretrain_stage2_balance_1126/checkpoint-197000'

# FASTA file containing protein sequences
FASTA_FILE = "isoform_seq.fasta"

# Maximum sequence length for tokenization
MAX_LENGTH = 1024

# Batch size for inference
BATCH_SIZE = 8

# Example protein IDs to demonstrate
CANONICAL_ID = "O15409"  # FOXP2_HUMAN - Forkhead box protein P2
ISOFORM_ID = "O15409-8"  # Isoform 8 of FOXP2

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


## 3: Define Helper Functions

In [3]:
def parse_fasta(fasta_file: str) -> Dict[str, str]:
    """
    Parse FASTA file and return a dictionary mapping protein IDs to sequences.
    
    Args:
        fasta_file: Path to FASTA file
    
    Returns:
        Dictionary mapping protein IDs (e.g., 'A0FGR8', 'A0FGR8-2') to sequences
    """
    sequences = {}
    current_id = None
    current_seq = []
    
    with open(fasta_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                # Save previous sequence
                if current_id is not None:
                    sequences[current_id] = ''.join(current_seq)
                
                # Parse new header: >sp|A0FGR8|ESYT2_HUMAN or >sp|A0FGR8-2|ESYT2_HUMAN
                parts = line.split('|')
                if len(parts) >= 2:
                    current_id = parts[1]  # Extract protein ID (e.g., A0FGR8 or A0FGR8-2)
                else:
                    current_id = line[1:].split()[0]  # Fallback
                current_seq = []
            else:
                current_seq.append(line)
        
        # Save last sequence
        if current_id is not None:
            sequences[current_id] = ''.join(current_seq)
    
    return sequences


def mean_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    """
    Apply mean pooling over sequence length (excluding padding).
    
    Args:
        hidden_states: Hidden states from model (batch_size, seq_len, hidden_size)
        attention_mask: Attention mask (batch_size, seq_len)
    
    Returns:
        Pooled embeddings (batch_size, hidden_size)
    """
    mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
    sum_embeddings = torch.sum(hidden_states * mask_expanded, dim=1)
    sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
    return sum_embeddings / sum_mask


def cosine_distance(x: torch.Tensor, y: torch.Tensor) -> float:
    """
    Compute cosine distance between two vectors.
    
    Cosine distance is defined as: 1 - cosine_similarity(x, y)
    
    Args:
        x: First embedding vector
        y: Second embedding vector
    
    Returns:
        Cosine distance (float between 0 and 2)
    """
    # Ensure vectors are 1D
    x = x.flatten()
    y = y.flatten()
    
    # Compute cosine similarity
    cos_sim = F.cosine_similarity(x.unsqueeze(0), y.unsqueeze(0), dim=1).item()
    
    # Return cosine distance
    return 1.0 - cos_sim


def extract_embedding(
    model: ProteoformerForEmbedding,
    tokenizer: ProteoformerTokenizer,
    sequence: str,
    device: torch.device,
    max_length: int = 1024
) -> torch.Tensor:
    """
    Extract embedding for a single protein sequence.
    
    Args:
        model: Pretrained ProteoformerForEmbedding model
        tokenizer: ProteoformerTokenizer
        sequence: Protein sequence (amino acid string)
        device: Device to run inference on
        max_length: Maximum sequence length for model input
    
    Returns:
        Embedding tensor (hidden_size,)
    """
    with torch.no_grad():
        # Tokenize sequence
        encoded = tokenizer(
            [sequence],
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
            add_special_tokens=False
        )
        
        # Move to device
        input_ids = encoded['input_ids'].to(device)
        attention_mask = encoded['attention_mask'].to(device)
        
        # Get embeddings from model
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Extract last hidden state
        hidden_states = outputs.last_hidden_state  # (1, seq_len, hidden_size)
        
        # Apply mean pooling
        embedding = mean_pooling(hidden_states, attention_mask)
        
    return embedding.squeeze(0).cpu()  # (hidden_size,)

def get_canonical_id(isoform_id: str) -> str:
    """
    Extract canonical protein ID from isoform ID.
    
    Args:
        isoform_id: Isoform ID (e.g., 'A0FGR8-2')
    
    Returns:
        Canonical protein ID (e.g., 'A0FGR8')
    """
    return isoform_id.split('-')[0]

## 4: Define Model Loading Function

In [4]:
def load_model(
    checkpoint_path: str,
    device: torch.device
) -> Tuple[ProteoformerForEmbedding, ProteoformerTokenizer]:
    """
    Load the π-Proteoformer model for embedding extraction.
    
    Args:
        checkpoint_path: Path to model checkpoint
        device: Torch device to load model on
    
    Returns:
        Tuple of (model, tokenizer)
    """
    # print(f"Loading π-Proteoformer from {checkpoint_path}...")
    
    # Load tokenizer
    tokenizer = ProteoformerTokenizer.from_pretrained(checkpoint_path)
    
    # Load model
    model = ProteoformerForEmbedding.from_pretrained(checkpoint_path)
    model = model.to(device)
    model.eval()
    
    # Print model info
    print(f"  - Model hidden size: {model.config.hidden_size}")
    print(f"  - Number of layers: {model.config.num_hidden_layers}")
    print(f"  - Number of attention heads: {model.config.num_attention_heads}")
    
    print("Model loaded successfully!")
    return model, tokenizer


## 5: Define Embedding Shift Computation Functions

In [5]:
def compute_embedding_shift(
    model: ProteoformerForEmbedding,
    tokenizer: ProteoformerTokenizer,
    canonical_seq: str,
    isoform_seq: str,
    device: torch.device,
    max_length: int = 1024
) -> Dict[str, any]:
    """
    Compute embedding shift between a canonical protein and its isoform.
    
    The embedding shift is defined as the Cosine distance:
        shift = 1 - cosine_similarity(isoform_embedding, canonical_embedding)
    
    Args:
        model: Loaded π-Proteoformer model
        tokenizer: Loaded tokenizer
        canonical_seq: Canonical protein sequence
        isoform_seq: Isoform sequence
        device: Torch device
        max_length: Maximum sequence length
    
    Returns:
        Dictionary containing:
        - canonical_embedding: Embedding vector for canonical sequence
        - isoform_embedding: Embedding vector for isoform sequence
        - cosine_similarity: Cosine similarity between embeddings
        - cosine_distance: Cosine distance (embedding shift)
        - euclidean_distance: Euclidean distance between embeddings
    """
    model.eval()
    
    print("Extracting canonical embedding...")
    canonical_emb = extract_embedding(
        model=model,
        tokenizer=tokenizer,
        sequence=canonical_seq,
        device=device,
        max_length=max_length
    )
    
    print("Extracting isoform embedding...")
    isoform_emb = extract_embedding(
        model=model,
        tokenizer=tokenizer,
        sequence=isoform_seq,
        device=device,
        max_length=max_length
    )
    
    # Compute cosine similarity
    cos_sim = F.cosine_similarity(
        canonical_emb.unsqueeze(0), 
        isoform_emb.unsqueeze(0), 
        dim=1
    ).item()
    
    # Compute cosine distance (embedding shift)
    cos_dist = 1.0 - cos_sim
    
    # Compute Euclidean distance
    euclidean_dist = torch.norm(canonical_emb - isoform_emb).item()
    
    return {
        'canonical_embedding': canonical_emb,
        'isoform_embedding': isoform_emb,
        'cosine_similarity': cos_sim,
        'cosine_distance': cos_dist,
        'euclidean_distance': euclidean_dist
    }

## 6: Load Model

In [6]:
print("=" * 80)
print("Loading π-Proteoformer Model")
print("=" * 80)

model, tokenizer = load_model(
    checkpoint_path=PRETRAINED_MODEL_PATH,
    device=device
)

Loading π-Proteoformer Model


You are using a model of type proteoformer2 to instantiate a model of type proteoformer. This is not supported for all configurations of models and can yield errors.


  - Model hidden size: 1280
  - Number of layers: 24
  - Number of attention heads: 16
Model loaded successfully!


## 7: Load Sequences from FASTA

In [8]:
print("=" * 80)
print("Loading Sequences from FASTA")
print("=" * 80)

# Parse FASTA file
sequences = parse_fasta(FASTA_FILE)

# Get our example sequences
canonical_seq = sequences.get(CANONICAL_ID)
isoform_seq = sequences.get(ISOFORM_ID)

Loading Sequences from FASTA


## 8: Single Pair Embedding Shift Example

In [9]:
# Compute embedding shift
result = compute_embedding_shift(
    model=model,
    tokenizer=tokenizer,
    canonical_seq=canonical_seq,
    isoform_seq=isoform_seq,
    device=device,
    max_length=MAX_LENGTH
)

Extracting canonical embedding...
Extracting isoform embedding...


In [10]:
print(f"\nEmbedding Shift Results:")
print(f"  - Embedding dimension: {result['canonical_embedding'].shape[0]}")
print(f"  - Cosine Similarity: {result['cosine_similarity']:.6f}")
print(f"  - Cosine Distance (Embedding Shift): {result['cosine_distance']:.6f}")


Embedding Shift Results:
  - Embedding dimension: 1280
  - Cosine Similarity: 0.605798
  - Cosine Distance (Embedding Shift): 0.394202
