# Proviz 5.0

This jupyter notebook contains all of the code used to process evolutionary coupling scores generated by EVcouplings (https://github.com/debbiemarkslab/EVcouplings) for the calculation of evolutionary frustration and comparison of evolutionary frustration to structure based frustration derived from AlphaFold predicted structures and/or experimental structures from the RCSB PDB.

Given a directory with a PDB ID labeled Fasta File for a monomeric RCSB PDB entry, pull the biological assembly, filter down to the monomer.

In [None]:
import os
import requests
import pymolPy3

# Initialize pymolPy3
pm = pymolPy3.pymolPy3(0)  # Launch PyMOL without GUI

def fetch_pdb_id_from_fasta(fasta_path):
    """Extract PDB ID from FASTA file."""
    if not os.path.exists(fasta_path):
        print(f"FASTA file not found: {fasta_path}")
        return None

    try:
        with open(fasta_path, 'r') as fasta_file:
            for line in fasta_file:
                if line.startswith('>'):
                    parts = line.strip().split('|')
                    if len(parts) > 0:
                        pdb_id = parts[0][1:].split('_')[0]  # Remove '>' and extract the part before '_'
                        return pdb_id
    except Exception as e:
        print(f"Error reading FASTA file: {e}")
        return None

def download_pdb_file(pdb_id, assembly_directory):
    """Download the biological assembly PDB file."""
    print(f"Downloading PDB file for PDB ID: {pdb_id}")
    os.makedirs(assembly_directory, exist_ok=True)

    biological_assembly_url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
    biological_assembly_path = os.path.join(assembly_directory, f"{pdb_id}_biological_assembly.pdb")
    response = requests.get(biological_assembly_url)
    if response.status_code == 200:
        with open(biological_assembly_path, 'wb') as file:
            file.write(response.content)
        print(f"Biological assembly PDB downloaded to {biological_assembly_path}")
        return biological_assembly_path
    else:
        print(f"Failed to download biological assembly PDB file for {pdb_id}. HTTP Status code: {response.status_code}")
        return None

def extract_monomer(biological_assembly_path, monomer_output_path):
    """Extract the monomer from the biological assembly using the available chain."""
    try:
        print(f"Extracting monomer from biological assembly: {biological_assembly_path}")
        
        # First, try to find the chain ID from the COMPND record
        compnd_chain = None
        with open(biological_assembly_path, 'r') as f:
            for line in f:
                if line.startswith("COMPND"):
                    if "CHAIN:" in line:
                        # Extract chain ID after "CHAIN:"
                        chain_part = line.split("CHAIN:")[1].strip()
                        # Remove any trailing semicolon and whitespace
                        chain_id = chain_part.rstrip(';').strip()
                        compnd_chain = chain_id
                        print(f"Found chain {chain_id} in COMPND record")
                        break
                if line.startswith("ATOM"):  # Stop reading if we hit ATOM records
                    break
        
        # Load structure into PyMOL
        pm(f"load {biological_assembly_path}")
        pm("remove not polymer.protein")  # Remove everything that's not a protein polymer
        pm("remove resn HOH")  # Remove water molecules
        
        # Create a selection of all protein atoms
        pm("select protein_chains, polymer.protein")
        
        # Get all chains from structure
        chain_identifiers = set()
        with open(biological_assembly_path, 'r') as f:
            for line in f:
                if line.startswith("ATOM  ") or line.startswith("HETATM"):
                    chain_id = line[21]  # Chain identifier is in column 22
                    if chain_id.strip():  # Only add non-empty chain IDs
                        chain_identifiers.add(chain_id)
        
        print(f"Found chains in structure: {sorted(chain_identifiers)}")
        
        if not chain_identifiers:
            print("No protein chains found in the structure")
            return
            
        # Select chain based on priority:
        # 1. Use COMPND chain if it exists in structure
        # 2. Otherwise use first alphabetical chain
        if compnd_chain and compnd_chain in chain_identifiers:
            selected_chain = compnd_chain
        else:
            selected_chain = sorted(chain_identifiers)[0]
            
        print(f"Selected chain {selected_chain} for monomer extraction")
        
        # Select and save the chosen chain
        pm(f"select monomer, chain {selected_chain}")
        pm(f"save {monomer_output_path}, monomer")
        pm("delete monomer")
        pm("delete protein_chains")
        pm("delete all")
        print(f"Monomer extracted and saved to {monomer_output_path}")
        
    except Exception as e:
        print(f"Error extracting monomer for {biological_assembly_path}: {e}")

# Main execution
if __name__ == "__main__":
    base_directory = "" #Path to directory containing the directory with the fasta file. Can you be used for multiple subdirectories simultaneously.

    # Iterate through each protein directory in base_directory
    for protein_dir in os.listdir(base_directory):
        protein_path = os.path.join(base_directory, protein_dir)
        if os.path.isdir(protein_path):  # Ensure it's a directory
            fasta_file = os.path.join(protein_path, f"{protein_dir}.fasta")
            experimental_data_dir = os.path.join(protein_path, "experimental_data")
            monomer_output_path = os.path.join(experimental_data_dir, "monomer.pdb")

            # Process each FASTA file
            pdb_id = fetch_pdb_id_from_fasta(fasta_file)
            if pdb_id:
                biological_assembly_path = download_pdb_file(pdb_id, experimental_data_dir)
                if biological_assembly_path:
                    extract_monomer(biological_assembly_path, monomer_output_path)

Calculate average B-factor per Residue for filtered monomer

In [None]:
import os
import Bio.PDB
import numpy as np

def calculate_average_b_factors(pdb_path, output_txt_path):
    """
    Calculate the average B-factors for each residue in the PDB structure and save them to a file,
    including the one-letter residue name, indexed starting from 1.

    Parameters:
    - pdb_path: str, path to the input PDB file.
    - output_txt_path: str, path to the output text file.
    """
    # Mapping three-letter residue names to one-letter codes
    one_letter_code = {
        "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
        "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
        "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
        "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V",
        # Handle uncommon residues with a placeholder
        "UNK": "X"
    }

    parser = Bio.PDB.PDBParser(QUIET=True)
    try:
        structure = parser.get_structure("protein", pdb_path)
        b_factors = []
        residue_names = []

        for model in structure:
            for chain in model:
                for residue in chain:
                    res_id = residue.get_id()[1]
                    res_name = residue.get_resname()
                    b_factor_list = [atom.get_bfactor() for atom in residue]
                    average_b_factor = np.mean(b_factor_list)
                    b_factors.append(average_b_factor)
                    residue_names.append(one_letter_code.get(res_name, "X"))  # Default to 'X' for unknown residues

        # Write the output file with re-indexed residue numbers starting from 1
        with open(output_txt_path, 'w') as file:
            file.write("Residue\tResidueAA\tAverage_B_Factor\n")
            for idx, (aa, b_factor) in enumerate(zip(residue_names, b_factors), start=1):
                file.write(f"{idx}\t{aa}\t{b_factor:.3f}\n")

        print(f"Average B-factors with residue names saved to {output_txt_path}")
    except Exception as e:
        print(f"Error processing PDB file {pdb_path}: {e}")

if __name__ == "__main__":
    base_directory = "" #Path to directory containing the directory with the fasta file. Can you be used for multiple subdirectories simultaneously.

    # Iterate through each protein directory in base_directory
    for protein_dir in os.listdir(base_directory):
        protein_path = os.path.join(base_directory, protein_dir)
        if os.path.isdir(protein_path):  # Ensure it's a directory
            experimental_data_dir = os.path.join(protein_path, "experimental_data")
            pdb_path = os.path.join(experimental_data_dir, "monomer.pdb")
            output_txt_path = os.path.join(experimental_data_dir, "average_b_factors.txt")

            # Ensure the monomer PDB file exists before processing
            if os.path.exists(pdb_path):
                calculate_average_b_factors(pdb_path, output_txt_path)
            else:
                print(f"Monomer PDB file not found: {pdb_path}. Skipping {protein_dir}.")

Pull AlphaFold structure for uniprot ID corresponding to the PDB ID in the input fasta file, trim to relevant section via pairwise sequence alignment 

In [None]:
import os
import re
import requests
import pandas as pd
from Bio import SeqIO, pairwise2
from Bio.PDB import *

def get_pdb_directories(root_path):
    """Return a list of directories that contain FASTA files."""
    valid_dirs = []
    
    try:
        for entry in os.scandir(root_path):
            if entry.is_dir():
                fasta_files = [f for f in os.scandir(entry.path) if f.name.endswith('.fasta')]
                if len(fasta_files) == 1:
                    valid_dirs.append(entry.path)
                elif len(fasta_files) > 1:
                    print(f"[WARNING] Directory {entry.path} contains multiple FASTA files - skipping")
    except Exception as e:
        print(f"[ERROR] Failed to scan directory {root_path}: {e}")
        return []
        
    return sorted(valid_dirs)

def get_sequence_from_pdb(pdb_path):
    """Extract sequence from PDB file using Bio.PDB"""
    parser = PDBParser(QUIET=True)
    try:
        structure = parser.get_structure('structure', pdb_path)
        ppb = PPBuilder()
        seq = ""
        for pp in ppb.build_peptides(structure):
            seq += str(pp.get_sequence())
        return seq
    except Exception as e:
        print(f"[ERROR] Failed to extract sequence from {pdb_path}: {e}")
        return None

def find_matching_residues(exp_seq, af_seq):
    """Find residue positions that match between sequences."""
    alignment = pairwise2.align.globalxx(exp_seq, af_seq)[0]
    exp_aligned, af_aligned = alignment[0], alignment[1]
    
    matching_positions = []
    exp_pos = 0
    af_pos = 0
    
    for i in range(len(exp_aligned)):
        if exp_aligned[i] != '-' and af_aligned[i] != '-':
            if exp_aligned[i] == af_aligned[i]:
                matching_positions.append((exp_pos + 1, af_pos + 1))
        
        if exp_aligned[i] != '-':
            exp_pos += 1
        if af_aligned[i] != '-':
            af_pos += 1
    
    return matching_positions

def parse_pdb_and_chain(fasta_file_path):
    """
    Parse PDB code & chain from the FASTA header.
    Expected formats:
    >4AKE_1|Chains A, B|ADENYLATE KINASE|Escherichia coli (562)
    >pdb|4AKE|A Chain A, molecule 1|ADENYLATE KINASE|Escherichia coli
    >1234_1|Chain A|Description
    """
    try:
        with open(fasta_file_path, 'r') as f:
            for line in f:
                if line.startswith(">"):
                    header = line.strip()[1:]
                    print(f"[DEBUG] Parsing FASTA header: {header}")
                    
                    parts = [p.strip() for p in header.split("|")]
                    
                    # Try to find PDB code
                    pdb_code = None
                    if parts[0]:
                        pdb_match = re.search(r'([0-9][A-Za-z0-9]{3})(?:_|$)', parts[0])
                        if pdb_match:
                            pdb_code = pdb_match.group(1).upper()
                    
                    # Try to find chain
                    chain = None
                    for part in parts:
                        chain_match = re.search(r'Chain[s]?\s+([A-Za-z])(?:\s|$|,)', part)
                        if chain_match:
                            chain = chain_match.group(1).upper()
                            break
                    
                    if pdb_code:
                        print(f"[INFO] Found PDB: {pdb_code}, Chain: {chain}")
                        return (pdb_code, chain)
                    else:
                        raise ValueError("Could not find valid PDB code in FASTA header")
                        
    except Exception as e:
        print(f"[ERROR] Failed to parse FASTA file {fasta_file_path}: {e}")
        raise
    
    raise ValueError("No FASTA header found in file")

def get_uniprot_id_for_pdb_chain(pdb_code, chain):
    """Query the PDBe SIFTS API to find the UniProt ID."""
    url = f"https://www.ebi.ac.uk/pdbe/api/mappings/uniprot/{pdb_code.lower()}"
    try:
        resp = requests.get(url)
        resp.raise_for_status()
        
        data = resp.json()
        if pdb_code.lower() not in data:
            raise ValueError(f"No data returned from SIFTS for PDB {pdb_code}")
            
        sifts_info = data[pdb_code.lower()]
        if "UniProt" not in sifts_info or not sifts_info["UniProt"]:
            raise ValueError(f"No UniProt mappings found for {pdb_code}")
            
        for uniprot_id, uniprot_dict in sifts_info["UniProt"].items():
            chain_mappings = uniprot_dict.get("mappings", [])
            for mapping in chain_mappings:
                if chain is None or mapping.get("chain_id", "") == chain:
                    return uniprot_id
                    
        raise ValueError(f"Could not find UniProt ID for {pdb_code} chain {chain}")
    except Exception as e:
        print(f"[ERROR] SIFTS API request failed: {e}")
        raise

def download_alphafold_structure(uniprot_id, output_pdb_path):
    """Download structure from AlphaFold DB."""
    url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.pdb"
    try:
        response = requests.get(url)
        response.raise_for_status()
        
        if "ATOM" not in response.text:
            raise ValueError("Downloaded file does not appear to be a valid PDB")
            
        with open(output_pdb_path, 'w') as f:
            f.write(response.text)
            
        return True
    except Exception as e:
        print(f"[ERROR] Failed to download AlphaFold structure: {e}")
        raise

def trim_alphafold_structure(exp_pdb_path, af_pdb_path, output_path, pymol=None):
    """Trim AlphaFold structure to match experimental sequence."""
    exp_seq = get_sequence_from_pdb(exp_pdb_path)
    af_seq = get_sequence_from_pdb(af_pdb_path)
    
    if not exp_seq or not af_seq:
        return False
        
    matching_positions = find_matching_residues(exp_seq, af_seq)
    if not matching_positions:
        print("[WARNING] No matching residues found between sequences")
        return False
        
    if pymol:
        try:
            af_residues = [af_pos for _, af_pos in matching_positions]
            resi_str = "+".join(str(pos) for pos in af_residues)
            
            pymol(f"load {af_pdb_path}, af_struct")
            pymol(f"select matching_residues, af_struct and resi {resi_str}")
            pymol(f"save {output_path}, matching_residues")
            pymol("delete all")
            return True
        except Exception as e:
            print(f"[ERROR] PyMOL trimming failed: {e}")
            return False
    else:
        print("[WARNING] PyMOL not available - skipping structure trimming")
        return False

def main():
    # Try to import PyMOL
    try:
        import pymolPy3
        pm = pymolPy3.pymolPy3(0)
        print("[INFO] PyMOL available - will perform structure trimming")
    except ImportError:
        pm = None
        print("[WARNING] PyMOL not available - structure trimming will be skipped")
    
    root_directory = input("Enter the root directory path: ")
    if not os.path.isdir(root_directory):
        print(f"[ERROR] Invalid directory path: {root_directory}")
        return
        
    pdb_dirs = get_pdb_directories(root_directory)
    if not pdb_dirs:
        print("[ERROR] No valid directories with FASTA files found")
        return

    for pdb_dir in pdb_dirs:
        print(f"\n[INFO] Processing: {pdb_dir}")
        
        try:
            fasta_files = [f for f in os.scandir(pdb_dir) if f.name.endswith('.fasta')]
            if not fasta_files:  # This shouldn't happen due to our directory validation
                continue
            
            fasta_path = fasta_files[0].path
            print(f"\n[DEBUG] Processing FASTA file: {fasta_path}")
            
            pdb_code, chain = parse_pdb_and_chain(fasta_path)
            print(f"[INFO] Working with PDB code: {pdb_code}, chain: {chain}")
            
            # Set up output directories
            exp_data_dir = os.path.join(pdb_dir, "experimental_data")
            alphafold_dir = os.path.join(pdb_dir, "alphafold_structure")
            frustratometer_dir = os.path.join(pdb_dir, "frustratometer_af")
            
            os.makedirs(alphafold_dir, exist_ok=True)
            os.makedirs(frustratometer_dir, exist_ok=True)
            
            # Check for experimental structure
            exp_pdb_path = os.path.join(exp_data_dir, "monomer.pdb")
            has_exp_structure = os.path.exists(exp_pdb_path)
            
            # Get UniProt ID and download AlphaFold structure
            uniprot_id = get_uniprot_id_for_pdb_chain(pdb_code, chain)
            full_af_path = os.path.join(alphafold_dir, "AF_structure_full.pdb")
            
            if download_alphafold_structure(uniprot_id, full_af_path):
                # Determine which structure to use for frustratometer
                if has_exp_structure and pm:
                    trimmed_af_path = os.path.join(alphafold_dir, "AF_structure_trimmed.pdb")
                    if trim_alphafold_structure(exp_pdb_path, full_af_path, trimmed_af_path, pm):
                        analysis_pdb = trimmed_af_path
                    else:
                        analysis_pdb = full_af_path
                else:
                    analysis_pdb = full_af_path
                
                print(f"[INFO] Successfully processed {pdb_dir}")
            
        except Exception as e:
            print(f"[ERROR] Failed to process {pdb_dir}: {e}")
            continue
    
    if pm:
        pm("quit")

if __name__ == "__main__":
    main()

Calculate the Evolutionary coupling based energy difference between the native sequence and each single amino acid substitution mutant(Parallelized Version Exists for Larger Proteins)

In [None]:
import numpy as np
import pandas as pd
from Bio import SeqIO
import os

# Load the protein sequence from a fasta file
def load_fasta(filename):
    with open(filename, "r") as file:
        for record in SeqIO.parse(file, "fasta"):
            return str(record.seq)

# Load the coupling scores from the provided file
def load_coupling_scores(filename):
    df = pd.read_csv(filename)
    
    # Normalize cn between 0 and 1
    cn_min = df['cn'].min()
    cn_max = df['cn'].max()
    df['cn_normalized'] = (df['cn'] - cn_min) / (cn_max - cn_min)
    
    # Multiply cn by its respective probability
    df['weighted_cn'] = df['cn_normalized'] * df['probability']
    
    coupling_scores = {}
    for _, row in df.iterrows():
        i, A_i, j, A_j, weighted_cn = int(row['i']), row['A_i'], int(row['j']), row['A_j'], row['weighted_cn']
        coupling_scores[(i, j)] = weighted_cn
        coupling_scores[(j, i)] = weighted_cn  # Ensure symmetry
    return coupling_scores, df


# Define the MJ coupling score matrix scaffold
amino_acids = ['C', 'M', 'F', 'I', 'L', 'V', 'W', 'Y', 'A', 'G', 'T', 'S', 'N', 'Q', 'D', 'E', 'H', 'R', 'K', 'P']

# Initialize the MJ matrix (example, fill with actual data)
mj_matrix = {(aa1, aa2): 0.0 for aa1 in amino_acids for aa2 in amino_acids}

# Fill in the MJ matrix 
mj_matrix[('C', 'C')] = -5.44
mj_matrix[('C', 'M')] = -4.99
mj_matrix[('C', 'F')] = -5.80
mj_matrix[('C', 'I')] = -5.50
mj_matrix[('C', 'L')] = -5.83
mj_matrix[('C', 'V')] = -4.96
mj_matrix[('C', 'W')] = -4.95
mj_matrix[('C', 'Y')] = -4.16
mj_matrix[('C', 'A')] = -3.57
mj_matrix[('C', 'G')] = -3.16
mj_matrix[('C', 'T')] = -3.11
mj_matrix[('C', 'S')] = -2.86
mj_matrix[('C', 'N')] = -2.59
mj_matrix[('C', 'Q')] = -2.85
mj_matrix[('C', 'D')] = -2.41
mj_matrix[('C', 'E')] = -2.27
mj_matrix[('C', 'H')] = -3.60
mj_matrix[('C', 'R')] = -2.57
mj_matrix[('C', 'K')] = -1.95
mj_matrix[('C', 'P')] = -3.07

mj_matrix[('M', 'M')] = -5.46
mj_matrix[('M', 'F')] = -6.56
mj_matrix[('M', 'I')] = -6.02
mj_matrix[('M', 'L')] = -6.41
mj_matrix[('M', 'V')] = -5.32
mj_matrix[('M', 'W')] = -5.55
mj_matrix[('M', 'Y')] = -4.91
mj_matrix[('M', 'A')] = -3.94
mj_matrix[('M', 'G')] = -3.39
mj_matrix[('M', 'T')] = -3.51
mj_matrix[('M', 'S')] = -3.03
mj_matrix[('M', 'N')] = -2.95
mj_matrix[('M', 'Q')] = -3.30
mj_matrix[('M', 'D')] = -2.57
mj_matrix[('M', 'E')] = -2.89
mj_matrix[('M', 'H')] = -3.98
mj_matrix[('M', 'R')] = -3.12
mj_matrix[('M', 'K')] = -2.48
mj_matrix[('M', 'P')] = -3.45

mj_matrix[('F', 'F')] = -7.26
mj_matrix[('F', 'I')] = -6.84
mj_matrix[('F', 'L')] = -7.28
mj_matrix[('F', 'V')] = -6.29
mj_matrix[('F', 'W')] = -6.16
mj_matrix[('F', 'Y')] = -5.66
mj_matrix[('F', 'A')] = -4.81
mj_matrix[('F', 'G')] = -4.13
mj_matrix[('F', 'T')] = -4.28
mj_matrix[('F', 'S')] = -4.02
mj_matrix[('F', 'N')] = -3.75
mj_matrix[('F', 'Q')] = -4.10
mj_matrix[('F', 'D')] = -3.48
mj_matrix[('F', 'E')] = -3.56
mj_matrix[('F', 'H')] = -4.77
mj_matrix[('F', 'R')] = -3.98
mj_matrix[('F', 'K')] = -3.36
mj_matrix[('F', 'P')] = -4.25

mj_matrix[('I', 'I')] = -6.54
mj_matrix[('I', 'L')] = -7.04
mj_matrix[('I', 'V')] = -6.05
mj_matrix[('I', 'W')] = -5.78
mj_matrix[('I', 'Y')] = -5.25
mj_matrix[('I', 'A')] = -4.58
mj_matrix[('I', 'G')] = -3.78
mj_matrix[('I', 'T')] = -4.03
mj_matrix[('I', 'S')] = -3.52
mj_matrix[('I', 'N')] = -3.24
mj_matrix[('I', 'Q')] = -3.67
mj_matrix[('I', 'D')] = -3.17
mj_matrix[('I', 'E')] = -3.27
mj_matrix[('I', 'H')] = -4.14
mj_matrix[('I', 'R')] = -3.63
mj_matrix[('I', 'K')] = -3.01
mj_matrix[('I', 'P')] = -3.76

mj_matrix[('L', 'L')] = -7.37
mj_matrix[('L', 'V')] = -6.48
mj_matrix[('L', 'W')] = -6.14
mj_matrix[('L', 'Y')] = -5.67
mj_matrix[('L', 'A')] = -4.91
mj_matrix[('L', 'G')] = -4.16
mj_matrix[('L', 'T')] = -4.34
mj_matrix[('L', 'S')] = -3.92
mj_matrix[('L', 'N')] = -3.74
mj_matrix[('L', 'Q')] = -4.04
mj_matrix[('L', 'D')] = -3.40
mj_matrix[('L', 'E')] = -3.59
mj_matrix[('L', 'H')] = -4.54
mj_matrix[('L', 'R')] = -4.03
mj_matrix[('L', 'K')] = -3.37
mj_matrix[('L', 'P')] = -4.20

mj_matrix[('V', 'V')] = -5.52
mj_matrix[('V', 'W')] = -5.18
mj_matrix[('V', 'Y')] = -4.62
mj_matrix[('V', 'A')] = -4.04
mj_matrix[('V', 'G')] = -3.38
mj_matrix[('V', 'T')] = -3.46
mj_matrix[('V', 'S')] = -3.05
mj_matrix[('V', 'N')] = -2.83
mj_matrix[('V', 'Q')] = -3.07
mj_matrix[('V', 'D')] = -2.48
mj_matrix[('V', 'E')] = -2.67
mj_matrix[('V', 'H')] = -3.58
mj_matrix[('V', 'R')] = -3.07
mj_matrix[('V', 'K')] = -2.49
mj_matrix[('V', 'P')] = -3.32

mj_matrix[('W', 'W')] = -5.06
mj_matrix[('W', 'Y')] = -4.66
mj_matrix[('W', 'A')] = -3.82
mj_matrix[('W', 'G')] = -3.42
mj_matrix[('W', 'T')] = -3.22
mj_matrix[('W', 'S')] = -2.99
mj_matrix[('W', 'N')] = -3.07
mj_matrix[('W', 'Q')] = -3.11
mj_matrix[('W', 'D')] = -2.84
mj_matrix[('W', 'E')] = -2.99
mj_matrix[('W', 'H')] = -3.98
mj_matrix[('W', 'R')] = -3.41
mj_matrix[('W', 'K')] = -2.69
mj_matrix[('W', 'P')] = -3.73

mj_matrix[('Y', 'Y')] = -4.17
mj_matrix[('Y', 'A')] = -3.36
mj_matrix[('Y', 'G')] = -3.01
mj_matrix[('Y', 'T')] = -3.01
mj_matrix[('Y', 'S')] = -2.78
mj_matrix[('Y', 'N')] = -2.76
mj_matrix[('Y', 'Q')] = -2.97
mj_matrix[('Y', 'D')] = -2.76
mj_matrix[('Y', 'E')] = -2.79
mj_matrix[('Y', 'H')] = -3.52
mj_matrix[('Y', 'R')] = -3.16
mj_matrix[('Y', 'K')] = -2.60
mj_matrix[('Y', 'P')] = -3.19

mj_matrix[('A', 'A')] = -2.72
mj_matrix[('A', 'G')] = -2.31
mj_matrix[('A', 'T')] = -2.32
mj_matrix[('A', 'S')] = -2.01
mj_matrix[('A', 'N')] = -1.84
mj_matrix[('A', 'Q')] = -1.89
mj_matrix[('A', 'D')] = -1.70
mj_matrix[('A', 'E')] = -1.51
mj_matrix[('A', 'H')] = -2.41
mj_matrix[('A', 'R')] = -1.83
mj_matrix[('A', 'K')] = -1.31
mj_matrix[('A', 'P')] = -2.03

mj_matrix[('G', 'G')] = -2.24
mj_matrix[('G', 'T')] = -2.08
mj_matrix[('G', 'S')] = -1.82
mj_matrix[('G', 'N')] = -1.74
mj_matrix[('G', 'Q')] = -1.66
mj_matrix[('G', 'D')] = -1.59
mj_matrix[('G', 'E')] = -1.22
mj_matrix[('G', 'H')] = -2.15
mj_matrix[('G', 'R')] = -1.72
mj_matrix[('G', 'K')] = -1.15
mj_matrix[('G', 'P')] = -1.87

mj_matrix[('T', 'T')] = -2.12
mj_matrix[('T', 'S')] = -1.96
mj_matrix[('T', 'N')] = -1.88
mj_matrix[('T', 'Q')] = -1.90
mj_matrix[('T', 'D')] = -1.80
mj_matrix[('T', 'E')] = -1.74
mj_matrix[('T', 'H')] = -2.42
mj_matrix[('T', 'R')] = -1.90
mj_matrix[('T', 'K')] = -1.31
mj_matrix[('T', 'P')] = -1.90

mj_matrix[('S', 'S')] = -1.67
mj_matrix[('S', 'N')] = -1.58
mj_matrix[('S', 'Q')] = -1.49
mj_matrix[('S', 'D')] = -1.63
mj_matrix[('S', 'E')] = -1.48
mj_matrix[('S', 'H')] = -2.11
mj_matrix[('S', 'R')] = -1.62
mj_matrix[('S', 'K')] = -1.05
mj_matrix[('S', 'P')] = -1.57

mj_matrix[('N', 'N')] = -1.68
mj_matrix[('N', 'Q')] = -1.71
mj_matrix[('N', 'D')] = -1.68
mj_matrix[('N', 'E')] = -1.51
mj_matrix[('N', 'H')] = -2.08
mj_matrix[('N', 'R')] = -1.64
mj_matrix[('N', 'K')] = -1.21
mj_matrix[('N', 'P')] = -1.53

mj_matrix[('Q', 'Q')] = -1.54
mj_matrix[('Q', 'D')] = -1.46
mj_matrix[('Q', 'E')] = -1.42
mj_matrix[('Q', 'H')] = -1.98
mj_matrix[('Q', 'R')] = -1.80
mj_matrix[('Q', 'K')] = -1.29
mj_matrix[('Q', 'P')] = -1.73

mj_matrix[('D', 'D')] = -1.21
mj_matrix[('D', 'E')] = -1.02
mj_matrix[('D', 'H')] = -2.32
mj_matrix[('D', 'R')] = -2.29
mj_matrix[('D', 'K')] = -1.68
mj_matrix[('D', 'P')] = -1.33

mj_matrix[('E', 'E')] = -0.91
mj_matrix[('E', 'H')] = -2.15
mj_matrix[('E', 'R')] = -2.27
mj_matrix[('E', 'K')] = -1.80
mj_matrix[('E', 'P')] = -1.26

mj_matrix[('H', 'H')] = -3.05
mj_matrix[('H', 'R')] = -2.16
mj_matrix[('H', 'K')] = -1.35
mj_matrix[('H', 'P')] = -2.25

mj_matrix[('R', 'R')] = -1.55
mj_matrix[('R', 'K')] = -0.59
mj_matrix[('R', 'P')] = -1.70

mj_matrix[('K', 'K')] = -0.12
mj_matrix[('K', 'P')] = -0.97

mj_matrix[('P', 'P')] = -1.75

# Ensure the MJ matrix is symmetric
def ensure_mj_matrix_symmetric(mj_matrix, amino_acids):
    for aa1 in amino_acids:
        for aa2 in amino_acids:
            if mj_matrix[(aa1, aa2)] == 0.0 and mj_matrix[(aa2, aa1)] != 0.0:
                mj_matrix[(aa1, aa2)] = mj_matrix[(aa2, aa1)]
            elif mj_matrix[(aa2, aa1)] == 0.0 and mj_matrix[(aa1, aa2)] != 0.0:
                mj_matrix[(aa2, aa1)] = mj_matrix[(aa1, aa2)]
    return mj_matrix

# Calculate the MJ score for a given sequence considering coupling scores
def calculate_mj_score(sequence, mj_matrix, coupling_scores):
    score = 0.0
    num_residues = len(sequence)
    for i in range(num_residues):
        for j in range(i+1, num_residues):
            coupling_score = coupling_scores.get((i+1, j+1), 0)  # Default to 0 if no score is found
            if coupling_score != 0:
                score += mj_matrix[(sequence[i], sequence[j])] * coupling_score # Weight by coupling score
    return score

# Generate all possible single amino acid mutations
def generate_mutations(sequence, amino_acids):
    mutations = []
    for i in range(len(sequence)):
        for aa in amino_acids:
            if aa != sequence[i]:
                mutated_seq = sequence[:i] + aa + sequence[i+1:]
                mutations.append((sequence[i], i+1, aa, mutated_seq))
    return mutations

# Calculate the change in MJ score for each mutation
def calculate_mutation_scores(sequence, mj_matrix, coupling_scores, amino_acids):
    original_score = calculate_mj_score(sequence, mj_matrix, coupling_scores)
    mutations = generate_mutations(sequence, amino_acids)
    
    mutation_scores = {"wt": original_score}
    for wt_residue, index, mutant_residue, mutated_seq in mutations:
        mutated_score = calculate_mj_score(mutated_seq, mj_matrix, coupling_scores)
        mutation_label = f"{wt_residue}{index}{mutant_residue}"
        mutation_scores[mutation_label] = mutated_score
    
    return mutation_scores, original_score

# Save the matrix to a file
def save_matrix(matrix, filename):
    with open(filename, 'w') as f:
        for key, value in matrix.items():
            f.write(f"{key}\t{value}\n")

# Save the MJ matrix
def save_mj_matrix(mj_matrix, filename):
    with open(filename, 'w') as f:
        for (aa1, aa2), value in mj_matrix.items():
            f.write(f"{aa1}-{aa2}\t{value}\n")

# Save the coupling score matrix as a sorted list
def save_coupling_matrix(coupling_scores, filename):
    with open(filename, 'w') as f:
        for (i, j), score in sorted(coupling_scores.items()):
            f.write(f"{i}-{j}\t{score}\n")

# Save the weighted MJ scores
def save_weighted_scores(scores, original_score, filename):
    with open(filename, 'w') as f:
        f.write("Label\tScore\tDifference\n")
        for label, score in scores.items():
            difference = (original_score - score) if label != "wt" else 0
            f.write(f"{label}\t{score}\t{difference}\n")

# Main function to run the analysis
def main(fasta_file, coupling_file, output_dir):
    sequence = load_fasta(fasta_file)
    print(f"Original Sequence: {sequence}")
    
    coupling_scores, coupling_df = load_coupling_scores(coupling_file)
    
    # Save the coupling score matrix as a sorted list
    save_coupling_matrix(coupling_scores, f"{output_dir}/coupling_scores_matrix.txt")
    
    # Ensure the MJ matrix is symmetric
    global mj_matrix
    mj_matrix = ensure_mj_matrix_symmetric(mj_matrix, amino_acids)
    
    # Save the MJ matrix
    save_mj_matrix(mj_matrix, f"{output_dir}/mj_matrix.txt")
    
    # Calculate mutation scores
    mutation_scores, original_score = calculate_mutation_scores(sequence, mj_matrix, coupling_scores, amino_acids)
    
    # Save the weighted MJ scores
    save_weighted_scores(mutation_scores, original_score, f"{output_dir}/stability_scores.txt")
    
    # Print results
    for label, score in mutation_scores.items():
        difference = original_score - score if label != "wt" else 0
        print(f"{label}: {score}, Difference: {difference}")

# Run the main function with your fasta file and coupling scores file
fasta_file = ""  # Change this to your fasta file name
coupling_file = ""  # Change this to your EVcouplings coupling scores file name
output_dir = ""  # Change this to your desired output directory

# Ensure the output directory exists
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

main(fasta_file, coupling_file, output_dir)

# The following section contains the scripts used to analyze the 20 randomly selected high quality monomeric proteins (20R)

Summarize (compress) mutational frustration output from the Frustratometer (20R)

Script expects the subdirectories containing the .rar frustratometer output to be called frustratometer and frustratometer_af (if you have an AlphaFold structure based frustratometer output). This script was used for the 20 random (20R) monomers from the PDB.

In [None]:
import os
import shutil
import glob
import patoolib
import pandas as pd

# Function to clean directories by deleting all files except .rar and removing subdirectories
def clean_directory(directory):
    """
    Deletes all files in the directory except .rar files.
    Deletes all subdirectories and their contents.

    Parameters:
    - directory (str): The path to the directory to clean.
    """
    print(f"\nCleaning directory: {directory}")
    for item in os.listdir(directory):
        item_path = os.path.join(directory, item)
        
        # If it's a file
        if os.path.isfile(item_path):
            if not item.lower().endswith('.rar'):
                print(f"Deleting file: {item_path}")
                os.remove(item_path)
            else:
                print(f"Keeping .rar file: {item_path}")
        
        # If it's a directory
        elif os.path.isdir(item_path):
            print(f"Deleting directory and its contents: {item_path}")
            shutil.rmtree(item_path)

# Function to extract .rar files from directories
def extract_rar_from_directory(input_dir):
    """
    Extracts the first .rar file found in the specified directory.
    """
    print(f"Scanning directory for .rar files: {input_dir}")
    rar_files = [f for f in os.listdir(input_dir) if f.lower().endswith('.rar')]
    if not rar_files:
        print(f"No .rar files found in {input_dir}.")
        return
    
    rar_file = os.path.join(input_dir, rar_files[0])
    extracted_folder = os.path.join(input_dir, "extracted")
    
    try:
        patoolib.extract_archive(rar_file, outdir=extracted_folder)
        print(f"Extracted {rar_file} to {extracted_folder}")
    except Exception as e:
        print(f"Failed to extract {rar_file}: {e}")

# Function to locate the .pdb_mutational file
def locate_pdb_mutational_file(extracted_folder):
    print(f"Searching for FrustrationData directory in: {extracted_folder}")
    frustration_dirs = [
        os.path.join(root, dir)
        for root, dirs, files in os.walk(extracted_folder)
        for dir in dirs if "FrustrationData" in dir
    ]
    if not frustration_dirs:
        raise FileNotFoundError(f"No FrustrationData directory found in {extracted_folder}")
    
    pdb_files = [
        os.path.join(frustration_dirs[0], f)
        for f in os.listdir(frustration_dirs[0])
        if f.endswith('.pdb_mutational')
    ]
    if not pdb_files:
        raise FileNotFoundError(f"No .pdb_mutational files found in {frustration_dirs[0]}")
    
    return pdb_files[0]

# Function to process the .pdb_mutational file
def process_frustration_file(input_file, output_file):
    print(f"Processing .pdb_mutational file: {input_file}")
    col_names = [
        "Res1", "Res2", "ChainRes1", "ChainRes2",
        "DensityRes1", "DensityRes2", "AA1", "AA2",
        "NativeEnergy", "DecoyEnergy", "SDEnergy",
        "FrstIndex", "Welltype", "FrstState"
    ]
    try:
        df = pd.read_csv(input_file, sep=r'\s+', comment="#", names=col_names)
        residue_data = {}

        for _, row in df.iterrows():
            energy_diff = row["NativeEnergy"] - row["DecoyEnergy"]
            residues = [(row["Res1"], row["AA1"]), (row["Res2"], row["AA2"])]
            for res, aa in residues:
                if res not in residue_data:
                    residue_data[res] = {"aa": aa, "sum_diff": 0.0, "count": 0}
                residue_data[res]["sum_diff"] += energy_diff
                residue_data[res]["count"] += 1

        output_data = []
        for res, data in sorted(residue_data.items()):
            average_diff = data["sum_diff"] / data["count"] if data["count"] > 0 else 0.0
            output_data.append([res, data["aa"], average_diff])

        reindexed_data = []
        for new_residue_index, row in enumerate(output_data, start=1):
            reindexed_data.append([new_residue_index, row[1], row[2]])

        with open(output_file, "w") as f:
            f.write("Residue# ResidueAA Difference\n")
            for row in reindexed_data:
                f.write(f"{row[0]} {row[1]} {row[2]:.4f}\n")
        print(f"Output written to {output_file}.")
    except Exception as e:
        print(f"Error processing file {input_file}: {e}")

# Function to process extracted directories
def process_subdirectory(test_dir, sub_dir, output_filename):
    subdirectory_path = os.path.join(test_dir, sub_dir)
    extracted_folder = os.path.join(subdirectory_path, "extracted")
    output_file = os.path.join(subdirectory_path, output_filename)

    if not os.path.isdir(extracted_folder):
        print(f"Extracted folder not found: {extracted_folder}. Skipping {sub_dir} in {test_dir}.")
        return

    try:
        pdb_mutational_file = locate_pdb_mutational_file(extracted_folder)
        process_frustration_file(pdb_mutational_file, output_file)
    except FileNotFoundError as e:
        print(f"Error processing {sub_dir} in {test_dir}: {e}")
    except Exception as e:
        print(f"An unexpected error occurred while processing {sub_dir} in {test_dir}: {e}")

# Main function
def main(root_directory):
    for dirpath, dirnames, _ in os.walk(root_directory):
        for dirname in dirnames:
            if dirname.lower() in ['frustratometer', 'frustratometer_af']:
                target_dir = os.path.join(dirpath, dirname)
                print(f"\nFound target directory: {target_dir}")
                
                # **Clean the target directory before extraction**
                clean_directory(target_dir)
                
                # Proceed with extraction
                extract_rar_from_directory(target_dir)
                
                # Determine output filename based on directory type
                if 'frustratometer_af' in dirname.lower():
                    output_filename = "frustration_af_summary.txt"
                elif 'frustratometer' in dirname.lower():
                    output_filename = "frustration_summary.txt"
                else:
                    print(f"Unknown directory type: {dirname}. Skipping.")
                    continue

                # Process the subdirectory
                process_subdirectory(dirpath, dirname, output_filename)

if __name__ == "__main__":
    # **Specify your root directory here**
    root_directory = "" 
    main(root_directory)

Combine and allign data for each protein (20R)

In [None]:
import os
import re
import numpy as np
import pandas as pd
import logging
from Bio import pairwise2
from Bio.Seq import Seq
from collections import defaultdict
from Bio.PDB.Polypeptide import is_aa
from Bio.PDB import PDBParser, DSSP
from pathlib import Path

# -------------------------------------------------------------------------
# 1) Logging Setup
# -------------------------------------------------------------------------
def setup_logging(debug=False):
    level = logging.DEBUG if debug else logging.INFO
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    return logging.getLogger(__name__)

logger = setup_logging(debug=True)

# -------------------------------------------------------------------------
# 2) Directory Validation
# -------------------------------------------------------------------------
def validate_directory_structure(dir_path):
    """
    Checks which data sources are available in the directory.
    Returns a tuple (has_any_data, available_sources).
    """
    possible_sources = {
        'frustratometer': ['frustration_summary.txt'],
        'frustratometer_af': ['frustration_af_summary.txt'],
        'experimental_data': ['average_b_factors.txt'],  
        'mj_analysis': ['stability_scores.txt']
    }
    
    available_sources = []
    dir_path = Path(dir_path)
    
    for subdir, required_files in possible_sources.items():
        subdir_path = dir_path / subdir
        if subdir_path.is_dir():
            # All "required_files" must be present
            all_files_exist = all((subdir_path / file).is_file() for file in required_files)
            if all_files_exist:
                available_sources.append(subdir)
    
    return len(available_sources) > 0, available_sources

def get_valid_directories(root_path):
    """
    Finds all subdirectories that have at least one valid data source.
    Returns a list of tuples (directory_path, available_sources).
    """
    valid_dirs = []
    root_path = Path(root_path)
    
    if not root_path.is_dir():
        logger.error(f"Root directory does not exist: {root_path}")
        return []
        
    for entry in root_path.iterdir():
        if entry.is_dir():
            has_data, available_sources = validate_directory_structure(entry)
            if has_data:
                valid_dirs.append((entry, available_sources))
                logger.info(f"Found directory {entry} with data sources: {', '.join(available_sources)}")
            else:
                logger.warning(f"Directory {entry} has no valid data sources, skipping")
    
    return sorted(valid_dirs, key=lambda x: str(x[0]))

# -------------------------------------------------------------------------
# 3) Parsing Functions
# -------------------------------------------------------------------------
def parse_frustration_file(file_path):
    """
    Parses frustration_summary.txt or frustration_af_summary.txt.
    
    Format:
        Residue# ResidueAA Difference
        1 A -0.1234
        2 G 0.5678
        3 S -1.2345
        ...
        
    Returns:
        sequence_str (str): Amino acid sequence in 1-letter codes.
        pos_values (dict): Dictionary mapping 1-based residue positions to frustration differences.
    """
    file_path = Path(file_path)
    if not file_path.is_file():
        logger.debug(f"parse_frustration_file: File not found {file_path}")
        return "", {}

    lines = []
    with open(file_path, 'r') as f:
        for line_number, line in enumerate(f, start=1):
            line = line.strip()
            if not line or line.startswith("Residue#"):
                continue
            parts = line.split()
            if len(parts) >= 3:
                try:
                    res_num = int(parts[0])
                    aa = parts[1].upper()
                    difference = float(parts[2])
                    
                    if aa not in "ACDEFGHIKLMNPQRSTVWY":
                        logger.warning(f"Line {line_number}: Unknown amino acid code '{aa}'. Assigned as 'X'.")
                        aa = 'X'

                    lines.append((res_num, aa, difference))
                except ValueError as ve:
                    logger.warning(f"parse_frustration_file: Skipping line {line_number} due to ValueError: {line}")
                    continue
            else:
                logger.warning(f"parse_frustration_file: Line {line_number} does not have enough parts: {line}")
                continue

    if not lines:
        logger.debug(f"parse_frustration_file: No valid data found in {file_path}")
        return "", {}

    lines_sorted = sorted(lines, key=lambda x: x[0])
    sequence = ''.join([aa for (_, aa, _) in lines_sorted])
    pos_values = {res_num: difference for (res_num, _, difference) in lines_sorted}

    return sequence, pos_values

def parse_b_factor(file_path):
    """
    Parses average_b_factors.txt
    Format:
      Residue ResidueAA Average_B_Factor
      1 K 84.683
      ...
    Returns:
        sequence_str (str): Amino acid sequence from file.
        pos_values (dict): Dictionary {1-based_pos: B-factor}.
    """
    file_path = Path(file_path)
    if not file_path.is_file():
        logger.debug(f"parse_b_factor: File not found {file_path}")
        return "", {}

    lines = []
    with open(file_path, 'r') as f:
        for line in f:
            if line.startswith("Residue"):
                continue
            parts = line.strip().split()
            if len(parts) < 3:
                continue
            try:
                idx = int(parts[0])
                aa = parts[1].upper()
                bfact = float(parts[2])
                lines.append((idx, aa, bfact))
            except ValueError as ve:
                logger.warning(f"parse_b_factor: Skipping line due to ValueError: {line}")
                continue

    if not lines:
        logger.debug(f"parse_b_factor: No lines parsed in {file_path}")
        return "", {}

    data_dict = {r[0]: (r[1], r[2]) for r in lines}
    sorted_indices = sorted(data_dict.keys())

    seq_builder = []
    pos_values = {}
    for pos, idx in enumerate(sorted_indices, start=1):
        aa, bfact = data_dict[idx]
        seq_builder.append(aa)
        pos_values[pos] = bfact

    sequence_str = "".join(seq_builder)
    return sequence_str, pos_values

def parse_evolutionary(file_path):
    """
    Parses stability_scores.txt
    Format:
      Label Score Difference
      M1C -456.89159 -0.0
      ...
    Returns (sequence_str, pos_values)
    """
    file_path = Path(file_path)
    if not file_path.is_file():
        logger.debug(f"parse_evolutionary: File not found {file_path}")
        return "", {}

    data_map = defaultdict(list)
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("Label"):
                continue
            parts = line.split()
            if len(parts) < 3:
                continue
            label = parts[0]
            try:
                diff = float(parts[2])
            except ValueError:
                logger.warning(f"parse_evolutionary: Invalid difference value in line: {line}")
                continue

            if label.lower() == "wt":
                continue

            m = re.match(r'([A-Z])(\d+)([A-Z])', label, re.IGNORECASE)
            if m:
                native_aa = m.group(1).upper()
                idx = int(m.group(2))
                data_map[(native_aa, idx)].append(diff)

    if not data_map:
        logger.debug(f"parse_evolutionary: No valid lines in {file_path}")
        return "", {}

    index_to_aa = {}
    index_to_diff = {}

    used_positions = set()
    sorted_keys = sorted(data_map.keys(), key=lambda x: x[1])
    for (aa, i) in sorted_keys:
        if i in used_positions:
            continue
        used_positions.add(i)
        diffs = data_map[(aa, i)]
        avg_diff = sum(diffs)/len(diffs) if diffs else 0.0
        index_to_aa[i] = aa
        index_to_diff[i] = avg_diff

    sorted_indices = sorted(index_to_aa.keys())
    seq_builder = []
    pos_values = {}
    for pos, idx in enumerate(sorted_indices, start=1):
        seq_builder.append(index_to_aa[idx])
        pos_values[pos] = index_to_diff[idx]

    sequence_str = "".join(seq_builder)
    return sequence_str, pos_values

def parse_rmsf(file_path):
    """
    Parses experimental_data/rmsf.csv of the form:
      26,A,2.472
      27,A,2.308
      28,A,2.657
      ...
    We reindex so that the first residue (e.g. 26) -> 1, second (27) -> 2, etc.
    Returns (sequence_str, pos_values) => {1-based_pos: RMSF}.
    """
    file_path = Path(file_path)
    if not file_path.is_file():
        logger.debug(f"parse_rmsf: File not found {file_path}")
        return "", {}

    lines = []
    with open(file_path, 'r') as f:
        for line_number, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
            parts = line.split(',')
            if len(parts) < 3:
                logger.warning(f"parse_rmsf: Skipping malformed line {line_number}: {line}")
                continue
            try:
                old_idx = int(parts[0])
                # chain = parts[1]  # Chain letter (not essential for alignment here)
                rmsf_val = float(parts[2])
                lines.append((old_idx, rmsf_val))
            except ValueError:
                logger.warning(f"parse_rmsf: Skipping invalid line {line_number}: {line}")
                continue

    if not lines:
        logger.debug("parse_rmsf: No valid lines in RMSF file.")
        return "", {}

    # Sort by the original residue index just in case
    lines_sorted = sorted(lines, key=lambda x: x[0])

    # Determine offset
    offset = lines_sorted[0][0] - 1  # e.g., if first index is 26, offset is 25
    pos_values = {}
    seq_builder = []
    for (old_idx, val) in lines_sorted:
        new_idx = old_idx - offset
        pos_values[new_idx] = val
        # Use a dummy 'A' (or any single-letter code) for alignment
        seq_builder.append("A")

    sequence_str = "".join(seq_builder)
    return sequence_str, pos_values

def parse_mutation_scores(file_path):
    """
    Parses mutation_scores.txt which is expected to have the following format:
    
       segment mutant  pos wt subs    frequency  column_conservation  effect_prediction_epistatic
           NaN   G12A   12  G    A 2.533944e-02             0.319221                    -2.248724
           NaN   G12C   12  G    C 5.786646e-04             0.319221                    -5.606657
           ...
    
    For each residue position, this function calculates the negative average of the 
    effect_prediction_epistatic scores and builds a wt sequence (using the wt residue).
    
    Returns:
       sequence_str (str): The wt amino acid sequence (ordered by position).
       pos_values (dict): Dictionary mapping 1-based residue positions to the negative average effect prediction.
    """
    file_path = Path(file_path)
    if not file_path.is_file():
        logger.debug(f"parse_mutation_scores: File not found {file_path}")
        return "", {}
    
    pos_effects = defaultdict(list)
    pos_wt = {}
    
    with open(file_path, 'r') as f:
        header = f.readline()  # Skip header line
        for line_number, line in enumerate(f, start=2):
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) < 8:
                logger.warning(f"parse_mutation_scores: Skipping malformed line {line_number}: {line}")
                continue
            try:
                # parts: [segment, mutant, pos, wt, subs, frequency, column_conservation, effect_prediction_epistatic]
                pos = int(parts[2])
                wt_res = parts[3].upper()
                effect = float(parts[7])
            except ValueError:
                logger.warning(f"parse_mutation_scores: Skipping line {line_number} due to conversion error: {line}")
                continue
            
            pos_effects[pos].append(effect)
            # Record wt residue (assuming consistent for a given pos)
            if pos not in pos_wt:
                pos_wt[pos] = wt_res
    
    if not pos_effects:
        logger.debug(f"parse_mutation_scores: No valid data parsed from {file_path}")
        return "", {}
    
    sorted_positions = sorted(pos_effects.keys())
    seq_builder = []
    pos_values = {}
    for new_pos, pos in enumerate(sorted_positions, start=1):
        wt = pos_wt.get(pos, 'X')
        avg_effect = sum(pos_effects[pos]) / len(pos_effects[pos])
        # Negative average as requested
        pos_values[new_pos] = -avg_effect
        seq_builder.append(wt)
    
    sequence_str = "".join(seq_builder)
    return sequence_str, pos_values

