# Pseq2Sites Binding Site Embeddings

This notebook demonstrates how to use the `Pseq2SitesEmbeddings` class to extract binding site embeddings from protein sequences.

The `Pseq2Sites` model predicts binding sites in protein sequences and can extract meaningful embeddings that capture binding site information.

## Setup and Imports

In [None]:
import sys
import os
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Add modules to path
sys.path.append("../")

from modules.pocket_modules.pseq2sites_embeddings import Pseq2SitesEmbeddings, extract_binding_site_embeddings

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

print("✓ All imports successful")

## Load Example Data

We'll load some example protein data to demonstrate the embedding extraction.

In [None]:
def load_example_data():
    """Load example protein data."""
    try:
        # Try to load real CASF2016 data
        feature_path = "../input_data/PDB/BA/CASF2016_protein_features.pkl"
        seq_path = "../input_data/PDB/BA/CASF2016_BA_data.tsv"
        
        if os.path.exists(feature_path) and os.path.exists(seq_path):
            # Load features
            with open(feature_path, 'rb') as f:
                protein_features = pickle.load(f)
                
            # Load sequences
            df = pd.read_csv(seq_path, sep='\t')
            protein_sequences = dict(zip(df.iloc[:, 1].values, df.iloc[:, 4].values))
            
            print(f"✓ Loaded {len(protein_features)} protein features")
            print(f"✓ Loaded {len(protein_sequences)} protein sequences")
            
            return protein_features, protein_sequences
        else:
            print("Real data not found, creating dummy data...")
            return create_dummy_data()
            
    except Exception as e:
        print(f"Error loading real data: {e}")
        print("Creating dummy data...")
        return create_dummy_data()

def create_dummy_data():
    """Create dummy data for demonstration."""
    protein_ids = ["P12345", "Q67890", "R54321", "S98765", "T11111"]
    amino_acids = "ACDEFGHIKLMNPQRSTVWY"
    
    protein_features = {}
    protein_sequences = {}
    
    for pid in protein_ids:
        # Random sequence length
        seq_len = np.random.randint(100, 400)
        
        # Create ProtBERT-like features
        features = np.random.randn(seq_len, 1024).astype(np.float32)
        protein_features[pid] = features
        
        # Create amino acid sequence
        sequence = ''.join(np.random.choice(list(amino_acids), seq_len))
        protein_sequences[pid] = sequence
        
    print(f"✓ Created dummy data for {len(protein_ids)} proteins")
    return protein_features, protein_sequences

# Load the data
protein_features, protein_sequences = load_example_data()

# Display basic info
print(f"\nDataset Summary:")
print(f"Number of proteins: {len(protein_features)}")
if protein_features:
    example_pid = list(protein_features.keys())[0]
    print(f"Example protein: {example_pid}")
    print(f"Feature shape: {protein_features[example_pid].shape}")
    print(f"Sequence length: {len(protein_sequences[example_pid])}")

## Initialize the Pseq2Sites Embedding Model

In [None]:
# Initialize the embedding model
print("Initializing Pseq2Sites embedding model...")

try:
    embedder = Pseq2SitesEmbeddings(
        device="auto"  # Automatically choose GPU if available
    )
    print("✓ Model initialized successfully!")
except Exception as e:
    print(f"❌ Error initializing model: {e}")
    print("\nThis might be due to:")
    print("1. Missing model checkpoint file")
    print("2. Missing configuration file")
    print("3. CUDA/device issues")
    raise

## Extract Binding Site Embeddings

Now we'll extract embeddings for a subset of proteins to demonstrate the functionality.

In [None]:
# Select a subset of proteins for demonstration
demo_proteins = list(protein_features.keys())[:3]
demo_features = {pid: protein_features[pid] for pid in demo_proteins}
demo_sequences = {pid: protein_sequences[pid] for pid in demo_proteins}

print(f"Extracting embeddings for {len(demo_proteins)} proteins: {demo_proteins}")

# Extract embeddings
results = embedder.extract_embeddings(
    protein_features=demo_features,
    protein_sequences=demo_sequences,
    batch_size=2,
    return_predictions=True,
    return_attention=True
)

print(f"\n✓ Successfully extracted embeddings for {len(results)} proteins")

## Analyze the Results

Let's examine what information we've extracted.

