# Model Serving for ESM2 Contact Prediction

This notebook demonstrates how to serve the trained ESM2 contact prediction model on new PDB files. We'll cover the complete pipeline from PDB file input to contact prediction output, including:

- Loading trained models using the serving infrastructure
- Processing PDB files and extracting protein chains
- Generating ESM2 embeddings for contact prediction
- Making predictions and analyzing results
- Batch processing multiple PDB files
- REST API serving (optional)

## What You'll Learn

1. **Model Loading**: How to load trained models for inference
2. **PDB Processing**: Extract protein sequences and structural information
3. **Feature Generation**: Create ESM2 embeddings from protein sequences
4. **Contact Prediction**: Use the model to predict protein contact maps
5. **Result Analysis**: Visualize and evaluate predictions
6. **Production Serving**: Deploy models via REST API

## Prerequisites

- Trained model checkpoint (`.pth` file) or MLflow model URI
- PDB files for prediction
- Required dependencies installed (`esm`, `torch`, `mlflow`, etc.)

## 1. Setup and Configuration

In [1]:
# Standard library imports
import os
import sys
import json
import time
import warnings
from pathlib import Path

# Data processing
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch and ML
import torch
import esm
import mlflow
import mlflow.pyfunc

# Progress tracking
from tqdm.notebook import tqdm

# Project modules
sys.path.append('..')
from src.esm2_contact.serving import ContactPredictor, create_pyfunc_model, log_model_to_mlflow
from src.esm2_contact.dataset.processing import extract_chains_from_pdb, compute_contact_map, load_amino_acid_mapping
from src.notebook_utils.esm2_embeddings import compute_esm2_embeddings_batch

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")

print("✅ All imports loaded successfully!")
print(f"   PyTorch version: {torch.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name()}")

✅ All imports loaded successfully!
   PyTorch version: 2.9.0+cu128
   CUDA available: True
   GPU: NVIDIA GeForce RTX 4080 Laptop GPU


In [2]:
# Configuration - Update these paths as needed
CONFIG = {
    # Model paths
    'model_checkpoint': '../experiments/full_dataset_training/model.pth',  # Update with your model path
    'mlflow_model_uri': 'runs:/411149998746302666/m-af9291099ab9441fbf4cb47431763fe7',  # MLflow model URI
    
    # Data paths
    'pdb_dir': '../data/train',  # Directory containing PDB files
    'output_dir': '../predictions',  # Directory to save predictions
    'test_pdb': '../data/train/4H2C.pdb',  # Example PDB file
    
    # Model configuration
    'prediction_threshold': 0.5,  # Threshold for binary contact prediction
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',  # Device for inference
    'esm_model': 'esm2_t33_650M_UR50D',  # ESM2 model variant for embeddings
    'batch_size': 1,  # Batch size for ESM2 embedding computation
    'esm_layer': 33,  # ESM2 layer to extract embeddings from
}

## 2. Model Loading

### 2.1 Load Model from Checkpoint

In [3]:
# Load model using the successful approach from results analysis
import sys
from pathlib import Path

# Resolve paths from notebook location (notebooks/ -> project root)
notebook_dir = Path().absolute()
project_root = notebook_dir.parent
LOCAL_EXPERIMENT_DIR = project_root / "experiments/full_dataset_training"

# Load experiment configuration
config = None
model_path = None
pytorch_model = None

print("🔄 Setting up model loading...")

# Load configuration from local experiment
if LOCAL_EXPERIMENT_DIR.exists():
    config_file = LOCAL_EXPERIMENT_DIR / "config.json"
    model_file = LOCAL_EXPERIMENT_DIR / "model.pth"
    
    if config_file.exists():
        import json
        with open(config_file, "r") as f:
            config = json.load(f)
        print(f"✅ Configuration loaded from: {config_file}")
        print(f"   Architecture: CNN ({config['in_channels']}→{config['base_channels']} channels)")
    
    if model_file.exists():
        model_path = str(model_file)
        print(f"✅ Model file found: {model_path}")
    else:
        print(f"❌ Model file not found: {model_file}")