# -------------------------------------------------------------------------
# 4) Alignment Functions
# -------------------------------------------------------------------------
def align_two(seqA, seqB, gap_open=-2, gap_extend=-0.5):
    """
    Attempt a global alignment with Biopython pairwise2.
    Returns (alignedA, alignedB).
    """
    seqA = seqA.upper()
    seqB = seqB.upper()

    if len(seqA) == 0 and len(seqB) == 0:
        logger.debug("align_two: Both sequences empty => '' ")
        return "", ""
    if len(seqA) == 0:
        logger.debug(f"align_two: SeqA empty, SeqB length={len(seqB)} => trivial alignment")
        return "-" * len(seqB), seqB
    if len(seqB) == 0:
        logger.debug(f"align_two: SeqB empty, SeqA length={len(seqA)} => trivial alignment")
        return seqA, "-" * len(seqA)

    logger.debug(f"align_two: Attempting global alignment: len(seqA)={len(seqA)}, len(seqB)={len(seqB)}")
    alignments = pairwise2.align.globalms(seqA, seqB, 2, -1, gap_open, gap_extend)
    if not alignments:
        logger.warning("align_two: No alignment from Biopython => trivial fallback.")
        max_len = max(len(seqA), len(seqB))
        if len(seqA) == max_len:
            return seqA, seqB + "-"*(len(seqA)-len(seqB))
        else:
            return seqA + "-"*(len(seqB)-len(seqA)), seqB

    best = alignments[0]
    return best[0], best[1]

def merge_val_alignment(alnA, alnB, valA, valB):
    """
    Merge aligned sequences with their values.
    Returns (aligned_valsA, aligned_valsB) as lists.
    """
    aligned_valsA = []
    aligned_valsB = []
    origA_pos = 1
    origB_pos = 1
    
    for i in range(len(alnA)):
        cA = alnA[i]
        cB = alnB[i]

        if cA == '-':
            aligned_valsA.append('n/a')
        else:
            aligned_valsA.append(valA.get(origA_pos, 'n/a'))
            origA_pos += 1

        if cB == '-':
            aligned_valsB.append('n/a')
        else:
            aligned_valsB.append(valB.get(origB_pos, 'n/a'))
            origB_pos += 1
            
    return aligned_valsA, aligned_valsB

# -------------------------------------------------------------------------
# 5) Multiple Sequence Alignment
# -------------------------------------------------------------------------
def multi_align_sequences(seq_list):
    """
    Progressive multiple sequence alignment.
    seq_list: [(name, seq_str, val_dict), ...]
    Returns a list of (name, final_aln_seq, final_aln_vals).
    """
    seq_list.sort(key=lambda x: len(x[1]), reverse=True)

    recs = []
    for (nm, s, v) in seq_list:
        arr = [v.get(i, 'n/a') for i in range(1, len(s)+1)]
        recs.append([nm, s, v, s, arr])

    if not recs:
        logger.debug("multi_align_sequences: No sequences to align.")
        return []

    # The first record is the "master" to start
    master_aln_seq = recs[0][3]
    master_aln_vals = recs[0][4]

    for i in range(1, len(recs)):
        nameB, seqB_str, valB, alnB_str, alnB_vals = recs[i]
        logger.debug(f"multi_align: Aligning MASTER({recs[0][0]}) with {nameB}")

        # Re-build dictionary for master positions
        master_dict = {}
        real_pos = 1
        for idx_char, char in enumerate(master_aln_seq):
            if char != '-':
                master_dict[real_pos] = master_aln_vals[idx_char]
                real_pos += 1

        alnA, alnB = align_two(master_aln_seq, seqB_str)
        newA_vals, newB_vals = merge_val_alignment(alnA, alnB, master_dict, valB)

        master_aln_seq = alnA
        master_aln_vals = newA_vals
        recs[0][3] = alnA
        recs[0][4] = newA_vals

        recs[i][3] = alnB
        recs[i][4] = newB_vals

    final_len = len(recs[0][3])
    logger.debug(f"multi_align_sequences: final alignment length={final_len}")

    # Pad/truncate all alignments to the final length
    for i in range(len(recs)):
        seq_aln = recs[i][3]
        vals_aln = recs[i][4]
        diff = final_len - len(seq_aln)
        if diff > 0:
            logger.debug(f"Padding {recs[i][0]} from length={len(seq_aln)} to {final_len}")
            seq_aln += '-'*diff
            vals_aln += ['n/a']*diff
            recs[i][3] = seq_aln
            recs[i][4] = vals_aln
        elif diff < 0:
            logger.warning(f"{recs[i][0]} alignment is longer than master!? Truncating.")
            recs[i][3] = seq_aln[:final_len]
            recs[i][4] = vals_aln[:final_len]

    final_data = []
    for r in recs:
        final_data.append((r[0], r[3], r[4]))
    return final_data

def map_ss_to_alignment(ss_map, exp_seq, residue_seq, aligned_dict):
    """Maps secondary structure assignments to aligned sequence positions."""
    if not ss_map or 'B_FACTOR' not in aligned_dict:
        return ['n/a'] * len(residue_seq)
        
    # Get B-factor alignment
    bf_aln_seq, bf_aln_vals = aligned_dict['B_FACTOR']
    
    # Create raw SS array
    raw_ss = ['n/a'] * len(residue_seq)
    
    # Track position in original SS map
    ss_pos = 1
    
    # Map through alignment
    for i, (master_res, bf_res) in enumerate(zip(residue_seq, bf_aln_seq)):
        if bf_res == '-' or bf_aln_vals[i] == 'n/a':
            raw_ss[i] = 'n/a'
        else:
            raw_ss[i] = ss_map.get(ss_pos, 'o')
            ss_pos += 1
            
    return raw_ss

# -------------------------------------------------------------------------
# 6) PDB Parsing Function
# -------------------------------------------------------------------------
def parse_secondary_structure_from_pdb(pdb_file_path):
    """
    Parses a PDB file to extract per-residue secondary structure assignments using DSSP.
    Returns:
        ss_map (dict): Mapping from sequential position (1-based) to 'A', 'B', or 'o'.
    """
    ss_map = {}
    pdb_file = Path(pdb_file_path)
    
    if not pdb_file.is_file():
        logger.error(f"PDB file not found: {pdb_file_path}")
        return ss_map

    parser = PDBParser(QUIET=True)
    try:
        structure = parser.get_structure('protein', pdb_file)
        model = structure[0]
        chains = list(model.get_chains())
        logger.debug(f"Found chains: {[chain.id for chain in chains]}")
        
        dssp = DSSP(model, str(pdb_file), dssp='/opt/homebrew/bin/mkdssp')
        logger.debug(f"DSSP successful, found {len(dssp.keys())} residues")
        
        # Get all keys and sort by residue number
        keys = sorted(dssp.keys(), key=lambda x: x[1][1])
        
        # Assign sequential positions
        for seq_pos, key in enumerate(keys, start=1):
            ss = dssp[key][2]
            if ss in ('H', 'G', 'I'):
                ss_code = 'A'
            elif ss in ('E', 'B'):
                ss_code = 'B'
            else:
                ss_code = 'o'
                
            ss_map[seq_pos] = ss_code
            logger.debug(f"Assigned SS for sequential position {seq_pos}: {ss} -> {ss_code}")

    except Exception as e:
        logger.error(f"Error processing PDB: {str(e)}")
        return ss_map

    logger.info(f"Successfully parsed {len(ss_map)} residues with SS assignments")
    return ss_map