In [None]:
# Display detailed results for each protein
for pid, result in results.items():
    print(f"\n{'='*50}")
    print(f"PROTEIN: {pid}")
    print(f"{'='*50}")
    
    print(f"Sequence length: {result['sequence_length']}")
    print(f"Sequence: {result['sequence'][:50]}..." if len(result['sequence']) > 50 else f"Sequence: {result['sequence']}")
    
    print(f"\nEmbedding shapes:")
    print(f"  Sequence embeddings: {result['sequence_embeddings'].shape}")
    print(f"  Protein embeddings: {result['protein_embeddings'].shape}")
    
    if 'binding_site_probabilities' in result:
        probs = result['binding_site_probabilities']
        predicted_sites = result['predicted_binding_sites']
        
        print(f"\nBinding site predictions:")
        print(f"  Number of predicted sites (>0.5): {np.sum(predicted_sites)}")
        print(f"  Max probability: {np.max(probs):.3f}")
        print(f"  Mean probability: {np.mean(probs):.3f}")
        print(f"  Std probability: {np.std(probs):.3f}")
        
        # Top binding sites
        top_indices = np.argsort(probs)[-5:][::-1]
        print(f"\n  Top 5 binding sites:")
        for i, idx in enumerate(top_indices, 1):
            aa = result['sequence'][idx] if idx < len(result['sequence']) else 'X'
            print(f"    {i}. Position {idx+1} ({aa}): {probs[idx]:.3f}")

## Generate Binding Site Summary

In [None]:
# Generate a summary of binding site predictions
summary = embedder.get_binding_site_summary(results, threshold=0.5)

print("Binding Site Prediction Summary:")
print("=" * 60)
display(summary)

## Visualize Binding Site Predictions

In [None]:
# Create visualizations
fig, axes = plt.subplots(len(results), 2, figsize=(15, 4*len(results)))
if len(results) == 1:
    axes = axes.reshape(1, -1)

