# ESM2 Interpretability from CIF
This notebook extracts sequences and DSSP labels from CIF files, embeds with ESM2, and learns biological directions.

In [1]:
import os, torch, esm, numpy as np, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
from Bio.PDB import MMCIFParser, PPBuilder, calc_dihedral, NeighborSearch
from Bio.PDB.DSSP import DSSP
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report, confusion_matrix
from scipy.spatial.distance import pdist, squareform
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [2]:
cif_dir = './pdbs'
# Store all features
all_X, all_ss, all_rsa = [], [], []
all_phi, all_psi, all_bfactors = [], [], []
all_contacts, all_hydrophobicity, all_charge = [], [], []
all_sequences, all_coords = [], []

# Amino acid properties
aa_hydrophobicity = {
    'A': 1.8, 'R': -4.5, 'N': -3.5, 'D': -3.5, 'C': 2.5,
    'Q': -3.5, 'E': -3.5, 'G': -0.4, 'H': -3.2, 'I': 4.5,
    'L': 3.8, 'K': -3.9, 'M': 1.9, 'F': 2.8, 'P': -1.6,
    'S': -0.8, 'T': -0.7, 'W': -0.9, 'Y': -1.3, 'V': 4.2
}

aa_charge = {
    'A': 0, 'R': 1, 'N': 0, 'D': -1, 'C': 0, 'Q': 0, 'E': -1, 
    'G': 0, 'H': 0.1, 'I': 0, 'L': 0, 'K': 1, 'M': 0, 'F': 0, 
    'P': 0, 'S': 0, 'T': 0, 'W': 0, 'Y': 0, 'V': 0
}

In [3]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model.eval()
batch_converter = alphabet.get_batch_converter()
parser = MMCIFParser(QUIET=True)
ppb = PPBuilder()

def calculate_phi_psi(residues):
    """Calculate phi and psi angles for a chain of residues"""
    phi_angles, psi_angles = [], []
    for i, residue in enumerate(residues):
        try:
            if i > 0 and i < len(residues) - 1:
                # Phi angle: C(i-1) - N(i) - CA(i) - C(i)
                phi = calc_dihedral(residues[i-1]['C'].get_vector(),
                                  residue['N'].get_vector(),
                                  residue['CA'].get_vector(),
                                  residue['C'].get_vector())
                # Psi angle: N(i) - CA(i) - C(i) - N(i+1)
                psi = calc_dihedral(residue['N'].get_vector(),
                                  residue['CA'].get_vector(),
                                  residue['C'].get_vector(),
                                  residues[i+1]['N'].get_vector())
                phi_angles.append(phi)
                psi_angles.append(psi)
            else:
                phi_angles.append(0.0)
                psi_angles.append(0.0)
        except:
            phi_angles.append(0.0)
            psi_angles.append(0.0)
    return phi_angles, psi_angles

def calculate_contacts(residues, cutoff=8.0):
    """Calculate contact map for residues"""
    coords = []
    for res in residues:
        try:
            coords.append(res['CA'].get_coord())
        except:
            coords.append(np.array([0, 0, 0]))
    
    coords = np.array(coords)
    distances = squareform(pdist(coords))
    contacts = (distances < cutoff).astype(float)
    
    # Return contact density for each residue
    contact_density = np.sum(contacts, axis=1) / len(contacts)
    return contact_density

In [None]:
cif_files = [f for f in os.listdir(cif_dir) if f.endswith('.cif')]
print(f"Found {len(cif_files)} CIF files to process")