# -------------------------------------------------------------------------
# 7) Main Processing Function
# -------------------------------------------------------------------------
def process_directory(root_directory, output_dir=None):
    """
    Process a directory containing subdirectories with data.
    """
    root_directory = Path(root_directory)
    
    # Set up summary data directory
    if output_dir:
        summary_data_dir = Path(output_dir)
    else:
        summary_data_dir = root_directory / "summary_data"
        
    try:
        summary_data_dir.mkdir(exist_ok=True)
        logger.info(f"Using summary directory: {summary_data_dir}")
    except Exception as e:
        logger.error(f"Failed to create summary directory: {e}")
        return

    # Get valid directories and their available sources
    dir_info = get_valid_directories(root_directory)

    if not dir_info:
        logger.error("No directories with valid data sources found.")
        return

    # Process each directory
    for dir_path, available_sources in dir_info:
        protein_id = dir_path.name
        logger.info(f"Processing: {dir_path}")
        
        # Initialize paths and data containers
        seq_exp, vals_exp = "", {}
        seq_af, vals_af = "", {}
        seq_bf, vals_bf = "", {}
        seq_ev, vals_ev = "", {}
        seq_rmsf, vals_rmsf = "", {}
        seq_mut, vals_mut = "", {}  # For mutation scores
        ss_map = {}  # Secondary Structure mapping

        try:
            # 1) Experimental frustration
            if 'frustratometer' in available_sources:
                exp_path = dir_path / "frustratometer" / "frustration_summary.txt"
                seq_exp, vals_exp = parse_frustration_file(str(exp_path))
                logger.debug(f"EXP_FRUST: length={len(seq_exp)} from {exp_path}")
            
            # 2) AlphaFold frustration
            if 'frustratometer_af' in available_sources:
                af_path = dir_path / "frustratometer_af" / "frustration_af_summary.txt"
                seq_af, vals_af = parse_frustration_file(str(af_path))
                logger.debug(f"AF_FRUST: length={len(seq_af)} from {af_path}")
            
            # 3) Experimental data (B-factor, RMSF)
            if 'experimental_data' in available_sources:
                # B-factor
                bf_path = dir_path / "experimental_data" / "average_b_factors.txt"
                seq_bf, vals_bf = parse_b_factor(str(bf_path))
                logger.debug(f"B_FACTOR: length={len(seq_bf)} from {bf_path}")
                
                # RMSF
                rmsf_path = dir_path / "experimental_data" / "rmsf.csv"
                if rmsf_path.is_file():
                    seq_rmsf, vals_rmsf = parse_rmsf(str(rmsf_path))
                    logger.debug(f"RMSF: length={len(seq_rmsf)} from {rmsf_path}")

            # 4) Evolutionary data (MJ analysis)
            if 'mj_analysis' in available_sources:
                evol_path = dir_path / "mj_analysis" / "stability_scores.txt"
                seq_ev, vals_ev = parse_evolutionary(str(evol_path))
                logger.debug(f"EVOL: length={len(seq_ev)} from {evol_path}")

            # 5) Mutation scores (Mutability)
            ms_path = dir_path / "evc_output" / "couplings" / "mutation_scores.txt"
            if ms_path.is_file():
                seq_mut, vals_mut = parse_mutation_scores(str(ms_path))
                logger.debug(f"MUTATION: length={len(seq_mut)} from {ms_path}")
            else:
                logger.warning(f"No mutation_scores.txt found at {ms_path}")

            # 6) Secondary Structure
            # Locate 'monomer.pdb' within 'experimental_data' subdirectory
            pdb_file_path = dir_path / "experimental_data" / "monomer.pdb"
            if not pdb_file_path.is_file():
                logger.warning(f"No PDB file found at {pdb_file_path}. Assigning 'n/a' to all residues.")
            else:
                ss_map = parse_secondary_structure_from_pdb(pdb_file_path)
                if ss_map:
                    logger.info(f"Secondary structure information extracted for protein {protein_id}.")
                else:
                    logger.warning(f"Secondary structure information could not be extracted for protein {protein_id}. Assigning 'n/a' to all residues.")

            # Collect available data sources for alignment
            data_sources = []
            if len(seq_exp) > 0:
                data_sources.append(("EXP_FRUST", seq_exp, vals_exp))
            if len(seq_af) > 0:
                data_sources.append(("AF_FRUST", seq_af, vals_af))
            if len(seq_bf) > 0:
                data_sources.append(("B_FACTOR", seq_bf, vals_bf))
            if len(seq_ev) > 0:
                data_sources.append(("EVOL", seq_ev, vals_ev))
            if len(seq_rmsf) > 0:
                data_sources.append(("RMSF", seq_rmsf, vals_rmsf))
            if len(seq_mut) > 0:
                data_sources.append(("MUTATION", seq_mut, vals_mut))

            if not data_sources:
                logger.warning(f"No valid sequence data found in {dir_path}, skipping.")
                continue

            logger.debug(f"Found {len(data_sources)} valid data sources for alignment")

            # Perform progressive multiple alignment
            aligned = multi_align_sequences(data_sources)
            if not aligned:
                logger.warning(f"Alignment failed for {dir_path}, skipping.")
                continue

            logger.debug(f"Successfully aligned {len(aligned)} sequences")

            # The reference alignment is the first entry, typically the longest sequence
            residue_seq = aligned[0][1]
            final_len = len(residue_seq)
            logger.debug(f"Final alignment length: {final_len}")

            # Build a lookup for name -> (aln_seq, aln_vals)
            aligned_dict = {x[0]: (x[1], x[2]) for x in aligned}
            logger.debug(f"Available data types in alignment: {list(aligned_dict.keys())}")

            # Prepare columns
            raw_index = list(range(1, final_len+1))
            raw_res = list(residue_seq)
            raw_exp = ['n/a'] * final_len
            raw_af  = ['n/a'] * final_len
            raw_bf  = ['n/a'] * final_len
            raw_ev  = ['n/a'] * final_len
            raw_rmsf = ['n/a'] * final_len
            raw_mut = ['n/a'] * final_len  # For mutation scores
            raw_ss = ['n/a'] * final_len  # Initialize with 'n/a'

            # Fill columns if available
            if "EXP_FRUST" in aligned_dict:
                _, exp_aln_vals = aligned_dict["EXP_FRUST"]
                raw_exp = exp_aln_vals
            if "AF_FRUST" in aligned_dict:
                _, af_aln_vals = aligned_dict["AF_FRUST"]
                raw_af = af_aln_vals
            if "B_FACTOR" in aligned_dict:
                _, bf_aln_vals = aligned_dict["B_FACTOR"]
                raw_bf = bf_aln_vals
            if "EVOL" in aligned_dict:
                _, ev_aln_vals = aligned_dict["EVOL"]
                raw_ev = ev_aln_vals
            if "RMSF" in aligned_dict:
                _, rmsf_aln_vals = aligned_dict["RMSF"]
                raw_rmsf = rmsf_aln_vals
            if "MUTATION" in aligned_dict:
                _, mut_aln_vals = aligned_dict["MUTATION"]
                raw_mut = mut_aln_vals

            # Assign Secondary Structure based on alignment
            raw_ss = map_ss_to_alignment(ss_map, seq_exp, residue_seq, aligned_dict)

            # --- Modified output: Write CSV file instead of a tab-delimited text file ---
            df = pd.DataFrame({
                'AlnIndex': raw_index,
                'Residue': raw_res,
                'SecondaryStructure': raw_ss,
                'B_Factor': raw_bf,
                'ExpFrust': raw_exp,
                'AFFrust': raw_af,
                'EvolFrust': raw_ev,
                'RMSF': raw_rmsf,
                'Mutability': raw_mut
            })

            # Write CSV file in the current directory
            out_path = dir_path / "summary.csv"
            df.to_csv(out_path, index=False)
            logger.info(f"Wrote CSV summary to {out_path}")

            # Also write CSV file in the central summary_data directory
            summary_filename = f"summary_{dir_path.name}.csv"
            summary_path = summary_data_dir / summary_filename
            df.to_csv(summary_path, index=False)
            logger.info(f"Wrote central CSV summary to {summary_path}")

        except Exception as e:
            logger.error(f"Error processing {dir_path}: {str(e)}")
            continue

    return summary_data_dir

# -------------------------------------------------------------------------
# 9) Example Usage
# -------------------------------------------------------------------------
# Example usage (uncomment to run):
root_dir = ""
summary_dir = process_directory(root_dir)

Generate violin plots for spearman correlation between each frustration type and B-factor for the set of proteins (20R)

In [None]:
import os
import pandas as pd
import numpy as np
from scipy.stats import spearmanr, levene, bartlett, kruskal
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import combinations
from statsmodels.stats.multitest import multipletests

def compute_spearman(x, y):
    """
    Compute Spearman correlation with improved handling of constant arrays.
    Returns None if correlation cannot be computed.
    """
    mask = x.notna() & y.notna()
    if mask.sum() > 1:
        x_valid = x[mask]
        y_valid = y[mask]
        
        # Check for constant arrays
        if x_valid.std() == 0 or y_valid.std() == 0:
            return None
        try:
            corr, _ = spearmanr(x_valid, y_valid)
            return corr if not np.isnan(corr) else None
        except:
            return None
    return None

def process_data(data_dir):
    """
    Process all data files in the specified directory and compute Spearman correlations.
    """
    # Initialize lists to store results
    results = []
    
    # Define frustration types
    frust_types = ['ExpFrust', 'AFFrust', 'EvolFrust']
    
    for filename in os.listdir(data_dir):
        if filename.endswith(('.txt', '.csv')):
            filepath = os.path.join(data_dir, filename)
            try:
                # Determine separator based on file extension
                sep = '\t' if filename.endswith('.txt') else ','
                df = pd.read_csv(filepath, sep=sep, na_values=['n/a', 'N/A'])
                
                # Skip if B_Factor column is missing
                if 'B_Factor' not in df.columns:
                    print(f"Skipping {filename}: Missing 'B_Factor' column.")
                    continue
                
                # Convert columns to numeric
                df['B_Factor'] = pd.to_numeric(df['B_Factor'], errors='coerce')
                
                # Process each frustration type
                for frust_type in frust_types:
                    if frust_type in df.columns:
                        df[frust_type] = pd.to_numeric(df[frust_type], errors='coerce')
                        corr = compute_spearman(df['B_Factor'], df[frust_type])
                        
                        if corr is not None:
                            results.append({
                                'Protein': filename,
                                'Frustration_Type': frust_type,
                                'Spearman_Correlation': corr
                            })
                    
            except Exception as e:
                print(f"Error processing {filename}: {e}")
    
    return pd.DataFrame(results)

def pairwise_levene_corrected(df, group_col, value_col, alpha=0.05, correction='bonferroni'):
    """
    Perform pairwise Levene's tests between all group combinations with Bonferroni correction.

    Parameters:
    - df: pandas DataFrame containing the data.
    - group_col: Column name representing group labels.
    - value_col: Column name containing the numerical values to compare.
    - alpha: Desired overall significance level (default is 0.05).
    - correction: Multiple testing correction method (default is 'bonferroni').

    Returns:
    - A DataFrame with pairwise comparison results including adjusted p-values and significance.
    """
    groups = df[group_col].unique()
    pairwise = list(combinations(groups, 2))
    
    results = []
    p_values = []
    pair_names = []
    
    # Perform pairwise Levene's tests
    for (group1, group2) in pairwise:
        data1 = df[df[group_col] == group1][value_col].dropna()
        data2 = df[df[group_col] == group2][value_col].dropna()
        stat, p = levene(data1, data2)
        results.append({'Group1': group1, 'Group2': group2, 'Statistic': stat, 'p-value': p})
        p_values.append(p)
        pair_names.append(f"{group1} vs {group2}")
    
    # Apply Bonferroni correction
    adjusted = multipletests(p_values, alpha=alpha, method=correction)
    adjusted_pvals = adjusted[1]
    reject = adjusted[0]
    
    # Compile results
    for i, pair in enumerate(results):
        pair['Adjusted p-value'] = adjusted_pvals[i]
        pair['Reject H0'] = reject[i]
    
    results_df = pd.DataFrame(results)
    
    return results_df

def add_annotations(ax, pairwise_df, order, value_col, significance_level=0.05):
    """
    Add non-overlapping bracket annotations with vertical p-values for all pairwise comparisons.
    Includes aligned shorter brackets with spacing where they meet.
    """
    def get_pvalue(group1, group2, df):
        """Helper function to safely get p-value for a pair of groups"""
        mask = ((df['Group1'] == group1) & (df['Group2'] == group2)) | \
               ((df['Group1'] == group2) & (df['Group2'] == group1))
        matched = df[mask]['Adjusted p-value']
        if len(matched) == 0:
            print(f"Warning: No p-value found for comparison between {group1} and {group2}")
            return None
        return matched.values[0]
    
    # Determine the y positions based on the order
    y_positions = {group: idx for idx, group in enumerate(order)}
    
    # Get the current x-axis limits to determine placement
    xlim = ax.get_xlim()
    max_x = xlim[1]
    x_offset = max_x * 0.25  # Base offset for spacing
    y_spacing = 0.1  # Small y-direction spacing where brackets meet
    
    # Calculate shared x positions for aligned shorter brackets
    shared_x_start = max_x 
    shared_x_end = max_x + 0.03
    long_x_start = max_x + 0.08
    long_x_end = max_x + 0.11
    
    # Add top short bracket (AFFrust to ExpFrust)
    group1 = 'AFFrust'
    group2 = 'ExpFrust'
    y1 = y_positions[group1]
    y2 = y_positions[group2]
    
    ax.plot([shared_x_start, shared_x_end], [y1, y1], lw=1.5, c='black')
    ax.plot([shared_x_start, shared_x_end], [y2 - y_spacing, y2 - y_spacing], lw=1.5, c='black')
    ax.plot([shared_x_end, shared_x_end], [y1, y2 - y_spacing], lw=1.5, c='black')
    
    p_value = get_pvalue(group1, group2, pairwise_df)
    if p_value is not None:
        asterisk = '*' if p_value < significance_level else ''
        ax.text(shared_x_end + 0.02, (y1 + y2) / 2, f'p = {p_value:.3f}{asterisk}', 
                ha='left', va='center', fontsize=10, color='black', rotation=270)
    
    # Add bottom short bracket (ExpFrust to EvolFrust)
    group1 = 'ExpFrust'
    group2 = 'EvolFrust'
    y1 = y_positions[group1]
    y2 = y_positions[group2]
    
    ax.plot([shared_x_start, shared_x_end], [y1 + y_spacing, y1 + y_spacing], lw=1.5, c='black')
    ax.plot([shared_x_start, shared_x_end], [y2, y2], lw=1.5, c='black')
    ax.plot([shared_x_end, shared_x_end], [y1 + y_spacing, y2], lw=1.5, c='black')
    
    p_value = get_pvalue(group1, group2, pairwise_df)
    if p_value is not None:
        asterisk = '*' if p_value < significance_level else ''
        ax.text(shared_x_end + 0.02, (y1 + y2) / 2, f'p = {p_value:.3f}{asterisk}', 
                ha='left', va='center', fontsize=10, color='black', rotation=270)
    
    # Add long bracket (AFFrust to EvolFrust)
    group1 = 'AFFrust'
    group2 = 'EvolFrust'
    y1 = y_positions[group1]
    y2 = y_positions[group2]
    
    ax.plot([long_x_start, long_x_end], [y1, y1], lw=1.5, c='black')
    ax.plot([long_x_start, long_x_end], [y2, y2], lw=1.5, c='black')
    ax.plot([long_x_end, long_x_end], [y1, y2], lw=1.5, c='black')
    
    p_value = get_pvalue(group1, group2, pairwise_df)
    if p_value is not None:
        asterisk = '*' if p_value < significance_level else ''
        ax.text(long_x_end + 0.02, (y1 + y2) / 2, f'p = {p_value:.3f}{asterisk}', 
                ha='left', va='center', fontsize=10, color='black', rotation=270)
    
    # Update the figure's right margin to accommodate vertical text
    plt.subplots_adjust(right=0.85)

def create_violin_plot(df, pairwise_results, kruskal_stat, kruskal_p, group_col='Frustration_Type', value_col='Spearman_Correlation'):
    """
    Create a publication-quality horizontal violin plot showing the distribution of 
    Spearman correlations for each frustration type, with annotations for significant variance differences
    and Kruskal-Wallis test results.
    """
    # Set up the figure with high DPI for publication quality
    fig, ax = plt.subplots(figsize=(10, 6), dpi=600)
    
    # Set publication-quality style
    plt.style.use('seaborn-v0_8-whitegrid')
    sns.set_context("paper", font_scale=1.5)
    
    # Create a custom order for the frustration types (bottom to top)
    order = ['AFFrust', 'ExpFrust', 'EvolFrust']
    
    # Create a custom color palette in the desired order
    palette = {'AFFrust': '#377eb8',    # Blue (AlphaFold)
              'ExpFrust': '#e41a1c',    # Red (Experimental)
              'EvolFrust': '#4daf4a'}   # Green (Evolutionary)
    
    # Create custom labels
    labels = {'AFFrust': 'AlphaFold',
             'ExpFrust': 'Experimental',
             'EvolFrust': 'Evolutionary'}
    
    # Set the gridlines to be behind everything
    ax.grid(True, axis='x', linestyle='--', alpha=0.7, zorder=0)
    
    # Add vertical dotted lines for means (behind the violins)
    for frust_type in order:
        mean_val = df[df[group_col] == frust_type][value_col].mean()
        ax.axvline(x=mean_val, ymin=0, ymax=1, color=palette[frust_type], 
                  linestyle=':', linewidth=2, alpha=0.8, zorder=1)
    
    # Create violin plot (in front of the mean lines)
    sns.violinplot(
        data=df,
        x=value_col,
        y=group_col,
        order=order,
        inner='box',
        palette=palette,
        linewidth=1.5,
        zorder=2,
        ax=ax
    )
    
    # Customize the plot
    ax.set_title('Distribution of Spearman Correlations\nby Frustration Type', 
                pad=20, fontsize=16, fontweight='bold')
    
    # Set axis labels with custom font properties
    ax.set_xlabel('Spearman Correlation Between Frustration and B-factor', labelpad=15, fontsize=14, fontweight='bold')
    ax.set_ylabel('Frustration Metric', labelpad=15, fontsize=14, fontweight='bold')
    
    # Customize tick parameters
    ax.tick_params(axis='both', which='major', labelsize=12)
    
    # Customize y-axis labels
    ax.set_yticklabels([labels[tick] for tick in order])
    
    # Set spines
    for spine in ax.spines.values():
        spine.set_linewidth(1.5)
    
    # Adjust plot margins to make room for annotations outside the plot area
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15, left=0.15, right=0.85, top=0.85)
    
    # Add annotations for significant variance differences
    add_annotations(ax, pairwise_results, order, value_col)
    
    # Add note about statistical test as a single line in bottom left
    fig.text(0.01, 0.01, 
                'p-values are from pairwise Levene tests for equality of variance with Bonferroni correction.',
                ha='left', va='bottom', fontsize=8, style='italic')
    
    # Prepare Kruskal-Wallis test annotation with significance comment
    if kruskal_p < 0.05:
        significance_comment = "Distributions are significantly different."
    else:
        significance_comment = "No significant differences in distributions."
    
    kruskal_text = f'Kruskal-Wallis Test:\nH = {kruskal_stat:.2f}, p = {kruskal_p:.3e}\n{significance_comment}'
    props = dict(boxstyle='round', facecolor='white', alpha=0.0, linewidth=1.5)
    
    # Add Kruskal-Wallis test results as a text box in the top-left corner of the figure
    fig.text(0.00, 0.95, kruskal_text, fontsize=12,
             verticalalignment='top', bbox=props, ha='left')
    
    return plt

def test_variance_equality(df):
    """
    Perform Levene's test to assess the equality of variances across frustration types.
    """
    groups = [group['Spearman_Correlation'].dropna() for name, group in df.groupby('Frustration_Type')]
    
    stat, p_value = levene(*groups)
    
    print("Levene’s Test for Equality of Variances")
    print(f"Statistic: {stat:.4f}, p-value: {p_value:.4e}")
    
    if p_value < 0.05:
        print("Result: Significant differences in variances (reject H0)")
    else:
        print("Result: No significant differences in variances (fail to reject H0)")
    
def test_bartlett_variance(df):
    """
    Perform Bartlett's test to assess the equality of variances across frustration types.
    """
    groups = [group['Spearman_Correlation'].dropna() for name, group in df.groupby('Frustration_Type')]
    
    stat, p_value = bartlett(*groups)
    
    print("Bartlett’s Test for Equality of Variances")
    print(f"Statistic: {stat:.4f}, p-value: {p_value:.4e}")
    
    if p_value < 0.05:
        print("Result: Significant differences in variances (reject H0)")
    else:
        print("Result: No significant differences in variances (fail to reject H0)")
    
def test_distribution_equality(df):
    """
    Perform Kruskal-Wallis test to assess the equality of distributions across frustration types.
    Returns the test statistic and p-value.
    """
    groups = [group['Spearman_Correlation'].dropna() for name, group in df.groupby('Frustration_Type')]
    
    stat, p_value = kruskal(*groups)
    
    print("Kruskal-Wallis H-Test for Equality of Distributions")
    print(f"Statistic: {stat:.4f}, p-value: {p_value:.4e}")
    
    if p_value < 0.05:
        print("Result: Significant differences in distributions (reject H0)")
    else:
        print("Result: No significant differences in distributions (fail to reject H0)")
    
    return stat, p_value

def main():
    # **Set your data directory**
    DATA_DIR = ''
    
    # Process the data
    results_df = process_data(DATA_DIR)
    
    if results_df.empty:
        print("No valid data found in the specified directory.")
        return
    
    # Print summary statistics
    print("\nSummary Statistics:")
    summary_stats = results_df.groupby('Frustration_Type')['Spearman_Correlation'].agg(['mean', 'std', 'count'])
    print(summary_stats)
    
    # Perform Levene's Test for equality of variances
    print("\n--- Variance Equality Test (Levene's Test) ---")
    test_variance_equality(results_df)
    
    # Perform Bartlett's Test for equality of variances
    print("\n--- Variance Equality Test (Bartlett's Test) ---")
    test_bartlett_variance(results_df)
    
    # Perform Kruskal-Wallis Test for equality of distributions
    print("\n--- Distribution Equality Test (Kruskal-Wallis Test) ---")
    kruskal_stat, kruskal_p = test_distribution_equality(results_df)
    
    # Perform pairwise Levene's Tests with Bonferroni correction
    print("\n--- Pairwise Variance Equality Tests (Levene's) with Bonferroni Correction ---")
    pairwise_results = pairwise_levene_corrected(
        df=results_df,
        group_col='Frustration_Type',
        value_col='Spearman_Correlation',
        alpha=0.05,
        correction='bonferroni'
    )
    print(pairwise_results)
    
    # Create and display the violin plot with annotations
    plt_obj = create_violin_plot(results_df, pairwise_results, kruskal_stat, kruskal_p)
    plt_obj.show()

if __name__ == "__main__":
    main()

Script to plot the single example plot used for Figure 2

In [None]:
import os
import pandas as pd
import numpy as np
from scipy.stats import spearmanr, linregress
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
import matplotlib.gridspec as gridspec
from matplotlib.ticker import MaxNLocator 

# Set Seaborn and Matplotlib styles to match the violin plot's aesthetics
sns.set(style="whitegrid", context="talk", palette="deep")
plt.rcParams.update({
    'figure.dpi': 600,             # High resolution
    'font.size': 25,               # Increased font size for better readability
    'axes.labelsize': 18,
    'axes.titlesize': 18,
    'legend.fontsize': 12,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20,
    'figure.figsize': (20, 15),    # Large figure size
    'axes.linewidth': 1.5,
    'lines.linewidth': 2.5,
    'grid.linewidth': 1.0
})


def lowess_smoothing(x, y, frac=0.1, it=3):
    """
    Applies LOWESS smoothing to the data.
    
    Parameters:
        x (array-like): Independent variable data.
        y (array-like): Dependent variable data.
        frac (float): The fraction of the data used when estimating each y-value.
        it (int): The number of robustifying iterations.
    
    Returns:
        np.ndarray: Smoothed y-values.
    """
    lowess = sm.nonparametric.lowess
    z = lowess(y, x, frac=frac, it=it, return_sorted=False)
    return z