else:
    print(f"❌ Experiment directory not found: {LOCAL_EXPERIMENT_DIR}")

# Load PyTorch model using the proven approach
if config and model_path:
    try:
        # Ensure the project root is in the Python path
        if str(project_root) not in sys.path:
            sys.path.insert(0, str(project_root))
        
        # Import the model class
        from src.esm2_contact.training.model import BinaryContactCNN
        
        print(f"📂 Loading model from: {model_path}")
        print(f"🔧 Model config: {config['in_channels']}→{config['base_channels']} channels")
        
        # Create model instance with the same configuration as training
        model = BinaryContactCNN(
            in_channels=config['in_channels'],
            base_channels=config['base_channels'],
            dropout_rate=config['dropout_rate']
        )
        
        # Load trained weights with proper checkpoint handling
        import torch
        checkpoint = torch.load(model_path, map_location=CONFIG['device'])
        
        if 'model_state_dict' in checkpoint:
            # Extract the actual state dict from the checkpoint
            state_dict = checkpoint['model_state_dict']
            model.load_state_dict(state_dict)
            print(f"✅ Model state dict loaded from checkpoint")
        else:
            # Fallback: try loading the checkpoint directly
            model.load_state_dict(checkpoint)
            print(f"✅ Model loaded directly from checkpoint")
        
        model.eval()  # Set to evaluation mode
        model.to(CONFIG['device'])  # Move to target device
        
        print(f"✅ Model loaded successfully!")
        print(f"   Model type: {type(model).__name__}")
        print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
        print(f"   Input channels: {config['in_channels']}")
        print(f"   Base channels: {config['base_channels']}")
        print(f"   Device: {CONFIG['device']}")
        
        # Store model for later use
        pytorch_model = model
        
        # Test model with dummy data to verify functionality
        print(f"\n🧪 Testing model with dummy input...")
        dummy_input = torch.randn(1, config['in_channels'], 64, 64).to(CONFIG['device'])
        with torch.no_grad():
            dummy_output = model(dummy_input)
            print(f"   Input shape: {dummy_input.shape}")
            print(f"   Output shape: {dummy_output.shape}")
            print(f"   Output type: {type(dummy_output)}")
            print(f"   Output range: [{dummy_output.min():.4f}, {dummy_output.max():.4f}]")
        
        print(f"✅ Model inference test successful!")
        
    except ImportError as ie:
        print(f"❌ Import error: {ie}")
        print(f"   Could not import BinaryContactCNN from src.esm2_contact.training.model")
        pytorch_model = None
    except Exception as e:
        print(f"❌ Failed to load PyTorch model: {e}")
        print(f"   Error type: {type(e).__name__}")
        pytorch_model = None

# Fallback to ContactPredictor if PyTorch loading fails
if pytorch_model is None:
    print(f"\n🔄 Attempting to load using ContactPredictor...")
    model_path = Path(CONFIG['model_checkpoint'])
    
    if model_path.exists():
        try:
            from src.esm2_contact.serving import ContactPredictor
            
            predictor = ContactPredictor(
                model_path=CONFIG['model_checkpoint'],
                threshold=CONFIG['prediction_threshold'],
                device=CONFIG['device']
            )
            
            print(f"✅ ContactPredictor loaded successfully!")
            pytorch_model = predictor
            
        except Exception as e:
            print(f"❌ Failed to load ContactPredictor: {e}")
            pytorch_model = None

# Final fallback to MLflow
if pytorch_model is None and CONFIG['mlflow_model_uri']:
    print(f"\n🔄 Attempting to load MLflow model...")
    try:
        mlflow_model = mlflow.pyfunc.load_model(CONFIG['mlflow_model_uri'])
        
        # Test with dummy input
        dummy_input = np.random.randn(1, config['in_channels'] if config else 68, 64, 64).astype(np.float32)
        test_result = mlflow_model.predict(None, dummy_input)
        
        print(f"✅ MLflow model loaded successfully!")
        pytorch_model = mlflow_model
        
    except Exception as e:
        print(f"❌ Failed to load MLflow model: {e}")
        pytorch_model = None

