This notebook processes AlphaFold Multimer prediction results for protein–protein interactions. 
For each predicted complex, it extracts key confidence metrics, including:

- **Per-residue pLDDT scores** (prediction confidence per residue)
- **Inter-protein PAE (Predicted Aligned Error)** between two chains
- **Summary statistics** for interaction regions (mean, median, min, max PAE)
- **Visualization of results** via LDDT line plots and PAE heatmaps

The notebook is designed to operate on a folder containing multiple AlphaFold prediction jobs, 
each in its own subdirectory with a FASTA file and prediction outputs.

Outputs include:
- `.png` images showing pLDDT and PAE results
- An Excel file summarizing interaction PAE statistics (`PAE_statistics.xlsx`) for all interactions in the input folder.

> Use the `display_avg_pae` flag in the `main()` function to toggle statistical summary generation and display per plot.

This analysis is useful for evaluating and comparing predicted interaction confidence across large-scale AFM screens.

In [1]:
# Required libraries for file handling, numerical operations, and plotting
import os
import numpy as np
import pickle
import glob
import pandas as pd
from matplotlib import pyplot as plt

# === Helper Function: Reads a FASTA file into a dictionary ===
def read_fasta(fasta_file):
    """
    Parses a FASTA file and returns a dictionary of {header: sequence}.
    """
    sequences = {}
    current_id = None
    with open(fasta_file, 'r') as file:
        for line in file:
            line = line.strip()
            if line.startswith('>'):
                current_id = line[1:]  # Extract header without '>'
                sequences[current_id] = ""
            else:
                if current_id:
                    sequences[current_id] += line
    return sequences

# === Helper Function: Calculates average inter-protein PAE ===
def calculate_interaction_pae(pae_matrix, len_A, len_B):
    """
    Calculates the average Predicted Aligned Error (PAE) for the interaction region
    between protein A and protein B.
    """
    lower_left = pae_matrix[:len_A, len_A:]      # PAE from A→B
    upper_right = pae_matrix[len_A:, :len_A]     # PAE from B→A

    avg_pae_lower_left = np.mean(lower_left)
    avg_pae_upper_right = np.mean(upper_right)

    return (avg_pae_lower_left + avg_pae_upper_right) / 2

# === Helper Function: Extracts pLDDT and PAE from AlphaFold .pkl result files ===
def get_pae_plddt(map_dir):
    """
    Loads AlphaFold prediction output from result_*.pkl files and extracts
    pLDDT scores and PAE matrices per model.
    """
    out = {}
    model_files = glob.glob(f'{map_dir}/result_*.pkl')
    for model_file in model_files:
        with open(model_file, 'rb') as f:
            d = pickle.load(f)
        basename = os.path.basename(model_file)
        model_name = basename[basename.index('result_'):-4]  # e.g. 'result_model_1'
        out[model_name] = {'plddt': d['plddt'], 'pae': d['predicted_aligned_error']}
    return out