def parse_summary_file(file_path, window_size=5, frac=0.1, it=3):
    """
    Parses the summary file and processes the data
    """
    required_cols = ["AlnIndex", "Residue", "SecondaryStructure", "B_Factor", "ExpFrust", "AFFrust", "EvolFrust"]
    
    if not os.path.isfile(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
    
    try:
        df = pd.read_csv(file_path, sep='\t')
    except Exception as e:
        raise ValueError(f"Failed to parse data from {file_path}. Error: {e}")
    
    if not set(required_cols).issubset(df.columns):
        missing = set(required_cols) - set(df.columns)
        raise ValueError(f"Missing required columns: {missing}")
    
    # Convert 'n/a' to NaN
    for col in ["B_Factor", "ExpFrust", "AFFrust", "EvolFrust"]:
        df[col] = pd.to_numeric(df[col].replace('n/a', np.nan), errors='coerce')
    
    df_original = df.copy()
    df_for_plot = df.copy()
    
    # Apply LOWESS smoothing to each metric
    for col in ["B_Factor", "ExpFrust", "AFFrust", "EvolFrust"]:
        x = df_for_plot["AlnIndex"].values
        y = df_for_plot[col].values
        mask = ~np.isnan(x) & ~np.isnan(y)
        if mask.sum() > 0:
            y_smooth = lowess_smoothing(x[mask], y[mask], frac=frac, it=it)
            df_for_plot.loc[mask, col] = y_smooth
        else:
            df_for_plot[col] = np.nan
    
    # Only normalize B-Factor
    valid = ~df_for_plot['B_Factor'].isna()
    if valid.any():
        col_min = df_for_plot.loc[valid, 'B_Factor'].min()
        col_max = df_for_plot.loc[valid, 'B_Factor'].max()
        if col_max > col_min:
            df_for_plot['B_Factor'] = (df_for_plot['B_Factor'] - col_min) / (col_max - col_min)
        else:
            df_for_plot['B_Factor'] = 0.0
    
    # Compute Spearman correlations
    corrs = {}
    sub = df_original.dropna(subset=["B_Factor","ExpFrust","AFFrust","EvolFrust"])
    if not sub.empty:
        combos = [
            ("B_Factor", "ExpFrust"),
            ("B_Factor", "AFFrust"),
            ("B_Factor", "EvolFrust")
        ]
        for (mA, mB) in combos:
            if sub[mA].nunique() < 2 or sub[mB].nunique() < 2:
                rho, pval = np.nan, np.nan
            else:
                rho, pval = spearmanr(sub[mA], sub[mB])
            corrs[(mA, mB)] = (rho, pval)
    
    return df_original, df_for_plot, corrs

def create_seaborn_figure(df_original, df_plot, corrs):
    """
    Creates a Seaborn figure with:
      1) A main line plot (smoothed metrics vs. AlnIndex) on the top row,
         with full-height background highlighting and secondary-structure shapes.
      2) Three scatter plots in the second row, each having:
         - A one-line title for "B_Factor Rank vs. Frust Metric"
         - A separate Spearman correlation line, placed just below the title
           (no overlap, custom color for each frustration metric).
    """

    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from matplotlib import gridspec
    from matplotlib.ticker import MaxNLocator
    from scipy.stats import spearmanr, linregress
    import statsmodels.api as sm

    ########################################
    # HELPER FUNCTIONS FOR HELIX/ARROW SHAPES
    ########################################
    def create_helix(x_start, width, height=0.25, frequency=2):
        """Returns (x, y) for a sinusoidal helix from x_start to x_start + width."""
        num_points = max(int(width * 20), 2)
        x = np.linspace(x_start, x_start + width, num_points)
        y = height * np.sin(2 * np.pi * frequency * (x - x_start) / width)
        return x, y

    def create_arrow(x_start, width, height=0.25):
        """Returns (x, y) outline for a beta-strand arrow shape."""
        x = [
            x_start, x_start,
            x_start, x_start + 0.7 * width,
            x_start + 0.7 * width, x_start + width,
            x_start + 0.7 * width,
            x_start + 0.7 * width, x_start,
            x_start
        ]
        y = [
            -height / 2,  height / 2,
            -height / 2, -height / 2,
            -height / 2,  0,
             height / 2,
             height / 2,  height / 2,
            -height / 2
        ]
        return np.array(x), np.array(y)

    ########################################
    # CREATE FIGURE WITH GRIDSPEC
    ########################################
    fig = plt.figure(constrained_layout=True)
    gs = gridspec.GridSpec(2, 3, figure=fig, height_ratios=[3, 2])
    
    # Top row (full width) -> main axis
    ax_main = fig.add_subplot(gs[0, :])

    ########################################
    # (A) PLOT SMOOTHED LINES
    ########################################
    metrics = ["B_Factor", "ExpFrust", "AFFrust", "EvolFrust"]
    colors = {
        "B_Factor": "#FF7F00",
        "ExpFrust": "#E41A1C",
        "AFFrust": "#377EB8",
        "EvolFrust": "#4DAF4A"
    }
    
    for metric in metrics:
        label = metric if metric != "B_Factor" else "B-Factor (Normalized)"
        sns.lineplot(
            x="AlnIndex",
            y=metric,
            data=df_plot,
            label=label,
            color=colors.get(metric, "black"),
            ax=ax_main,
            zorder=3
        )

    # Let Matplotlib calculate initial limits
    fig.canvas.draw()
    x_min, x_max = ax_main.get_xlim()
    y_min, y_max = ax_main.get_ylim()

    # Add extra space at the top for SS annotation
    extra_top_space = 0.75
    ax_main.set_ylim(y_min, y_max + extra_top_space)
    # Re-fetch final y-limits after expanding top
    y_min, y_max = ax_main.get_ylim()

    ########################################
    # (B) BACKGROUND HIGHLIGHT BY SECONDARY STRUCTURE
    ########################################
    if 'SecondaryStructure' in df_plot.columns:
        ss_colors = {
            'A': ('#800080', 0.1),  # alpha-helix
            'B': ('#008080', 0.1),  # beta-sheet
            'O': ('#808080', 0.1)   # other
        }

        # Sort by AlnIndex to ensure left -> right
        df_plot_sorted = df_plot.sort_values(by='AlnIndex').reset_index(drop=True)

        prev_ss = None
        start_x = None

        for i in range(len(df_plot_sorted)):
            current_ss = df_plot_sorted.loc[i, 'SecondaryStructure']
            current_x = df_plot_sorted.loc[i, 'AlnIndex']

            # Draw a rectangle whenever the SS type changes
            if current_ss != prev_ss:
                if prev_ss is not None and prev_ss in ss_colors and start_x is not None:
                    color, alpha = ss_colors[prev_ss]
                    width = current_x - start_x
                    ax_main.add_patch(
                        plt.Rectangle(
                            (start_x, y_min),
                            width,
                            y_max - y_min,  # full vertical extent
                            facecolor=color,
                            alpha=alpha,
                            zorder=1
                        )
                    )
                start_x = current_x
                prev_ss = current_ss

        # Close out the final segment
        if prev_ss in ss_colors and start_x is not None:
            color, alpha = ss_colors[prev_ss]
            last_x = df_plot_sorted['AlnIndex'].iloc[-1]
            width = last_x - start_x  # Removed the +1e-9
            ax_main.add_patch(
                plt.Rectangle(
                    (start_x, y_min),
                    width,
                    y_max - y_min,
                    facecolor=color,
                    alpha=alpha,
                    zorder=1
                )
            )

    ########################################
    # (C) SECONDARY STRUCTURE SHAPES (TOP)
    ########################################
    box_height = 0.6
    box_bottom = y_max - box_height
    ax_main.add_patch(plt.Rectangle(
        (ax_main.get_xlim()[0], box_bottom),
        ax_main.get_xlim()[1] - ax_main.get_xlim()[0],
        box_height,
        facecolor='#f5f5f5',
        edgecolor='#d3d3d3',
        alpha=0.5,
        zorder=2
    ))

    y_pos = box_bottom + box_height / 2  # midline for shapes

    if 'SecondaryStructure' in df_plot.columns:
        df_plot_sorted = df_plot.sort_values(by='AlnIndex').reset_index(drop=True)
        prev_ss = None
        start_x = None

        for i in range(len(df_plot_sorted)):
            current_ss = df_plot_sorted.loc[i, 'SecondaryStructure']
            current_x = df_plot_sorted.loc[i, 'AlnIndex']

            if current_ss != prev_ss:
                # Draw shape for the old region
                if prev_ss is not None and start_x is not None:
                    width = current_x - start_x
                    if prev_ss == 'A':  # alpha-helix
                        x_helix, y_helix = create_helix(start_x, width)
                        ax_main.plot(x_helix, y_helix + y_pos,
                                     color='#800080', linewidth=2, zorder=3)
                    elif prev_ss == 'B':  # beta-strand
                        x_arrow, y_arrow = create_arrow(start_x, width)
                        ax_main.plot(x_arrow, y_arrow + y_pos,
                                     color='#008080', linewidth=2, zorder=3)
                    else:
                        # "Other"
                        ax_main.plot([start_x, start_x + width],
                                     [y_pos, y_pos],
                                     color='#808080', linewidth=2, zorder=3)

                start_x = current_x
                prev_ss = current_ss

        # Final shape
        if prev_ss is not None and start_x is not None:
            last_x = df_plot_sorted['AlnIndex'].iloc[-1]
            width = last_x - start_x
            if prev_ss == 'A':
                x_helix, y_helix = create_helix(start_x, width)
                ax_main.plot(x_helix, y_helix + y_pos,
                             color='#800080', linewidth=2, zorder=3)
            elif prev_ss == 'B':
                x_arrow, y_arrow = create_arrow(start_x, width)
                ax_main.plot(x_arrow, y_arrow + y_pos,
                             color='#008080', linewidth=2, zorder=3)
            else:
                ax_main.plot([start_x, start_x + width],
                             [y_pos, y_pos],
                             color='#808080', linewidth=2, zorder=3)

    ########################################
    # (D) MAIN PLOT LABELS & LEGENDS
    ########################################
    ax_main.set_title("Smoothed Frustration and B-Factor vs Residue Index", 
                      fontsize=24, fontweight='bold', pad=20)
    ax_main.set_xlabel("Residue Index", fontsize=20, fontweight='bold')
    ax_main.set_ylabel("Frustration & Normalized B-Factor", fontsize=20, fontweight='bold')

    metrics_legend = ax_main.legend(title="Metrics", fontsize=14, title_fontsize=16, 
                                    loc='lower left', frameon=True)
    frame = metrics_legend.get_frame()
    frame.set_facecolor('white')
    frame.set_edgecolor('black')
    frame.set_alpha(1)
    metrics_legend.set_zorder(10)

    ss_legend_elements = [
        plt.Line2D([0], [0], color='#800080', linewidth=2, label='α-helix'),
        plt.Line2D([0], [0], color='#008080', linewidth=2, label='β-sheet'),
        plt.Line2D([0], [0], color='#808080', linewidth=2, label='other')
    ]
    ss_legend = ax_main.legend(handles=ss_legend_elements, loc='lower right',
                               title='Secondary Structure', fontsize=14, title_fontsize=16,
                               frameon=True)
    ss_frame = ss_legend.get_frame()
    ss_frame.set_facecolor('white')
    ss_frame.set_edgecolor('black')
    ss_frame.set_alpha(1)
    ss_legend.set_zorder(10)
    
    ax_main.add_artist(metrics_legend)  # Ensure both legends remain visible

    ########################################
    # (E) SCATTER PLOTS (BOTTOM ROW)
    ########################################
    scatter_metrics = ["ExpFrust", "AFFrust", "EvolFrust"]
    scatter_colors = {
        "ExpFrust": "#E41A1C",
        "AFFrust": "#377EB8",
        "EvolFrust": "#4DAF4A"
    }

    for i, metric in enumerate(scatter_metrics):
        ax = fig.add_subplot(gs[1, i])
        
        sub = df_original.dropna(subset=["B_Factor", metric])
        if sub.empty:
            ax.text(0.5, 0.5, "No Data Available",
                    horizontalalignment='center',
                    verticalalignment='center',
                    transform=ax.transAxes, 
                    fontsize=14, color='red')
            continue
        
        sub = sub.copy()
        sub['B_Factor_Rank'] = sub['B_Factor'].rank(method='average')
        sub[f'{metric}_Rank'] = sub[metric].rank(method='average')
        
        # Spearman correlation
        rho, pval = spearmanr(sub['B_Factor'], sub[metric])

        # Scatter
        sns.scatterplot(
            x='B_Factor_Rank',
            y=f'{metric}_Rank',
            data=sub,
            color=scatter_colors.get(metric, "black"),
            alpha=0.7,
            ax=ax
        )

        # Optional regression line
        slope, intercept, r_value, p_value, std_err = linregress(
            sub['B_Factor_Rank'], sub[f'{metric}_Rank']
        )
        x_vals = np.array(ax.get_xlim())
        y_vals = intercept + slope * x_vals
        ax.plot(x_vals, y_vals, '--', color='gray', linewidth=2)

        # -------------
        # 1) Normal one-line title
        # -------------
        main_title = f"B_Factor Rank vs {metric} Rank"
        ax.set_title(main_title, fontsize=18, fontweight='bold', pad=35)

        # -------------
        # 2) Spearman correlation text, BELOW the title
        # -------------
        corr_text = f"Spearman ρ = {rho:.3f} (p={pval:.2e})"
        ax.text(
            0.5,   # x in axes coords
            1.05,  # y in axes coords (just above the plot area)
            corr_text,
            transform=ax.transAxes,
            ha='center',
            va='top',
            color=scatter_colors[metric],  # match metric color
            fontsize=16,
            fontweight='bold',
            clip_on=False,
            zorder=5
        )

        ax.set_xlabel("B-Factor Rank", fontsize=16, fontweight='bold')
        ax.set_ylabel(f"{metric} Rank", fontsize=16, fontweight='bold')
        ax.xaxis.set_major_locator(MaxNLocator(nbins=5))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
    
    ########################################
    # (F) FORCE THE X-LIMITS TO THE DATA RANGE
    ########################################
    data_x_min = df_plot["AlnIndex"].min()
    data_x_max = df_plot["AlnIndex"].max()
    ax_main.set_xlim(data_x_min, data_x_max)

    return fig

# Specify the path to your summary file
summary_file_path = ""

# Example:
# summary_file_path = "data/summary_test001.txt"

# **Parse and Process the Data**


# Define the LOWESS parameters
window_size = 5   # Not used in LOWESS but kept for compatibility
frac = 0.1        # The fraction of the data used when estimating each y-value
it = 3            # The number of robustifying iterations

# Parse the summary file
try:
    df_original, df_plot, corrs = parse_summary_file(summary_file_path, window_size=window_size, frac=frac, it=it)
    print("Data parsing and processing completed successfully.")
except Exception as e:
    print(f"An error occurred: {e}")


# Check if data was successfully parsed
if 'df_original' in locals() and not df_original.empty:
    fig = create_seaborn_figure(df_original, df_plot, corrs)
    plt.show()
else:
    print("No data available to plot.")


#**(Optional) Save the Figure**

# Uncomment the lines below to save the figure as a PNG file
# output_image_path = "path/to/save/figure.png"
# fig.savefig(output_image_path, dpi=300)
# print(f"Figure saved to {output_image_path}")

Script used to generate supplimental figures S1-S20, Figure 3, and Figure 5

In [None]:
import os  
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
from scipy.stats import spearmanr
from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
import matplotlib.gridspec as gridspec
from matplotlib.backends.backend_pdf import PdfPages  
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from scipy.stats import wilcoxon

########################################
# GLOBAL MAPPINGS AND HELPER FUNCTIONS #
########################################

# Map internal keys to display names
DISPLAY_MAP = {
    'ExpFrust_Experimental': 'Experimental Frustration',
    'ExpFrust_AlphaFold':   'AlphaFold Frustration',
    'EvolFrust':            'Evolutionary Frustration'
}

def remap_legend(ax, mapping, **legend_kwargs):
    """
    Re-labels the legend entries on `ax` according to `mapping`.
    Keeps any suffix (e.g., " (AUC=0.82)") intact.
    """
    handles, labels = ax.get_legend_handles_labels()
    new_labels = []
    for lbl in labels:
        key = lbl.split()[0]
        rest = lbl[len(key):]
        new_labels.append(mapping.get(key, key) + rest)
    ax.legend(handles, new_labels, **legend_kwargs)

########################################
# 1) BASIC SETUP AND HELPER FUNCTIONS  #
########################################

def read_frustration_file(filepath, file_type='summary'):
    """
    Read and process frustration data from a summary file in the new format.
    
    New summary file format (tab-separated):
        AlnIndex, Residue, SecondaryStructure, B_Factor, ExpFrust, AFFrust, EvolFrust, ...
    
    For plotting purposes:
      - Experimental Frustration is taken from the ExpFrust column.
      - AlphaFold Frustration is taken from the AFFrust column (and mapped to ExpFrust).
      - Evolutionary Frustration is taken from the EvolFrust column.
      
    Both frustration DataFrames use the same B_Factor.
    """
    if file_type == 'summary':
        df = pd.read_csv(filepath, sep=',', na_values=['n/a'])
        
        # For Experimental Frustration: use columns: AlnIndex, Residue, SecondaryStructure, B_Factor, ExpFrust
        exp_columns = {
            'AlnIndex': 'AlnIndex',
            'Residue': 'Residue',
            'SecondaryStructure': 'SecondaryStructure',
            'B_Factor': 'B_Factor',
            'ExpFrust': 'ExpFrust'
        }
        exp_present = [col for col in exp_columns if col in df.columns]
        exp_df = df[exp_present].rename(columns=exp_columns)
        for v in exp_columns.values():
            if v not in exp_df.columns:
                exp_df[v] = np.nan
        
        # For AlphaFold Frustration: use columns: AlnIndex, Residue, SecondaryStructure, B_Factor, AFFrust (mapped to ExpFrust)
        af_columns = {
            'AlnIndex': 'AlnIndex',
            'Residue': 'Residue',
            'SecondaryStructure': 'SecondaryStructure',
            'B_Factor': 'B_Factor',
            'AFFrust': 'ExpFrust'
        }
        af_present = [col for col in af_columns if col in df.columns]
        af_df = df[af_present].rename(columns={k: v for k, v in af_columns.items() if k in af_present})
        for v in af_columns.values():
            if v not in af_df.columns:
                af_df[v] = np.nan
        
        # Extract Evolutionary Frustration from EvolFrust column (if available)
        if 'EvolFrust' in df.columns:
            evol_frust = pd.to_numeric(df['EvolFrust'], errors='coerce')
        else:
            evol_frust = pd.Series([np.nan] * len(df))
        
        # Ensure numeric columns are numeric
        numeric_cols = ['B_Factor', 'ExpFrust']
        for col in numeric_cols:
            exp_df[col] = pd.to_numeric(exp_df[col], errors='coerce')
            af_df[col] = pd.to_numeric(af_df[col], errors='coerce')
        
        return exp_df, af_df, evol_frust
    else:
        raise ValueError("Unsupported file type. Only 'summary' is supported.")

def lowess_smoothing(x, y, frac=0.1, it=3):
    """
    Apply LOWESS smoothing to the data.
    """
    mask = ~(pd.isna(x) | pd.isna(y))
    x_clean = x[mask]
    y_clean = y[mask]
    if len(x_clean) == 0:
        return np.array([]), np.array([])
    lowess = sm.nonparametric.lowess
    z = lowess(y_clean, x_clean, frac=frac, it=it, return_sorted=False)
    return x_clean, z

def create_gradient_line(x, y, values, cmap, linestyle='-', linewidth=3):
    """
    Create a gradient line as a collection of segments.
    """
    if len(x) < 2:
        return None
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    lc = LineCollection(segments, cmap=cmap, linestyle=linestyle, linewidth=linewidth)
    lc.set_array(values[:-1])
    return lc

def create_dashed_gradient_line(x, y, values, cmap, linewidth=3, dash_on=10, dash_off=5):
    """
    Create a single dashed gradient line.
    """
    if len(x) < 2:
        return None
    x = np.asarray(x)
    y = np.asarray(y)
    v = np.asarray(values)
    dx = np.diff(x)
    dy = np.diff(y)
    seg_lengths = np.sqrt(dx*dx + dy*dy)
    dist = np.concatenate(([0], np.cumsum(seg_lengths)))
    def color_at_distance(d_val):
        return np.interp(d_val, dist, v)
    pattern_length = dash_on + dash_off
    def get_on_subsegments(s1, s2):
        segments_on = []
        current = s1
        while current < s2:
            cycle_pos = (current % pattern_length)
            cycle_on_end = current - cycle_pos + dash_on
            if cycle_on_end <= current:
                next_cycle_start = current - cycle_pos + pattern_length
                current = next_cycle_start
                continue
            seg_start = current
            seg_end = min(cycle_on_end, s2)
            if seg_end > seg_start:
                segments_on.append((seg_start, seg_end))
            current = seg_end
            cycle_off_end = current - (current % pattern_length) + pattern_length
            if cycle_off_end < current:
                cycle_off_end += pattern_length
            current = max(current, min(cycle_off_end, s2))
        return segments_on
    all_on_segments = []
    color_values = []
    for i in range(len(x) - 1):
        s1 = dist[i]
        s2 = dist[i+1]
        if s2 == s1:
            continue
        on_subs = get_on_subsegments(s1, s2)
        if not on_subs:
            continue
        for (s_on_start, s_on_end) in on_subs:
            t1 = (s_on_start - s1) / (s2 - s1)
            x1 = x[i] + t1 * (x[i+1] - x[i])
            y1 = y[i] + t1 * (y[i+1] - y[i])
            t2 = (s_on_end - s1) / (s2 - s1)
            x2 = x[i] + t2 * (x[i+1] - x[i])
            y2 = y[i] + t2 * (y[i+1] - y[i])
            mid = 0.5*(s_on_start + s_on_end)
            c_mid = color_at_distance(mid)
            all_on_segments.append([[x1, y1], [x2, y2]])
            color_values.append(c_mid)
    if not all_on_segments:
        return None
    lc = LineCollection(
        all_on_segments,
        cmap=cmap,
        norm=plt.Normalize(v.min(), v.max()),
        linewidth=linewidth,
        linestyles='solid'
    )
    lc.set_array(np.array(color_values))
    return lc

def create_custom_cmap(vmin, vmax):
    """
    Create a custom colormap that transitions through gray at zero.
    """
    total = abs(vmin) + abs(vmax)
    zero_pos = abs(vmin) / total if total != 0 else 0.5
    colors = [(0, '#0c1359'), (zero_pos, '#D0D0D0'), (1, '#f05b05')]
    return LinearSegmentedColormap.from_list("custom", colors, N=100)

def create_helix(x_start, width, height=0.5, frequency=2):
    """
    Create a helix representation for alpha helices.
    """
    num_points = int(width * 20)
    x = np.linspace(x_start, x_start + width, num_points)
    y = height * np.sin(2 * np.pi * frequency * (x - x_start) / width)
    return x, y

def create_arrow(x_start, width, height=0.5):
    """
    Create an arrow representation for beta sheets.
    """
    x = [x_start, x_start,
         x_start, x_start + 0.7*width,
         x_start + 0.7*width, x_start + width,
         x_start + 0.7*width,
         x_start + 0.7*width, x_start,
         x_start]
    y = [-height/2, height/2,
         -height/2, -height/2,
         -height/2, 0,
         height/2,
         height/2, height/2,
         -height/2]
    return x, y

def create_scatter_subplot(ax, x_data, y_data, color, title, xlabel, ylabel, marker='o'):
    """
    Create a scatter plot with rank correlation.
    """
    mask = ~(pd.isna(x_data) | pd.isna(y_data))
    x_clean = x_data[mask]
    y_clean = y_data[mask]
    if len(x_clean) < 2:
        ax.text(0.5, 0.5, "Insufficient data", ha='center', va='center', transform=ax.transAxes)
        ax.set_title(title, fontsize=16, pad=20)
        return
    try:
        x_rank = x_clean.rank()
        y_rank = y_clean.rank()
        rho, pval = spearmanr(x_clean, y_clean)
        sns.scatterplot(x=x_rank, y=y_rank, ax=ax, color=color, alpha=0.6, marker=marker, linewidth=2, s=100)
        if len(x_rank.unique()) > 1 and len(y_rank.unique()) > 1:
            sns.regplot(x=x_rank, y=y_rank, ax=ax, scatter=False, color='gray', 
                        line_kws={'linestyle': '--', 'alpha': 0.8})
        corr_text = f"ρ = {rho:.3f}\np = {pval:.2e}"
        ax.text(0.05, 0.95, corr_text, transform=ax.transAxes, verticalalignment='top', fontsize=12,
                color='black', bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
        ax.set_title(title, fontsize=16, pad=20)
        ax.set_xlabel(xlabel, fontsize=14)
        ax.set_ylabel(ylabel, fontsize=14)
        ax.tick_params(labelsize=12)
    except Exception as e:
        print(f"Warning: Error in scatter plot creation: {e}")
        ax.text(0.5, 0.5, "Error in plot creation", ha='center', va='center', transform=ax.transAxes)

########################################
# 2) MAIN PLOTTING FUNCTION           #
########################################

def plot_frustration_comparison(summary_filepath, 
                                box_height_ratio=0.05, 
                                spacing_ratio=0.075, 
                                additional_space_ratio=0.30, 
                                box_padding_ratio=0.02, 
                                legend_separation_ratio=-0.05):
    """
    Create a comprehensive plot comparing protein frustration data.
    
    The main plot (row 0) shows the LOWESS-smoothed frustration curves:
      - Solid line: Experimental Frustration (from ExpFrust)
      - Dashed line: AlphaFold Frustration (from AFFrust)
      - Dotted line: Evolutionary Frustration (from EvolFrust)
    
    Rows 1–3 show scatter plots of each frustration metric (ranked) vs. the B-Factor (ranked).
    
    Row 4 shows the summary Spearman correlation for each metric.
    
    Row 5 shows the normalized smoothed B-Factor.
    
    Row 6 displays ROC and Precision–Recall analyses (using quartile‐based binary classifications).
    """
    sns.set_style("whitegrid")
    plt.rcParams.update({
        'figure.figsize': (20, 60),
        'font.size': 14,
        'axes.labelsize': 14,
        'axes.titlesize': 16
    })
    
    # Read frustration data (Experimental and AlphaFold)
    exp_data, af_data, evol_frust = read_frustration_file(summary_filepath)
    
    # Merge Experimental and AlphaFold data on AlnIndex.
    merged_data = exp_data.merge(af_data, on='AlnIndex', suffixes=('_Experimental', '_AlphaFold'))
    merged_data['EvolFrust'] = evol_frust
    # Use Experimental B_Factor.
    if "B_Factor" not in merged_data.columns and "B_Factor_Experimental" in merged_data.columns:
        merged_data["B_Factor"] = merged_data["B_Factor_Experimental"]
    
    # Create a complete-data mask.
    complete_data_mask = (
        ~merged_data['ExpFrust_Experimental'].isna() &
        ~merged_data['ExpFrust_AlphaFold'].isna() &
        ~merged_data['EvolFrust'].isna() &
        ~merged_data['B_Factor'].isna()
    )
    merged_data_filtered = merged_data[complete_data_mask]
    if merged_data_filtered.empty or len(merged_data_filtered) < 5:
        raise ValueError("Insufficient complete data to generate plot.")
    
    # Set up grid: 7 rows x 2 columns.
    nrows = 7
    grid_cols = 2
    height_ratios = [3, 2, 2, 2, 3, 2, 3]  # row6 will hold ROC/PR analysis
    gs = gridspec.GridSpec(nrows, grid_cols, height_ratios=height_ratios, wspace=0.3, hspace=0.4)
    
    fig = plt.figure(figsize=(20, 60))
    # Row 0: Main frustration curves.
    ax_main = fig.add_subplot(gs[0, :])
    # Row 4: Summary correlation plot.
    ax_corr_summary = fig.add_subplot(gs[4, :])
    
    # LOWESS smoothing for frustration metrics and B_Factor.
    exp_x, exp_smooth = lowess_smoothing(merged_data_filtered['AlnIndex'], merged_data_filtered['ExpFrust_Experimental'])
    af_x, af_smooth = lowess_smoothing(merged_data_filtered['AlnIndex'], merged_data_filtered['ExpFrust_AlphaFold'])
    evol_x, evol_smooth = lowess_smoothing(merged_data_filtered['AlnIndex'], merged_data_filtered['EvolFrust'])
    bf_x, bf_smooth = lowess_smoothing(merged_data_filtered['AlnIndex'], merged_data_filtered['B_Factor'])
    
    default_y_min, default_y_max = -2, 2
    all_y = np.concatenate([exp_smooth, af_smooth, evol_smooth])
    finite_mask = np.isfinite(all_y)
    try:
        if np.any(finite_mask):
            y_min = float(np.nanmin(all_y[finite_mask]))
            y_max = float(np.nanmax(all_y[finite_mask]))
            if not (np.isfinite(y_min) and np.isfinite(y_max)):
                y_min, y_max = default_y_min, default_y_max
        else:
            y_min, y_max = default_y_min, default_y_max
        y_range = y_max - y_min
        y_padding = y_range * 0.05
        plot_y_min = y_min - y_padding
        plot_y_max = y_max + y_padding + additional_space_ratio * y_range
        if legend_separation_ratio < 0:
            plot_y_min += y_range * legend_separation_ratio
        elif legend_separation_ratio > 0:
            plot_y_max += y_range * legend_separation_ratio
        if not (np.isfinite(plot_y_min) and np.isfinite(plot_y_max)):
            plot_y_min, plot_y_max = default_y_min, default_y_max
    except Exception as e:
        print(f"Error calculating plot limits: {e}")
        plot_y_min, plot_y_max = default_y_min, default_y_max

    ax_main.set_ylim(plot_y_min, plot_y_max)
    
    # Create custom colormaps.
    cmap_exp = create_custom_cmap(exp_smooth.min(), exp_smooth.max())
    cmap_af = create_custom_cmap(af_smooth.min(), af_smooth.max())
    cmap_evol = create_custom_cmap(evol_smooth.min(), evol_smooth.max())
    
    exp_line = create_gradient_line(exp_x, exp_smooth, exp_smooth, cmap_exp, linestyle='-', linewidth=4)
    if exp_line:
        ax_main.add_collection(exp_line)
    af_line = create_dashed_gradient_line(af_x, af_smooth, af_smooth, cmap_af, linewidth=4, dash_on=2, dash_off=2)
    if af_line:
        ax_main.add_collection(af_line)
    evol_line = create_gradient_line(evol_x, evol_smooth, evol_smooth, cmap_evol, linestyle=':', linewidth=2)
    if evol_line:
        ax_main.add_collection(evol_line)
    
    x_min = merged_data_filtered['AlnIndex'].min()
    x_max = merged_data_filtered['AlnIndex'].max()
    ax_main.set_xlim(x_min, x_max)
    
    ax_main.set_title('Protein Frustration Comparison', fontsize=24, fontweight='bold', pad=20)
    ax_main.set_xlabel('Residue Number', fontsize=20, fontweight='bold')
    ax_main.set_ylabel('Frustration', fontsize=20, fontweight='bold')
    
    # Add legends for frustration types and levels.
    legends = []
    line_style_legend = [
        Line2D([0], [0], color='black', linestyle='-', linewidth=4, label='Experimental Frustration'),
        Line2D([0], [0], color='black', linestyle='--', linewidth=4, label='AlphaFold Frustration'),
        Line2D([0], [0], color='black', linestyle=':', linewidth=2, label='Evolutionary Frustration')
    ]
    legends.append(('Frustration Types', line_style_legend))
    frustration_level_legend = [
        Line2D([0], [0], color='#0c1359', label='Minimally Frustrated', linewidth=3),
        Line2D([0], [0], color='#D0D0D0', label='Neutral', linewidth=3),
        Line2D([0], [0], color='#f05b05', label='Highly Frustrated', linewidth=3)
    ]
    legends.append(('Frustration Level', frustration_level_legend))
    num_legends = len(legends)
    spacing = 1.0/(num_legends+1)
    legend_y = 0.02
    for i, (title, handles) in enumerate(legends):
        x_pos = spacing*(i+1)
        legend = ax_main.legend(handles=handles, title=title, fontsize=14, title_fontsize=16,
                                loc='lower center', bbox_to_anchor=(x_pos, legend_y),
                                frameon=True, ncol=1)
        legend.get_frame().set_facecolor('white')
        legend.get_frame().set_edgecolor('black')
        ax_main.add_artist(legend)
    
    # Define colors for scatter plots – using the established Spearman colors.
    category_colors = {
        'Experimental Frustration': '#e41a1c',  # dark red
        'AlphaFold Frustration': '#377eb8',      # red
        'Evolutionary Frustration': '#4DAF4A'      # green
    }


    # ---------------------------
    # Scatter Plots for Each Frustration Metric (Rows 1–3)
    # ---------------------------
    bf_rank = merged_data_filtered['B_Factor'].rank()
    def plot_scatter_row(row_index, metric_series, metric_label, color):
        ax_scatter = fig.add_subplot(gs[row_index, :])
        create_scatter_subplot(ax_scatter, bf_rank, metric_series.rank(), color,
                               f'{metric_label} vs B-Factor',
                               'B-Factor Rank', metric_label)
    plot_scatter_row(1, merged_data_filtered['ExpFrust_Experimental'], 'Experimental Frustration', category_colors['Experimental Frustration'])
    plot_scatter_row(2, merged_data_filtered['ExpFrust_AlphaFold'], 'AlphaFold Frustration', category_colors['AlphaFold Frustration'])
    plot_scatter_row(3, merged_data_filtered['EvolFrust'], 'Evolutionary Frustration', category_colors['Evolutionary Frustration'])
    
    # ---------------------------
    # Summary Correlation Plot (Row 4)
    # ---------------------------
    summary_correlations = []
    metrics_dict = {
        'Experimental Frustration': merged_data_filtered['ExpFrust_Experimental'],
        'AlphaFold Frustration': merged_data_filtered['ExpFrust_AlphaFold'],
        'Evolutionary Frustration': merged_data_filtered['EvolFrust']
    }
    for metric_name, metric_series in metrics_dict.items():
        if not metric_series.isna().all() and not merged_data_filtered['B_Factor'].isna().all():
            rho, pval = spearmanr(metric_series, merged_data_filtered['B_Factor'])
            summary_correlations.append({'Metric': metric_name, 'Spearman_rho': rho, 'pval': pval})
    x_positions = {'Experimental Frustration': 1, 'AlphaFold Frustration': 2, 'Evolutionary Frustration': 3}
    for corr in summary_correlations:
        x_val = x_positions[corr['Metric']]
        y_val = corr['Spearman_rho']
        pval = corr['pval']
        ax_corr_summary.scatter(x_val, y_val, c=[category_colors[corr['Metric']]], marker='o', s=200, linewidth=2)
        if pval < 0.05:
            ax_corr_summary.scatter(x_val, y_val, facecolors='none', edgecolors='black', linewidth=2, s=500, marker='s', zorder=6)
    ax_corr_summary.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax_corr_summary.grid(True, alpha=0.3)
    ax_corr_summary.set_xlim(0.5, 3.5)
    if summary_correlations:
        y_min_corr = min(corr['Spearman_rho'] for corr in summary_correlations)
        y_max_corr = max(corr['Spearman_rho'] for corr in summary_correlations)
        y_padding_corr = (y_max_corr - y_min_corr) * 0.1 if (y_max_corr - y_min_corr) != 0 else 1
    else:
        y_min_corr, y_max_corr = -1, 1
        y_padding_corr = 0.1
    ax_corr_summary.set_ylim(y_min_corr - y_padding_corr, y_max_corr + y_padding_corr)
    ax_corr_summary.set_xticks([1, 2, 3])
    ax_corr_summary.set_xticklabels(['Experimental Frustration', 'AlphaFold Frustration', 'Evolutionary Frustration'],
                                    fontsize=12, ha='center')
    ax_corr_summary.set_ylabel("Spearman's ρ", fontsize=16)
    ax_corr_summary.spines['top'].set_visible(False)
    ax_corr_summary.spines['right'].set_visible(False)
    ax_corr_summary.yaxis.set_ticks_position('left')
    ax_corr_summary.set_title('Summary of B-Factor Correlations', fontsize=18, pad=20)
    
    # ---------------------------
    # Normalized Smoothed B-Factor Plot (Row 5)
    # ---------------------------
    ax_bf_norm = fig.add_subplot(gs[5, :])
    def normalize_series(series):
        min_val = series.min()
        max_val = series.max()
        if max_val - min_val == 0:
            return pd.Series([0.5] * len(series), index=series.index)
        return (series - min_val) / (max_val - min_val)
    bf_normalized = normalize_series(pd.Series(bf_smooth, index=bf_x.index))
    ax_bf_norm.plot(bf_x, bf_normalized, label='Normalized B-Factor', color='blue', linewidth=2)
    ax_bf_norm.set_title('Normalized Smoothed B-Factor', fontsize=18, pad=20)
    ax_bf_norm.set_xlabel('Residue Number', fontsize=14)
    ax_bf_norm.set_ylabel('Normalized B-Factor', fontsize=14)
    ax_bf_norm.legend(loc='upper right', fontsize=12)
    ax_bf_norm.grid(True, alpha=0.3)
    
    # ---------------------------
    # ROC and Precision–Recall Analysis (Row 6)
    # ---------------------------
    ax_roc = fig.add_subplot(gs[6, 0])
    ax_pr = fig.add_subplot(gs[6, 1])
    
    # Use quartiles to define extremes. For each frustration metric, binary classification on B_Factor:
    # Label 1 if B_Factor >= top quartile; 0 if B_Factor <= bottom quartile.
    # Use the frustration metric’s continuous value as the score.
    roc_summary = []
    pr_summary = []
    roc_metrics = {
        'ExpFrust_Experimental': merged_data_filtered['ExpFrust_Experimental'],
        'ExpFrust_AlphaFold':    merged_data_filtered['ExpFrust_AlphaFold'],
        'EvolFrust':             merged_data_filtered['EvolFrust']
    }
    b_series = merged_data_filtered['B_Factor']
    for fkey, f_series in roc_metrics.items():
        b_low = b_series.quantile(0.25)
        b_high = b_series.quantile(0.75)
        b_mask = (b_series <= b_low) | (b_series >= b_high)
        valid_mask = b_mask
        if valid_mask.sum() < 5:
            continue
        score = f_series[valid_mask]
        truth = np.where(b_series[valid_mask] >= b_high, 1, 0)
        fpr, tpr, _ = roc_curve(truth, score, pos_label=1)
        roc_auc = auc(fpr, tpr)
        color = {'ExpFrust_Experimental': '#e41a1c',
                 'ExpFrust_AlphaFold':    '#377eb8',
                 'EvolFrust':             '#4DAF4A'}[fkey]
        ax_roc.plot(fpr, tpr, color=color, linestyle='-', linewidth=2,
                    label=f"{fkey} (AUC={roc_auc:.2f})")
        roc_summary.append({'frustration_metric': fkey, 'roc_auc': roc_auc})
        
        # ——— HIGH-B-FACTOR as positive class ———
        truth_high = np.where(b_series[valid_mask] >= b_high, 1, 0)
        score_high = score
        prec_high, rec_high, _ = precision_recall_curve(truth_high, score_high, pos_label=1)
        ap_high = average_precision_score(truth_high, score_high)
        ax_pr.plot(rec_high, prec_high,
                   color=color, linestyle='-', linewidth=2,
                   label=f"{fkey} high-B AP={ap_high:.2f}")
        pr_summary.append({'frustration_metric': fkey, 'pr_ap': ap_high})

        # ——— LOW-B-FACTOR as positive class ———
        truth_low = np.where(b_series[valid_mask] <= b_low, 1, 0)
        score_low = -score_high   # invert so “higher” score → more likely low-B
        prec_low, rec_low, _ = precision_recall_curve(truth_low, score_low, pos_label=1)
        ap_low = average_precision_score(truth_low, score_low)
        ax_pr.plot(rec_low, prec_low,
                   color=color, linestyle='--', linewidth=2,
                   label=f"{fkey} low-B AP={ap_low:.2f}")
        pr_summary.append({'frustration_metric': fkey + '_lowB', 'pr_ap': ap_low})
    
    # Add fixed baseline at 0.5 in the PR plot.
    ax_pr.axhline(y=0.5, color='grey', linestyle='--', linewidth=2, label='PR Baseline')
    
    # Finalize ROC plot
    ax_roc.plot([0, 1], [0, 1], color='gray', linestyle='--')
    ax_roc.set_xlabel('False Positive Rate', fontsize=14)
    ax_roc.set_ylabel('True Positive Rate', fontsize=14)
    ax_roc.set_title('ROC Curves: Frustration predicting B-Factor extremes', fontsize=16)
    ax_roc.legend(fontsize=10)
    remap_legend(ax_roc, DISPLAY_MAP, fontsize=10)
    ax_roc.grid(True, alpha=0.3)
    
    # Finalize PR plot
    ax_pr.set_xlabel('Recall', fontsize=14)
    ax_pr.set_ylabel('Precision', fontsize=14)
    ax_pr.set_title('Precision–Recall Curves: Frustration predicting B-Factor extremes', fontsize=16)
    ax_pr.legend(fontsize=10)
    remap_legend(ax_pr, DISPLAY_MAP, fontsize=10)
    ax_pr.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Return ROC/PR summaries for final summary figure.
    return fig, (roc_summary, pr_summary)

########################################
# 3) PROCESSING ALL SUBDIRECTORIES     #
########################################

def process_all_subdirectories(root_dir,
                               summary_filename="summary.txt",
                               box_height_ratio=0.05,
                               spacing_ratio=0.15,
                               additional_space_ratio=0.295,
                               box_padding_ratio=0.05,
                               legend_separation_ratio=-0.75):
    """
    Iterate through all immediate subdirectories of 'root_dir'. For each subdirectory
    containing 'summary_filename', generate a frustration-comparison plot and save it
    as '{PDBID}_frustration_comparison.pdf' inside that subdirectory, with a suptitle
    'Figure S#. {PDBID}'. Then aggregate all metrics into a single 'all_plots.pdf'.
    """
    big_figures = []
    all_roc_summary = []
    all_pr_summary = []
    counter = 1

    for entry in os.listdir(root_dir):
        subdir_path = os.path.join(root_dir, entry)
        if not os.path.isdir(subdir_path):
            continue

        # Remove any old 'frustration' PDFs
        for fn in os.listdir(subdir_path):
            if fn.lower().endswith('.pdf') and "frustration" in fn.lower():
                try:
                    os.remove(os.path.join(subdir_path, fn))
                except Exception as err:
                    print(f"Error removing file {fn}: {err}")

        summary_filepath = os.path.join(subdir_path, summary_filename)
        if not os.path.exists(summary_filepath):
            print(f"Skipping '{entry}': summary file not found.")
            continue

        try:
            print(f"Processing '{entry}'...")
            fig, (roc_summary, pr_summary) = plot_frustration_comparison(
                summary_filepath,
                box_height_ratio=box_height_ratio,
                spacing_ratio=spacing_ratio,
                additional_space_ratio=additional_space_ratio,
                box_padding_ratio=box_padding_ratio,
                legend_separation_ratio=legend_separation_ratio
            )

            # Updated suptitle with numbering
            fig.suptitle(f"Figure S{counter}. {entry}",
                         fontsize=20, fontweight='bold', y=0.9)
            counter += 1

            output_filename = f"{entry}_frustration_comparison.pdf"
            output_path = os.path.join(subdir_path, output_filename)
            fig.savefig(output_path, dpi=600, bbox_inches='tight')
            print(f"Plot saved successfully at: {output_path}")

            big_figures.append(fig)
            # tag cluster for summary tables
            for d in roc_summary:
                d['cluster'] = entry
            for d in pr_summary:
                d['cluster'] = entry
            all_roc_summary.extend(roc_summary)
            all_pr_summary.extend(pr_summary)

        except Exception as e:
            print(f"Skipping '{entry}' due to error: {e}")

    # Build the final summary PDF if we have any figures
    if big_figures:
        roc_df = pd.DataFrame(all_roc_summary)
        pr_df  = pd.DataFrame(all_pr_summary)

        # Split PR into high-B and low-B
        pr_high_df = pr_df[~pr_df['frustration_metric'].str.endswith('_lowB')].copy()
        pr_low_df  = pr_df[ pr_df['frustration_metric'].str.endswith('_lowB')].copy()
        pr_low_df['frustration_metric'] = pr_low_df['frustration_metric'].str.replace('_lowB', '', regex=False)

        # Rename codes for plotting
        rename_map = {
            'ExpFrust_Experimental': 'ExpFrust',
            'ExpFrust_AlphaFold':    'AFFrust',
            'EvolFrust':             'EvolFrust'
        }
        roc_df['frustration_metric']      = roc_df['frustration_metric'].replace(rename_map)
        pr_high_df['frustration_metric']  = pr_high_df['frustration_metric'].replace(rename_map)
        pr_low_df['frustration_metric']   = pr_low_df['frustration_metric'].replace(rename_map)

        color_map = {
            'ExpFrust':  '#8B0000',
            'AFFrust':   '#FF4444',
            'EvolFrust': '#4DAF4A'
        }

        # Create 3-row summary figure
        fig_summary = plt.figure(figsize=(20, 15))

        # 1) ROC AUC
        ax_roc = fig_summary.add_subplot(311)
        sns.scatterplot(data=roc_df, x='cluster', y='roc_auc',
                        hue='frustration_metric', palette=color_map,
                        s=100, ax=ax_roc)
        ax_roc.set_title('ROC AUC by Protein and Frustration Metric', fontsize=16)
        ax_roc.set_xlabel('Protein PDB_ID')
        ax_roc.set_ylabel('ROC AUC')
        for metric in roc_df['frustration_metric'].unique():
            avg = roc_df.loc[roc_df['frustration_metric']==metric, 'roc_auc'].mean()
            ax_roc.axhline(avg, color=color_map[metric], linestyle='--', linewidth=2,
                           label=f"{metric} mean AUC")
        ax_roc.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)

        # 2) PR high-B
        ax_pr_high = fig_summary.add_subplot(312)
        sns.scatterplot(data=pr_high_df, x='cluster', y='pr_ap',
                        hue='frustration_metric', palette=color_map,
                        s=100, ax=ax_pr_high, legend='brief')
        ax_pr_high.set_title('PR Average Precision (High B-factor) by Protein and Frustration Metric', fontsize=16)
        ax_pr_high.set_xlabel('Protein PDB_ID')
        ax_pr_high.set_ylabel('PR Average Precision')
        ax_pr_high.axhline(0.5, color='gray', linestyle='--', linewidth=2, label='Baseline 0.5')
        for metric in pr_high_df['frustration_metric'].unique():
            avg = pr_high_df.loc[pr_high_df['frustration_metric']==metric, 'pr_ap'].mean()
            ax_pr_high.axhline(avg, color=color_map[metric], linestyle=':', linewidth=2,
                               label=f"{metric} mean PR")
        # Wilcoxon
        wilcox_texts = ["Wilcoxon p-values (AP vs 0.5):"]
        for metric in pr_high_df['frustration_metric'].unique():
            vals = pr_high_df.loc[pr_high_df['frustration_metric']==metric, 'pr_ap']
            stat, p = wilcoxon(vals - 0.5)
            wilcox_texts.append(f"{metric}: p={p:.3f}")
        ax_pr_high.text(1.05, 0.6, "\n".join(wilcox_texts),
                        transform=ax_pr_high.transAxes, fontsize=10,
                        va='top', ha='left')
        ax_pr_high.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)

        # 3) PR low-B
        ax_pr_low = fig_summary.add_subplot(313)
        sns.scatterplot(data=pr_low_df, x='cluster', y='pr_ap',
                        hue='frustration_metric', palette=color_map,
                        s=100, ax=ax_pr_low, legend='brief')
        ax_pr_low.set_title('PR Average Precision (Low B-factor) by Protein and Frustration Metric', fontsize=16)
        ax_pr_low.set_xlabel('Protein PDB_ID')
        ax_pr_low.set_ylabel('PR Average Precision')
        ax_pr_low.axhline(0.5, color='gray', linestyle='--', linewidth=2, label='Baseline 0.5')
        for metric in pr_low_df['frustration_metric'].unique():
            avg = pr_low_df.loc[pr_low_df['frustration_metric']==metric, 'pr_ap'].mean()
            ax_pr_low.axhline(avg, color=color_map[metric], linestyle=':', linewidth=2,
                              label=f"{metric} mean PR")
        wilcox_texts_low = ["Wilcoxon p-values (AP vs 0.5):"]
        for metric in pr_low_df['frustration_metric'].unique():
            vals = pr_low_df.loc[pr_low_df['frustration_metric']==metric, 'pr_ap']
            stat, p = wilcoxon(vals - 0.5)
            wilcox_texts_low.append(f"{metric}: p={p:.3f}")
        ax_pr_low.text(1.05, 0.6, "\n".join(wilcox_texts_low),
                       transform=ax_pr_low.transAxes, fontsize=10,
                       va='top', ha='left')
        ax_pr_low.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
        # 4) Histogram of Spearman correlations between EvolFrust and ExpFrust
        spearman_vals = []
        for fig, entry in zip(big_figures, [d['cluster'] for d in all_roc_summary if d['frustration_metric'] == 'EvolFrust']):
            subdir_path = os.path.join(root_dir, entry)
            summary_path = os.path.join(subdir_path, summary_filename)
            if not os.path.exists(summary_path):
                continue
            try:
                exp_df, af_df, evol = read_frustration_file(summary_path)
                merged = exp_df.copy()
                merged['EvolFrust'] = evol
                mask = ~merged['ExpFrust'].isna() & ~merged['EvolFrust'].isna()
                if mask.sum() > 2:
                    rho, _ = spearmanr(merged['ExpFrust'][mask], merged['EvolFrust'][mask])
                    if np.isfinite(rho):
                        spearman_vals.append(rho)
            except Exception as e:
                print(f"Skipping Spearman histogram for {entry}: {e}")

        # Save histogram to separate file
        if spearman_vals:
            fig_hist = plt.figure(figsize=(12, 5))
            ax_hist = fig_hist.add_subplot(111)
            sns.histplot(spearman_vals, bins=20, kde=False, color="#4DAF4A", edgecolor='black', ax=ax_hist)
            ax_hist.set_title("Spearman Correlation: Evolutionary vs Experimental Frustration", fontsize=16)
            ax_hist.set_xlabel("Spearman's ρ", fontsize=14)
            ax_hist.set_ylabel("Frequency", fontsize=14)
            mean_rho = np.mean(spearman_vals)
            ax_hist.axvline(mean_rho, color='black', linestyle='--', linewidth=2)
            ax_hist.text(0.95, 0.95,
                         f"Mean = {mean_rho:.2f}",
                         transform=ax_hist.transAxes,
                         ha='right', va='top',
                         fontsize=12,
                         bbox=dict(facecolor='white', edgecolor='none', boxstyle='round,pad=0.3'))
            hist_path = os.path.join(root_dir, "evol_vs_exp_spearman_histogram.pdf")
            fig_hist.savefig(hist_path, dpi=600, bbox_inches='tight')
            plt.close(fig_hist)
            print(f"Spearman histogram saved to: {hist_path}")
        plt.tight_layout()
        big_figures.append(fig_summary)

        # Save all into one PDF
        big_pdf_path = os.path.join(root_dir, "all_plots.pdf")
        with PdfPages(big_pdf_path) as pdf:
            for fig in big_figures:
                pdf.savefig(fig, bbox_inches='tight')
                plt.close(fig)
        print(f"All plots saved in one PDF at: {big_pdf_path}")

    else:
        print("No figures to save in the big PDF.")