for cif in tqdm(cif_files, desc="Processing CIF files"):
    path = os.path.join(cif_dir, cif)
    
    # Check if file actually exists
    if not os.path.exists(path):
        print(f"Warning: {cif} not found at {path}")
        continue
    
    try:
        struct = parser.get_structure('x', path)
        dssp = DSSP(struct[0], path)
        
        for model0 in struct:
            for chain in model0:
                peptides = ppb.build_peptides(chain)
                residues = list(chain.get_residues())
                
                for pep in peptides:
                    seq = str(pep.get_sequence())
                    if len(seq) < 10:  # Skip very short sequences
                        continue
                        
                    # ESM2 embeddings
                    data = [('p', seq)]
                    _, _, toks = batch_converter(data)
                    with torch.no_grad(): 
                        out = model(toks, repr_layers=[33])
                    reps = out['representations'][33][0, 1:len(seq)+1].numpy()
                    
                    # Structural features from DSSP
                    ss, rsa, bfac = [], [], []
                    valid_indices = []
                    
                    for i, key in enumerate(dssp.keys()):
                        if i >= len(seq): break
                        try:
                            ss.append(dssp[key][2])
                            rsa.append(dssp[key][3])
                            # B-factor (temperature factor)
                            bfac.append(residues[i]['CA'].get_bfactor() if i < len(residues) else 50.0)
                            valid_indices.append(i)
                        except:
                            continue
                    
                    if len(valid_indices) < len(seq) * 0.8:  # Skip if too many missing values
                        continue
                    
                    # Calculate dihedral angles
                    phi_angles, psi_angles = calculate_phi_psi(residues[:len(seq)])
                    
                    # Calculate contact density
                    contacts = calculate_contacts(residues[:len(seq)])
                    
                    # Physicochemical properties
                    hydrophob = [aa_hydrophobicity.get(aa, 0) for aa in seq]
                    charges = [aa_charge.get(aa, 0) for aa in seq]
                    
                    # Store all features (only for valid residues)
                    all_X.append(reps[:len(valid_indices)])
                    all_ss.append(np.array(ss))
                    all_rsa.append(np.array(rsa))
                    all_phi.append(np.array(phi_angles[:len(valid_indices)]))
                    all_psi.append(np.array(psi_angles[:len(valid_indices)]))
                    all_bfactors.append(np.array(bfac))
                    all_contacts.append(contacts[:len(valid_indices)])
                    all_hydrophobicity.append(np.array(hydrophob[:len(valid_indices)]))
                    all_charge.append(np.array(charges[:len(valid_indices)]))
                    all_sequences.append(seq[:len(valid_indices)])
                    
    except Exception as e:
        print(f"Error processing {cif}: {e}")
        continue

print(f"Processed {len(all_X)} protein chains")
print(f"Total residues: {sum(len(x) for x in all_X)}")

Found 200 CIF files to process


Processing CIF files:   0%|          | 0/200 [00:00<?, ?it/s]

Processing CIF files:  10%|â–ˆ         | 21/200 [06:28<35:19, 11.84s/it]  

In [None]:
# Concatenate all features
X = np.concatenate(all_X)
Y_ss = np.concatenate(all_ss)
Y_rsa = np.concatenate(all_rsa)
Y_phi = np.concatenate(all_phi)
Y_psi = np.concatenate(all_psi)
Y_bfactors = np.concatenate(all_bfactors)
Y_contacts = np.concatenate(all_contacts)
Y_hydrophobicity = np.concatenate(all_hydrophobicity)
Y_charge = np.concatenate(all_charge)

print(f"Feature matrix shape: {X.shape}")
print(f"Unique secondary structures: {np.unique(Y_ss)}")
print(f"RSA range: [{Y_rsa.min():.3f}, {Y_rsa.max():.3f}]")
print(f"Contact density range: [{Y_contacts.min():.3f}, {Y_contacts.max():.3f}]")

In [None]:
# Train multiple classifiers for different biological properties
print("Training classifiers...")

# Secondary structure
clf_ss = LogisticRegression(max_iter=2000, class_weight='balanced').fit(X, Y_ss)
ss_score = cross_val_score(clf_ss, X, Y_ss, cv=5).mean()
print(f"Secondary structure CV accuracy: {ss_score:.3f}")

# Surface accessibility (binary)
Y_rsa_bin = (Y_rsa > 0.25).astype(int)
clf_rsa = LogisticRegression(max_iter=2000, class_weight='balanced').fit(X, Y_rsa_bin)
rsa_score = cross_val_score(clf_rsa, X, Y_rsa_bin, cv=5).mean()
print(f"Surface accessibility CV accuracy: {rsa_score:.3f}")

# B-factor (flexibility) - high vs low
Y_bfactor_bin = (Y_bfactors > np.median(Y_bfactors)).astype(int)
clf_bfactor = LogisticRegression(max_iter=2000, class_weight='balanced').fit(X, Y_bfactor_bin)
bfactor_score = cross_val_score(clf_bfactor, X, Y_bfactor_bin, cv=5).mean()
print(f"B-factor (flexibility) CV accuracy: {bfactor_score:.3f}")