# Final status
if pytorch_model is not None:
    print(f"\n🎉 Model loading completed successfully!")
    print(f"   Model type: {type(pytorch_model).__name__}")
    predictor = pytorch_model
else:
    print(f"\n❌ No model could be loaded. Please check model paths and dependencies.")
    print(f"   Model checkpoint: {CONFIG['model_checkpoint']}")
    print(f"   MLflow URI: {CONFIG['mlflow_model_uri']}")
    predictor = None

🔄 Setting up model loading...
✅ Configuration loaded from: /home/calmscout/Projects/PythonProjects/esm2-contact-prediction/experiments/full_dataset_training/config.json
   Architecture: CNN (68→32 channels)
✅ Model file found: /home/calmscout/Projects/PythonProjects/esm2-contact-prediction/experiments/full_dataset_training/model.pth
📂 Loading model from: /home/calmscout/Projects/PythonProjects/esm2-contact-prediction/experiments/full_dataset_training/model.pth
🔧 Model config: 68→32 channels
✅ Model state dict loaded from checkpoint
✅ Model loaded successfully!
   Model type: BinaryContactCNN
   Parameters: 380,033
   Input channels: 68
   Base channels: 32
   Device: cuda

🧪 Testing model with dummy input...
   Input shape: torch.Size([1, 68, 64, 64])
   Output shape: torch.Size([1, 64, 64])
   Output type: <class 'torch.Tensor'>
   Output range: [-9.5421, 6.7355]
✅ Model inference test successful!

🎉 Model loading completed successfully!
   Model type: BinaryContactCNN


### 2.2 Model Loading Status

In [4]:
# Model loading was completed successfully in cell 6
print(f"✅ Current model status:")
print(f"   Model type: {type(predictor).__name__}")
print(f"   Model loaded: {predictor is not None}")

if predictor is not None:
    print(f"   Device: {CONFIG['device']}")
    print(f"   Ready for inference: ✅")
else:
    print(f"   Ready for inference: ❌")

✅ Current model status:
   Model type: BinaryContactCNN
   Model loaded: True
   Device: cuda
   Ready for inference: ✅


## 3. PDB File Processing Pipeline

### 3.1 Load Amino Acid Mapping

In [5]:
# Load amino acid 3-letter to 1-letter mapping
aa_mapping_file = '../amino_acid_three_to_one.json'
aa_three_to_one = load_amino_acid_mapping(aa_mapping_file)
print(f"✅ Loaded {len(aa_three_to_one)} amino acid mappings")
print(f"   Examples: {list(aa_three_to_one.items())[:5]}")

✅ Loaded 20 amino acid mappings
   Examples: [('ALA', 'A'), ('ARG', 'R'), ('ASN', 'N'), ('ASP', 'D'), ('CYS', 'C')]


### 3.2 Process Example PDB File

In [6]:
# Find an example PDB file
def find_example_pdb():
    """Find an example PDB file for demonstration."""
    # Try to configured test PDB first
    if CONFIG['test_pdb'] and Path(CONFIG['test_pdb']).exists():
        return Path(CONFIG['test_pdb'])
    
    # Look in PDB directory
    pdb_dir = Path(CONFIG['pdb_dir'])
    if pdb_dir.exists():
        pdb_files = list(pdb_dir.glob('*.pdb'))
        if pdb_files:
            return pdb_files[0]
    
    return None

example_pdb = find_example_pdb()

if example_pdb:
    print(f"📄 Using example PDB file: {example_pdb}")
    
    # Extract chains from PDB
    print("🔄 Extracting protein chains...")
    chains_data = extract_chains_from_pdb(example_pdb, aa_three_to_one)
    
    print(f"✅ Found {len(chains_data)} chains:")
    for chain_id, data in chains_data.items():
        print(f"   Chain {chain_id}: {data['length']} residues")
        print(f"   Sequence: {data['sequence'][:50]}{'...' if len(data['sequence']) > 50 else ''}")
        