########################################
# 4) MAIN EXECUTION BLOCK              #
########################################

if __name__ == "__main__":
    root_directory = ""  # Change this path as needed
    process_all_subdirectories(root_directory)

# The following section contains the scripts used to analyze the 20 highly flexible proteins (20F)

Script to calculate average B-factor for the two separate pdb files (one for each conformation), labeled by PDB ID and chain

In [None]:
import os
import re
import Bio.PDB
import numpy as np

def calculate_average_b_factors(pdb_path, output_txt_path):
    """
    Calculate the average B-factors for each residue in the PDB structure and save them to a file,
    including the one-letter residue name, indexed starting from 1.

    Parameters:
    - pdb_path: str, path to the input PDB file.
    - output_txt_path: str, path to the output text file.
    """
    # Mapping three-letter residue names to one-letter codes
    one_letter_code = {
        "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
        "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
        "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
        "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V",
        # Handle uncommon residues with a placeholder
        "UNK": "X"
    }

    parser = Bio.PDB.PDBParser(QUIET=True)
    try:
        structure = parser.get_structure("protein", pdb_path)
        b_factors = []
        residue_names = []

        for model in structure:
            for chain in model:
                for residue in chain:
                    res_id = residue.get_id()[1]
                    res_name = residue.get_resname()
                    b_factor_list = [atom.get_bfactor() for atom in residue]
                    average_b_factor = np.mean(b_factor_list)
                    b_factors.append(average_b_factor)
                    residue_names.append(one_letter_code.get(res_name, "X"))  # Default to 'X' for unknown residues

        # Write the output file with re-indexed residue numbers starting from 1
        with open(output_txt_path, "w") as file:
            file.write("Residue\tResidueAA\tAverage_B_Factor\n")
            for idx, (aa, b_factor) in enumerate(zip(residue_names, b_factors), start=1):
                file.write(f"{idx}\t{aa}\t{b_factor:.3f}\n")

        print(f"Average B-factors for {os.path.basename(pdb_path)} saved to {output_txt_path}")
    except Exception as e:
        print(f"Error processing PDB file {pdb_path}: {e}")

def has_chain_letter(filename):
    """
    Determines if the given filename has a chain letter before the .pdb extension.
    Assumes that a chain letter is a single uppercase letter appended to the PDB ID.

    Parameters:
    - filename: str, name of the file.

    Returns:
    - bool: True if a chain letter is present, False otherwise.
    """
    # Regex to match filenames ending with a single uppercase letter before .pdb
    return bool(re.match(r'^.+[A-Z]\.pdb$', filename))

def main(base_directory):
    """
    Iterate through each protein directory in the base_directory, find the two PDB files with
    chain letters, calculate their average B-factors, and save the results.

    Parameters:
    - base_directory: str, path to the base directory containing protein subdirectories.
    """
    for protein_dir in os.listdir(base_directory):
        protein_path = os.path.join(base_directory, protein_dir)
        if os.path.isdir(protein_path):  # Ensure it's a directory
            experimental_data_dir = os.path.join(protein_path, "experimental_data")
            if not os.path.isdir(experimental_data_dir):
                print(f"'experimental_data' directory not found in {protein_path}. Skipping.")
                continue

            # List all PDB files in the experimental_data directory
            all_pdb_files = [f for f in os.listdir(experimental_data_dir) if f.lower().endswith('.pdb')]

            # Filter PDB files that have a chain letter in their filename
            chain_pdb_files = [f for f in all_pdb_files if has_chain_letter(f)]

            if len(chain_pdb_files) != 2:
                print(f"Expected 2 PDB files with chain letters in {experimental_data_dir}, found {len(chain_pdb_files)}. Skipping {protein_dir}.")
                continue

            for pdb_file in chain_pdb_files:
                pdb_path = os.path.join(experimental_data_dir, pdb_file)
                # Define output filename, e.g., 'average_b_factors_1bgxT.txt'
                pdb_base = os.path.splitext(pdb_file)[0]
                output_txt_filename = f"average_b_factors_{pdb_base}.txt"
                output_txt_path = os.path.join(experimental_data_dir, output_txt_filename)

                # Calculate average B-factors
                calculate_average_b_factors(pdb_path, output_txt_path)

if __name__ == "__main__":
    # **Specify your base directory here**
    base_directory = ""
    main(base_directory)

Summarize (compress) the frustratometer outputs for two structures for each sequence. The directory containing the frustratometer output .rar files are expected to be named frustratometer_1 and frustratometer_2

In [None]:
import os
import shutil
import glob
import patoolib
import pandas as pd

# Function to clean directories by deleting all files except .rar and removing subdirectories
def clean_directory(directory):
    """
    Deletes all files in the directory except .rar files.
    Deletes all subdirectories and their contents.

    Parameters:
    - directory (str): The path to the directory to clean.
    """
    print(f"\nCleaning directory: {directory}")
    for item in os.listdir(directory):
        item_path = os.path.join(directory, item)
        
        # If it's a file
        if os.path.isfile(item_path):
            if not item.lower().endswith('.rar'):
                print(f"Deleting file: {item_path}")
                os.remove(item_path)
            else:
                print(f"Keeping .rar file: {item_path}")
        
        # If it's a directory
        elif os.path.isdir(item_path):
            print(f"Deleting directory and its contents: {item_path}")
            shutil.rmtree(item_path)

# Function to extract .rar files from directories
def extract_rar_from_directory(input_dir):
    """
    Extracts the first .rar file found in the specified directory.
    """
    print(f"Scanning directory for .rar files: {input_dir}")
    rar_files = [f for f in os.listdir(input_dir) if f.lower().endswith('.rar')]
    if not rar_files:
        print(f"No .rar files found in {input_dir}.")
        return
    
    rar_file = os.path.join(input_dir, rar_files[0])
    extracted_folder = os.path.join(input_dir, "extracted")
    
    try:
        patoolib.extract_archive(rar_file, outdir=extracted_folder)
        print(f"Extracted {rar_file} to {extracted_folder}")
    except Exception as e:
        print(f"Failed to extract {rar_file}: {e}")

# Function to locate the .pdb_mutational file
def locate_pdb_mutational_file(extracted_folder):
    print(f"Searching for FrustrationData directory in: {extracted_folder}")
    frustration_dirs = [
        os.path.join(root, dir)
        for root, dirs, files in os.walk(extracted_folder)
        for dir in dirs if "FrustrationData" in dir
    ]
    if not frustration_dirs:
        raise FileNotFoundError(f"No FrustrationData directory found in {extracted_folder}")
    
    pdb_files = [
        os.path.join(frustration_dirs[0], f)
        for f in os.listdir(frustration_dirs[0])
        if f.endswith('.pdb_mutational')
    ]
    if not pdb_files:
        raise FileNotFoundError(f"No .pdb_mutational files found in {frustration_dirs[0]}")
    
    return pdb_files[0]

# Function to process the .pdb_mutational file
def process_frustration_file(input_file, output_file):
    print(f"Processing .pdb_mutational file: {input_file}")
    col_names = [
        "Res1", "Res2", "ChainRes1", "ChainRes2",
        "DensityRes1", "DensityRes2", "AA1", "AA2",
        "NativeEnergy", "DecoyEnergy", "SDEnergy",
        "FrstIndex", "Welltype", "FrstState"
    ]
    try:
        df = pd.read_csv(input_file, sep=r'\s+', comment="#", names=col_names)
        residue_data = {}

        for _, row in df.iterrows():
            energy_diff = row["NativeEnergy"] - row["DecoyEnergy"]
            residues = [(row["Res1"], row["AA1"]), (row["Res2"], row["AA2"])]
            for res, aa in residues:
                if res not in residue_data:
                    residue_data[res] = {"aa": aa, "sum_diff": 0.0, "count": 0}
                residue_data[res]["sum_diff"] += energy_diff
                residue_data[res]["count"] += 1

        output_data = []
        for res, data in sorted(residue_data.items()):
            average_diff = data["sum_diff"] / data["count"] if data["count"] > 0 else 0.0
            output_data.append([res, data["aa"], average_diff])

        reindexed_data = []
        for new_residue_index, row in enumerate(output_data, start=1):
            reindexed_data.append([new_residue_index, row[1], row[2]])

        with open(output_file, "w") as f:
            f.write("Residue# ResidueAA Difference\n")
            for row in reindexed_data:
                f.write(f"{row[0]} {row[1]} {row[2]:.4f}\n")
        print(f"Output written to {output_file}.")
    except Exception as e:
        print(f"Error processing file {input_file}: {e}")

# Function to process extracted directories
def process_subdirectory(test_dir, sub_dir, output_filename):
    subdirectory_path = os.path.join(test_dir, sub_dir)
    extracted_folder = os.path.join(subdirectory_path, "extracted")
    output_file = os.path.join(subdirectory_path, output_filename)

    if not os.path.isdir(extracted_folder):
        print(f"Extracted folder not found: {extracted_folder}. Skipping {sub_dir} in {test_dir}.")
        return

    try:
        pdb_mutational_file = locate_pdb_mutational_file(extracted_folder)
        process_frustration_file(pdb_mutational_file, output_file)
    except FileNotFoundError as e:
        print(f"Error processing {sub_dir} in {test_dir}: {e}")
    except Exception as e:
        print(f"An unexpected error occurred while processing {sub_dir} in {test_dir}: {e}")

# Main function
def main(root_directory):
    for dirpath, dirnames, _ in os.walk(root_directory):
        for dirname in dirnames:
            if dirname.lower() in ['frustratometer_1', 'frustratometer_2']:
                target_dir = os.path.join(dirpath, dirname)
                print(f"\nFound target directory: {target_dir}")
                
                # **Clean the target directory before extraction**
                clean_directory(target_dir)
                
                # Proceed with extraction
                extract_rar_from_directory(target_dir)
                
                # Determine output filename based on directory type
                if 'frustratometer_2' in dirname.lower():
                    output_filename = "frustration_summary.txt"
                elif 'frustratometer_1' in dirname.lower():
                    output_filename = "frustration_summary.txt"
                else:
                    print(f"Unknown directory type: {dirname}. Skipping.")
                    continue

                # Process the subdirectory
                process_subdirectory(dirpath, dirname, output_filename)

if __name__ == "__main__":
    # **Specify your root directory here**
    root_directory = "" 
    main(root_directory)

Collect data in summary file for 20F set of proteins

In [None]:
import os
import re
import glob
import numpy as np
import pandas as pd
import logging
from itertools import chain
from Bio import pairwise2
from Bio.Seq import Seq
from collections import defaultdict
from Bio.PDB.Polypeptide import is_aa
from Bio.PDB import PDBParser, DSSP
from pathlib import Path

# -------------------------------------------------------------------------
# 1) Logging Setup
# -------------------------------------------------------------------------
def setup_logging(debug=False):
    level = logging.DEBUG if debug else logging.INFO
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    return logging.getLogger(__name__)

logger = setup_logging(debug=True)

# -------------------------------------------------------------------------
# 2) Directory Validation
# -------------------------------------------------------------------------
def validate_directory_structure(dir_path):
    """
    Checks which data sources are available in the directory.
    Returns a tuple (has_any_data, available_sources).
    """
    available_sources = []
    dir_path = Path(dir_path)
    
    # Detect frustratometer_1 and frustratometer_2 directories
    frustratometer_dirs = sorted([d for d in dir_path.iterdir() if d.is_dir() and d.name.startswith('frustratometer_')])
    if len(frustratometer_dirs) > 2:
        logger.warning(f"More than two frustratometer directories found in {dir_path}. Only the first two will be considered.")
        frustratometer_dirs = frustratometer_dirs[:2]
    for i, subdir in enumerate(frustratometer_dirs, start=1):
        required_files = ['frustration_summary.txt']
        if all((subdir / file).is_file() for file in required_files):
            source_name = f'EXP_FRUST_{i}'
            available_sources.append({'type': 'frustratometer', 'name': source_name, 'path': subdir})
            logger.debug(f"Detected frustration source: {source_name} at {subdir}")
        else:
            logger.warning(f"Frustratometer directory {subdir} is missing required files.")
    
    # Detect frustratometer_af
    frustratometer_af = dir_path / 'frustratometer_af'
    if frustratometer_af.is_dir() and (frustratometer_af / 'frustration_af_summary.txt').is_file():
        available_sources.append({'type': 'frustratometer_af', 'name': 'AF_FRUST_1', 'path': frustratometer_af})
        logger.debug(f"Detected frustration AF source: AF_FRUST_1 at {frustratometer_af}")
    
    # Detect average B-factor files in experimental_data
    experimental_data = dir_path / 'experimental_data'
    if experimental_data.is_dir():
        # Use itertools.chain to combine glob results
        b_factor_files = sorted(
            chain(
                experimental_data.glob('average_b_factors*.csv'), 
                experimental_data.glob('average_b_factors*.txt')
            )
        )
        b_factor_counter = 1  # Initialize counter for B-factor sources
        for bf_file in b_factor_files:
            # Extract pdbID and chainID from filename
            match = re.search(r'average_b_factors_(\w+)([A-Z])\.(csv|txt)$', bf_file.name)
            if match:
                source_name = f'B_FACTOR_{b_factor_counter}'
                available_sources.append({'type': 'b_factor', 'name': source_name, 'path': bf_file})
                logger.debug(f"Detected B-factor source: {source_name} at {bf_file}")
                b_factor_counter += 1
            else:
                logger.warning(f"B-factor file {bf_file} does not match the expected naming convention.")
        
        # Detect RMSF files if present
        rmsf_files = sorted(
            chain(
                experimental_data.glob('rmsf*.csv'), 
                experimental_data.glob('rmsf*.txt')
            )
        )
        rmsf_counter = 1  # Initialize counter for RMSF sources
        for rmsf_file in rmsf_files:
            # Extract pdbID and chainID from filename
            match = re.search(r'rmsf_(\w+)([A-Z])\.(csv|txt)$', rmsf_file.name)
            if match:
                source_name = f'RMSF_{rmsf_counter}'
                available_sources.append({'type': 'rmsf', 'name': source_name, 'path': rmsf_file})
                logger.debug(f"Detected RMSF source: {source_name} at {rmsf_file}")
                rmsf_counter += 1
            else:
                logger.warning(f"RMSF file {rmsf_file} does not match the expected naming convention.")
    
    # Detect mj_analysis
    mj_analysis = dir_path / 'mj_analysis'
    if mj_analysis.is_dir() and (mj_analysis / 'stability_scores.txt').is_file():
        available_sources.append({'type': 'mj_analysis', 'name': 'EVOL_FRUST', 'path': mj_analysis / 'stability_scores.txt'})
        logger.debug(f"Detected MJ analysis source: EVOL_FRUST at {mj_analysis / 'stability_scores.txt'}")
    
    # Detect PDB files with chain letters
    pdb_files = sorted([f for f in experimental_data.glob("*.pdb") if re.search(r'\w+[A-Z]\.pdb$', f.name)])
    pdb_counter = 1  # Initialize counter for PDB sources
    if len(pdb_files) < 2:
        logger.warning(f"Less than two PDB files with chain IDs found in {experimental_data}. Expected two.")
    elif len(pdb_files) > 2:
        logger.warning(f"More than two PDB files with chain IDs found in {experimental_data}. Only the first two will be considered.")
        pdb_files = pdb_files[:2]
    
    for pdb_file in pdb_files:
        if pdb_counter <= len(frustratometer_dirs):
            src_name = f'PDB_{pdb_counter}'
            available_sources.append({'type': 'pdb', 'name': src_name, 'path': pdb_file, 'frustratometer': f'EXP_FRUST_{pdb_counter}'})
            logger.debug(f"Mapped PDB file {pdb_file.name} to frustration source EXP_FRUST_{pdb_counter}")
            pdb_counter += 1
        else:
            logger.warning(f"No corresponding frustratometer directory for PDB file {pdb_file.name}")
    
    has_any_data = len(available_sources) > 0
    return has_any_data, available_sources