# Contact density - high vs low contacts
Y_contacts_bin = (Y_contacts > np.median(Y_contacts)).astype(int)
clf_contacts = LogisticRegression(max_iter=2000, class_weight='balanced').fit(X, Y_contacts_bin)
contacts_score = cross_val_score(clf_contacts, X, Y_contacts_bin, cv=5).mean()
print(f"Contact density CV accuracy: {contacts_score:.3f}")

# Hydrophobicity - hydrophobic vs hydrophilic
Y_hydrophob_bin = (Y_hydrophobicity > 0).astype(int)
clf_hydrophob = LogisticRegression(max_iter=2000, class_weight='balanced').fit(X, Y_hydrophob_bin)
hydrophob_score = cross_val_score(clf_hydrophob, X, Y_hydrophob_bin, cv=5).mean()
print(f"Hydrophobicity CV accuracy: {hydrophob_score:.3f}")

# Charge - charged vs neutral
Y_charge_bin = (np.abs(Y_charge) > 0.5).astype(int)
clf_charge = LogisticRegression(max_iter=2000, class_weight='balanced').fit(X, Y_charge_bin)
charge_score = cross_val_score(clf_charge, X, Y_charge_bin, cv=5).mean()
print(f"Charge CV accuracy: {charge_score:.3f}")

In [None]:
# Extract interpretable directions from all classifiers
n_dims = 20

# Secondary structure directions
W_ss = clf_ss.coef_
ss_classes = clf_ss.classes_
helix_dims = np.argsort(np.abs(W_ss[list(ss_classes).index('H')]))[0][-n_dims:] if 'H' in ss_classes else []
sheet_dims = np.argsort(np.abs(W_ss[list(ss_classes).index('E')]))[0][-n_dims:] if 'E' in ss_classes else []
coil_dims = np.argsort(np.abs(W_ss[list(ss_classes).index('-')]))[0][-n_dims:] if '-' in ss_classes else []

# Other biological property directions
surface_dims = np.argsort(np.abs(clf_rsa.coef_[0]))[-n_dims:]
flexibility_dims = np.argsort(np.abs(clf_bfactor.coef_[0]))[-n_dims:]
contact_dims = np.argsort(np.abs(clf_contacts.coef_[0]))[-n_dims:]
hydrophobic_dims = np.argsort(np.abs(clf_hydrophob.coef_[0]))[-n_dims:]
charge_dims = np.argsort(np.abs(clf_charge.coef_[0]))[-n_dims:]

print('=== INTERPRETABLE DIRECTIONS ===')
print(f'Alpha-helix dimensions: {helix_dims}')
print(f'Beta-sheet dimensions: {sheet_dims}')
print(f'Coil/loop dimensions: {coil_dims}')
print(f'Surface exposure dimensions: {surface_dims}')
print(f'Flexibility dimensions: {flexibility_dims}')
print(f'Contact density dimensions: {contact_dims}')
print(f'Hydrophobicity dimensions: {hydrophobic_dims}')
print(f'Charge dimensions: {charge_dims}')

In [None]:
# Analyze dimension overlap and orthogonality
def analyze_dimension_overlap(dims_dict):
    """Analyze overlap between different biological property dimensions"""
    properties = list(dims_dict.keys())
    overlap_matrix = np.zeros((len(properties), len(properties)))
    
    for i, prop1 in enumerate(properties):
        for j, prop2 in enumerate(properties):
            dims1 = set(dims_dict[prop1])
            dims2 = set(dims_dict[prop2])
            overlap = len(dims1.intersection(dims2)) / len(dims1.union(dims2))
            overlap_matrix[i, j] = overlap
    
    return overlap_matrix, properties

dims_dict = {
    'Helix': helix_dims,
    'Sheet': sheet_dims, 
    'Surface': surface_dims,
    'Flexibility': flexibility_dims,
    'Contacts': contact_dims,
    'Hydrophobic': hydrophobic_dims,
    'Charge': charge_dims
}

overlap_matrix, properties = analyze_dimension_overlap(dims_dict)