else:
    print("❌ No PDB files found. Please update CONFIG['pdb_dir'] or CONFIG['test_pdb']")
    # Exit gracefully if no data available
    raise FileNotFoundError("No PDB files available for processing")

📄 Using example PDB file: ../data/train/4H2C.pdb
🔄 Extracting protein chains...
✅ Found 1 chains:
   Chain A: 555 residues
   Sequence: GAPWWKSAVFYQVYPRSFKDTNGDGIGDFKGLTEKLDYLKGLGIDAIWIN...


### 3.3 Generate ESM2 Embeddings

In [7]:
# Load ESM2 model for embeddings
print(f"🔄 Loading ESM2 model: {CONFIG['esm_model']}")
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()
esm_model.to(CONFIG['device'])

print(f"✅ ESM2 model loaded on {CONFIG['device']}")
print(f"   Model parameters: {sum(p.numel() for p in esm_model.parameters()):,}")
print(f"   Number of layers: {esm_model.num_layers}")

🔄 Loading ESM2 model: esm2_t33_650M_UR50D
✅ ESM2 model loaded on cuda
   Model parameters: 651,043,254
   Number of layers: 33


In [8]:
# Prepare sequences for ESM2
sequences_list = []
for chain_id, data in chains_data.items():
    sequences_list.append((f"{example_pdb.stem}_{chain_id}" if example_pdb else f"synthetic_{chain_id}", 
                          data['sequence']))

print(f"📊 Processing {len(sequences_list)} sequences for ESM2 embeddings")
for name, seq in sequences_list:
    print(f"   {name}: {len(seq)} residues")

# Compute ESM2 embeddings
print("\n🧠 Computing ESM2 embeddings...")
start_time = time.time()

embeddings_dict = compute_esm2_embeddings_batch(
    model=esm_model,
    batch_converter=batch_converter,
    sequences_list=sequences_list,
    device=CONFIG['device'],
    batch_size=CONFIG['batch_size'],
    layer=CONFIG['esm_layer'],
    return_contacts=False,
    verbose=True
)

embedding_time = time.time() - start_time
print(f"✅ ESM2 embeddings computed in {embedding_time:.2f}s")

# Display embedding information
for identifier, data in embeddings_dict.items():
    embedding = data['embedding']
    print(f"   {identifier}: {embedding.shape}")

📊 Processing 1 sequences for ESM2 embeddings
   4H2C_A: 555 residues

🧠 Computing ESM2 embeddings...
🔢 Computing embeddings from layer 33
📦 Batch size: 1
🖥️  Device: cuda


Computing embeddings:   0%|          | 0/1 [00:00<?, ?it/s]

✅ Computed embeddings for 1 sequences
✅ ESM2 embeddings computed in 0.33s
   4H2C_A: (555, 1280)


### 3.4 Create Contact Prediction Features

In [9]:
# Create feature maps from embeddings for contact prediction (matching training pipeline)
print("🔄 Creating contact prediction features...")

def create_template_channels(sequence_length: int) -> np.ndarray:
    """
    Create 4 template channels with sequence-based patterns.
    
    Args:
        sequence_length: Length of the protein sequence
        
    Returns:
        numpy.ndarray: Template channels of shape (4, L, L)
    """
    L = sequence_length
    template_channels = np.zeros((4, L, L), dtype=np.float32)
    
    # Channel 0: Sequential proximity pattern
    for i in range(L):
        for j in range(L):
            if abs(i - j) <= 2:
                template_channels[0, i, j] = 0.8
    
    # Channel 1: Distance-based exponential decay pattern
    for i in range(L):
        for j in range(L):
            dist = abs(i - j)
            if dist <= 8:
                template_channels[1, i, j] = np.exp(-dist / 4.0)
    
    # Channel 2: Helical propensity pattern with periodicity
    for i in range(L):
        for j in range(L):
            if i != j:
                dist = abs(i - j)
                # Helical periodicity pattern
                if 3 <= dist <= 5:
                    template_channels[2, i, j] = 0.3
                elif dist >= 15:
                    template_channels[2, i, j] = 0.1
    
    # Channel 3: Long-range coevolution pattern
    for i in range(L):
        for j in range(L):
            if i != j:
                dist = abs(i - j)
                if dist > 12 and dist < 50:
                    template_channels[3, i, j] = 0.2 * (1 - dist / 50)
    
    # Set diagonal to 1.0 for all template channels
    for i in range(4):
        np.fill_diagonal(template_channels[i], 1.0)
    
    return template_channels