def get_valid_directories(root_path):
    """
    Finds all subdirectories that have at least one valid data source.
    Returns a list of tuples (directory_path, available_sources).
    """
    valid_dirs = []
    root_path = Path(root_path)
    
    if not root_path.is_dir():
        logger.error(f"Root directory does not exist: {root_path}")
        return []
        
    for entry in root_path.iterdir():
        if entry.is_dir():
            has_data, available_sources = validate_directory_structure(entry)
            if has_data:
                valid_dirs.append((entry, available_sources))
                source_names = ', '.join([src['name'] for src in available_sources])
                logger.info(f"Found directory {entry} with data sources: {source_names}")
            else:
                logger.warning(f"Directory {entry} has no valid data sources, skipping")
    
    return sorted(valid_dirs, key=lambda x: str(x[0]))

# -------------------------------------------------------------------------
# 3) Parsing Functions
# -------------------------------------------------------------------------
def parse_frustration_file(file_path):
    """
    Parses frustration_summary.txt or frustration_af_summary.txt.
    
    Format:
        Residue# ResidueAA Difference
        1 A -0.1234
        2 G 0.5678
        3 S -1.2345
        ...
        
    Returns:
        sequence_str (str): Amino acid sequence in 1-letter codes.
        pos_values (dict): Dictionary mapping 1-based residue positions to frustration differences.
    """
    file_path = Path(file_path)
    if not file_path.is_file():
        logger.debug(f"parse_frustration_file: File not found {file_path}")
        return "", {}

    lines = []
    with open(file_path, 'r') as f:
        for line_number, line in enumerate(f, start=1):
            line = line.strip()
            if not line or line.startswith("Residue#"):
                continue
            parts = line.split()
            if len(parts) >= 3:
                try:
                    res_num = int(parts[0])
                    aa = parts[1].upper()
                    difference = float(parts[2])
                    
                    if aa not in "ACDEFGHIKLMNPQRSTVWY":
                        logger.warning(f"Line {line_number}: Unknown amino acid code '{aa}'. Assigned as 'X'.")
                        aa = 'X'

                    lines.append((res_num, aa, difference))
                except ValueError as ve:
                    logger.warning(f"parse_frustration_file: Skipping line {line_number} due to ValueError: {line}")
                    continue
            else:
                logger.warning(f"parse_frustration_file: Line {line_number} does not have enough parts: {line}")
                continue

    if not lines:
        logger.debug(f"parse_frustration_file: No valid data found in {file_path}")
        return "", {}

    lines_sorted = sorted(lines, key=lambda x: x[0])
    sequence = ''.join([aa for (_, aa, _) in lines_sorted])
    pos_values = {res_num: difference for (res_num, _, difference) in lines_sorted}

    return sequence, pos_values

def parse_b_factor(file_path):
    """
    Parses average_b_factors*.csv or average_b_factors*.txt
    Format:
      Residue,ResidueAA,Average_B_Factor (for .csv)
      or
      Residue\tResidueAA\tAverage_B_Factor (for .txt)
    Returns:
        sequence_str (str): Amino acid sequence from file.
        pos_values (dict): Dictionary {1-based_pos: B-factor}.
    """
    file_path = Path(file_path)
    if not file_path.is_file():
        logger.debug(f"parse_b_factor: File not found {file_path}")
        return "", {}

    # Determine delimiter based on file extension
    if file_path.suffix.lower() == '.txt':
        delimiter = '\t'
    else:
        delimiter = ','
    
    lines = []
    with open(file_path, 'r') as f:
        # Read the header to confirm columns
        header = f.readline().strip()
        expected_headers = ['Residue', 'ResidueAA', 'Average_B_Factor']
        actual_headers = header.split(delimiter)
        if actual_headers != expected_headers:
            logger.warning(f"parse_b_factor: Unexpected headers in {file_path}: {actual_headers}")
            # Optionally, handle different header names or order here
        for line_number, line in enumerate(f, start=2):  # Start at 2 to account for header
            line = line.strip()
            if not line:
                continue
            parts = line.split(delimiter)
            if len(parts) < 3:
                logger.warning(f"parse_b_factor: Skipping malformed line {line_number}: {line}")
                continue
            try:
                idx = int(parts[0])
                aa = parts[1].upper()
                bfact = float(parts[2])
                lines.append((idx, aa, bfact))
            except ValueError as ve:
                logger.warning(f"parse_b_factor: Skipping line {line_number} due to ValueError: {line}")
                continue

    if not lines:
        logger.debug(f"parse_b_factor: No lines parsed in {file_path}")
        return "", {}

    data_dict = {r[0]: (r[1], r[2]) for r in lines}
    sorted_indices = sorted(data_dict.keys())

    seq_builder = []
    pos_values = {}
    for pos, idx in enumerate(sorted_indices, start=1):
        aa, bfact = data_dict[idx]
        seq_builder.append(aa)
        pos_values[pos] = bfact

    sequence_str = "".join(seq_builder)
    logger.debug(f"parse_b_factor: Parsed sequence length={len(sequence_str)}, B-factors extracted={len(pos_values)}")
    logger.debug(f"parse_b_factor: B-factor values: {list(pos_values.items())[:5]}...")  # Show first 5 for brevity
    return sequence_str, pos_values

def parse_evolutionary(file_path):
    """
    Parses stability_scores.txt
    Format:
      Label Score Difference
      M1C -456.89159 -0.0
      ...
    Returns (sequence_str, pos_values)
    """
    file_path = Path(file_path)
    if not file_path.is_file():
        logger.debug(f"parse_evolutionary: File not found {file_path}")
        return "", {}

    data_map = defaultdict(list)
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("Label"):
                continue
            parts = line.split()
            if len(parts) < 3:
                continue
            label = parts[0]
            try:
                diff = float(parts[2])
            except ValueError:
                logger.warning(f"parse_evolutionary: Invalid difference value in line: {line}")
                continue

            if label.lower() == "wt":
                continue

            m = re.match(r'([A-Z])(\d+)([A-Z])', label, re.IGNORECASE)
            if m:
                native_aa = m.group(1).upper()
                idx = int(m.group(2))
                data_map[(native_aa, idx)].append(diff)

    if not data_map:
        logger.debug(f"parse_evolutionary: No valid lines in {file_path}")
        return "", {}

    index_to_aa = {}
    index_to_diff = {}

    used_positions = set()
    sorted_keys = sorted(data_map.keys(), key=lambda x: x[1])
    for (aa, i) in sorted_keys:
        if i in used_positions:
            continue
        used_positions.add(i)
        diffs = data_map[(aa, i)]
        avg_diff = sum(diffs)/len(diffs) if diffs else 0.0
        index_to_aa[i] = aa
        index_to_diff[i] = avg_diff

    sorted_indices = sorted(index_to_aa.keys())
    seq_builder = []
    pos_values = {}
    for pos, idx in enumerate(sorted_indices, start=1):
        seq_builder.append(index_to_aa[idx])
        pos_values[pos] = index_to_diff[idx]

    sequence_str = "".join(seq_builder)
    return sequence_str, pos_values

def parse_rmsf(file_path):
    """
    Parses experimental_data/rmsf*.csv or rmsf*.txt of the form:
      26,A,2.472
      27,A,2.308
      28,A,2.657
      ...
    We reindex so that the first residue (e.g. 26) -> 1, second (27) -> 2, etc.
    Returns (sequence_str, pos_values) => {1-based_pos: RMSF}.
    """
    file_path = Path(file_path)
    if not file_path.is_file():
        logger.debug(f"parse_rmsf: File not found {file_path}")
        return "", {}

    lines = []
    with open(file_path, 'r') as f:
        for line_number, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
            parts = line.split(',')
            if len(parts) < 3:
                logger.warning(f"parse_rmsf: Skipping malformed line {line_number}: {line}")
                continue
            try:
                old_idx = int(parts[0])
                rmsf_val = float(parts[2])
                lines.append((old_idx, rmsf_val))
            except ValueError:
                logger.warning(f"parse_rmsf: Skipping invalid line {line_number}: {line}")
                continue

    if not lines:
        logger.debug("parse_rmsf: No valid lines in RMSF file.")
        return "", {}

    # Sort by the original residue index just in case
    lines_sorted = sorted(lines, key=lambda x: x[0])

    # Determine offset
    offset = lines_sorted[0][0] - 1  # e.g., if first index is 26, offset is 25
    pos_values = {}
    seq_builder = []
    for (old_idx, val) in lines_sorted:
        new_idx = old_idx - offset
        pos_values[new_idx] = val
        # Use a dummy 'A' (or any single-letter code) for alignment
        seq_builder.append("A")

    sequence_str = "".join(seq_builder)
    return sequence_str, pos_values

def parse_mutation_scores(file_path):
    """
    Parses mutation_scores.txt for Mutability.
    The file has the columns:
       segment  mutant  pos wt subs    frequency  column_conservation  effect_prediction_epistatic
    For each unique residue position (given by the 'pos' column) the function computes
    the negative average of the effect_prediction_epistatic score.
    
    Returns:
        sequence_str (str): Concatenated wt residues for each position (ordered by pos).
        pos_values (dict): Mapping (1-based sequential) of the negative average effect per position.
    """
    file_path = Path(file_path)
    if not file_path.is_file():
        logger.debug(f"parse_mutation_scores: File not found {file_path}")
        return "", {}

    data = defaultdict(list)
    wt_map = {}
    with open(file_path, 'r') as f:
        header = f.readline().strip()
        for line_number, line in enumerate(f, start=2):
            line = line.strip()
            if not line:
                continue
            # Use regex to split on one or more whitespace characters.
            parts = re.split(r'\s+', line)
            if len(parts) < 8:
                logger.warning(f"parse_mutation_scores: Skipping malformed line {line_number}: {line}")
                continue
            try:
                pos = int(parts[2])
                wt_res = parts[3].upper()
                effect = float(parts[7])
                data[pos].append(effect)
                wt_map[pos] = wt_res
            except ValueError as ve:
                logger.warning(f"parse_mutation_scores: Error parsing line {line_number}: {line}")
                continue

    if not data:
        return "", {}

    sorted_positions = sorted(data.keys())
    seq_builder = []
    pos_values = {}
    for count, pos in enumerate(sorted_positions, start=1):
        avg_effect = sum(data[pos]) / len(data[pos])
        neg_avg = -avg_effect
        seq_builder.append(wt_map[pos])
        pos_values[count] = neg_avg  # using sequential index in the new sequence
    sequence_str = "".join(seq_builder)
    logger.debug(f"parse_mutation_scores: Parsed mutation score sequence length={len(sequence_str)}")
    return sequence_str, pos_values

# -------------------------------------------------------------------------
# 4) Alignment Functions
# -------------------------------------------------------------------------
def align_two(seqA, seqB, gap_open=-2, gap_extend=-0.5):
    """
    Attempt a global alignment with Biopython pairwise2.
    Returns (alignedA, alignedB).
    """
    seqA = seqA.upper()
    seqB = seqB.upper()

    if len(seqA) == 0 and len(seqB) == 0:
        logger.debug("align_two: Both sequences empty => '' ")
        return "", ""
    if len(seqA) == 0:
        logger.debug(f"align_two: SeqA empty, SeqB length={len(seqB)} => trivial alignment")
        return "-" * len(seqB), seqB
    if len(seqB) == 0:
        logger.debug(f"align_two: SeqB empty, SeqA length={len(seqA)} => trivial alignment")
        return seqA, "-" * len(seqA)

    logger.debug(f"align_two: Attempting global alignment: len(seqA)={len(seqA)}, len(seqB)={len(seqB)}")
    alignments = pairwise2.align.globalms(seqA, seqB, 2, -1, gap_open, gap_extend)
    if not alignments:
        logger.warning("align_two: No alignment from Biopython => trivial fallback.")
        max_len = max(len(seqA), len(seqB))
        if len(seqA) == max_len:
            return seqA, seqB + "-"*(len(seqA)-len(seqB))
        else:
            return seqA + "-"*(len(seqB)-len(seqA)), seqB

    best = alignments[0]
    return best[0], best[1]

def merge_val_alignment(alnA, alnB, valA, valB):
    """
    Merge aligned sequences with their values.
    Returns (aligned_valsA, aligned_valsB) as lists.
    """
    aligned_valsA = []
    aligned_valsB = []
    origA_pos = 1
    origB_pos = 1
    
    for i in range(len(alnA)):
        cA = alnA[i]
        cB = alnB[i]

        if cA == '-':
            aligned_valsA.append('n/a')
        else:
            aligned_valsA.append(valA.get(origA_pos, 'n/a'))
            origA_pos += 1

        if cB == '-':
            aligned_valsB.append('n/a')
        else:
            aligned_valsB.append(valB.get(origB_pos, 'n/a'))
            origB_pos += 1
            
    return aligned_valsA, aligned_valsB

# -------------------------------------------------------------------------
# 5) Multiple Sequence Alignment
# -------------------------------------------------------------------------
def multi_align_sequences(seq_list):
    """
    Progressive multiple sequence alignment.
    seq_list: [(name, seq_str, val_dict), ...]
    Returns a list of (name, final_aln_seq, final_aln_vals).
    """
    if not seq_list:
        logger.debug("multi_align_sequences: No sequences to align.")
        return []
    
    # Initialize with the first sequence as the master
    master_name, master_seq, master_vals = seq_list[0]
    master_aln_seq = master_seq
    master_aln_vals = [master_vals.get(i, 'n/a') for i in range(1, len(master_seq)+1)]
    
    aligned_sequences = [{'name': master_name, 'aln_seq': master_aln_seq, 'aln_vals': master_aln_vals}]
    
    for name, seq, vals in seq_list[1:]:
        logger.debug(f"multi_align: Aligning MASTER({master_name}) with {name}")
        # Align master sequence with the new sequence
        alnA, alnB = align_two(master_aln_seq, seq)
        # Merge values based on alignment
        new_master_vals, new_seq_vals = merge_val_alignment(alnA, alnB, 
                                                             {i+1: v for i, v in enumerate(master_aln_vals)}, 
                                                             vals)
        # Update master alignment
        master_aln_seq = alnA
        master_aln_vals = new_master_vals
        aligned_sequences[0]['aln_seq'] = master_aln_seq
        aligned_sequences[0]['aln_vals'] = master_aln_vals
        # Add the new sequence alignment
        aligned_sequences.append({'name': name, 'aln_seq': alnB, 'aln_vals': new_seq_vals})
    
    # After progressive alignment, ensure all sequences have the same length
    final_len = len(master_aln_seq)
    for seq in aligned_sequences:
        aln_length = len(seq['aln_seq'])
        if aln_length < final_len:
            padding = '-' * (final_len - aln_length)
            seq['aln_seq'] += padding
            seq['aln_vals'] += ['n/a'] * (final_len - aln_length)
        elif aln_length > final_len:
            logger.warning(f"{seq['name']} alignment is longer than master!? Truncating.")
            seq['aln_seq'] = seq['aln_seq'][:final_len]
            seq['aln_vals'] = seq['aln_vals'][:final_len]
    
    # Prepare final data
    final_data = []
    for seq in aligned_sequences:
        final_data.append((seq['name'], seq['aln_seq'], seq['aln_vals']))
    
    return final_data

def map_ss_to_alignment(ss_maps, residue_seq, aligned_dict):
    """
    Maps multiple secondary structure assignments to aligned sequence positions.
    
    ss_maps: dict of {frustratometer_name: ss_map}
    residue_seq: aligned master sequence
    aligned_dict: {source_name: (aln_seq, aln_vals)}
    
    Returns:
        ss_result: dict of {SecondaryStructure_<frustratometer_name>: [ss_assignments]}
    """
    ss_result = {}
    
    for frustratometer_name, ss_map in ss_maps.items():
        if not ss_map:
            ss_result[f'SecondaryStructure_{frustratometer_name}'] = ['n/a'] * len(residue_seq)
            continue
        
        # Check if frustration source exists in aligned_dict
        if frustratometer_name not in aligned_dict:
            logger.warning(f"Frustratometer source {frustratometer_name} not found in alignment for SS mapping.")
            ss_result[f'SecondaryStructure_{frustratometer_name}'] = ['n/a'] * len(residue_seq)
            continue
        
        aln_seq, aln_vals = aligned_dict[frustratometer_name]
        
        # Initialize SS list
        ss_list = ['n/a'] * len(residue_seq)
        
        ss_pos = 1
        for i, (master_res, src_res) in enumerate(zip(residue_seq, aln_seq)):
            if src_res == '-' or aln_vals[i] == 'n/a':
                ss_list[i] = 'n/a'
            else:
                ss_list[i] = ss_map.get(ss_pos, 'o')
                ss_pos += 1
                
        ss_result[f'SecondaryStructure_{frustratometer_name}'] = ss_list
    
    return ss_result

# -------------------------------------------------------------------------
# 6) PDB Parsing Function
# -------------------------------------------------------------------------
def parse_secondary_structure_from_pdb(pdb_file_path):
    """
    Parses a PDB file to extract per-residue secondary structure assignments using DSSP.
    Returns:
        ss_map (dict): Mapping from sequential position (1-based) to 'A', 'B', or 'o'.
    """
    ss_map = {}
    pdb_file = Path(pdb_file_path)
    
    if not pdb_file.is_file():
        logger.error(f"PDB file not found: {pdb_file_path}")
        return ss_map

    parser = PDBParser(QUIET=True)
    try:
        structure = parser.get_structure('protein', pdb_file)
        model = structure[0]
        chains = list(model.get_chains())
        logger.debug(f"Found chains: {[chain.id for chain in chains]}")
        
        # Initialize DSSP
        dssp = DSSP(model, str(pdb_file), dssp='/opt/homebrew/bin/mkdssp')
        logger.debug(f"DSSP successful, found {len(dssp.keys())} residues")
        
        # Sort residues by chain ID and residue number
        sorted_keys = sorted(dssp.keys(), key=lambda x: (x[1][0], x[1][1]))  # (chain_id, res_num)
        
        # Assign sequential positions
        for seq_pos, key in enumerate(sorted_keys, start=1):
            ss = dssp[key][2]
            if ss in ('H', 'G', 'I'):
                ss_code = 'A'  # Alpha-helix
            elif ss in ('E', 'B'):
                ss_code = 'B'  # Beta-sheet
            else:
                ss_code = 'o'  # Other/loop
            ss_map[seq_pos] = ss_code
            logger.debug(f"Assigned SS for sequential position {seq_pos}: {ss} -> {ss_code}")

    except Exception as e:
        logger.error(f"Error processing PDB: {str(e)}")
        return ss_map

    logger.info(f"Successfully parsed {len(ss_map)} residues with SS assignments from {pdb_file_path}")
    return ss_map

# -------------------------------------------------------------------------
# 7) Main Processing Function
# -------------------------------------------------------------------------
def process_directory(root_directory, output_dir=None):
    """
    Process a directory containing subdirectories with data.
    """
    root_directory = Path(root_directory)
    
    # Set up summary data directory
    if output_dir:
        summary_data_dir = Path(output_dir)
    else:
        summary_data_dir = root_directory / "summary_data"
        
    try:
        summary_data_dir.mkdir(exist_ok=True)
        logger.info(f"Using summary directory: {summary_data_dir}")
    except Exception as e:
        logger.error(f"Failed to create summary directory: {e}")
        return
    
    # Get valid directories and their available sources
    dir_info = get_valid_directories(root_directory)

    if not dir_info:
        logger.error("No directories with valid data sources found.")
        return

    # Process each directory
    for dir_path, available_sources in dir_info:
        protein_id = dir_path.name
        logger.info(f"Processing: {dir_path}")
        
        # Initialize data containers
        data_sources = []
        ss_maps = {}  # Secondary Structure mappings per frustratometer
    
        try:
            # Extract PDB to frustratometer mapping
            pdb_to_frustratometer = {}
            for src in available_sources:
                if src['type'] == 'pdb':
                    pdb_name = src['path'].name
                    frustratometer_name = src.get('frustratometer')
                    if frustratometer_name:
                        pdb_to_frustratometer[frustratometer_name] = src['name']
                        logger.debug(f"Mapping frustratometer {frustratometer_name} to PDB file {pdb_name}")

            # Parse frustration sources
            for src in available_sources:
                src_type = src['type']
                src_name = src['name']
                src_path = src['path']
                
                if src_type == 'frustratometer':
                    # Parse frustration summary
                    frustration_summary = src_path / "frustration_summary.txt"
                    seq, vals = parse_frustration_file(str(frustration_summary))
                    if seq:
                        data_sources.append((src_name, seq, vals))
                        logger.debug(f"Parsed frustration source {src_name}: length={len(seq)}")
                
                elif src_type == 'frustratometer_af':
                    # Parse AlphaFold frustration summary
                    frustration_af_summary = src_path / "frustration_af_summary.txt"
                    seq, vals = parse_frustration_file(str(frustration_af_summary))
                    if seq:
                        data_sources.append((src_name, seq, vals))
                        logger.debug(f"Parsed AlphaFold frustration source {src_name}: length={len(seq)}")
                
                elif src_type == 'b_factor':
                    # Parse B-factor file
                    seq, vals = parse_b_factor(str(src_path))
                    if seq:
                        data_sources.append((src_name, seq, vals))
                        logger.debug(f"Parsed B-factor source {src_name}: length={len(seq)}, B-factors count={len(vals)}")
                    else:
                        logger.warning(f"B-factor source {src_name} was parsed but returned empty data.")
                
                elif src_type == 'rmsf':
                    # Parse RMSF file
                    seq, vals = parse_rmsf(str(src_path))
                    if seq:
                        data_sources.append((src_name, seq, vals))
                        logger.debug(f"Parsed RMSF source {src_name}: length={len(seq)}")
                
                elif src_type == 'mj_analysis':
                    # Parse Evolutionary data
                    seq, vals = parse_evolutionary(str(src_path))
                    if seq:
                        data_sources.append((src_name, seq, vals))
                        logger.debug(f"Parsed Evolutionary source {src_name}: length={len(seq)}")
    
            # Parse Secondary Structure for each PDB and map to frustratometer
            for frustratometer_name, pdb_source_name in pdb_to_frustratometer.items():
                # Find the corresponding PDB source in available_sources
                pdb_src = next((src for src in available_sources if src['name'] == pdb_source_name), None)
                if pdb_src:
                    ss_map = parse_secondary_structure_from_pdb(str(pdb_src['path']))
                    ss_maps[frustratometer_name] = ss_map
                    if ss_map:
                        logger.info(f"Secondary structure information extracted from {pdb_src['path'].name} for {frustratometer_name}.")
                    else:
                        logger.warning(f"Secondary structure information could not be extracted from {pdb_src['path'].name} for {frustratometer_name}. Assigning 'n/a' to all residues.")
                else:
                    logger.warning(f"No PDB source found for frustratometer {frustratometer_name}")
    
            # Collect available data sources for alignment
            if not data_sources:
                logger.warning(f"No valid sequence data found in {dir_path}, skipping.")
                continue

            logger.debug(f"Found {len(data_sources)} valid data sources for alignment")

            # Perform multiple alignment
            aligned = multi_align_sequences(data_sources)
            if not aligned:
                logger.warning(f"Alignment failed for {dir_path}, skipping.")
                continue

            logger.debug(f"Successfully aligned {len(aligned)} sequences")

            # The reference alignment is the first entry
            residue_seq = aligned[0][1]
            final_len = len(residue_seq)
            logger.debug(f"Final alignment length: {final_len}")

            # Build a lookup for name -> (aln_seq, aln_vals)
            aligned_dict = {x[0]: (x[1], x[2]) for x in aligned}
            logger.debug(f"Available data types in alignment: {list(aligned_dict.keys())}")

            # Process mutation scores for Mutability
            mut_scores_file = dir_path / "evc_output" / "couplings" / "mutation_scores.txt"
            if mut_scores_file.is_file():
                mut_seq, mut_vals = parse_mutation_scores(str(mut_scores_file))
                if mut_seq:
                    logger.debug(f"Parsed mutation scores: sequence length {len(mut_seq)}")
                    # Align the mutation score sequence to the master residue sequence.
                    aln_master, aln_mut = align_two(residue_seq, mut_seq)
                    # Create a dummy dictionary for the master alignment (not used) with length equal to mut_seq length.
                    dummy_master = {i+1: 'n/a' for i in range(len(mut_seq))}
                    _, aligned_mut_vals = merge_val_alignment(aln_master, aln_mut, dummy_master, mut_vals)
                else:
                    logger.warning("Mutation scores file parsed but returned empty sequence.")
                    aligned_mut_vals = ['n/a'] * final_len
            else:
                logger.warning("Mutation scores file not found.")
                aligned_mut_vals = ['n/a'] * final_len

            # Prepare columns
            raw_index = list(range(1, final_len+1))
            raw_res = list(residue_seq)
            
            # Assign Secondary Structure based on alignment
            ss_assignments = map_ss_to_alignment(ss_maps, residue_seq, aligned_dict)
            if not ss_assignments:
                ss_assignments = {}
            
            # Initialize dictionary to hold all values
            summary_dict = {
                'AlnIndex': raw_index,
                'Residue': raw_res
            }

            # Add SS columns
            for ss_col, ss_vals in ss_assignments.items():
                summary_dict[ss_col] = ss_vals

            # Dynamically add all available data sources to the summary
            for source_name in aligned_dict:
                seq_aln, vals_aln = aligned_dict[source_name]
                summary_dict[source_name] = vals_aln

            # Add the new column for Mutability
            summary_dict["Mutability"] = aligned_mut_vals

            # Convert the summary dictionary to a DataFrame
            df_summary = pd.DataFrame(summary_dict)

            # Define the order of columns
            columns_order = ['AlnIndex', 'Residue']
            # Add SS columns first
            ss_columns = sorted([col for col in df_summary.columns if col.startswith('SecondaryStructure_')])
            columns_order.extend(ss_columns)
            # Then add the new mutation susceptibility column
            columns_order.append("Mutability")
            # Finally, add the remaining data source columns
            data_columns = sorted([col for col in df_summary.columns if col not in columns_order])
            columns_order.extend(data_columns)

            df_summary = df_summary[columns_order]

            # -------------------------
            # Modified Output: Write CSV instead of TXT
            # -------------------------
            # Write to summary.csv inside each directory
            out_path = dir_path / "summary.csv"
            df_summary.to_csv(out_path, index=False, na_rep='n/a')
            logger.info(f"Wrote summary to {out_path}")

            # Also write to the central summary_data directory
            summary_filename = f"summary_{dir_path.name}.csv"
            summary_path = summary_data_dir / summary_filename
            df_summary.to_csv(summary_path, index=False, na_rep='n/a')
            logger.info(f"Wrote summary to {summary_path}")

        except Exception as e:
            logger.error(f"Error processing {dir_path}: {str(e)}")
            continue

    return summary_data_dir

# -------------------------------------------------------------------------
# 9) Example Usage
# -------------------------------------------------------------------------
if __name__ == "__main__":
    root_dir = ""
    summary_dir = process_directory(root_dir)

Generate Violin plots of distributions of spearman correlation coefficients between frustration and B-factor for 20F (Figure 7)

In [None]:
import os
import pandas as pd
import numpy as np
from scipy.stats import spearmanr, levene, bartlett, kruskal, mannwhitneyu
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import combinations
from statsmodels.stats.multitest import multipletests

def compute_spearman(x, y):
    """
    Compute Spearman correlation with improved handling of constant arrays.
    Returns None if correlation cannot be computed.
    """
    mask = x.notna() & y.notna()
    if mask.sum() > 1:
        x_valid = x[mask]
        y_valid = y[mask]
        if x_valid.std() == 0 or y_valid.std() == 0:
            return None
        try:
            corr, _ = spearmanr(x_valid, y_valid)
            return corr if not np.isnan(corr) else None
        except Exception:
            return None
    return None

def process_data(data_dir):
    """
    Process all data files in the specified directory and compute Spearman correlations.
    
    Expects files with (at least) the following columns:
      - B_FACTOR_1
      - B_FACTOR_2
      - EXP_FRUST_1
      - EXP_FRUST_2
      - EVOL_FRUST
      
    Computes Spearman correlations for the following pairs:
      1. B_FACTOR_1 vs EXP_FRUST_1
      2. B_FACTOR_1 vs EXP_FRUST_2
      3. B_FACTOR_2 vs EXP_FRUST_1
      4. B_FACTOR_2 vs EXP_FRUST_2
      5. B_FACTOR_1 vs EVOL_FRUST
      6. B_FACTOR_2 vs EVOL_FRUST
    """
    results = []
    required_cols = ['B_FACTOR_1', 'B_FACTOR_2', 'EXP_FRUST_1', 'EXP_FRUST_2', 'EVOL_FRUST']
    
    for filename in os.listdir(data_dir):
        if filename.endswith(('.txt', '.csv')):
            filepath = os.path.join(data_dir, filename)
            try:
                sep = '\t' if filename.endswith('.txt') else ','
                df = pd.read_csv(filepath, sep=sep, na_values=['n/a', 'N/A'])
                missing = [col for col in required_cols if col not in df.columns]
                if missing:
                    print(f"Skipping {filename}: Missing columns {missing}")
                    continue
                for col in required_cols:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                
                b_factors = ['B_FACTOR_1', 'B_FACTOR_2']
                exp_frust = ['EXP_FRUST_1', 'EXP_FRUST_2']
                evol_frust = ['EVOL_FRUST']
                
                # Experimental frustration comparisons
                for b in b_factors:
                    for f in exp_frust:
                        corr = compute_spearman(df[b], df[f])
                        if corr is not None:
                            results.append({
                                'Protein': filename,
                                'Pair': f"{b} vs {f}",
                                'Spearman_Correlation': corr
                            })
                # Evolutionary frustration comparisons
                for b in b_factors:
                    for f in evol_frust:
                        corr = compute_spearman(df[b], df[f])
                        if corr is not None:
                            results.append({
                                'Protein': filename,
                                'Pair': f"{b} vs {f}",
                                'Spearman_Correlation': corr
                            })
            except Exception as e:
                print(f"Error processing {filename}: {e}")
    
    return pd.DataFrame(results)