plt.figure(figsize=(10, 8))
sns.heatmap(overlap_matrix, annot=True, cmap='coolwarm', center=0,
            xticklabels=properties, yticklabels=properties)
plt.title('Dimension Overlap Between Biological Properties')
plt.tight_layout()
plt.show()

In [None]:
# Detailed performance analysis
def detailed_performance_analysis():
    """Generate detailed classification reports for each property"""
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    classifiers = [
        (clf_ss, Y_ss, 'Secondary Structure', ss_classes),
        (clf_rsa, Y_rsa_bin, 'Surface Accessibility', ['Buried', 'Exposed']),
        (clf_bfactor, Y_bfactor_bin, 'Flexibility (B-factor)', ['Low', 'High']),
        (clf_contacts, Y_contacts_bin, 'Contact Density', ['Low', 'High']),
        (clf_hydrophob, Y_hydrophob_bin, 'Hydrophobicity', ['Hydrophilic', 'Hydrophobic']),
        (clf_charge, Y_charge_bin, 'Charge', ['Neutral', 'Charged'])
    ]
    
    for i, (clf, y_true, title, labels) in enumerate(classifiers):
        y_pred = clf.predict(X)
        cm = confusion_matrix(y_true, y_pred)
        
        # Normalize confusion matrix
        cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        sns.heatmap(cm_norm, annot=True, fmt='.3f', cmap='Blues',
                   xticklabels=labels, yticklabels=labels, ax=axes[i])
        axes[i].set_title(f'{title}\nAccuracy: {clf.score(X, y_true):.3f}')
        axes[i].set_xlabel('Predicted')
        axes[i].set_ylabel('Actual')
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed classification reports
    for clf, y_true, title, _ in classifiers:
        print(f"\n=== {title} ===")
        y_pred = clf.predict(X)
        print(classification_report(y_true, y_pred, zero_division=0))

detailed_performance_analysis()