def assemble_68_channel_features(esm2_embedding: np.ndarray, template_channels: np.ndarray) -> np.ndarray:
    """
    Assemble 68-channel tensor following the training pipeline approach.
    
    Args:
        esm2_embedding: ESM2 embeddings of shape (L, 1280)
        template_channels: Template channels of shape (4, L, L)
        
    Returns:
        numpy.ndarray: 68-channel tensor of shape (68, L, L)
    """
    L = esm2_embedding.shape[0]
    channels = 68
    tensor = np.zeros((channels, L, L), dtype=np.float32)
    
    # Channels 0-3: Template channels
    tensor[0:4] = template_channels
    
    # Channels 4-67: ESM2 channels (use first 64 dimensions)
    esm2_64_channels = esm2_embedding[:, :64]  # Shape: (L, 64)
    
    for i in range(64):
        # Tile each ESM2 dimension across the sequence to create 2D map
        tensor[4 + i] = np.tile(esm2_64_channels[i:i+1, :], (L, 1))
    
    return tensor

# Process each chain
feature_maps = []
chain_info = []

for i, (identifier, data) in enumerate(embeddings_dict.items()):
    embedding = data['embedding']  # Shape: (L, 1280)
    chain_data = chains_data[list(chains_data.keys())[i]]
    
    print(f"   Processing {identifier}:")
    print(f"     ESM2 embedding shape: {embedding.shape}")
    
    # Create template channels
    template_channels = create_template_channels(embedding.shape[0])
    print(f"     Template channels shape: {template_channels.shape}")
    
    # Assemble 68-channel tensor
    feature_map = assemble_68_channel_features(embedding, template_channels)
    print(f"     Final feature map shape: {feature_map.shape}")
    
    feature_maps.append(feature_map)
    chain_info.append({
        'identifier': identifier,
        'sequence_length': embedding.shape[0],
        'original_sequence': chain_data['sequence']
    })

# Convert to numpy array for model input
features_array = np.stack(feature_maps)
print(f"\n✅ 68-channel feature maps created: {features_array.shape}")

# Resize to match model input size (64x64) if needed
target_size = 64
processed_features = []