def pairwise_mannwhitney_corrected(df, group_col, value_col, alpha=0.05, correction='bonferroni'):
    """
    Perform pairwise Mann–Whitney U tests between all group combinations with Bonferroni correction.
    """
    groups = df[group_col].unique()
    pairwise = list(combinations(groups, 2))
    results = []
    p_values = []
    for (group1, group2) in pairwise:
        data1 = df[df[group_col] == group1][value_col].dropna()
        data2 = df[df[group_col] == group2][value_col].dropna()
        stat, p = mannwhitneyu(data1, data2, alternative='two-sided')
        results.append({'Group1': group1, 'Group2': group2, 'Statistic': stat, 'p-value': p})
        p_values.append(p)
    adjusted = multipletests(p_values, alpha=alpha, method=correction)
    adjusted_pvals = adjusted[1]
    reject = adjusted[0]
    for i, res in enumerate(results):
        res['Adjusted p-value'] = adjusted_pvals[i]
        res['Reject H0'] = reject[i]
    return pd.DataFrame(results)

def add_all_mw_annotations(ax, mw_df, order, significance_level=0.05):
    """
    Add bracket annotations for ALL pairwise Mann–Whitney U tests that are significant.
    
    For each pair (of groups given by 'order'), this function draws a bracket to the right 
    of the plot and annotates the adjusted p-value—but only if the p-value is below the significance_level.
    
    A greedy algorithm assigns a "level" (vertical offset) so that overlapping brackets are staggered.
    """
    # Map each group to its y-axis position 
    y_positions = {group: idx for idx, group in enumerate(order)}
    xlim = ax.get_xlim()
    max_x = xlim[1]
    base_offset = max_x * 0.10  # base horizontal offset from the plot
    hline_length = max_x * 0.02  # horizontal length of bracket lines
    text_offset = max_x * 0.015  # extra horizontal space before text
    
    # List to store already used annotation levels
    annotations = []
    
    # Create all pairwise combinations from the order list
    pairs = []
    for i in range(len(order)):
        for j in range(i+1, len(order)):
            pairs.append((order[i], order[j]))
    # Sort pairs by vertical span (shorter spans first)
    pairs = sorted(pairs, key=lambda pair: (y_positions[pair[1]] - y_positions[pair[0]], y_positions[pair[0]]))
    
    for group1, group2 in pairs:
        mask = ((mw_df['Group1'] == group1) & (mw_df['Group2'] == group2)) | \
               ((mw_df['Group1'] == group2) & (mw_df['Group2'] == group1))
        if mw_df[mask].empty:
            continue
        p_val = mw_df[mask]['Adjusted p-value'].values[0]
        # Only annotate if significant
        if p_val >= significance_level:
            continue
        y1 = y_positions[group1]
        y2 = y_positions[group2]
        level = 0
        # Increase level until no overlapping bracket is found
        while any(level == lev and not (y2 <= y_low or y1 >= y_high) for (y_low, y_high, lev) in annotations):
            level += 1
        annotations.append((y1, y2, level))
        x0 = max_x + level * (hline_length + base_offset)
        # Draw bracket lines
        ax.plot([x0, x0 + hline_length], [y1, y1], lw=1.5, c='darkred')
        ax.plot([x0, x0 + hline_length], [y2, y2], lw=1.5, c='darkred')
        ax.plot([x0 + hline_length, x0 + hline_length], [y1, y2], lw=1.5, c='darkred')
        asterisk = '*' if p_val < significance_level else ''
        ax.text(x0 + hline_length + text_offset, (y1 + y2) / 2,
                f'p = {p_val:.3f}{asterisk}', 
                ha='left', va='center', fontsize=10, color='darkred', rotation=270)
    
    plt.subplots_adjust(right=0.85)

def create_violin_plot(df, mw_results, kruskal_stat, kruskal_p,
                       group_col='Pair', value_col='Spearman_Correlation'):
    fig, ax = plt.subplots(figsize=(10, 6), dpi=600)
    plt.style.use('seaborn-v0_8-whitegrid')
    sns.set_context("paper", font_scale=1.5)

    order = [
        "B_FACTOR_1 vs EXP_FRUST_1",
        "B_FACTOR_2 vs EXP_FRUST_1",
        "B_FACTOR_1 vs EXP_FRUST_2",
        "B_FACTOR_2 vs EXP_FRUST_2",
        "B_FACTOR_1 vs EVOL_FRUST",
        "B_FACTOR_2 vs EVOL_FRUST"
    ]

    palette_colors = sns.color_palette("Blues", n_colors=len(order))
    palette = dict(zip(order, palette_colors))


    ax.grid(True, axis='x', linestyle='--', alpha=0.7, zorder=0)

    # vertical dashed lines at each group's mean
    for pair in order:
        mean_val = df[df[group_col] == pair][value_col].mean()
        ax.axvline(x=mean_val,
                   color=palette[pair],
                   linestyle=':',
                   linewidth=2,
                   alpha=0.8,
                   zorder=1)

    # draw the violin plot with our new blue gradient palette
    sns.violinplot(
        data=df,
        x=value_col,
        y=group_col,
        order=order,
        inner='box',
        palette=palette,
        linewidth=1.5,
        zorder=2,
        ax=ax
    )

    ax.set_title('Distribution of Spearman Correlations\n(B-factor vs. Frustration Metrics)', 
                 pad=20, fontsize=16, fontweight='bold')
    ax.set_xlabel('Spearman Correlation', labelpad=15, fontsize=14, fontweight='bold')
    ax.set_ylabel('', labelpad=15, fontsize=14, fontweight='bold')
    ax.tick_params(axis='both', which='major', labelsize=12)

    for spine in ax.spines.values():
        spine.set_linewidth(1.5)

    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15, left=0.15, right=0.85, top=0.85)

    # annotate significant pairwise MW tests
    add_all_mw_annotations(ax, mw_results, order, significance_level=0.05)

    fig.text(0.01, 0.01, 
             'Annotations: Only significant pairwise Mann–Whitney U tests (Bonferroni corrected) are shown as brackets to the right.',
             ha='left', va='bottom', fontsize=8, style='italic')

    # Kruskal–Wallis box
    if kruskal_p < 0.05:
        significance_comment = "Distributions are significantly different."
    else:
        significance_comment = "No significant differences in distributions."
    kruskal_text = f'Kruskal–Wallis Test:\nH = {kruskal_stat:.2f}, p = {kruskal_p:.3e}\n{significance_comment}'
    props = dict(boxstyle='round', facecolor='white', alpha=0.0, linewidth=1.5)
    fig.text(0.00, 0.95, kruskal_text, fontsize=12,
             verticalalignment='top', bbox=props, ha='left')

    return plt
    
def test_distribution_equality(df):
    """
    Perform Kruskal–Wallis test to assess equality of distributions across pairs.
    Returns the test statistic and p-value.
    """
    groups = [group['Spearman_Correlation'].dropna() for name, group in df.groupby('Pair')]
    stat, p_value = kruskal(*groups)
    print("Kruskal–Wallis Test for Equality of Distributions")
    print(f"Statistic: {stat:.4f}, p-value: {p_value:.4e}")
    if p_value < 0.05:
        print("Result: Significant differences (reject H0)")
    else:
        print("Result: No significant differences (fail to reject H0)")
    return stat, p_value

def main():
    # Set your data directory 
    DATA_DIR = ''
    results_df = process_data(DATA_DIR)
    if results_df.empty:
        print("No valid data found in the specified directory.")
        return
    
    print("\nSummary Statistics:")
    summary_stats = results_df.groupby('Pair')['Spearman_Correlation'].agg(['mean', 'std', 'count'])
    print(summary_stats)
    
    print("\n--- Distribution Equality Test (Kruskal–Wallis Test) ---")
    kruskal_stat, kruskal_p = test_distribution_equality(results_df)
    
    print("\n--- Pairwise Distribution Comparisons (Mann–Whitney U Tests) with Bonferroni Correction ---")
    mw_results = pairwise_mannwhitney_corrected(
        df=results_df,
        group_col='Pair',
        value_col='Spearman_Correlation',
        alpha=0.05,
        correction='bonferroni'
    )
    print(mw_results)
    
    plt_obj = create_violin_plot(results_df, mw_results, kruskal_stat, kruskal_p)
    plt_obj.show()

if __name__ == "__main__":
    main()

Generate pairwise spearman difference violin plots (figure 6)

In [None]:
import os
import pandas as pd
import numpy as np
from scipy.stats import spearmanr, wilcoxon
import seaborn as sns
import matplotlib.pyplot as plt

# -----------------------------
# Step 1. Functions to compute Spearman correlations per file
# -----------------------------

def compute_spearman(x, y):
    """
    Compute Spearman correlation with improved handling of constant arrays.
    Returns None if correlation cannot be computed.
    """
    mask = x.notna() & y.notna()
    if mask.sum() > 1:
        x_valid = x[mask]
        y_valid = y[mask]
        if x_valid.std() == 0 or y_valid.std() == 0:
            return None
        try:
            corr, _ = spearmanr(x_valid, y_valid)
            return corr if not np.isnan(corr) else None
        except Exception:
            return None
    return None

def process_data(data_dir):
    """
    Process all data files in the specified directory and compute six Spearman correlations per protein.
    
    Expects files with at least the following columns:
      - B_FACTOR_1
      - B_FACTOR_2
      - EXP_FRUST_1
      - EXP_FRUST_2
      - EVOL_FRUST
      
    For each protein (file), the following correlations are computed:
      1. B_FACTOR_1 vs EXP_FRUST_1
      2. B_FACTOR_1 vs EXP_FRUST_2
      3. B_FACTOR_2 vs EXP_FRUST_1
      4. B_FACTOR_2 vs EXP_FRUST_2
      5. B_FACTOR_1 vs EVOL_FRUST
      6. B_FACTOR_2 vs EVOL_FRUST
      
    Returns a DataFrame with columns: Protein, Pair, Spearman_Correlation.
    """
    results = []
    required_cols = ['B_FACTOR_1', 'B_FACTOR_2', 'EXP_FRUST_1', 'EXP_FRUST_2', 'EVOL_FRUST']
    
    for filename in os.listdir(data_dir):
        if filename.endswith(('.txt', '.csv')):
            filepath = os.path.join(data_dir, filename)
            try:
                sep = '\t' if filename.endswith('.txt') else ','
                df = pd.read_csv(filepath, sep=sep, na_values=['n/a', 'N/A'])
                missing = [col for col in required_cols if col not in df.columns]
                if missing:
                    print(f"Skipping {filename}: Missing columns {missing}")
                    continue
                for col in required_cols:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                
                protein_id = filename  
                
                b_factors = ['B_FACTOR_1', 'B_FACTOR_2']
                exp_frust = ['EXP_FRUST_1', 'EXP_FRUST_2']
                evol_frust = ['EVOL_FRUST']
                
                # Compute experimental correlations (for each B-factor with each experimental frustration)
                for b in b_factors:
                    for f in exp_frust:
                        corr = compute_spearman(df[b], df[f])
                        if corr is not None:
                            results.append({
                                'Protein': protein_id,
                                'Pair': f"{b} vs {f}",
                                'Spearman_Correlation': corr
                            })
                # Compute evolutionary correlations (for each B-factor with EVOL_FRUST)
                for b in b_factors:
                    for f in evol_frust:
                        corr = compute_spearman(df[b], df[f])
                        if corr is not None:
                            results.append({
                                'Protein': protein_id,
                                'Pair': f"{b} vs {f}",
                                'Spearman_Correlation': corr
                            })
            except Exception as e:
                print(f"Error processing {filename}: {e}")
    
    return pd.DataFrame(results)

# -----------------------------
# Step 2. Reorganize the data and compute differences per protein
# -----------------------------

def compute_difference_df(corr_df):
    """
    Given a DataFrame with columns [Protein, Pair, Spearman_Correlation] (6 rows per protein),
    pivot it so that each protein becomes a row with columns for each Pair. Then compute
    the differences:
      For B_FACTOR_1:
          diff_B1_EXP1 = (B_FACTOR_1 vs EVOL_FRUST) - (B_FACTOR_1 vs EXP_FRUST_1)
          diff_B1_EXP2 = (B_FACTOR_1 vs EVOL_FRUST) - (B_FACTOR_1 vs EXP_FRUST_2)
      For B_FACTOR_2:
          diff_B2_EXP1 = (B_FACTOR_2 vs EVOL_FRUST) - (B_FACTOR_2 vs EXP_FRUST_1)
          diff_B2_EXP2 = (B_FACTOR_2 vs EVOL_FRUST) - (B_FACTOR_2 vs EXP_FRUST_2)
    Returns a long-format DataFrame with columns: Protein, Diff_Type, Difference.
    """
    pivot_df = corr_df.pivot(index='Protein', columns='Pair', values='Spearman_Correlation')
    needed = ["B_FACTOR_1 vs EXP_FRUST_1", "B_FACTOR_1 vs EXP_FRUST_2", "B_FACTOR_1 vs EVOL_FRUST",
              "B_FACTOR_2 vs EXP_FRUST_1", "B_FACTOR_2 vs EXP_FRUST_2", "B_FACTOR_2 vs EVOL_FRUST"]
    for col in needed:
        if col not in pivot_df.columns:
            pivot_df[col] = np.nan
    
    pivot_df['diff_B1_EXP1'] = pivot_df["B_FACTOR_1 vs EVOL_FRUST"] - pivot_df["B_FACTOR_1 vs EXP_FRUST_1"]
    pivot_df['diff_B1_EXP2'] = pivot_df["B_FACTOR_1 vs EVOL_FRUST"] - pivot_df["B_FACTOR_1 vs EXP_FRUST_2"]
    pivot_df['diff_B2_EXP1'] = pivot_df["B_FACTOR_2 vs EVOL_FRUST"] - pivot_df["B_FACTOR_2 vs EXP_FRUST_1"]
    pivot_df['diff_B2_EXP2'] = pivot_df["B_FACTOR_2 vs EVOL_FRUST"] - pivot_df["B_FACTOR_2 vs EXP_FRUST_2"]
    
    diff_cols = ['diff_B1_EXP1', 'diff_B1_EXP2', 'diff_B2_EXP1', 'diff_B2_EXP2']
    diff_long = pivot_df[diff_cols].reset_index().melt(id_vars="Protein", 
                                                       value_vars=diff_cols,
                                                       var_name="Diff_Type",
                                                       value_name="Difference")
    diff_long = diff_long.dropna(subset=["Difference"])
    return diff_long

# -----------------------------
# Step 3. Plotting and Statistical Testing (Horizontal Violins)
# -----------------------------

def plot_difference_violins(diff_long):
    """
    Plot horizontal violin plots for the difference distributions (one violin per Diff_Type).
    Draw mean lines, then the zero reference line *behind* the violins, and finally the violins.
    """
    sns.set(style="whitegrid")
    fig, ax = plt.subplots(figsize=(10, 7), dpi=300)

    # 1) Define the exact order of the four categories
    order = ["diff_B1_EXP1", "diff_B1_EXP2", "diff_B2_EXP1", "diff_B2_EXP2"]
    
    # 2) Generate a blue gradient palette of length 4
    palette_colors = sns.color_palette("Blues", n_colors=len(order))
    palette = dict(zip(order, palette_colors))

    # 3) Draw each group's mean as a dashed line (zorder=1)
    for diff_type in order:
        mean_val = diff_long.loc[
            diff_long["Diff_Type"] == diff_type, "Difference"
        ].mean()
        ax.axvline(
            x=mean_val,
            color=palette[diff_type],
            linestyle=":",
            linewidth=2,
            alpha=0.8,
            zorder=1
        )

    # 4) Draw the zero reference line BEFORE the violins, at the lowest zorder
    ax.axvline(0, ls="--", color="gray", lw=1, zorder=0)

    # 5) Now draw violins on top (zorder=2)
    sns.violinplot(
        data=diff_long,
        x="Difference",
        y="Diff_Type",
        order=order,
        inner="box",
        palette=palette,
        zorder=2,
        ax=ax
    )

    # 6) Title, labels, ticks
    ax.set_title("Paired Spearman Correlation Differences\n(Evolutionary - Experimental)", 
                 fontsize=16, fontweight="bold", pad=15)
    ax.set_xlabel("Difference in Spearman Correlation", fontsize=14, fontweight="bold")
    ax.set_ylabel("", fontsize=14, fontweight="bold")
    ax.tick_params(axis="x", labelsize=12)
    ax.set_yticklabels(order, fontsize=12)

    # 7) Wilcoxon tests and p‐value annotations
    stat_results = {}
    for diff_type in order:
        group = diff_long.loc[diff_long["Diff_Type"] == diff_type, "Difference"]
        try:
            stat, p = wilcoxon(group - 0)
        except ValueError:
            stat, p = np.nan, np.nan
        stat_results[diff_type] = (stat, p)

        median_val = group.median()
        y_pos = order.index(diff_type)
        ax.text(
            median_val,
            y_pos + 0.25,
            f"p = {p:.3f}" if not np.isnan(p) else "n.s.",
            ha="center",
            va="bottom",
            fontsize=12
        )

    # 8) Caption
    fig.text(
        0.5, 0.02, 
        "P-values (from Wilcoxon signed-rank tests) indicate whether the median difference\nis significantly different from 0.",
        ha="center", fontsize=10, style="italic"
    )

    plt.tight_layout(rect=[0, 0.05, 1, 1])
    return fig, stat_results
# -----------------------------
# Main driver function
# -----------------------------

def main():
    # Set your data directory 
    DATA_DIR = ''  
    
    corr_df = process_data(DATA_DIR)
    if corr_df.empty:
        print("No valid data found in the specified directory.")
        return
    
    print("\nPer-protein Spearman correlations:")
    print(corr_df.groupby("Protein").size())
    
    # Compute differences per protein (4 differences per protein)
    diff_long = compute_difference_df(corr_df)
    print("\nDifference data (first few rows):")
    print(diff_long.head())
    
    # Plot horizontal violin plots of the difference distributions and perform significance tests
    fig, stat_results = plot_difference_violins(diff_long)
    plt.show()
    
    # Print test results to the console
    print("\nWilcoxon test results for each difference type:")
    for diff_type, (stat, p) in stat_results.items():
        print(f"{diff_type}: statistic = {stat:.3f}, p-value = {p:.3e}")

if __name__ == "__main__":
    main()

Plotting for 20F set of proteins supplimental figures 21-40

In [None]:
import os  
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
from scipy.stats import spearmanr
from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
import matplotlib.gridspec as gridspec
from matplotlib.backends.backend_pdf import PdfPages 

########################################
# 1) BASIC SETUP AND HELPER FUNCTIONS  #
########################################

def read_frustration_file(filepath, file_type='summary'):
    """
    Read and process frustration data from a summary file containing both REP1 and REP2.
    Now modified to work with CSV input.
    """
    if file_type == 'summary':
        
        df = pd.read_csv(filepath, na_values=['n/a'])
        
        # Extract REP1 data
        rep1_columns = {
            'AlnIndex': 'AlnIndex',
            'Residue': 'Residue',
            'SecondaryStructure_EXP_FRUST_1': 'SecondaryStructure',
            'B_FACTOR_1': 'B_Factor',
            'EXP_FRUST_1': 'ExpFrust'
        }
        rep1_present_cols = [col for col in rep1_columns.keys() if col in df.columns]
        rep1_df = df[rep1_present_cols].rename(columns={k: v for k, v in rep1_columns.items() if k in rep1_present_cols})
        for v in rep1_columns.values():
            if v not in rep1_df.columns:
                rep1_df[v] = np.nan
        
        # Extract REP2 data
        rep2_columns = {
            'AlnIndex': 'AlnIndex',
            'Residue': 'Residue',
            'SecondaryStructure_EXP_FRUST_2': 'SecondaryStructure',
            'B_FACTOR_2': 'B_Factor',
            'EXP_FRUST_2': 'ExpFrust'
        }
        rep2_present_cols = [col for col in rep2_columns.keys() if col in df.columns]
        rep2_df = df[rep2_present_cols].rename(columns={k: v for k, v in rep2_columns.items() if k in rep2_present_cols})
        for v in rep2_columns.values():
            if v not in rep2_df.columns:
                rep2_df[v] = np.nan
        
        # Extract evolutionary frustration
        if 'EVOL_FRUST' in df.columns:
            evol_frust = pd.to_numeric(df['EVOL_FRUST'], errors='coerce')
        else:
            evol_frust = pd.Series([np.nan]*len(df))
        
        numeric_cols = ['B_Factor', 'ExpFrust']
        for col in numeric_cols:
            rep1_df[col] = pd.to_numeric(rep1_df[col], errors='coerce')
            rep2_df[col] = pd.to_numeric(rep2_df[col], errors='coerce')
        
        evol_frust = pd.to_numeric(evol_frust, errors='coerce')
        
        return rep1_df, rep2_df, evol_frust
    else:
        raise ValueError("Unsupported file type. Only 'summary' is supported.")

def lowess_smoothing(x, y, frac=0.1, it=3):
    """
    Apply LOWESS smoothing to the data.
    """
    mask = ~(pd.isna(x) | pd.isna(y))
    x_clean = x[mask]
    y_clean = y[mask]
    
    if len(x_clean) == 0:
        return np.array([]), np.array([])
    
    lowess = sm.nonparametric.lowess
    z = lowess(y_clean, x_clean, frac=frac, it=it, return_sorted=False)
    return x_clean, z

def create_gradient_line(x, y, values, cmap, linestyle='-', linewidth=3):
    """
    Create a gradient line as a collection of segments.
    """
    if len(x) < 2:
        return None
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    lc = LineCollection(segments, cmap=cmap, linestyle=linestyle, linewidth=linewidth)
    lc.set_array(values[:-1])
    return lc

def create_dashed_gradient_line(x, y, values, cmap, linewidth=3, dash_on=10, dash_off=5):
    """
    Create a single dashed gradient line.
    """
    if len(x) < 2:
        return None
    x = np.asarray(x)
    y = np.asarray(y)
    v = np.asarray(values)
    
    dx = np.diff(x)
    dy = np.diff(y)
    seg_lengths = np.sqrt(dx*dx + dy*dy)
    dist = np.concatenate(([0], np.cumsum(seg_lengths)))
    
    def color_at_distance(d_val):
        return np.interp(d_val, dist, v)
    
    pattern_length = dash_on + dash_off
    
    def get_on_subsegments(s1, s2):
        segments_on = []
        current = s1
        while current < s2:
            cycle_pos = (current % pattern_length)
            cycle_on_end = current - cycle_pos + dash_on
            
            if cycle_on_end <= current:
                next_cycle_start = current - cycle_pos + pattern_length
                current = next_cycle_start
                continue
            
            seg_start = current
            seg_end = min(cycle_on_end, s2)
            if seg_end > seg_start:
                segments_on.append((seg_start, seg_end))
            
            current = seg_end
            cycle_off_end = current - (current % pattern_length) + pattern_length
            if cycle_off_end < current:
                cycle_off_end += pattern_length
            
            current = max(current, min(cycle_off_end, s2))
        
        return segments_on

    all_on_segments = []
    color_values = []
    
    for i in range(len(x) - 1):
        s1 = dist[i]
        s2 = dist[i+1]
        if s2 == s1:
            continue
        
        on_subs = get_on_subsegments(s1, s2)
        if not on_subs:
            continue
        
        for (s_on_start, s_on_end) in on_subs:
            t1 = (s_on_start - s1) / (s2 - s1)
            x1 = x[i] + t1 * (x[i+1] - x[i])
            y1 = y[i] + t1 * (y[i+1] - y[i])
            
            t2 = (s_on_end - s1) / (s2 - s1)
            x2 = x[i] + t2 * (x[i+1] - x[i])
            y2 = y[i] + t2 * (y[i+1] - y[i])
            
            mid = 0.5*(s_on_start + s_on_end)
            c_mid = color_at_distance(mid)
            
            all_on_segments.append([[x1, y1], [x2, y2]])
            color_values.append(c_mid)
    
    if not all_on_segments:
        return None
    
    lc = LineCollection(
        all_on_segments,
        cmap=cmap,
        norm=plt.Normalize(v.min(), v.max()),
        linewidth=linewidth,
        linestyles='solid'
    )
    lc.set_array(np.array(color_values))
    return lc

def create_custom_cmap(vmin, vmax):
    """
    Create a custom colormap that transitions through gray at zero.
    """
    total = abs(vmin) + abs(vmax)
    zero_pos = abs(vmin) / total if total != 0 else 0.5
    colors = [
        (0, '#0c1359'),
        (zero_pos, '#D0D0D0'),
        (1, '#f05b05')
    ]
    return LinearSegmentedColormap.from_list("custom", colors, N=100)