# === Main Plotting and Statistics Function ===
def generate_output_images(feature_dict, out_dir, name, pae_plddt_per_model, len_A, len_B, display_avg_pae):
    """
    Generates plots for pLDDT and PAE for each AlphaFold Multimer prediction,
    optionally displaying interaction-specific PAE statistics.

    Inputs:
    - feature_dict: AlphaFold input features (used to access the MSA)
    - out_dir: Folder to save plots
    - name: Base name for output files
    - pae_plddt_per_model: Dictionary of model_name → {'plddt': ..., 'pae': ...}
    - len_A, len_B: Lengths of protein A and B (used to locate inter-protein PAE)
    - display_avg_pae: If True, overlay average inter-protein PAE on each heatmap and calculate summary stats

    Returns:
    - Dictionary of summary PAE statistics per interaction (or None values if disabled)
    """
    # Plot Predicted LDDT per residue
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title("Predicted LDDT per position")
    for model_name, value in pae_plddt_per_model.items():
        plt.plot(value["plddt"], label=model_name)
    plt.ylim(0, 100)
    plt.ylabel("Predicted LDDT")
    plt.xlabel("Positions")
    plt.legend()
    plt.savefig(f"{out_dir}/{name}_coverage_LDDT.png")
    plt.close()

    # Plot PAE heatmaps for up to 25 models
    num_models = len(pae_plddt_per_model)
    num_plots = min(25, num_models)
    fig, axes = plt.subplots(5, 5, figsize=(20, 20), dpi=100)
    avg_pae_values = []

    for n, (model_name, value) in enumerate(list(pae_plddt_per_model.items())[:num_plots]):
        if isinstance(value["pae"], np.ndarray) and value["pae"].size > 0:
            ax = axes[n // 5, n % 5]
            im = ax.imshow(value["pae"], cmap="bwr", vmin=0, vmax=30)
            plt.colorbar(im, ax=ax)

            if display_avg_pae:
                avg_pae_interaction = calculate_interaction_pae(value["pae"], len_A, len_B)
                avg_pae_values.append(avg_pae_interaction)
                ax.text(0.5, 0.9, f'Avg PAE: {avg_pae_interaction:.2f}',
                        horizontalalignment='center', verticalalignment='center',
                        transform=ax.transAxes, fontsize=10,
                        bbox=dict(facecolor='white', alpha=0.6))
            ax.set_title(model_name)
        else:
            print(f"Skipping {model_name} due to invalid PAE data.")

    # Delete any unused subplots
    for i in range(n + 1, 25):
        fig.delaxes(axes[i // 5, i % 5])

    # Add overall statistics if requested
    if display_avg_pae:
        avg_pae_values = np.array(avg_pae_values)
        fig.suptitle(f'PAE Statistics\nMean: {avg_pae_values.mean():.2f}, '
                     f'Median: {np.median(avg_pae_values):.2f}, '
                     f'Min: {avg_pae_values.min():.2f}, '
                     f'Max: {avg_pae_values.max():.2f}')
    else:
        fig.suptitle('PAE Statistics')

    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(f"{out_dir}/{name}_PAE.png")
    plt.close()

    # Return summary stats or None, depending on user flag
    return {
        "Directory Name": name,
        "Avg PAE Mean": float(avg_pae_values.mean()) if display_avg_pae else None,
        "Avg PAE Median": float(np.median(avg_pae_values)) if display_avg_pae else None,
        "Avg PAE Min": float(avg_pae_values.min()) if display_avg_pae else None,
        "Avg PAE Max": float(avg_pae_values.max()) if display_avg_pae else None
    }

# === MAIN FUNCTION ===
def main(input_dir, display_avg_pae=True):
    """
    Processes AlphaFold Multimer prediction results in a batch of directories.

    Parameters:
    - input_dir (str): Path to the parent directory containing subdirectories for each interaction.
                       Each subdirectory must contain a .fasta file and a map/ folder with AFM outputs.
    - display_avg_pae (bool): If True, computes and annotates inter-protein PAE statistics on plots
                              and returns summary statistics in a table.

    Output:
    - Saves LDDT and PAE plots to each subdirectory
    - Writes PAE statistics (if enabled) to an Excel file: "<input_dir>/PAE_statistics.xlsx"
    - Prints the summary table to console
    """

    # List all subdirectories (assumed to contain individual AFM predictions)
    intermediate_dirs = glob.glob(f'{input_dir}/*/')

    results = []

    for intermediate_dir in intermediate_dirs:
        # Locate and read the FASTA file for this interaction
        fasta_file = glob.glob(f'{intermediate_dir}/*.fasta')[0]
        sequences = read_fasta(fasta_file)
        len_A = len(next(iter(sequences.values())))  # Assume only two sequences (A and B)
        len_B = len(next(iter(sequences.values())))

        # Identify the AlphaFold output directory
        map_dir = glob.glob(f'{intermediate_dir}/*/')[0]

        if os.path.exists(f'{map_dir}/features.pkl'):
            # Load AFM feature dictionary and PAE/pLDDT values
            feature_dict = pickle.load(open(f'{map_dir}/features.pkl', 'rb'))
            pae_plddt_per_model = get_pae_plddt(map_dir)

            name = os.path.basename(os.path.normpath(map_dir))

            # Generate output visualizations and collect statistics
            stats = generate_output_images(feature_dict, intermediate_dir, name, pae_plddt_per_model,
                                           len_A, len_B, display_avg_pae)
            if stats:
                results.append(stats)

    # Compile and export all statistics
    df = pd.DataFrame(results)
    df.to_excel(f"{input_dir}/PAE_statistics.xlsx", index=False)
    print(df.to_string(index=False))


In [None]:
main(r"E:\temp\240425", display_avg_pae=True)

In [2]:
main(r"D:\known_interactors\100225", display_avg_pae=False)

            Directory Name Avg PAE Mean Avg PAE Median Avg PAE Min Avg PAE Max
         lmcd1_cr_pet_dux4         None           None        None        None
  lmcd1_cr_pet_dux4_notail         None           None        None        None
tes_pet_lim1-2_dux4_notail         None           None        None        None
                    P07437         None           None        None        None
                    P08670         None           None        None        None
                    P12814         None           None        None        None
                    P21333         None           None        None        None
                    P46940         None           None        None        None
                    P55072         None           None        None        None
                    P60660         None           None        None        None
                    Q00610         None           None        None        None
                    Q01082         None           No