for i, (pid, result) in enumerate(results.items()):
    if 'binding_site_probabilities' not in result:
        continue
        
    probs = result['binding_site_probabilities']
    seq_len = result['sequence_length']
    positions = np.arange(1, seq_len + 1)
    
    # Plot 1: Binding site probabilities along sequence
    axes[i, 0].plot(positions, probs, 'b-', alpha=0.7, linewidth=1)
    axes[i, 0].axhline(y=0.5, color='r', linestyle='--', alpha=0.7, label='Threshold (0.5)')
    axes[i, 0].fill_between(positions, probs, alpha=0.3)
    axes[i, 0].set_title(f'Binding Site Probabilities - {pid}')
    axes[i, 0].set_xlabel('Residue Position')
    axes[i, 0].set_ylabel('Binding Probability')
    axes[i, 0].legend()
    axes[i, 0].grid(True, alpha=0.3)
    
    # Plot 2: Histogram of probabilities
    axes[i, 1].hist(probs, bins=50, alpha=0.7, edgecolor='black')
    axes[i, 1].axvline(x=0.5, color='r', linestyle='--', alpha=0.7, label='Threshold (0.5)')
    axes[i, 1].set_title(f'Probability Distribution - {pid}')
    axes[i, 1].set_xlabel('Binding Probability')
    axes[i, 1].set_ylabel('Frequency')
    axes[i, 1].legend()
    axes[i, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Analyze Embeddings

In [None]:
# Analyze the embedding spaces
print("Embedding Analysis:")
print("=" * 50)

for pid, result in results.items():
    seq_emb = result['sequence_embeddings']
    prot_emb = result['protein_embeddings']
    
    print(f"\nProtein {pid}:")
    print(f"  Sequence embeddings - Shape: {seq_emb.shape}")
    print(f"    Mean: {np.mean(seq_emb):.3f}, Std: {np.std(seq_emb):.3f}")
    print(f"    Min: {np.min(seq_emb):.3f}, Max: {np.max(seq_emb):.3f}")
    
    print(f"  Protein embeddings - Shape: {prot_emb.shape}")
    print(f"    Mean: {np.mean(prot_emb):.3f}, Std: {np.std(prot_emb):.3f}")
    print(f"    Min: {np.min(prot_emb):.3f}, Max: {np.max(prot_emb):.3f}")

## Save Results

In [None]:
# Create output directory
output_dir = "../results/notebook_embeddings/"
os.makedirs(output_dir, exist_ok=True)

# Save embeddings
embedder.save_embeddings(
    results, 
    os.path.join(output_dir, "pseq2sites_embeddings.pkl")
)

# Save summary
summary.to_csv(os.path.join(output_dir, "binding_site_summary.csv"), index=False)

print(f"✓ Results saved to {output_dir}")
print(f"  - Embeddings: pseq2sites_embeddings.pkl")
print(f"  - Summary: binding_site_summary.csv")

## Advanced Usage: Process Multiple Proteins

Let's demonstrate processing a larger batch of proteins.

In [None]:
# Process a larger subset
if len(protein_features) > 3:
    print("Processing larger batch of proteins...")
    
    # Take up to 10 proteins
    batch_proteins = list(protein_features.keys())[:min(10, len(protein_features))]
    batch_features = {pid: protein_features[pid] for pid in batch_proteins}
    batch_sequences = {pid: protein_sequences[pid] for pid in batch_proteins}
    
    # Extract embeddings
    batch_results = embedder.extract_embeddings(
        protein_features=batch_features,
        protein_sequences=batch_sequences,
        batch_size=4,
        return_predictions=True,
        return_attention=False  # Skip attention for speed
    )
    
    # Generate summary
    batch_summary = embedder.get_binding_site_summary(batch_results)
    
    print(f"\n✓ Processed {len(batch_results)} proteins")
    print("\nBatch Summary Statistics:")
    print(f"Average sequence length: {batch_summary['sequence_length'].mean():.1f}")
    print(f"Average binding sites per protein: {batch_summary['num_predicted_binding_sites'].mean():.1f}")
    print(f"Average binding site percentage: {batch_summary['binding_site_percentage'].mean():.1f}%")
    
    # Plot summary statistics
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    axes[0].hist(batch_summary['sequence_length'], bins=10, alpha=0.7, edgecolor='black')
    axes[0].set_title('Sequence Length Distribution')
    axes[0].set_xlabel('Sequence Length')
    axes[0].set_ylabel('Count')
    
    axes[1].hist(batch_summary['num_predicted_binding_sites'], bins=10, alpha=0.7, edgecolor='black')
    axes[1].set_title('Number of Binding Sites')
    axes[1].set_xlabel('Number of Sites')
    axes[1].set_ylabel('Count')
    
    axes[2].hist(batch_summary['binding_site_percentage'], bins=10, alpha=0.7, edgecolor='black')
    axes[2].set_title('Binding Site Percentage')
    axes[2].set_xlabel('Percentage (%)')
    axes[2].set_ylabel('Count')
    
    plt.tight_layout()
    plt.show()
else:
    print("Not enough proteins for batch processing demo")

## Summary

This notebook demonstrated how to:

1. **Initialize** the Pseq2Sites embedding model
2. **Extract embeddings** from protein sequences including:
   - Sequence-level embeddings (256-dim per residue)
   - Protein-level embeddings (256-dim per residue) 
   - Binding site probability predictions
   - Attention weights
3. **Analyze results** with summary statistics and visualizations
4. **Save and load** embeddings for future use

### Key Applications:

- **Drug Discovery**: Identify potential binding sites for drug design
- **Protein Analysis**: Understand protein-ligand interaction patterns
- **Machine Learning**: Use embeddings as features for downstream tasks
- **Comparative Studies**: Compare binding site patterns across proteins

### Next Steps:

- Use embeddings for clustering similar binding sites
- Train downstream models for specific drug discovery tasks
- Analyze attention patterns to understand model decision-making
- Compare predictions with experimental binding site data


## Additional Context: Ligand Encoding in BlendNet

While this notebook focuses on the `Pseq2Sites` model for protein pocket embeddings, it's useful to understand how ligands are handled in the broader BlendNet architecture.

In BlendNet:

*   **Ligands are Encoded as Vectors**: Yes, ligands (compounds) are processed and converted into vector representations.
*   **Graph Neural Networks (GNNs)**: Ligands are typically represented as molecular graphs (atoms as nodes, bonds as edges). A GNN (such as PNA or Net3D, as seen in BlendNet's `compound_modules`) is used to learn representations from these graphs.
*   **Types of Vector Representations**:
    *   **Atom Embeddings (`node_representations`)**: The GNN outputs a vector for each atom in the ligand, capturing its chemical context within the molecule.
    *   **Graph-Level Embeddings (`graph_representations`)**: Often, the atom embeddings are aggregated (e.g., through sum, mean, or max pooling) to produce a single vector representing the entire ligand molecule.

These ligand vector representations are then combined with protein pocket representations (like those from Pseq2Sites or similar models) in subsequent layers of BlendNet (e.g., cross-attention mechanisms) to predict protein-ligand binding affinity and interaction patterns.