In [None]:
# Feature importance and biological interpretation
def analyze_feature_importance():
    """Analyze which ESM2 dimensions are most important for each biological property"""
    
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()
    
    feature_sets = [
        (W_ss[list(ss_classes).index('H')] if 'H' in ss_classes else np.zeros(X.shape[1]), 'Alpha Helix', 'red'),
        (W_ss[list(ss_classes).index('E')] if 'E' in ss_classes else np.zeros(X.shape[1]), 'Beta Sheet', 'blue'),
        (clf_rsa.coef_[0], 'Surface Exposure', 'green'),
        (clf_bfactor.coef_[0], 'Flexibility', 'orange'),
        (clf_contacts.coef_[0], 'Contact Density', 'purple'),
        (clf_hydrophob.coef_[0], 'Hydrophobicity', 'brown'),
        (clf_charge.coef_[0], 'Charge', 'pink'),
        (np.var(X, axis=0), 'ESM2 Variance', 'gray')
    ]
    
    for i, (weights, title, color) in enumerate(feature_sets):
        axes[i].bar(range(len(weights)), np.abs(weights), color=color, alpha=0.7)
        axes[i].set_title(f'{title}\nMax: {np.max(np.abs(weights)):.3f}')
        axes[i].set_xlabel('ESM2 Dimension')
        axes[i].set_ylabel('Absolute Weight')
        
        # Highlight top 20 dimensions
        top_dims = np.argsort(np.abs(weights))[-20:]
        for dim in top_dims:
            axes[i].bar(dim, np.abs(weights[dim]), color='black', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

analyze_feature_importance()

In [None]:
# Ramachandran plot analysis using learned directions
def ramachandran_analysis():
    """Analyze how ESM2 representations relate to backbone geometry"""
    
    # Convert angles to degrees
    phi_deg = np.degrees(Y_phi)
    psi_deg = np.degrees(Y_psi)
    
    # Project ESM2 embeddings onto structural directions
    helix_scores = X @ W_ss[list(ss_classes).index('H')] if 'H' in ss_classes else np.zeros(len(X))
    sheet_scores = X @ W_ss[list(ss_classes).index('E')] if 'E' in ss_classes else np.zeros(len(X))
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Traditional Ramachandran plot colored by secondary structure
    for ss_type, color, label in [('H', 'red', 'Helix'), ('E', 'blue', 'Sheet'), ('-', 'gray', 'Coil')]:
        mask = Y_ss == ss_type
        if np.any(mask):
            axes[0].scatter(phi_deg[mask], psi_deg[mask], c=color, alpha=0.6, s=1, label=label)
    
    axes[0].set_xlim(-180, 180)
    axes[0].set_ylim(-180, 180)
    axes[0].set_xlabel('Phi (degrees)')
    axes[0].set_ylabel('Psi (degrees)')
    axes[0].set_title('Traditional Ramachandran Plot')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # ESM2 helix score projection
    scatter = axes[1].scatter(phi_deg, psi_deg, c=helix_scores, cmap='Reds', s=1, alpha=0.7)
    axes[1].set_xlim(-180, 180)
    axes[1].set_ylim(-180, 180)
    axes[1].set_xlabel('Phi (degrees)')
    axes[1].set_ylabel('Psi (degrees)')
    axes[1].set_title('ESM2 Helix Score Projection')
    plt.colorbar(scatter, ax=axes[1])
    
    # ESM2 sheet score projection
    scatter = axes[2].scatter(phi_deg, psi_deg, c=sheet_scores, cmap='Blues', s=1, alpha=0.7)
    axes[2].set_xlim(-180, 180)
    axes[2].set_ylim(-180, 180)
    axes[2].set_xlabel('Phi (degrees)')
    axes[2].set_ylabel('Psi (degrees)')
    axes[2].set_title('ESM2 Sheet Score Projection')
    plt.colorbar(scatter, ax=axes[2])
    
    plt.tight_layout()
    plt.show()

ramachandran_analysis()

In [None]:
# Correlation analysis between biological properties
def correlation_analysis():
    """Analyze correlations between different biological properties"""
    
    # Create property matrix
    properties_matrix = np.column_stack([
        Y_rsa,
        Y_bfactors, 
        Y_contacts,
        Y_hydrophobicity,
        Y_charge,
        np.degrees(Y_phi),
        np.degrees(Y_psi)
    ])
    
    property_names = ['RSA', 'B-factor', 'Contacts', 'Hydrophobicity', 
                     'Charge', 'Phi', 'Psi']
    
    # Calculate correlation matrix
    corr_matrix = np.corrcoef(properties_matrix.T)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(corr_matrix, annot=True, cmap='RdBu_r', center=0,
                xticklabels=property_names, yticklabels=property_names)
    plt.title('Correlation Between Biological Properties')
    plt.tight_layout()
    plt.show()
    
    # Print interesting correlations
    print("=== NOTABLE CORRELATIONS ===")
    for i in range(len(property_names)):
        for j in range(i+1, len(property_names)):
            corr = corr_matrix[i, j]
            if abs(corr) > 0.3:
                print(f"{property_names[i]} vs {property_names[j]}: {corr:.3f}")

correlation_analysis()

## Summary of Enhanced Interpretability Pipeline

This enhanced pipeline now captures and analyzes multiple biological features:

### **Structural Features**
- **Secondary Structure**: Alpha-helix, beta-sheet, coil regions
- **Backbone Geometry**: Phi/psi dihedral angles with Ramachandran analysis
- **B-factors**: Atomic flexibility and thermal motion
- **Contact Density**: Local packing environment

### **Physicochemical Properties** 
- **Surface Accessibility**: Solvent exposure patterns
- **Hydrophobicity**: Hydrophobic vs hydrophilic regions
- **Electrostatic Charge**: Charged vs neutral residues

### **Key Improvements**
1. **Multi-target Learning**: Trained separate classifiers for each biological property
2. **Cross-validation**: Robust performance estimates with 5-fold CV
3. **Dimension Analysis**: Identified which ESM2 dimensions encode each property
4. **Overlap Analysis**: Quantified sharing of dimensions between properties
5. **Correlation Studies**: Explored relationships between biological features
6. **Ramachandran Projections**: Connected sequence embeddings to 3D geometry

### **Biological Insights**
- ESM2 embeddings contain interpretable directions for multiple structural properties
- Different biological features utilize both shared and distinct embedding dimensions
- The model captures the relationship between sequence, structure, and dynamics