for feature_map in feature_maps:
    channels, h, w = feature_map.shape
    
    if h == target_size and w == target_size:
        processed_features.append(feature_map)
    else:
        # Resize to target size
        if h > target_size or w > target_size:
            # Take center crop
            start_h = max(0, (h - target_size) // 2)
            start_w = max(0, (w - target_size) // 2)
            resized_map = feature_map[:, start_h:start_h+target_size, start_w:start_w+target_size]
        else:
            # Pad with zeros
            padded_map = np.zeros((channels, target_size, target_size), dtype=np.float32)
            padded_map[:, :h, :w] = feature_map
            resized_map = padded_map
        
        processed_features.append(resized_map)
        print(f"   Resized from {feature_map.shape} to {resized_map.shape}")

features_array = np.stack(processed_features)
print(f"✅ Final features ready for model input: {features_array.shape}")
print(f"   Expected by model: (batch, 68, 64, 64)")

🔄 Creating contact prediction features...
   Processing 4H2C_A:
     ESM2 embedding shape: (555, 1280)
     Template channels shape: (4, 555, 555)


ValueError: could not broadcast input array from shape (555,64) into shape (555,555)

## 4. Contact Prediction

### 4.1 Make Predictions

In [None]:
# Helper function for direct PyTorch model inference
def direct_model_inference(model, features_array, config):
    """
    Handle inference for direct PyTorch models (BinaryContactCNN).
    
    Args:
        model: PyTorch model (BinaryContactCNN)
        features_array: numpy array of shape (batch, channels, height, width)
        config: configuration dictionary
        
    Returns:
        dict: Predictions in the same format as ContactPredictor
    """
    import torch
    
    # Convert numpy array to torch tensor
    if isinstance(features_array, np.ndarray):
        features_tensor = torch.from_numpy(features_array).float().to(config['device'])
    else:
        features_tensor = features_array.to(config['device'])
    
    # Ensure correct input shape (batch, channels, height, width)
    if features_tensor.dim() == 4:
        # Expected shape: (batch_size, channels, height, width)
        pass
    elif features_tensor.dim() == 3:
        # Add batch dimension if missing
        features_tensor = features_tensor.unsqueeze(0)
    
    # Run inference
    model.eval()
    with torch.no_grad():
        raw_outputs = model(features_tensor)  # Shape: (batch, height, width)
        
        # Apply sigmoid to get probabilities
        probabilities = torch.sigmoid(raw_outputs)
        
        # Apply threshold to get binary predictions
        threshold = config['prediction_threshold']
        binary_predictions = (probabilities >= threshold).float()
        
        # Calculate confidence scores (distance from threshold)
        confidence_scores = torch.abs(probabilities - 0.5) * 2  # Maps to [0, 1]
    
    # Convert back to numpy and format results
    batch_size = features_tensor.shape[0]
    
    results = {
        'batch_size': batch_size,
        'threshold': threshold,
        'predictions': [],
        'probabilities': [],
        'confidence_scores': []
    }
    
    for i in range(batch_size):
        results['predictions'].append(binary_predictions[i].cpu().numpy())
        results['probabilities'].append(probabilities[i].cpu().numpy())
        results['confidence_scores'].append(confidence_scores[i].cpu().numpy())
    
    return results

In [None]:
if predictor is not None:
    print("🔮 Making contact predictions...")
    
    start_time = time.time()
    
    # Make predictions based on model type
    if hasattr(predictor, '_predict_batch'):
        # ContactPredictor class
        predictions = predictor._predict_batch(features_array)
    elif hasattr(predictor, 'forward'):
        # Direct PyTorch model (BinaryContactCNN)
        predictions = direct_model_inference(predictor, features_array, CONFIG)
    elif hasattr(predictor, 'predict'):
        # MLflow pyfunc model
        predictions = predictor.predict(None, features_array)
    else:
        raise ValueError(f"Unknown model type: {type(predictor)}")
    
    prediction_time = time.time() - start_time
    
    print(f"✅ Predictions completed in {prediction_time:.4f}s")
    print(f"   Batch size: {predictions['batch_size']}")
    print(f"   Prediction shapes:")
    
    for i, (chain_info_item, pred_shape) in enumerate(zip(chain_info, predictions['predictions'])):
        print(f"     {chain_info_item['identifier']}: {np.array(pred_shape).shape}")
    
else:
    print("❌ No model loaded. Please load a model first.")
    raise RuntimeError("Model loading failed - cannot proceed with predictions")

### 4.2 Process and Analyze Results

In [None]:
# Process predictions and create results summary
results_summary = []

for i, chain_info_item in enumerate(chain_info):
    identifier = chain_info_item['identifier']
    seq_len = chain_info_item['sequence_length']
    
    # Get predictions for this chain
    pred_contacts = np.array(predictions['predictions'][i])
    pred_probabilities = np.array(predictions['probabilities'][i])
    confidence_scores = np.array(predictions['confidence_scores'][i])
    
    # Crop to actual sequence length
    if seq_len < 64:
        pred_contacts = pred_contacts[:seq_len, :seq_len]
        pred_probabilities = pred_probabilities[:seq_len, :seq_len]
        confidence_scores = confidence_scores[:seq_len, :seq_len]
    else:
        pred_contacts = pred_contacts[:seq_len, :seq_len]
        pred_probabilities = pred_probabilities[:seq_len, :seq_len]
        confidence_scores = confidence_scores[:seq_len, :seq_len]
    
    # Calculate metrics
    total_contacts = np.sum(pred_contacts)
    contact_density = total_contacts / (seq_len * (seq_len - 1) / 2) if seq_len > 1 else 0
    avg_confidence = np.mean(confidence_scores)
    
    # Store results
    result = {
        'identifier': identifier,
        'sequence_length': seq_len,
        'predicted_contacts': pred_contacts,
        'probabilities': pred_probabilities,
        'confidence_scores': confidence_scores,
        'total_contacts': int(total_contacts),
        'contact_density': float(contact_density),
        'avg_confidence': float(avg_confidence),
        'threshold': predictions['threshold']
    }
    
    results_summary.append(result)
    
    print(f"\n📊 Results for {identifier}:")
    print(f"   Sequence length: {seq_len}")
    print(f"   Predicted contacts: {total_contacts:,}")
    print(f"   Contact density: {contact_density:.4f}")
    print(f"   Average confidence: {avg_confidence:.4f}")
    print(f"   Contact map shape: {pred_contacts.shape}")

## 5. Results Analysis and Visualization

### 5.1 Visualize Contact Maps

In [None]:
# Create visualization of contact predictions
n_chains = len(results_summary)
if n_chains > 0:
    fig, axes = plt.subplots(2, 2*n_chains, figsize=(6*n_chains, 12))
    
    if n_chains == 1:
        axes = axes.reshape(1, -1)  # Handle single chain case
    
    for i, result in enumerate(results_summary):
        identifier = result['identifier']
        contacts = result['predicted_contacts']
        probabilities = result['probabilities']
        confidence = result['confidence_scores']
        
        # Binary contacts
        ax1 = axes[0, 2*i] if n_chains > 1 else axes[0, i]
        im1 = ax1.imshow(contacts, cmap='Blues', interpolation='nearest')
        ax1.set_title(f'{identifier}\nBinary Contacts\n({result["total_contacts"]:,} contacts)')
        ax1.set_xlabel('Residue index')
        ax1.set_ylabel('Residue index')
        plt.colorbar(im1, ax=ax1, fraction=0.046)
        
        # Probabilities
        ax2 = axes[0, 2*i+1] if n_chains > 1 else axes[1, i]
        im2 = ax2.imshow(probabilities, cmap='viridis', interpolation='nearest')
        ax2.set_title(f'{identifier}\nContact Probabilities\n(threshold={result["threshold"]})')
        ax2.set_xlabel('Residue index')
        ax2.set_ylabel('Residue index')
        plt.colorbar(im2, ax=ax2, fraction=0.046)
        
        # Confidence scores
        ax3 = axes[1, 2*i] if n_chains > 1 else axes[2, i]
        im3 = ax3.imshow(confidence, cmap='plasma', interpolation='nearest')
        ax3.set_title(f'{identifier}\nConfidence Scores\n(avg={result["avg_confidence"]:.3f})')
        ax3.set_xlabel('Residue index')
        ax3.set_ylabel('Residue index')
        plt.colorbar(im3, ax=ax3, fraction=0.046)
        
        # Contact distribution
        ax4 = axes[1, 2*i+1] if n_chains > 1 else axes[3, i]
        seq_len = result['sequence_length']
        
        # Calculate contact distribution by sequence separation
        separations = range(1, min(seq_len, 50))
        contact_rates = []
        
        for sep in separations:
            sep_contacts = 0
            total_pairs = 0
            for i_pos in range(seq_len - sep):
                sep_contacts += contacts[i_pos, i_pos + sep]
                total_pairs += 1
            contact_rates.append(sep_contacts / total_pairs if total_pairs > 0 else 0)
        
        ax4.plot(separations, contact_rates, 'b-', linewidth=2)
        ax4.set_title(f'{identifier}\nContact Rate vs Separation')
        ax4.set_xlabel('Sequence separation')
        ax4.set_ylabel('Contact rate')
        ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Save figure
    output_path = Path(CONFIG['output_dir']) / 'contact_predictions.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"📊 Contact map visualization saved to: {output_path}")
else:
    print("❌ No results to visualize")

### 5.2 Summary Statistics

In [None]:
# Create summary statistics table
if results_summary:
    summary_df = pd.DataFrame([
        {
            'Chain': result['identifier'],
            'Length': result['sequence_length'],
            'Contacts': result['total_contacts'],
            'Contact Density': f"{result['contact_density']:.4f}",
            'Avg Confidence': f"{result['avg_confidence']:.4f}",
            'Max Probability': f"{np.max(result['probabilities']):.4f}",
            'Threshold': result['threshold']
        }
        for result in results_summary
    ])
    
    print("📊 Prediction Summary:")
    display(summary_df.style.set_caption("Contact Prediction Results"))
    
    # Save summary to file
    summary_path = Path(CONFIG['output_dir']) / 'prediction_summary.csv'
    summary_df.to_csv(summary_path, index=False)
    print(f"\n💾 Summary saved to: {summary_path}")
else:
    print("❌ No results to summarize")

## 6. Save Results and Export

### 7.1 Save Complete Results

In [None]:
# Save complete results to JSON
if results_summary:
    complete_results = {
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        'config': CONFIG,
        'model_info': {
            'model_type': type(predictor).__name__ if predictor else None,
            'threshold': predictions.get('threshold', CONFIG['prediction_threshold'])
        },
        'results': []
    }
    
    # Add individual chain results
    for result in results_summary:
        complete_results['results'].append({
            'identifier': result['identifier'],
            'sequence_length': result['sequence_length'],
            'total_contacts': result['total_contacts'],
            'contact_density': result['contact_density'],
            'avg_confidence': result['avg_confidence'],
            'predicted_contacts': result['predicted_contacts'].tolist(),
            'probabilities': result['probabilities'].tolist(),
            'confidence_scores': result['confidence_scores'].tolist()
        })
    
    # Save complete results
    complete_results_path = Path(CONFIG['output_dir']) / 'complete_results.json'
    with open(complete_results_path, 'w') as f:
        json.dump(complete_results, f, indent=2)
    
    print(f"💾 Complete results saved to: {complete_results_path}")
    print(f"   File size: {complete_results_path.stat().st_size / 1024:.1f} KB")
else:
    print("❌ No results to save")

## 8. Summary and Next Steps

### 8.1 What We Accomplished

In [None]:
print("🎉 ESM2 Contact Prediction Serving - Summary")
print("=" * 50)

print(f"\n📊 Processing Results:")
print(f"   • Protein chains analyzed: {len(results_summary)}")
print(f"   • Total contacts predicted: {sum(r['total_contacts'] for r in results_summary):,}")

if results_summary:
    avg_density = np.mean([r['contact_density'] for r in results_summary])
    avg_confidence = np.mean([r['avg_confidence'] for r in results_summary])
    print(f"   • Average contact density: {avg_density:.4f}")
    print(f"   • Average confidence: {avg_confidence:.4f}")

print(f"\n🔧 Technical Infrastructure:")
print(f"   • Model loading: {'✅' if predictor else '❌'}")
print(f"   • ESM2 embeddings: {'✅' if 'embeddings_dict' in locals() else '❌'}")
print(f"   • Contact prediction: {'✅' if predictions else '❌'}")
print(f"   • Visualization: {'✅' if 'plt' in locals() else '❌'}")

print(f"\n📁 Output Files Created:")
output_path = Path(CONFIG['output_dir'])
if output_path.exists():
    created_files = list(output_path.rglob('*'))
    print(f"   • Total files created: {len(created_files)}")
    for file_path in created_files[:10]:  # Show first 10
        rel_path = file_path.relative_to(output_path)
        size_kb = file_path.stat().st_size / 1024
        print(f"     • {rel_path} ({size_kb:.1f} KB)")
    if len(created_files) > 10:
        print(f"     • ... and {len(created_files) - 10} more files")

print(f"\n✨ Successfully completed ESM2 contact prediction serving!")
print(f"   Model worked with real PDB files and generated contact predictions.")