def read_binding_sites(exp_data_path):
    """
    Read binding sites from binding_sites.txt in the experimental_data directory.
    Returns a dictionary of residue numbers involved in binding.
    """
    binding_sites_file = os.path.join(exp_data_path, "binding_sites.txt")
    binding_residues = set()
    
    if not os.path.exists(binding_sites_file):
        return binding_residues
        
    try:
        with open(binding_sites_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                if line.strip().startswith(('ALA', 'CYS', 'ASP', 'GLU', 'PHE', 'GLY', 'HIS', 'ILE', 'LYS', 'LEU', 'MET', 'ASN', 'PRO', 'GLN', 'ARG', 'SER', 'THR', 'VAL', 'TRP', 'TYR')):
                    # Extract residue number - assuming format like "ALA 123"
                    parts = line.strip().split()
                    if len(parts) >= 2:
                        try:
                            residue_num = int(parts[1])
                            binding_residues.add(residue_num)
                        except ValueError:
                            continue
    except Exception as e:
        print(f"Error reading binding sites file: {e}")
    
    return binding_residues
    
def create_helix(x_start, width, height=0.5, frequency=2):
    """
    Create a helix representation for alpha helices.
    """
    num_points = int(width * 20)
    x = np.linspace(x_start, x_start + width, num_points)
    y = height * np.sin(2 * np.pi * frequency * (x - x_start) / width)
    return x, y

def create_arrow(x_start, width, height=0.5):
    """
    Create an arrow representation for beta sheets.
    """
    x = [x_start, x_start,
         x_start, x_start + 0.7*width,
         x_start + 0.7*width, x_start + width,
         x_start + 0.7*width,
         x_start + 0.7*width, x_start,
         x_start]
    y = [-height/2, height/2,
         -height/2, -height/2,
         -height/2, 0,
         height/2,
         height/2, height/2,
         -height/2]
    return x, y

def create_scatter_subplot(ax, x_data, y_data, color, title, xlabel, ylabel, marker='o'):
    """
    Create a scatter plot with rank correlation.
    """
    # Ensure x_data and y_data are Series with a default integer index.
    x_data = pd.Series(x_data).reset_index(drop=True)
    y_data = pd.Series(y_data).reset_index(drop=True)
    
    # For evolutionary frustration data, filter out zeros.
    if 'Evol' in title or 'Evolutionary' in title or 'Evolutionary' in ylabel:
        mask = y_data != 0
        x_data = x_data[mask].reset_index(drop=True)
        y_data = y_data[mask].reset_index(drop=True)
    
    mask = ~(pd.isna(x_data) | pd.isna(y_data))
    x_clean = x_data[mask]
    y_clean = y_data[mask]
    
    if len(x_clean) < 2:
        ax.text(0.5, 0.5, "Insufficient data", 
                ha='center', va='center', transform=ax.transAxes)
        ax.set_title(title, fontsize=16, pad=20)
        return
    
    try:
        x_rank = x_clean.rank()
        y_rank = y_clean.rank()
        rho, pval = spearmanr(x_clean, y_clean)
        sns.scatterplot(x=x_rank, y=y_rank, ax=ax, color=color, alpha=0.6, marker=marker, linewidth=2, s=100)
        if len(x_rank.unique()) > 1 and len(y_rank.unique()) > 1:
            sns.regplot(x=x_rank, y=y_rank, ax=ax, scatter=False, 
                       color='gray', line_kws={'linestyle': '--', 'alpha': 0.8})
        corr_text = f"ρ = {rho:.3f}\np = {pval:.2e}"
        ax.text(0.05, 0.95, corr_text, transform=ax.transAxes,
                verticalalignment='top', fontsize=12, color='black',
                bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
        ax.set_title(title, fontsize=16, pad=20)
        ax.set_xlabel(xlabel, fontsize=14)
        ax.set_ylabel(ylabel, fontsize=14)
        ax.tick_params(labelsize=12)
    except Exception as e:
        print(f"Warning: Error in scatter plot creation: {e}")
        ax.text(0.5, 0.5, "Error in plot creation", 
                ha='center', va='center', transform=ax.transAxes)

########################################
# 2) MAIN PLOTTING FUNCTION           #
########################################

def plot_frustration_comparison(summary_filepath, 
                                box_height_ratio=0.05, 
                                spacing_ratio=0.075, 
                                additional_space_ratio=0.30, 
                                box_padding_ratio=0.02, 
                                legend_separation_ratio=-0.05):
    """
    Create a comprehensive plot comparing protein frustration data for REP1 and REP2.
    """
    sns.set_style("whitegrid")
    plt.rcParams.update({
        'figure.figsize': (20, 50),
        'font.size': 14,
        'axes.labelsize': 14,
        'axes.titlesize': 16
    })
    
    rep1_data, rep2_data, evol_frust = read_frustration_file(summary_filepath)
    fig = plt.figure(figsize=(20, 50))
    # Update the grid to 9 rows (adding two new full-width rows for the Spearman scatter plots)
    gs = gridspec.GridSpec(9, 2, 
                          height_ratios=[3, 2, 2, 2, 3, 2, 2, 2, 2],
                          width_ratios=[1, 1], 
                          hspace=0.4, 
                          wspace=0.3)
                          
    ax_summary = fig.add_subplot(gs[4, :])
    ax_main = fig.add_subplot(gs[0, :])
    
    has_ss = 'SecondaryStructure' in rep1_data.columns and 'SecondaryStructure' in rep2_data.columns

    merged_data = rep1_data.merge(rep2_data, on='AlnIndex', suffixes=('_REP1', '_REP2'))
    merged_data['EvolFrust'] = evol_frust
    # Reset the index so that later boolean masks align properly.
    merged_data = merged_data.reset_index(drop=True)
    
    complete_data_mask = (
        ~merged_data['ExpFrust_REP1'].isna() &
        ~merged_data['ExpFrust_REP2'].isna() &
        ~merged_data['EvolFrust'].isna() &
        ~merged_data['B_Factor_REP1'].isna() &
        ~merged_data['B_Factor_REP2'].isna()
    )
    
    merged_data_filtered = merged_data[complete_data_mask].reset_index(drop=True)
    if merged_data_filtered.empty or len(merged_data_filtered) < 5:
        raise ValueError("Insufficient complete data to generate plot.")
    
    rep1_x_exp, rep1_smooth_exp = lowess_smoothing(merged_data_filtered['AlnIndex'], 
                                                merged_data_filtered['ExpFrust_REP1'])
    rep2_x_exp, rep2_smooth_exp = lowess_smoothing(merged_data_filtered['AlnIndex'], 
                                                merged_data_filtered['ExpFrust_REP2'])
    evol_x, evol_smooth = lowess_smoothing(merged_data_filtered['AlnIndex'], 
                                          merged_data_filtered['EvolFrust'])
    
    rep1_x_bf, rep1_smooth_bf = lowess_smoothing(merged_data_filtered['AlnIndex'], 
                                                merged_data_filtered['B_Factor_REP1'])
    rep2_x_bf, rep2_smooth_bf = lowess_smoothing(merged_data_filtered['AlnIndex'], 
                                                merged_data_filtered['B_Factor_REP2'])
    
    default_y_min, default_y_max = -2, 2
    all_y = np.concatenate([rep1_smooth_exp, rep2_smooth_exp, evol_smooth])
    finite_mask = np.isfinite(all_y)
    
    try:
        if np.any(finite_mask):
            y_min = float(np.nanmin(all_y[finite_mask]))
            y_max = float(np.nanmax(all_y[finite_mask]))
            if not (np.isfinite(y_min) and np.isfinite(y_max)):
                y_min, y_max = default_y_min, default_y_max
        else:
            y_min, y_max = default_y_min, default_y_max
            
        y_range = y_max - y_min
        y_padding = y_range * 0.05
        plot_y_min = y_min - y_padding
        plot_y_max = y_max + y_padding + additional_space_ratio * y_range
        
        if legend_separation_ratio < 0:
            plot_y_min += y_range * legend_separation_ratio
        elif legend_separation_ratio > 0:
            plot_y_max += y_range * legend_separation_ratio
        
        if not (np.isfinite(plot_y_min) and np.isfinite(plot_y_max)):
            plot_y_min, plot_y_max = default_y_min, default_y_max
            
    except Exception as e:
        print(f"Error calculating plot limits: {e}")
        plot_y_min, plot_y_max = default_y_min, default_y_max

    if has_ss:
        y_max_extended = plot_y_max + (y_max - y_min) * additional_space_ratio
        ax_main.set_ylim(plot_y_min, y_max_extended)
        ss_colors = {
            'A': '#800080',
            'B': '#008080',
            'O': '#808080'
        }
        box_height = y_range * box_height_ratio
        spacing = y_range * spacing_ratio
        box_padding = y_range * box_padding_ratio
        
        rep1_box_bottom = y_max_extended - spacing - box_height  
        rep2_box_bottom = rep1_box_bottom - spacing - box_height  
        
        ax_main.add_patch(Rectangle((merged_data_filtered['AlnIndex'].min(), rep1_box_bottom),
                                     merged_data_filtered['AlnIndex'].max() - merged_data_filtered['AlnIndex'].min(),
                                     box_height + box_padding,
                                     facecolor='#f5f5f5',
                                     edgecolor='#d3d3d3',
                                     alpha=1.0,
                                     zorder=2))
        
        ax_main.add_patch(Rectangle((merged_data_filtered['AlnIndex'].min(), rep2_box_bottom),
                                     merged_data_filtered['AlnIndex'].max() - merged_data_filtered['AlnIndex'].min(),
                                     box_height + box_padding,
                                     facecolor='#f5f5f5',
                                     edgecolor='#d3d3d3',
                                     alpha=1.0,
                                     zorder=2))
        
        ss_symbol_height = box_height / 2  
        prev_ss_rep1 = None
        start_idx_rep1 = None
        
        for idx, ss in zip(merged_data_filtered['AlnIndex'], merged_data_filtered['SecondaryStructure_REP1']):
            if pd.isna(ss):
                continue
            if ss != prev_ss_rep1:
                if prev_ss_rep1 is not None:
                    width = idx - start_idx_rep1
                    if width <= 0:
                        width = 1
                    if prev_ss_rep1 == 'A':
                        x_helix, y_helix = create_helix(start_idx_rep1, width, height=ss_symbol_height)
                        ax_main.plot(x_helix, y_helix + rep1_box_bottom + (box_height + box_padding)/2, 
                                    color=ss_colors.get('A', '#800080'), linewidth=2, zorder=4)
                    elif prev_ss_rep1 == 'B':
                        x_arrow, y_arrow = create_arrow(start_idx_rep1, width, height=ss_symbol_height)
                        ax_main.plot(x_arrow, np.array(y_arrow) + rep1_box_bottom + (box_height + box_padding)/2, 
                                    color=ss_colors.get('B', '#008080'), linewidth=2, zorder=4)
                    else:
                        ax_main.plot([start_idx_rep1, idx], 
                                    [rep1_box_bottom + (box_height + box_padding)/2, rep1_box_bottom + (box_height + box_padding)/2], 
                                    color=ss_colors.get('O', '#808080'), linewidth=1, zorder=4)
                start_idx_rep1 = idx
                prev_ss_rep1 = ss
        
        if prev_ss_rep1 is not None:
            width = merged_data_filtered['AlnIndex'].iloc[-1] - start_idx_rep1 + 1
            if width <= 0:
                width = 1
            if prev_ss_rep1 == 'A':
                x_helix, y_helix = create_helix(start_idx_rep1, width, height=ss_symbol_height)
                ax_main.plot(x_helix, y_helix + rep1_box_bottom + (box_height + box_padding)/2, 
                            color=ss_colors.get('A', '#800080'), linewidth=2, zorder=4)
            elif prev_ss_rep1 == 'B':
                x_arrow, y_arrow = create_arrow(start_idx_rep1, width, height=ss_symbol_height)
                ax_main.plot(x_arrow, np.array(y_arrow) + rep1_box_bottom + (box_height + box_padding)/2, 
                            color=ss_colors.get('B', '#008080'), linewidth=2, zorder=4)
            else:
                ax_main.plot([start_idx_rep1, merged_data_filtered['AlnIndex'].iloc[-1]], 
                            [rep1_box_bottom + (box_height + box_padding)/2, rep1_box_bottom + (box_height + box_padding)/2], 
                            color=ss_colors.get('O', '#808080'), linewidth=1, zorder=4)
        
        prev_ss_rep2 = None
        start_idx_rep2 = None
        
        for idx, ss in zip(merged_data_filtered['AlnIndex'], merged_data_filtered['SecondaryStructure_REP2']):
            if pd.isna(ss):
                continue
            if ss != prev_ss_rep2:
                if prev_ss_rep2 is not None:
                    width = idx - start_idx_rep2
                    if width <= 0:
                        width = 1
                    if prev_ss_rep2 == 'A':
                        x_helix, y_helix = create_helix(start_idx_rep2, width, height=ss_symbol_height)
                        ax_main.plot(x_helix, y_helix + rep2_box_bottom + (box_height + box_padding)/2, 
                                    color=ss_colors.get('A', '#800080'), linewidth=2, zorder=3)
                    elif prev_ss_rep2 == 'B':
                        x_arrow, y_arrow = create_arrow(start_idx_rep2, width, height=ss_symbol_height)
                        ax_main.plot(x_arrow, np.array(y_arrow) + rep2_box_bottom + (box_height + box_padding)/2, 
                                    color=ss_colors.get('B', '#008080'), linewidth=2, zorder=3)
                    else:
                        ax_main.plot([start_idx_rep2, idx], 
                                    [rep2_box_bottom + (box_height + box_padding)/2, rep2_box_bottom + (box_height + box_padding)/2], 
                                    color=ss_colors.get('O', '#808080'), linewidth=1, zorder=3)
                start_idx_rep2 = idx
                prev_ss_rep2 = ss
        
        if prev_ss_rep2 is not None:
            width = merged_data_filtered['AlnIndex'].iloc[-1] - start_idx_rep2 + 1
            if width <= 0:
                width = 1
            if prev_ss_rep2 == 'A':
                x_helix, y_helix = create_helix(start_idx_rep2, width, height=ss_symbol_height)
                ax_main.plot(x_helix, y_helix + rep2_box_bottom + (box_height + box_padding)/2, 
                            color=ss_colors.get('A', '#800080'), linewidth=2, zorder=3)
            elif prev_ss_rep2 == 'B':
                x_arrow, y_arrow = create_arrow(start_idx_rep2, width, height=ss_symbol_height)
                ax_main.plot(x_arrow, np.array(y_arrow) + rep2_box_bottom + (box_height + box_padding)/2, 
                            color=ss_colors.get('B', '#008080'), linewidth=2, zorder=3)
            else:
                ax_main.plot([start_idx_rep2, merged_data_filtered['AlnIndex'].iloc[-1]], 
                            [rep2_box_bottom + (box_height + box_padding)/2, rep2_box_bottom + (box_height + box_padding)/2], 
                            color=ss_colors.get('O', '#808080'), linewidth=1, zorder=3)
        
        x_min = merged_data_filtered['AlnIndex'].min()
        x_max = merged_data_filtered['AlnIndex'].max()
        x_label = x_min + 0.02 * (x_max - x_min)
        legend_y = 0.02  
        
        ax_main.text(x_label, rep1_box_bottom + box_height + box_padding, 'REP1', 
                     horizontalalignment='left', verticalalignment='bottom', fontsize=14, color='black', rotation=0, zorder=5)
        ax_main.text(x_label, rep2_box_bottom + box_height + box_padding, 'REP2', 
                     horizontalalignment='left', verticalalignment='bottom', fontsize=14, color='black', rotation=0, zorder=5)
    else:
        ax_main.set_ylim(plot_y_min, plot_y_max)

    cmap_rep1_exp = create_custom_cmap(rep1_smooth_exp.min(), rep1_smooth_exp.max())
    cmap_rep2_exp = create_custom_cmap(rep2_smooth_exp.min(), rep2_smooth_exp.max())
    cmap_evol = create_custom_cmap(evol_smooth.min(), evol_smooth.max())

    rep1_line_exp = create_gradient_line(
        rep1_x_exp, rep1_smooth_exp, rep1_smooth_exp, cmap_rep1_exp,
        linestyle='-', linewidth=4
    )
    if rep1_line_exp:
        ax_main.add_collection(rep1_line_exp)

    rep2_line_exp = create_dashed_gradient_line(
        rep2_x_exp, rep2_smooth_exp, rep2_smooth_exp, cmap_rep2_exp,
        linewidth=4, dash_on=2, dash_off=2
    )
    if rep2_line_exp:
        ax_main.add_collection(rep2_line_exp)

    evol_line = create_gradient_line(
        evol_x, evol_smooth, evol_smooth, cmap_evol,
        linestyle=':', linewidth=2
    )
    if evol_line:
        ax_main.add_collection(evol_line)

    x_min = merged_data_filtered['AlnIndex'].min()
    x_max = merged_data_filtered['AlnIndex'].max()
    ax_main.set_xlim(x_min, x_max)

    ax_main.set_title('Protein Frustration Comparison: REP1 vs REP2', 
                      fontsize=24, fontweight='bold', pad=20)
    ax_main.set_xlabel('Residue Number', fontsize=20, fontweight='bold')
    ax_main.set_ylabel('Frustration', fontsize=20, fontweight='bold')

    legends = []
    line_style_legend = [
        Line2D([0], [0], color='black', linestyle='-', linewidth=4, label='REP1 Experimental'),
        Line2D([0], [0], color='black', linestyle='--', linewidth=4, label='REP2 Experimental'),
        Line2D([0], [0], color='black', linestyle=':', linewidth=2, label='Evolutionary Frustration')
    ]
    legends.append(('Frustration Types', line_style_legend))

    frustration_legend = [
        Line2D([0], [0], color='#0c1359', label='Minimally Frustrated', linewidth=3),
        Line2D([0], [0], color='#D0D0D0', label='Neutral', linewidth=3),
        Line2D([0], [0], color='#f05b05', label='Highly Frustrated', linewidth=3)
    ]
    legends.append(('Frustration Level', frustration_legend))
    
    num_legends = len(legends)
    spacing = 1.0 / (num_legends + 1)
    
    for i, (title, handles) in enumerate(legends):
        x_pos = spacing * (i + 1)
        legend = ax_main.legend(
            handles=handles,
            title=title,
            fontsize=14,
            title_fontsize=16,
            loc='lower center',
            bbox_to_anchor=(x_pos, 0.02),
            frameon=True,
            ncol=1
        )
        legend.get_frame().set_facecolor('white')
        legend.get_frame().set_edgecolor('black')
        ax_main.add_artist(legend)

    category_colors = {
        'Experimental Frustration Rep1': '#8B0000',
        'Experimental Frustration Rep2': '#FF4444',
        'Evolutionary Frustration': '#4DAF4A'
    }
    marker_styles = {
        'REP1 B-Factor': 'o',
        'REP2 B-Factor': 'x'
    }

    # ── scatter plots ──
    # left column (REP1 B-Factor)
    mask_evol = merged_data['EvolFrust'] != 0
    temp_evol = merged_data.loc[mask_evol].reset_index(drop=True)

    ax1 = fig.add_subplot(gs[1, 0])
    create_scatter_subplot(
        ax1,
        merged_data['B_Factor_REP1'],
        merged_data['ExpFrust_REP1'],
        category_colors['Experimental Frustration Rep1'],
        'REP1 Experimental vs REP1 B-Factor',
        'REP1 B-Factor Rank',
        'REP1 Exp. Frustration Rank',
        marker=marker_styles['REP1 B-Factor']
    )

    ax2 = fig.add_subplot(gs[2, 0])
    create_scatter_subplot(
        ax2,
        merged_data['B_Factor_REP1'],
        merged_data['ExpFrust_REP2'],
        category_colors['Experimental Frustration Rep2'],
        'REP2 Experimental vs REP1 B-Factor',
        'REP1 B-Factor Rank',
        'REP2 Exp. Frustration Rank',
        marker=marker_styles['REP2 B-Factor']
    )

    ax3 = fig.add_subplot(gs[3, 0])
    create_scatter_subplot(
        ax3,
        temp_evol['B_Factor_REP1'],
        temp_evol['EvolFrust'],
        category_colors['Evolutionary Frustration'],
        'Evolutionary Frustration vs REP1 B-Factor',
        'REP1 B-Factor Rank',
        'Evolutionary Frustration Rank',
        marker='^'
    )

    # right column (REP2 B-Factor)
    ax4 = fig.add_subplot(gs[1, 1])
    create_scatter_subplot(
        ax4,
        merged_data['B_Factor_REP2'],
        merged_data['ExpFrust_REP1'],
        category_colors['Experimental Frustration Rep1'],
        'REP1 Experimental vs REP2 B-Factor',
        'REP2 B-Factor Rank',
        'REP1 Exp. Frustration Rank',
        marker=marker_styles['REP1 B-Factor']
    )

    ax5 = fig.add_subplot(gs[2, 1])
    create_scatter_subplot(
        ax5,
        merged_data['B_Factor_REP2'],
        merged_data['ExpFrust_REP2'],
        category_colors['Experimental Frustration Rep2'],
        'REP2 Experimental vs REP2 B-Factor',
        'REP2 B-Factor Rank',
        'REP2 Exp. Frustration Rank',
        marker=marker_styles['REP2 B-Factor']
    )

    ax6 = fig.add_subplot(gs[3, 1])
    create_scatter_subplot(
        ax6,
        temp_evol['B_Factor_REP2'],
        temp_evol['EvolFrust'],
        category_colors['Evolutionary Frustration'],
        'Evolutionary Frustration vs REP2 B-Factor',
        'REP2 B-Factor Rank',
        'Evolutionary Frustration Rank',
        marker='^'
    )
    # 4. Normalized Smoothed B-Factor Plot (Full Width)
    ax_bfactor_normalized = fig.add_subplot(gs[5, :])
    def normalize_series(series):
        min_val = series.min()
        max_val = series.max()
        if max_val - min_val == 0:
            return pd.Series([0.5] * len(series), index=series.index)
        return (series - min_val) / (max_val - min_val)

    rep1_normalized = normalize_series(pd.Series(rep1_smooth_bf, index=rep1_x_bf.index))
    rep2_normalized = normalize_series(pd.Series(rep2_smooth_bf, index=rep2_x_bf.index))

    ax_bfactor_normalized.plot(rep1_x_bf, rep1_normalized, label='REP1 Normalized B-Factor', color='blue', linewidth=2)
    ax_bfactor_normalized.plot(rep2_x_bf, rep2_normalized, label='REP2 Normalized B-Factor', color='orange', linewidth=2)
    ax_bfactor_normalized.set_title('Normalized Smoothed B-Factor Comparison: REP1 vs REP2', fontsize=18, pad=20)
    ax_bfactor_normalized.set_xlabel('Residue Number', fontsize=14)
    ax_bfactor_normalized.set_ylabel('Normalized B-Factor', fontsize=14)
    ax_bfactor_normalized.legend(loc='upper right', fontsize=12)
    ax_bfactor_normalized.grid(True, alpha=0.3)

    # 5. B-Factor Ranks Scatter Plot (Full Width)
    ax_bfactor_rank_scatter = fig.add_subplot(gs[6, :])
    rep1_bfactor_rank = merged_data_filtered['B_Factor_REP1'].rank()
    rep2_bfactor_rank = merged_data_filtered['B_Factor_REP2'].rank()
    rho_bfactor, pval_bfactor = spearmanr(rep1_bfactor_rank, rep2_bfactor_rank)
    sns.scatterplot(x=rep1_bfactor_rank, y=rep2_bfactor_rank, ax=ax_bfactor_rank_scatter,
                    color='purple', alpha=0.6, marker='D', edgecolor='black', linewidth=0.5, s=100)
    if len(rep1_bfactor_rank.unique()) > 1 and len(rep2_bfactor_rank.unique()) > 1:
        sns.regplot(x=rep1_bfactor_rank, y=rep2_bfactor_rank, ax=ax_bfactor_rank_scatter, scatter=False, 
                   color='gray', line_kws={'linestyle': '--', 'alpha': 0.8})
    corr_text_bfactor = f"ρ = {rho_bfactor:.3f}\np = {pval_bfactor:.2e}"
    ax_bfactor_rank_scatter.text(0.05, 0.95, corr_text_bfactor, transform=ax_bfactor_rank_scatter.transAxes,
                                  verticalalignment='top', fontsize=12, color='black',
                                  bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
    ax_bfactor_rank_scatter.set_title('B-Factor Rank Comparison: REP1 vs REP2', fontsize=18, pad=20)
    ax_bfactor_rank_scatter.set_xlabel('REP1 B-Factor Rank', fontsize=14)
    ax_bfactor_rank_scatter.set_ylabel('REP2 B-Factor Rank', fontsize=14)
    legend_elements_bfactor = [
        Line2D([0], [0], marker='D', color='w', markerfacecolor='purple', markersize=10, label='B-Factor Ranks')
    ]
    ax_bfactor_rank_scatter.legend(handles=legend_elements_bfactor, loc='upper right', fontsize=12)
    ax_bfactor_rank_scatter.grid(True, alpha=0.3)

    # NEW: Two Spearman Scatter Plots comparing Experimental Frustration to Evolutionary Frustration
    mask_evol_filtered = merged_data_filtered['EvolFrust'] != 0
    temp_evol3 = merged_data_filtered.loc[mask_evol_filtered].reset_index(drop=True)
    ax_spearman_rep1 = fig.add_subplot(gs[7, :])
    create_scatter_subplot(ax_spearman_rep1,
                           temp_evol3['ExpFrust_REP1'],
                           temp_evol3['EvolFrust'],
                           category_colors['Experimental Frustration Rep1'],
                           'Evolutionary Frustration vs REP1 Experimental',
                           'REP1 Experimental Frustration Rank',
                           'Evolutionary Frustration Rank',
                           marker='s')  # square marker

    ax_spearman_rep2 = fig.add_subplot(gs[8, :])
    create_scatter_subplot(ax_spearman_rep2,
                           temp_evol3['ExpFrust_REP2'],
                           temp_evol3['EvolFrust'],
                           category_colors['Experimental Frustration Rep2'],
                           'Evolutionary Frustration vs REP2 Experimental',
                           'REP2 Experimental Frustration Rank',
                           'Evolutionary Frustration Rank',
                           marker='^')  # triangle marker

    summary_correlations = []
    metrics = {
        'Experimental Frustration Rep1': merged_data_filtered['ExpFrust_REP1'],
        'Experimental Frustration Rep2': merged_data_filtered['ExpFrust_REP2'],
        'Evolutionary Frustration': merged_data_filtered['EvolFrust']
    }
    
    b_factors = {
        'REP1 B-Factor': merged_data_filtered['B_Factor_REP1'],
        'REP2 B-Factor': merged_data_filtered['B_Factor_REP2']
    }
    
    for metric_name, metric_series in metrics.items():
        for b_factor_name, b_factor_series in b_factors.items():
            # If metric is evolutionary frustration, use less strict filtering (AND INCLUDE ZEROS)
            if metric_name == "Evolutionary Frustration":
                # --- Start Modification ---
                # Use original data, not the pre-filtered series
                evol_series_orig = merged_data['EvolFrust'] # Use the full EvolFrust series including zeros

                # Get the corresponding original B-factor series based on the loop variable 'b_factor_name'
                if b_factor_name == 'REP1 B-Factor':
                    b_factor_series_orig = merged_data['B_Factor_REP1']
                else: # Assumes 'REP2 B-Factor'
                    b_factor_series_orig = merged_data['B_Factor_REP2']

                # ----> NO ZERO FILTERING <----
                # We directly use evol_series_orig and b_factor_series_orig

                # Let spearmanr handle the NaNs within the pair using nan_policy='omit'
                try:
                        # Pass the original series; nan_policy='omit' handles pairs with NaNs
                        rho, pval = spearmanr(evol_series_orig, b_factor_series_orig, nan_policy='omit')

                        # Check if rho is NaN (can happen if < 2 valid pairs after 'omit')
                        if not np.isnan(rho):
                            summary_correlations.append({
                                'Metric': metric_name,
                                'Spearman_rho': rho,
                                'B_Factor': b_factor_name,
                                'pval': pval
                            })
                        # else: spearmanr resulted in NaN (e.g., < 2 valid pairs), so don't append
                except ValueError:
                    # Should ideally be caught by nan_policy, but just in case
                    pass # Do not append if spearmanr fails
                # --- End Modification ---

            else: # Original block for Experimental Frustration 
                # Reset index just in case
                metric_series = metric_series.reset_index(drop=True)
                b_factor_series = b_factor_series.reset_index(drop=True)
                # Check for sufficient non-NaN pairs
                valid_mask = ~metric_series.isna() & ~b_factor_series.isna()
                if valid_mask.sum() >= 2:
                    rho, pval = spearmanr(metric_series[valid_mask], b_factor_series[valid_mask])
                    summary_correlations.append({
                        'Metric': metric_name,
                        'Spearman_rho': rho,
                        'B_Factor': b_factor_name,
                        'pval': pval
                    })

    x_positions = {
        'Experimental Frustration Rep1': 1,
        'Experimental Frustration Rep2': 2,
        'Evolutionary Frustration': 3
    }
    
    for corr in summary_correlations:
        x = x_positions[corr['Metric']]
        y = corr['Spearman_rho']
        pval = corr['pval']
        if corr['B_Factor'] == 'REP1 B-Factor':
            marker = 'o'
            color = category_colors[corr['Metric']]
        else:
            marker = 'x'
            color = category_colors[corr['Metric']]
        ax_summary.scatter(x, y, c=[color], marker=marker, s=200, linewidth=2)
        if pval < 0.05:
            ax_summary.scatter(x, y, facecolors='none', edgecolors='black', linewidth=2, s=500, marker='s', zorder=6)

    ax_summary.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax_summary.grid(True, alpha=0.3)
    ax_summary.set_xlim(0.5, 3.5)
    if summary_correlations:
        y_min_corr = min(corr['Spearman_rho'] for corr in summary_correlations)
        y_max_corr = max(corr['Spearman_rho'] for corr in summary_correlations)
        y_padding_corr = (y_max_corr - y_min_corr) if (y_max_corr - y_min_corr) != 0 else 1
    else:
        y_min_corr, y_max_corr = -1, 1
        y_padding_corr = 0.1
    ax_summary.set_ylim(y_min_corr - y_padding_corr, y_max_corr + y_padding_corr)
    ax_summary.set_xticks([1, 2, 3])
    ax_summary.set_xticklabels(['Experimental Frustration Rep1', 
                                'Experimental Frustration Rep2', 
                                'Evolutionary Frustration'], 
                               fontsize=12, rotation=0, ha='center')
    ax_summary.set_ylabel("Spearman's ρ", fontsize=16)
    ax_summary.spines['top'].set_visible(False)
    ax_summary.spines['right'].set_visible(False)
    ax_summary.yaxis.set_ticks_position('left')
    ax_summary.set_title('Summary of B-Factor Correlations', fontsize=18, pad=20)
    
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', markersize=10, label='Spearman with REP1 B-Factor'),
        Line2D([0], [0], marker='x', color='w', markeredgecolor='gray', markersize=10, label='Spearman with REP2 B-Factor'),
        Line2D([0], [0], marker='s', color='w', markerfacecolor='none', markeredgecolor='black', markersize=12, label='p < 0.05')
    ]
    ax_summary.legend(handles=legend_elements, 
                     loc='lower right',
                     fontsize=12,
                     frameon=True,
                     framealpha=0.9)
    
    for ax in [ax1, ax2, ax3, ax4, ax5, ax6]:
        ax.set_aspect('equal', adjustable='box')
    
    plt.tight_layout()
    return fig

########################################
# 3) PROCESSING ALL SUBDIRECTORIES     #
########################################

# 3) PROCESSING ALL SUBDIRECTORIES
def process_all_subdirectories(root_dir,
                               summary_filename="summary.csv",
                               box_height_ratio=0.1,
                               spacing_ratio=0.15,
                               additional_space_ratio=0.295,
                               box_padding_ratio=0.05,
                               legend_separation_ratio=-0.75):
    """
    Iterate through all immediate subdirectories of 'root_dir'. For each subdirectory,
    if a file named 'summary_filename' is found, generate a plot, number it S21–S40,
    and save both individual and aggregate PDFs.
    """
    big_figures = []  # List to store each generated figure
    all_corrs   = []  # List to store Spearman ρ for each protein (subdirectory)

    # Start numbering supplemental figures at S21
    figure_number = 21

    for entry in os.listdir(root_dir):
        subdir_path = os.path.join(root_dir, entry)
        if not os.path.isdir(subdir_path):
            continue

        # remove any old "frustration" PDFs
        for filename in os.listdir(subdir_path):
            if filename.lower().endswith('.pdf') and "frustration" in filename.lower():
                try:
                    os.remove(os.path.join(subdir_path, filename))
                except Exception:
                    pass

        summary_filepath = os.path.join(subdir_path, summary_filename)
        if not os.path.exists(summary_filepath):
            print(f"Skipping '{entry}': summary file not found.")
            continue

        try:
            # generate the frustration comparison figure
            fig = plot_frustration_comparison(
                summary_filepath,
                box_height_ratio=box_height_ratio,
                spacing_ratio=spacing_ratio,
                additional_space_ratio=additional_space_ratio,
                box_padding_ratio=box_padding_ratio,
                legend_separation_ratio=legend_separation_ratio
            )

            # give it a numbered suptitle, e.g. "Figure S21. 1abcD"
            fig.suptitle(
                f"Figure S{figure_number}. {entry}",
                fontsize=20,
                weight='bold',
                y=0.9
            )
            figure_number += 1

            # save the individual PDF in the subdirectory
            output_filename = f"{entry}_frustration_comparison.pdf"
            output_path = os.path.join(subdir_path, output_filename)
            fig.savefig(output_path, dpi=600, bbox_inches='tight')
            print(f"Plot saved successfully at: {output_path}")

            big_figures.append(fig)

        except Exception as e:
            print(f"Skipping '{entry}' due to error: {e}")
            continue


    # save all individual figures into one big PDF
    if big_figures:
        all_plots_path = os.path.join(root_dir, "all_plots.pdf")
        with PdfPages(all_plots_path) as pdf:
            for fig in big_figures:
                pdf.savefig(fig, bbox_inches='tight')
                plt.close(fig)
        print(f"All plots saved in one PDF at: {all_plots_path}")

    
    # ---- NEW: Create and save histogram of the Spearman correlations ----
    if all_corrs:
        plt.figure(figsize=(8, 6))
        plt.hist(all_corrs, bins=10, edgecolor='black')
        plt.xlabel("Spearman's ρ", fontsize=14)
        plt.ylabel("Frequency", fontsize=14)
        plt.title("Histogram of Spearman Correlation Coefficients\n(Evolutionary vs Experimental Frustration)", fontsize=16)
        histogram_path = os.path.join(root_dir, "correlation_histogram.pdf")
        plt.savefig(histogram_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Histogram saved successfully at: {histogram_path}")
    else:
        print("No valid correlations computed to plot histogram.")

########################################
# 4) MAIN EXECUTION BLOCK              #
########################################

if __name__ == "__main__":
    # Specify the root directory that contains subdirectories to process.
    root_directory = ""  
    process_all_subdirectories(root_directory)