# CRYSTAL VS MODEL

In [None]:
import os
import pandas as pd
import os
import subprocess
import re
os.chdir("/Users/alexascunceparis/Desktop/TCR/AF3-TCRpMHC")

## ANNOTATION

In [None]:
import os
import subprocess
from pathlib import Path

MEMORY = "5G"

def run_mir(pdb_dir,
            mir_path="./mir-1.0-SNAPSHOT.jar",
            output_dir="mir_output",
            arg="annotate-structures",  
            print_log=True):
    
    pdb_dir = Path(pdb_dir)
    pdb_list = list(pdb_dir.glob("*"))
    pdb_paths = " ".join(str(p) for p in pdb_list)

    # Crear el directorio de salida
    output_dir_path = output_dir
    os.makedirs(output_dir_path, exist_ok=True)

    cmd = f"java -Xmx{MEMORY} -cp {mir_path} com.milaboratory.mir.scripts.Examples {arg} -I {pdb_paths} -O {output_dir_path}/"

    try:
        result = subprocess.run(cmd, shell=True, check=True,
                                stdout=(None if print_log else subprocess.DEVNULL),
                                stderr=(None if print_log else subprocess.DEVNULL))
    except subprocess.CalledProcessError as e:
        raise RuntimeError(f"Failed to execute '{cmd}'") from e

# Verificar versión de Java
subprocess.run("java -version", shell=True)

pdb_dir = "./structures/crystals_pdb/"
output_dir = "./structures/structures_annotation/"

run_mir(pdb_dir=pdb_dir, output_dir=output_dir, arg="annotate-structures")

In [2]:
def parse_general_file(general_file):
    """
    Parses the general file and creates a dictionary mapping PDB IDs to specific chain information
    such as 'tcra_chain', 'tcrb_chain', 'peptide_chain', and 'mhc_chain'.
    
    :param general_file: Path to the general file.
    :return: A dictionary where keys are PDB IDs and values are dictionaries with chain information
    """
    # Read the general file into a pandas DataFrame
    df = pd.read_csv(general_file, sep='\t')
    pdb_dict = {}

    # Group by PDB ID and iterate through each group
    for pdb_id, group in df.groupby('pdb.id'):
        pdb_id = pdb_id.split('.')[0]  # Extract the PDB ID from the first part of the string
        chains = {
            'tcra_chain': None,
            'tcrb_chain': None,
            'peptide_chain': None,
            'mhc_chain': None
        }

        # Iterate through the rows in the group to assign chain types
        for _, row in group.iterrows():
            chain_id = row['chain.id']
            chain_type = row['chain.type']
            chain_component = row['chain.component']
            chain_supertype = row['chain.supertype']

            # Assign chain IDs based on their component and type
            if chain_component == 'TCR' and chain_type == 'TRA':
                chains['tcra_chain'] = chain_id
            elif chain_component == 'TCR' and chain_type == 'TRB':
                chains['tcrb_chain'] = chain_id
            elif chain_component == 'PEPTIDE':
                chains['peptide_chain'] = chain_id
            elif chain_component == 'MHC' and chain_supertype == 'MHCI' and chain_type == 'MHCa':
                chains['mhc_chain'] = chain_id
            elif chain_component == 'MHC' and chain_supertype == 'MHCI' and chain_type == 'MHCb':
                chains['b2_chain'] = chain_id
        
        # Add the chain information for this pdb_id to the dictionary
        pdb_dict[pdb_id] = chains

    return pdb_dict

chain_dict = parse_general_file('./structures/structures_annotation/general.txt')

# UTILS

- Cif to PDB with pymol
- Cif to PDB with BeEM
- Remove Headers, Hetatm...
- Remove Residues with letter (e.g 116A)
- Merge chains A (TCR), B (pMHC)

In [11]:
import os
import subprocess

def convert_cif_to_pdb_BeEM(input_dir, pdb_dir):
    """
    Converts CIF files located in nested seed folders to PDB format using BeEM.
    
    - input_dir: str, root folder containing 4-letter PDB ID folders
    - pdb_dir: str, output folder for resulting PDB files
    """
    beem_path = "./BeEM/BeEM"
    os.makedirs(pdb_dir, exist_ok=True)

    for folder in os.listdir(input_dir):
        if len(folder) == 4:  # Typical PDB ID
            pdb_id = folder
            folder_path = os.path.join(input_dir, folder)

            for folder2 in os.listdir(folder_path):
                if folder2.startswith("seed"):
                    model_number = folder2.split("-")[-1]
                    cif_path = os.path.join(folder_path, folder2, "model.cif")
                    output_path = os.path.join(pdb_dir, f"{pdb_id}_{model_number}.pdb")

                    if not os.path.exists(cif_path):
                        print(f"Missing CIF file: {cif_path}. Skipping.")
                        continue

                    command = f"{beem_path} -p={output_path} {cif_path}"
                    try:
                        subprocess.run(command, shell=True, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
                        print(f"Converted {cif_path} → {output_path}")
                    except subprocess.CalledProcessError as e:
                        print(f"Error converting {cif_path}: {e}")

def convert_cif_to_pdb(input_dir, pdb_dir):
    # Create folders 
    os.makedirs(pdb_dir, exist_ok=True)

    for folder in os.listdir(input_dir):
        if len(folder) == 4:  # PDB ID
            pdb_id = folder
            for folder2 in os.listdir(os.path.join(input_dir, folder)):
                if folder2.startswith("seed"):
                    model_number = folder2.split("-")[-1]
                    cif_file = "model.cif"
                    cif_path = os.path.join(input_dir, folder, folder2, cif_file)
                    output_path = os.path.join(pdb_dir, f"{pdb_id}_{model_number}.pdb")

                    pymol_command = f"pymol -c -d 'load {cif_path}; save {output_path}'"
                    try:
                        subprocess.run(pymol_command, shell=True, check=True)
                        print(f"Converted {cif_path} → {output_path}")
                    except subprocess.CalledProcessError as e:
                        print(f"Error converting {cif_file}: {e}")

def clean_pdb_file(pdb_file_path, output_file_path):
    with open(pdb_file_path, "r") as infile:
        lines = infile.readlines()

    cleaned_lines = []
    for line in lines:
        if line.startswith("ATOM"):
            residue_number = line[22:27].strip()
            # Exclude residue numbers with letter suffix
            if residue_number[-1].isalpha():
                continue
            cleaned_lines.append(line)
        # Ignore HETATM and all other lines

    with open(output_file_path, "w") as outfile:
        outfile.writelines(cleaned_lines)

def clean_all_pdbs(pdb_dir, cleaned_pdb_dir):
    # Create folders
    os.makedirs(cleaned_pdb_dir, exist_ok=True)
    for pdb_file in os.listdir(pdb_dir):
        if pdb_file.endswith(".pdb"):
            pdb_path = os.path.join(pdb_dir, pdb_file)
            cleaned_pdb_path = os.path.join(cleaned_pdb_dir, pdb_file)
            clean_pdb_file(pdb_path, cleaned_pdb_path)
            print(f"Cleaned → {cleaned_pdb_path}")

def merge_chains(input_dir, output_dir, chain_dict):
    """
    Procesa archivos PDB, selecciona y renombra cadenas específicas, elimina encabezados,
    y guarda archivos PDB combinados en el directorio de salida.
    """
    def remove_headers(file_path):
        """Elimina líneas que no empiezan con 'ATOM' del archivo PDB."""
        with open(file_path, 'r') as f:
            return [line for line in f if line.startswith("ATOM")]

    os.makedirs(output_dir, exist_ok=True)

    for pdb_file in os.listdir(input_dir):
        if not pdb_file.endswith(".pdb"):
            continue

        pdb_id = pdb_file.split(".")[0]
        pdb_file_path = os.path.join(input_dir, pdb_file)

        if pdb_id not in chain_dict:
            print(f"Warning: No chain information found for {pdb_id}. Setting default chains.")
            chain_dict[pdb_id] = {
                'tcra_chain': 'D',
                'tcrb_chain': 'E',
                'peptide_chain': 'C',
                'b2_chain': 'B',
                'mhc_chain': 'A'}

        tcra_id = chain_dict[pdb_id]['tcra_chain']
        tcrb_id = chain_dict[pdb_id]['tcrb_chain']
        mhc_id  = chain_dict[pdb_id]['mhc_chain']
        b2_id   = chain_dict[pdb_id]['b2_chain']
        epitope_id = chain_dict[pdb_id]['peptide_chain']

        print(f"Processing: {pdb_file}")
        print(f"tcra: {tcra_id}, tcrb: {tcrb_id}, mhc: {mhc_id}, b2: {b2_id}, epitope: {epitope_id}")

        command_AB = f"pdb_selchain -{tcra_id},{tcrb_id} {pdb_file_path} | pdb_chain -B | pdb_reres -1 | pdb_delhetatm > B.pdb"
        command_MB = f"pdb_selchain -{mhc_id},{b2_id},{epitope_id} {pdb_file_path} | pdb_chain -A | pdb_reres -1 | pdb_delhetatm > A.pdb"

        try:
            subprocess.run(command_MB, shell=True, check=True)
            subprocess.run(command_AB, shell=True, check=True)

            A_lines = remove_headers("A.pdb")
            B_lines = remove_headers("B.pdb")

            output_file_path = os.path.join(output_dir, f"{pdb_id}_merged.pdb")
            with open(output_file_path, 'w') as outfile:
                outfile.writelines(A_lines)
                outfile.writelines(B_lines)

            os.remove("A.pdb")
            os.remove("B.pdb")

            print(f"Processed and saved merged PDB as: {output_file_path}")

        except subprocess.CalledProcessError as e:
            print(f"Error processing {pdb_file}: {e}")
            if os.path.exists(output_file_path):
                os.remove(output_file_path)

In [None]:
input_dir = "./structures/af3_output"

pdb_dir = "./structures/models/models_pdb"
cleaned_pdb_dir = "./structures/models/cleaned_models_pdb"
merged_models_dir = "./structures/models/merged_models"

crystal_pdb_dir = "./structures/crystals/crystals_pdb"
cleaned_crystal_pdb_dir = "./structures/crystals/cleaned_crystals_pdb"
merged_crystals_dir = "./structures/crystals/merged_crystals"

# Convert model CIF files to PDB
convert_cif_to_pdb(input_dir, pdb_dir)

# Clean model PDB files
clean_all_pdbs(pdb_dir, cleaned_pdb_dir)

# Clean crystal PDB files
clean_all_pdbs(crystal_pdb_dir, cleaned_crystal_pdb_dir)

# Merge chains in model PDB files
merge_chains(cleaned_pdb_dir, merged_models_dir, chain_dict)

# Merge chains in crystal PDB files
merge_chains(cleaned_crystal_pdb_dir, merged_crystals_dir, chain_dict)

Processing: 1oga_2.pdb
tcra: D, tcrb: E, mhc: A, b2: B, epitope: C
Processed and saved merged PDB as: ./structures/merged_models/1oga_2_merged.pdb
Processing: 1mwa_2.pdb
tcra: D, tcrb: E, mhc: A, b2: B, epitope: C
Processed and saved merged PDB as: ./structures/merged_models/1mwa_2_merged.pdb
Processing: 1mwa_3.pdb
tcra: D, tcrb: E, mhc: A, b2: B, epitope: C
Processed and saved merged PDB as: ./structures/merged_models/1mwa_3_merged.pdb
Processing: 1oga_3.pdb
tcra: D, tcrb: E, mhc: A, b2: B, epitope: C
Processed and saved merged PDB as: ./structures/merged_models/1oga_3_merged.pdb
Processing: 1oga_1.pdb
tcra: D, tcrb: E, mhc: A, b2: B, epitope: C
Processed and saved merged PDB as: ./structures/merged_models/1oga_1_merged.pdb
Processing: 1mwa_1.pdb
tcra: D, tcrb: E, mhc: A, b2: B, epitope: C
Processed and saved merged PDB as: ./structures/merged_models/1mwa_1_merged.pdb
Processing: 1mwa_0.pdb
tcra: D, tcrb: E, mhc: A, b2: B, epitope: C
Processed and saved merged PDB as: ./structures/mer

# CRYSTAL VS MODEL

In [18]:
from Bio import PDB
from Bio.PDB import PDBParser, MMCIFParser
from Bio.Align import PairwiseAligner
from Bio.SVDSuperimposer import SVDSuperimposer
import numpy as np
import os
import pandas as pd
import subprocess
import re
import argparse

residue_mapping = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D',
    'CYS': 'C', 'GLU': 'E', 'GLN': 'Q', 'GLY': 'G',
    'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',
    'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',
    'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'}

# RMSD functions -----------------------------------------------------------
def cif_to_pdb(cif_file):
    """
    Converts a single CIF file to PDB format and saves it in the same folder with the same name.
    
    Parameters:
    - cif_file: str, path to the input CIF file.
    """
    # Define the output PDB file path (same folder, same name, .pdb extension)
    pdb_file = os.path.splitext(cif_file)[0]
    beem_path = "../../data_augmentation/BeEM/BeEM"#"/gpfs/projects/bsc72/aascunce/data_augmentation/BeEM/BeEM"
    # Build the PyMOL command to convert CIF to PDB
    command = f"{beem_path} -p={pdb_file} {cif_file}"

    # Suppress the output of the pymol command by redirecting stdout and stderr to os.devnull
    try:
        with open(os.devnull, 'w') as devnull:
            subprocess.run(command, shell=True, stdout=devnull, stderr=devnull, check=True)
        print(f"Successfully converted {cif_file} to {pdb_file}")
    except subprocess.CalledProcessError as e:
        # Handle errors during the conversion process
        print(f"Error converting {cif_file}: {e}")

def merge_pdb(pdb_file):
    """
    Processes a single PDB file, modifies chain IDs, and merges the results into a single PDB file.

    Parameters:
    - pdb_file: str, path to the input PDB file.

    Output:
    - A merged PDB file with the same name as the input, appended with '_merged', saved in the same directory.
    """
    # Define chain IDs
    tcra_id = "D"
    tcrb_id = "E"
    mhc_id = "A"
    b2_id = "B"
    epitope_id = "C"

    # Extract base name and define output file path
    base_name = pdb_file.rsplit(".", 1)[0]  # Name without extension
    output_file_path = f"{base_name}_merged.pdb"  # Add '_merged' to the output file name

    # Preprocess the input file to remove invalid lines and save a temporary cleaned file
    cleaned_pdb_file = f"{base_name}_cleaned.pdb"
    cleaned_lines = remove_headers(pdb_file)
    with open(cleaned_pdb_file, 'w') as cleaned_file:
        cleaned_file.writelines(cleaned_lines)

    # Construct shell commands
    command_AB = (
        f"pdb_selchain -{tcra_id},{tcrb_id} {cleaned_pdb_file} "
        f"| pdb_chain -B | pdb_reres -1 | pdb_delhetatm > B.pdb"
    )
    command_MB = (
        f"pdb_selchain -{mhc_id},{b2_id},{epitope_id} {cleaned_pdb_file} "
        f"| pdb_chain -A | pdb_reres -1 | pdb_delhetatm > A.pdb"
    )

    try:
        # Execute shell commands to generate temporary files
        subprocess.run(command_MB, shell=True, check=True)
        subprocess.run(command_AB, shell=True, check=True)

        # Remove headers and merge the files
        A_lines = remove_headers("A.pdb")
        B_lines = remove_headers("B.pdb")

        # Save the merged PDB file
        with open(output_file_path, 'w') as outfile:
            outfile.writelines(A_lines)
            outfile.writelines(B_lines)

        # Clean up temporary files
        os.remove('A.pdb')
        os.remove('B.pdb')
        os.remove(cleaned_pdb_file)

    except subprocess.CalledProcessError as e:
        print(f"Error processing {pdb_file}: {e}")

def remove_headers(file_path):
    """
    Remove non-ATOM/HETATM lines from a PDB file and validate line length.

    Parameters:
    - file_path: str, path to the PDB file.

    Returns:
    - List of cleaned lines (ATOM/HETATM and properly formatted).
    """
    cleaned_lines = []
    with open(file_path, 'r') as file:
        for line in file:
            # Keep only ATOM/HETATM lines with sufficient length
            if (line.startswith("ATOM") or line.startswith("HETATM")) and len(line) > 21:
                cleaned_lines.append(line)
    return cleaned_lines

def extract_sequences(pdb_file):
    """
    Extract sequences for all chains from a PDB file in two forms:
    - A dictionary of sequences as single-letter residue codes (string).
    - A dictionary of sequences as lists of (resname, resid) tuples.
    
    Returns:
    - sequences_str (dict): A dictionary with chain_id as key and sequence as string of 1-letter codes.
    - sequences_tuples (dict): A dictionary with chain_id as key and sequence as list of tuples (resname, resid).
    """
    if pdb_file.endswith(".pdb"):
        parser = PDB.PDBParser(QUIET=True)
        structure = parser.get_structure('structure', pdb_file)
    else:
        parser=PDB.MMCIFParser(QUIET=True)
        structure = parser.get_structure('structure', pdb_file)
    
    sequences_str = {}
    sequences_tuples = {}

    for model in structure:
        for chain in model.get_chains():
            chain_id = chain.get_id()
            sequence_str = []  # For single-letter sequence
            sequence_tuples = []  # For (resname, resid) tuples
            for residue in chain:
                if PDB.is_aa(residue):  # Ensure the residue is an amino acid
                    res_name = residue.get_resname()  
                    resid = residue.get_id()[1]  # Residue ID
                    # Add the single-letter residue code to the string
                    sequence_str.append(residue_mapping.get(res_name, 'X'))  # 'X' if unknown residue
                    # Store the (resname, resid) tuple for residue identity
                    sequence_tuples.append((res_name, resid))  
            sequences_str[chain_id] = ''.join(sequence_str)  # Join into string
            sequences_tuples[chain_id] = sequence_tuples  # Store the tuples
    return sequences_str, sequences_tuples


def align_sequences(seqA, seqB):
    """
    Align two sequences using pairwise sequence alignment and return the aligned sequences with residue info.
    """
    aligner = PairwiseAligner()
    aligner.match = 5
    aligner.mismatch = -1
    aligner.open_gap_score = -4
    aligner.extend_gap_score = -1
    aln = aligner.align(seqA, seqB)[0]
    return aln

def format_alignment(aln, chain_id_cry, chain_id_mod, dict_cry, dict_mode):
    """
    Format the alignment and return the residue names and IDs for each sequence in the alignment.
    Also return the matches as before ('|' for match, '.' for mismatch, ' ' for gap).
    """
    aligned_residues = {'seqA': [], 'seqB': [], 'matches': []}
    seqA_aligned = aln[0, :]
    seqB_aligned = aln[1, :]
    indexA = 0
    indexB = 0
    for i in range(len(seqA_aligned)):
        if seqA_aligned[i] != '-':
            res_nameA, res_idA = dict_cry[chain_id_cry][indexA] 
            aligned_residues['seqA'].append((res_nameA, res_idA))
            indexA += 1
        else:
            aligned_residues['seqA'].append(('-', '-'))

        if seqB_aligned[i] != '-':
            res_nameB, res_idB = dict_mode[chain_id_mod][indexB] 
            aligned_residues['seqB'].append((res_nameB, res_idB))
            indexB += 1
        else:
            aligned_residues['seqB'].append(('-', '-'))

        if seqA_aligned[i] == seqB_aligned[i]:
            aligned_residues['matches'].append('|')  
        elif seqA_aligned[i] == '-' or seqB_aligned[i] == '-':
            aligned_residues['matches'].append(' ')  
        else:
            aligned_residues['matches'].append('.')  
    return aligned_residues

def get_aligned_residues(alignment):
    """
    Map aligned residues from two chains based on a given sequence alignment.
    Only returns residues that are matches, formatted as (resname, resid_crystal, resid_model).
    """
    aligned_residues = []
    seqA_aligned = alignment["seqA"]
    seqB_aligned = alignment["seqB"]
    matches = alignment["matches"]
    indexA = 0
    indexB = 0
    for i in range(len(seqA_aligned)):
        if seqA_aligned[i] != "-" and seqB_aligned[i] != "-" and matches[i] == "|":
            res_nameA, resid_crystal = seqA_aligned[i]
            res_nameB, resid_model = seqB_aligned[i]
            aligned_residues.append((res_nameA, resid_crystal, resid_model))
            indexA += 1
            indexB += 1
    return aligned_residues

def get_interface(pdb_file, reference_chain, chain_ids, distance_cutoff=10.0, select_heavy_atoms=False):
    """
    Select Cα atoms from chains A, B, and C that are within a certain distance (default: 10 Å)
    from any atom of chain C, and also include all Cα atoms from chain C. Optionally, select heavy atoms.

    Args:
    - pdb_file (str): Path to the PDB file.
    - reference_chain (str): The chain ID of the reference chain (usually chain C).
    - chain_ids (list of str): List of chain IDs to select atoms from, including 'A', 'B', and 'C'.
    - distance_cutoff (float): Distance cutoff in Å (default 10 Å).
    - select_heavy_atoms (bool): If True, select all heavy atoms (default: False).

    Returns:
    - selected_atoms (list of tuples): A list of atoms (atom name, residue ID, residue name, and chain ID)
      that are within the distance cutoff from atoms in chain C, plus all Cα atoms from chain C (or all heavy atoms if specified).
    """
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure('structure', pdb_file)
    selected_atoms = []
    chain_ref_atoms = []
    chain_others_atoms = []
    for model in structure:
        for chain in model.get_chains():
            chain_id = chain.get_id()
            if chain_id == reference_chain:
                for residue in chain:
                    for atom in residue:
                        chain_ref_atoms.append(atom)
            if chain_id in chain_ids:
                for residue in chain:
                    for atom in residue:
                        if select_heavy_atoms:
                            if atom.element != 'H':
                                chain_others_atoms.append((atom, residue.get_id()[1], residue.get_resname(), chain_id))
                        else:
                            if atom.get_name() == "CA":
                                chain_others_atoms.append((atom, residue.get_id()[1], residue.get_resname(), chain_id))
    for atom, resid, resname, chain_id in chain_others_atoms:
        for ref_atom in chain_ref_atoms:
            distance = atom - ref_atom
            if distance <= distance_cutoff:
                selected_atoms.append((atom.get_name(), resid, resname, chain_id))
    for atom in chain_others_atoms:
        atom_obj, resid, resname, chain_id = atom
        if chain_id == reference_chain:  # Only add atoms from chain C
            selected_atoms.append((atom_obj.get_name(), resid, resname, chain_id))
    return selected_atoms

def get_atom_coordinates(pdb_file, selected_atoms):
    if pdb_file.endswith(".pdb"):
        parser = PDB.PDBParser(QUIET=True)
    else:
        parser = PDB.MMCIFParser(QUIET=True)
    
    structure = parser.get_structure('structure', pdb_file)
    coordinates = []

    for atom_name, resid, resname, chain_id in selected_atoms:
        try:
            chain = structure[0][chain_id]
            residue = chain[resid]
            found = False
            for atom in residue:
                if atom.get_name() == atom_name:
                    coordinates.append(atom.get_coord())
                    found = True
                    break
            if not found:
                print(f"Warning: Atom {atom_name} in residue {resid} of chain {chain_id} not found.")
        except KeyError:
            print(f"Warning: Chain {chain_id} or residue {resid} not found in the structure.")

    if not coordinates:
        raise ValueError(f"No coordinates found for the selected atoms: {selected_atoms}")

    coordinates_array = np.array(coordinates, dtype='f')
    return coordinates_array

# ANARCI and parsing functions -----------------------------------------------------------
def extract_residues_and_resids(pdb_file, chain_id):
    """
    Extract the residue IDs and residues (in one-letter code) from a specific chain in a PDB file.
    
    Args:
        pdb_file (str): Path to the PDB file.
        chain_id (str): Chain ID to extract residues from.
    
    Returns:
        list of tuples: List of tuples where each tuple contains (resid, residue_one_letter).
    """
    if pdb_file.endswith(".pdb"):
        parser = PDB.PDBParser(QUIET=True)
        structure = parser.get_structure('structure', pdb_file)
    else:
        parser=PDB.MMCIFParser(QUIET=True)
        structure = parser.get_structure('structure', pdb_file)
        
    residues = []
    for model in structure:
        for chain in model:
            if chain.id == chain_id:
                for residue in chain:
                    resid = residue.get_id()[1]
                    resname = residue.get_resname()
                    residue_one_letter = PDB.Polypeptide.protein_letters_3to1.get(resname, 'X')  # Use 'X' for unknown residues
                    residues.append((resid, residue_one_letter))
    
    return residues

def run_anarci(sequence):
    """
    Execute ANARCI to assign IMGT numbering to a TCR sequence.
    """
    try:
        command=f"ANARCI -i {sequence} --scheme imgt"
        result = subprocess.run(command, shell=True, capture_output=True, text=True, check=True)
        return result.stdout
    except subprocess.CalledProcessError as e:
        return f"Error: {e.stderr}"
    
import re

def parse_anarci_output(anarci_output):
    """
    Parse the output of ANARCI to extract IMGT numbering and residues, ensuring uniqueness of IMGT numbers.
    
    Args:
        anarci_output (str): Output from ANARCI as a string.
    
    Returns:
        list of tuples: A list where each tuple contains (IMGT_number, residue) with unique IMGT numbers.
    """
    pattern = r'^([A-Z])\s+(\d+)\s+([A-Z\-])'
    matches = re.findall(pattern, anarci_output, re.MULTILINE)
    
    imgt_numbered_seq = []
    seen_imgt_numbers = set()
    
    for match in matches:
        try:
            chain_letter, imgt_num, residue = match
            imgt_num = int(imgt_num)
            
            # Ensure uniqueness of IMGT numbers
            if imgt_num not in seen_imgt_numbers:
                imgt_numbered_seq.append((imgt_num, residue))
                seen_imgt_numbers.add(imgt_num)
        except ValueError as e:
            print(f"Error processing match: {match}. Error: {e}")
    
    return imgt_numbered_seq

def map_imgt_to_original(imgt_numbered_seq, pdb_resids):
    """
    Map the original numbering of a sequence from the PDB 'resids' to the IMGT numbering.
    
    Args:
        imgt_numbered_seq (list of tuples): The IMGT numbered sequence as tuples (IMGT_number, residue).
        pdb_resids (list of tuples): The original residue numbers from the PDB file as tuples (resid, residue_one_letter).
    
    Returns:
        list of tuples: A list where each tuple contains (original_resid, IMGT_number, residue).
    """
    mapping = []
    pdb_resid_index = 0  # Index for PDB residues
    
    for imgt_pos, residue in imgt_numbered_seq:
        if residue != "-":  # Only process non-gap residues in IMGT
            for original_resid, residue1 in pdb_resids[pdb_resid_index:]:
                if residue1 == residue:
                    mapping.append((original_resid, imgt_pos, residue))
                    pdb_resid_index += 1
                    break
                else:
                    pdb_resid_index += 1
            else:
                mapping.append((None, imgt_pos, residue))
        else:
            mapping.append((None, imgt_pos, residue))
    return mapping

def parse_CDR3 (mapping):
    cdr3_tuples = [tupple for tupple in mapping if 105 <= tupple[1] <= 117 and tupple[2] != "-"]
    return cdr3_tuples

def parse_CDR2 (mapping):
    cdr2_tuples = [tupple for tupple in mapping if 56 <= tupple[1] <= 65 and tupple[2] != "-"]
    return cdr2_tuples

def parse_CDR1 (mapping):
    cdr1_tuples = [tupple for tupple in mapping if 27 <= tupple[1] <= 38 and tupple[2] != "-"]
    return cdr1_tuples

def extract_atoms_for_cdr(cdr_list, pdb_file, chain_id):
    """
    Extract atoms of the residues of a CDR from a PDB file.
    
    :param cdr_list: List of tuples in the format (resid, imgtid, resname)
    :param pdb_file: Path to the PDB file of the structure
    :param chain_id: Chain ID (e.g., 'A', 'B') to extract the atoms from
    :return: List of tuples with the format (atomname, resid, resname, chain_id)
    """
    if pdb_file.endswith(".pdb"):
        parser = PDB.PDBParser(QUIET=True)
    else:
        parser=PDB.MMCIFParser(QUIET=True)
    structure = parser.get_structure('structure', pdb_file)
    atom_list = [] 
    for model in structure:
        for chain in model:
            if chain.id == chain_id:  
                for residue in chain:
                    resid = residue.get_id()[1]  
                    resname_3 = residue.get_resname() 
                    resname_1 = residue_mapping.get(resname_3, 'X')
                    for cdr_resid, cdr_imgtid, cdr_resname in cdr_list:
                        if resid == cdr_resid and resname_1 == cdr_resname:
                            for atom in residue:
                                atom_list.append((atom.get_name(), resid, resname_3, chain.id))
    return atom_list

def calculate_rmsd(crystal_pdb, model_pdb, pdb_id, chain_dict, distance_cutoff=10.0):
    """
    Calculate RMSD between a crystal and a model structure for all chains.

    Args:
        crystal_pdb (str): Path to the crystal structure PDB file.
        model_pdb (str): Path to the model structure PDB file.
        chain_ids (list): List of chain IDs to consider for the interface (default ['A', 'B', 'C']).
        distance_cutoff (float): Distance cutoff for selecting interface residues (default 10.0 Å).

    Returns:
        dict: Overall RMSD and chain-specific RMSD details.
    """
    
    # Step 1: Parse structures and extract sequences
    if model_pdb.endswith(".pdb"):
        parser = PDBParser(QUIET=True)
    else:
        parser = MMCIFParser(QUIET=True)
    model_structure = parser.get_structure("model", model_pdb)

    # Extract sequences and map residues
    crystal_sequences, dict_cry = extract_sequences(crystal_pdb)
    model_sequences, dict_mod = extract_sequences(model_pdb)
    chain_dict=chain_dict[pdb_id]
    mapping = []

    if len(model_sequences) == 5:
        model_chain_mapping = {
            'mhc_chain':'A',
            'b2_chain':'B',
            'peptide_chain':'C',
            'tcra_chain':'D',
            'tcrb_chain':'E'}
        
    elif len(model_sequences) == 4:
        model_chain_mapping = {
            'mhc_chain':'A',
            'peptide_chain':'B',
            'tcra_chain':'C',
            'tcrb_chain':'D'}
    else:
        print(f"Error: Model structure does not have 4 or 5 chains.")

    # Step 2: Loop over chains to align sequences
    for chain_crystal, seq_crystal in crystal_sequences.items():
        key = next((key for key, value in chain_dict.items() if value == chain_crystal), None)
        if key is not None:
            model_id = model_chain_mapping.get(key)
            if model_id in model_sequences:
                seq_model = model_sequences[model_id]
                alignment = align_sequences(seq_crystal, seq_model)
                formatted_alignment = format_alignment(alignment, chain_crystal, model_id, dict_cry, dict_mod)
                aligned_residues = get_aligned_residues(formatted_alignment)
                aligned_residues = [res + (f'{chain_crystal}',) for res in aligned_residues]
                mapping.extend(aligned_residues)
            else:
                print(f"Error: {model_id} not in model_sequences.")
        else:
            print(f"Error: {chain_crystal} not in chain_dict.")

    # Step 3: Select interface atoms
    reference_chain = chain_dict['peptide_chain']
    if len(model_sequences) == 5:
        chain_ids = [chain_dict['tcra_chain'], chain_dict['tcrb_chain'], chain_dict['peptide_chain'], chain_dict['mhc_chain'], chain_dict['b2_chain']]
    elif len(model_sequences) == 4:
        chain_ids = [chain_dict['tcra_chain'], chain_dict['tcrb_chain'], chain_dict['peptide_chain'], chain_dict['mhc_chain']]
 
    selected_atoms_crystal = sorted(set(get_interface(crystal_pdb, reference_chain, chain_ids=chain_ids, distance_cutoff=distance_cutoff, select_heavy_atoms=True)),key=lambda x: (x[3], x[1]))
    selected_atoms_model = []
    atoms_to_remove_i = []
    for atom_crystal in selected_atoms_crystal:
        atom_name_crystal, resid_crystal, resname_crystal, chain_id_crystal = atom_crystal
        found_match = False  
        for resname, resid_crystal_mapping, resid_model_mapping, chain_id_mapping in mapping:
            if resid_crystal == resid_crystal_mapping and resname_crystal == resname and chain_id_crystal == chain_id_mapping:
                chain_ident = next((key for key, value in chain_dict.items() if value == chain_id_mapping), None)
                chain_id_model = model_chain_mapping.get(chain_ident)
                chain_model = model_structure[0][chain_id_model]
                for residue_model in chain_model:
                    if residue_model.get_id()[1] == resid_model_mapping:
                        for atom_model in residue_model:
                            if atom_model.get_name() == atom_name_crystal:
                                selected_atoms_model.append((atom_model.get_name(), resid_model_mapping, resname, chain_id_model))
                                found_match = True  # Marcar como encontrado
                                break
                    if found_match:  
                        break
            if found_match:
                break
        if not found_match:
            atoms_to_remove_i.append(atom_crystal)
    selected_atoms_crystal = [atom for atom in selected_atoms_crystal if atom not in atoms_to_remove_i]

    # Step 4: Get atom coordinates
    coordinates_crystal = get_atom_coordinates(crystal_pdb, selected_atoms_crystal)
    coordinates_model = get_atom_coordinates(model_pdb, selected_atoms_model)

    # Step5: Get CDR3 residues and atoms
    residues_crystal_A=extract_residues_and_resids(crystal_pdb, chain_dict['tcra_chain'])
    residues_crystal_B=extract_residues_and_resids(crystal_pdb, chain_dict['tcrb_chain'])

    anarci_A_cry=run_anarci(crystal_sequences[chain_dict['tcra_chain']])
    anarci_B_cry=run_anarci(crystal_sequences[chain_dict['tcrb_chain']])
    
    parsed_cry_A=parse_anarci_output(anarci_A_cry)
    parsed_cry_B=parse_anarci_output(anarci_B_cry)

    map_cry_A=map_imgt_to_original(parsed_cry_A, residues_crystal_A)
    map_cry_B=map_imgt_to_original(parsed_cry_B, residues_crystal_B)
    
    cdr3_cry_A=parse_CDR3(map_cry_A)
    cdr3_cry_B=parse_CDR3(map_cry_B)
    
    cdr_atoms_cry_A = extract_atoms_for_cdr(cdr3_cry_A, crystal_pdb, chain_dict['tcra_chain'])
    cdr_atoms_cry_B = extract_atoms_for_cdr(cdr3_cry_B, crystal_pdb, chain_dict['tcrb_chain'])
    cdr_atoms_crystal = cdr_atoms_cry_A + cdr_atoms_cry_B

    cdr_atoms_model = []
    atoms_to_remove = []

    for cdr_atom_crystal in cdr_atoms_crystal:
        cdr_atom_name_crystal, cdr_resid_crystal, cdr_resname_crystal, cdr_chain_id_crystal = cdr_atom_crystal
        found_match = False 
        for cdr_resname, cdr_resid_crystal_mapping, cdr_resid_model_mapping, cdr_chain_id_mapping in mapping:
            if cdr_resid_crystal == cdr_resid_crystal_mapping and cdr_resname_crystal == cdr_resname and cdr_chain_id_crystal == cdr_chain_id_mapping:
                cdr_chain_ident = next((key for key, value in chain_dict.items() if value == cdr_chain_id_mapping), None)
                cdr_chain_id_model = model_chain_mapping.get(cdr_chain_ident)
                chain_model = model_structure[0][cdr_chain_id_model]
                for residue_model in chain_model:
                    if residue_model.get_id()[1] == cdr_resid_model_mapping:  # Match the residue ID
                        for atom_model in residue_model:
                            if atom_model.get_name() == cdr_atom_name_crystal:  # Match the atom name
                                cdr_atoms_model.append((atom_model.get_name(), cdr_resid_model_mapping, cdr_resname_crystal, cdr_chain_id_model))
                                found_match = True  # Mark as found
                                break
                        if found_match:  # Break outer loops if a match is found
                            break
                if found_match:
                    break
        if not found_match:
            atoms_to_remove.append(cdr_atom_crystal)

    cdr_atoms_crystal = [atom for atom in cdr_atoms_crystal if atom not in atoms_to_remove]

    # Perform superposition and calculate overall RMSD
    if len(coordinates_crystal) == len(coordinates_model) and len(coordinates_crystal) > 0:
        sup = SVDSuperimposer()
        sup.set(coordinates_crystal, coordinates_model)
        sup.run()
        y_on_x = sup.get_transformed()
        overall_rmsd = sup.get_rms()
    else:
        return "Error: Interface atoms mismatch. Cannot superimpose."

    # Step 6: Categorize chains and calculate RMSD
    if len(model_sequences) == 5:
        categories_crystal = {
            'TCRA/TCRB': [chain_dict['tcra_chain'], chain_dict['tcrb_chain']],
            'Peptide': [chain_dict['peptide_chain']],
            'MHC/B2M': [chain_dict['mhc_chain'], chain_dict['b2_chain']]}

        categories_model = {
            'TCRA/TCRB': [model_chain_mapping['tcra_chain'], model_chain_mapping['tcrb_chain']],
            'Peptide': [model_chain_mapping['peptide_chain']],
            'MHC/B2M': [model_chain_mapping['mhc_chain'], model_chain_mapping['b2_chain']]}
    
    elif len(model_sequences) == 4:
        categories_crystal = {
            'TCRA/TCRB': [chain_dict['tcra_chain'], chain_dict['tcrb_chain']],
            'Peptide': [chain_dict['peptide_chain']],
            'MHC': [chain_dict['mhc_chain']]}
        
        categories_model = {
            'TCRA/TCRB': [model_chain_mapping['tcra_chain'], model_chain_mapping['tcrb_chain']],
            'Peptide': [model_chain_mapping['peptide_chain']],
            'MHC': [model_chain_mapping['mhc_chain']]}
    else:
        return "Error: Model structure does not have 4 or 5 chains."

    category_crystal_coords = {category: [] for category in categories_crystal}
    category_model_coords = {category: [] for category in categories_model}
    
    for atom, coord in zip(selected_atoms_crystal, coordinates_crystal):
        chain_id = atom[3]
        for category, chains in categories_crystal.items():
            if chain_id in chains:
                category_crystal_coords[category].append(coord)
                break 

    for atom, coord in zip(selected_atoms_model, y_on_x):
        chain_id = atom[3]
        for category, chains in categories_model.items():
            if chain_id in chains:
                category_model_coords[category].append(coord)
                break  

    # Step 6: Calculate RMSD for each category
    category_rmsd_results = {}
    for category, crystal_coords in category_crystal_coords.items():
        model_coords = category_model_coords[category]
        
        # Check if coordinates are valid and have the same length
        if len(crystal_coords) > 0 and len(crystal_coords) == len(model_coords):
            # Convert to numpy arrays for easy manipulation
            crystal_coords = np.array(crystal_coords)
            model_coords = np.array(model_coords)
            
            # Calculate the difference between the coordinates
            diff = crystal_coords - model_coords
            
            # Calculate RMSD: sqrt(sum((crystal - model)^2) / N)
            rmsd = np.sqrt(np.sum(np.square(diff)) / len(crystal_coords))
            category_rmsd_results[category] = rmsd
        else:
            category_rmsd_results[category] = None

    # Prepare result string
    result_string = f"Number of CA in interface: {len(selected_atoms_crystal)}, Overall iRMSD: {overall_rmsd:.2f} angstroms\n"
    for category, rmsd in category_rmsd_results.items():
        num_atoms = len(category_crystal_coords[category])
        
        if rmsd is not None:
            result_string += f"{category}: Number of CA: {num_atoms}, iRMSD: {rmsd:.2f} angstroms\n"
        else:
            result_string += f"Category {category}: Insufficient data for RMSD calculation.\n"
    #Step 7: calculate CDR RMSD
    cdr_coords_crystal_A = get_atom_coordinates(crystal_pdb, cdr_atoms_cry_A)
    cdr_coords_crystal_B = get_atom_coordinates(crystal_pdb, cdr_atoms_cry_B)
    
    cdr_coords_model_A = []
    cdr_coords_model_B = []

    indices_to_remove_A = []
    indices_to_remove_B = []

    for idx, cdr_coord in enumerate(cdr_coords_crystal_A):
        idx_crystal = np.where(np.all(coordinates_crystal == cdr_coord, axis=1))[0]
        if len(idx_crystal) > 0:
            cdr_coords_model_A.append(y_on_x[idx_crystal[0]])
        else:
            indices_to_remove_A.append(idx)

    for idx, cdr_coord in enumerate(cdr_coords_crystal_B):
        idx_crystal = np.where(np.all(coordinates_crystal == cdr_coord, axis=1))[0]
        if len(idx_crystal) > 0:
            cdr_coords_model_B.append(y_on_x[idx_crystal[0]])
        else:
            indices_to_remove_B.append(idx)

    cdr_coords_crystal_A = np.delete(cdr_coords_crystal_A, indices_to_remove_A, axis=0)
    cdr_coords_crystal_B = np.delete(cdr_coords_crystal_B, indices_to_remove_B, axis=0)

    cdr_coords_model_A = np.array(cdr_coords_model_A)
    cdr_coords_model_B = np.array(cdr_coords_model_B)

    rmsd_cdrs = {}
    if len(cdr_coords_crystal_A) > 0 and len(cdr_coords_model_A) == len(cdr_coords_crystal_A):
        diff = cdr_coords_crystal_A - cdr_coords_model_A
        rmsd = np.sqrt(np.sum(np.square(diff)) / len(cdr_coords_crystal_A))
        rmsd_cdrs['TCRA'] = rmsd
    else:
        rmsd_cdrs['TCRA'] = None
        print("Difference in number of atoms for CDR3 in chain A")
    if len(cdr_coords_crystal_B) > 0 and len(cdr_coords_model_B) == len(cdr_coords_crystal_B):
        diff = cdr_coords_crystal_B - cdr_coords_model_B
        rmsd = np.sqrt(np.sum(np.square(diff)) / len(cdr_coords_crystal_B))
        rmsd_cdrs['TCRB'] = rmsd
    else:
        rmsd_cdrs['TCRB'] = None
        print("Difference in number of atoms for CDR3 in chain B")
    
    result_string+=f"CDR3 TCRA: {rmsd_cdrs['TCRA']:.2f} angstroms\n" if rmsd_cdrs['TCRA'] is not None else "CDR3 TCRA: Insufficient data for RMSD calculation.\n"
    result_string+=f"CDR3 TCRB: {rmsd_cdrs['TCRB']:.2f} angstroms\n" if rmsd_cdrs['TCRB'] is not None else "CDR3 TCRB: Insufficient data for RMSD calculation.\n"
    if len(model_sequences) == 5:
        return (result_string, overall_rmsd, category_rmsd_results["TCRA/TCRB"], category_rmsd_results["Peptide"], category_rmsd_results["MHC/B2M"], rmsd_cdrs.get('TCRA', None), rmsd_cdrs.get('TCRB', None))
    elif len(model_sequences) == 4:
        return (result_string, overall_rmsd, category_rmsd_results["TCRA/TCRB"], category_rmsd_results["Peptide"], category_rmsd_results["MHC"], rmsd_cdrs.get('TCRA', None), rmsd_cdrs.get('TCRB', None))

def run_dockq(model_path, native_path):
    # Ejecutar el comando DockQ
    dockq_command = f"DockQ {model_path} {native_path}"
    result = subprocess.run(dockq_command, shell=True, capture_output=True, text=True, check=True)
    
    # Obtener el texto de salida
    output = result.stdout

    # Usamos expresiones regulares para extraer las métricas
    dockq_score = re.search(r'DockQ:\s*([0-9\.]+)', output)
    irmsd = re.search(r'iRMSD:\s*([0-9\.]+)', output)
    lrmsd = re.search(r'LRMSD:\s*([0-9\.]+)', output)
    fnat = re.search(r'fnat:\s*([0-9\.]+)', output)
    clashes = re.search(r'clashes:\s*(\d+)', output)

    # Parsear los resultados
    dockq_score_val = float(dockq_score.group(1)) if dockq_score else None
    irmsd_val = float(irmsd.group(1)) if irmsd else None
    lrmsd_val = float(lrmsd.group(1)) if lrmsd else None
    fnat_val = float(fnat.group(1)) if fnat else None
    clashes_val = int(clashes.group(1)) if clashes else None

    # Retornar los valores directamente
    return dockq_score_val, irmsd_val, lrmsd_val, fnat_val, clashes_val

In [20]:
# Use case 

pdb_id= '1ao7'
model_number='1'

crystal_pdb = f'./structures/crystals/cleaned_crystals_pdb/{pdb_id}.pdb'
model_pdb = f'./structures/models/cleaned_models_pdb/{pdb_id}_{model_number}.pdb'

dockq_native = f'./structures/crystals/merged_crystals/{pdb_id}_merged.pdb'
dockq_model = f'./structures/models/merged_models/{pdb_id}_{model_number}_merged.pdb'


result_string, overall_rmsd, rmsd_TCRA_TCRB, rmsd_Peptide, rmsd_MHC_B2M, rmsd_CDR_TCRA, rmsd_CDR_TCRB = calculate_rmsd(crystal_pdb, model_pdb, pdb_id, chain_dict, distance_cutoff=10.0)
dockq_score, irmsd, lrmsd, fnat, clashes = run_dockq(dockq_model, dockq_native)
print(result_string)
print("Dockq score", dockq_score)
print("iRMSD", irmsd)
print("lRMSD", lrmsd)
print("fnat",fnat)
print("Clashes",clashes)

Number of CA in interface: 1020, Overall iRMSD: 0.96 angstroms
TCRA/TCRB: Number of CA: 268, iRMSD: 1.70 angstroms
Peptide: Number of CA: 77, iRMSD: 0.67 angstroms
MHC/B2M: Number of CA: 675, iRMSD: 0.45 angstroms
CDR3 TCRA: 0.64 angstroms
CDR3 TCRB: 1.22 angstroms

Dockq score 0.632
iRMSD 0.736
lRMSD 14.784
fnat 0.843
Clashes 0


In [37]:
#pdbs_4chains=['4mvb', '4ms8', '3e3q', '4n5e', '6l9l', '3tfk', '3tjh', '3e2h', '2e7l', '8d5q', '3tpu', '2oi9', '4mxq', '4n0c']

def process_pdbs(crystal_folder, model_folder):
    dataframe = pd.DataFrame()
    model_folder_pdb = os.path.join(model_folder, "cleaned_models_pdb")
    crystal_folder_pdb = os.path.join(crystal_folder, "cleaned_crystals_pdb")
    for pdb_file in os.listdir(crystal_folder_pdb):
        if pdb_file.endswith(".pdb"):  
            pdb_id = pdb_file.split(".")[0]
            #if pdb_id not in pdbs_4chains:
                #continue  
            pdb_path = os.path.join(crystal_folder_pdb,pdb_file)
            print(f"Processing PDB file: {pdb_path}...")
            print(f"Processing crystal structure for pdb_id: {pdb_id}...")
            dockq_native = f'./{crystal_folder}/merged_crystals/{pdb_id}_merged.pdb'
            model_files = [f for f in os.listdir(model_folder_pdb) if f.startswith(pdb_id) and f.endswith(".pdb")]
            if len(model_files) == 5:
                print(f"Found 5 model files for {pdb_id}, processing each model...")
                for model_file in model_files:
                    model_number = model_file.split("_")[1].split(".")[0]
                    model_path = os.path.join(model_folder_pdb, model_file)
                    dockq_model = f'./{model_folder}/merged_models/{pdb_id}_{model_number}_merged.pdb'
                    print(f"  Processing model {model_number}...")
                    result_string, overall_rmsd, rmsd_TCRA_TCRB, rmsd_Peptide, rmsd_MHC_B2M, rmsd_CDR_TCRA, rmsd_CDR_TCRB = calculate_rmsd(
                        pdb_path, model_path, pdb_id, chain_dict, distance_cutoff=10.0)
                    dockq_score, irmsd, lrmsd, fnat, clashes = run_dockq(dockq_model, dockq_native)
                    row = {
                        "pdb_id": pdb_id,
                        "model_number": model_number,
                        "overall_rmsd": overall_rmsd,
                        "rmsd_TCRA_TCRB": rmsd_TCRA_TCRB,
                        "rmsd_Peptide": rmsd_Peptide,
                        "rmsd_MHC_B2M": rmsd_MHC_B2M,
                        "rmsd_CDR_TCRA": rmsd_CDR_TCRA,
                        "rmsd_CDR_TCRB": rmsd_CDR_TCRB,
                        "dockq_score": dockq_score,
                        "irmsd": irmsd,
                        "lrmsd": lrmsd,
                        "fnat": fnat,
                        "clashes": clashes}
                    dataframe = pd.concat([dataframe, pd.DataFrame([row])], ignore_index=True)
                    print(f"Processed: {pdb_id} model {model_number}")
                    print(result_string)
    return dataframe

In [38]:
df=process_pdbs("./structures/crystals/", "./structures/models/")
df.to_csv("./structures/crystal_vs_models.csv", index=False)

Processing PDB file: ./structures/crystals/cleaned_crystals_pdb/1lp9.pdb...
Processing crystal structure for pdb_id: 1lp9...
Found 5 model files for 1lp9, processing each model...
  Processing model 2...
Processed: 1lp9 model 2
Number of CA in interface: 955, Overall iRMSD: 1.80 angstroms
TCRA/TCRB: Number of CA: 198, iRMSD: 3.42 angstroms
Peptide: Number of CA: 76, iRMSD: 0.92 angstroms
MHC/B2M: Number of CA: 681, iRMSD: 1.02 angstroms
CDR3 TCRA: 2.70 angstroms
CDR3 TCRB: 3.40 angstroms

  Processing model 3...
Processed: 1lp9 model 3
Number of CA in interface: 955, Overall iRMSD: 1.83 angstroms
TCRA/TCRB: Number of CA: 198, iRMSD: 3.44 angstroms
Peptide: Number of CA: 76, iRMSD: 0.99 angstroms
MHC/B2M: Number of CA: 681, iRMSD: 1.07 angstroms
CDR3 TCRA: 2.99 angstroms
CDR3 TCRB: 3.49 angstroms

  Processing model 1...
Processed: 1lp9 model 1
Number of CA in interface: 955, Overall iRMSD: 1.86 angstroms
TCRA/TCRB: Number of CA: 198, iRMSD: 3.48 angstroms
Peptide: Number of CA: 76, iRM

# MODEL QUALITY METRICS

In [22]:
def calculate_global_plddt(json_file_path):
    """
    Calculate the mean of the values in the `atom_plddts` key from a JSON file.

    Args:
        json_file_path (str): Path to the JSON file.

    Returns:
        float: The mean of the values in `atom_plddts`, or None if the key is empty.
    """
    try:
        with open(json_file_path, 'r') as file:
            data = json.load(file)
        
        atom_plddts = data.get('atom_plddts', [])
        if not atom_plddts:
            print("No data found in 'atom_plddts'.")
            return None
        
        mean_plddt = sum(atom_plddts) / len(atom_plddts)
        return mean_plddt

    except FileNotFoundError:
        print(f"Error: The file {json_file_path} was not found.")
        return None
    except json.JSONDecodeError:
        print("Error: Failed to decode the JSON file. Please check its format.")
        return None
    
def extract_b_factors(cdr_atoms, chain):
    """
    Extract B-factors for a given list of atoms from a specific chain.
    
    Args:
        cdr_atoms: List of tuples containing (atom_name, residue_id, residue_name, chain_id).
        chain: Bio.PDB.Chain object corresponding to the chain to extract B-factors from.
        
    Returns:
        A list of B-factors for the specified atoms.
    """
    b_factors = []
    for atomname, resid, resname, chainid in cdr_atoms:
        if chainid == chain.id:
            try:
                residue = chain[resid]  # Access the residue using its ID
                if atomname in residue:  # Check if the atom exists in the residue
                    atom = residue[atomname]
                    b_factors.append(atom.get_bfactor())
                else:
                    print(f"Atom {atomname} not found in residue {resid} ({resname}) of chain {chain.id}")
            except KeyError:
                print(f"Residue {resid} ({resname}) not found in chain {chain.id}")
    return b_factors

def cdr_plddts(model_file, alpha_chain, beta_chain):
    model_sequences, model_dict = extract_sequences(model_file)
    residues_A = extract_residues_and_resids(model_file, alpha_chain)
    residues_B = extract_residues_and_resids(model_file, beta_chain)
    anarci_A = run_anarci(model_sequences[alpha_chain])
    anarci_B = run_anarci(model_sequences[beta_chain])
    parsed_A = parse_anarci_output(anarci_A)
    parsed_B = parse_anarci_output(anarci_B)
    map_A = map_imgt_to_original(parsed_A, residues_A)
    map_B = map_imgt_to_original(parsed_B, residues_B)

    # Parse CDR regions from the maps
    cdr3_A = parse_CDR3(map_A)
    cdr3_B = parse_CDR3(map_B)
    cdr2_A = parse_CDR2(map_A)
    cdr2_B = parse_CDR2(map_B)
    cdr1_A = parse_CDR1(map_A)
    cdr1_B = parse_CDR1(map_B)

    # Extract atom information for each CDR
    cdr3_atoms_A = extract_atoms_for_cdr(cdr3_A, model_file, alpha_chain)
    cdr3_atoms_B = extract_atoms_for_cdr(cdr3_B, model_file, beta_chain)
    cdr2_atoms_A = extract_atoms_for_cdr(cdr2_A, model_file, alpha_chain)
    cdr2_atoms_B = extract_atoms_for_cdr(cdr2_B, model_file, beta_chain)
    cdr1_atoms_A = extract_atoms_for_cdr(cdr1_A, model_file, alpha_chain)
    cdr1_atoms_B = extract_atoms_for_cdr(cdr1_B, model_file, beta_chain)

    
    # Parse the structure based on file type (PDB or MMCIF)
    if model_file.endswith(".pdb"):
        parser = PDB.PDBParser(QUIET=True)
    else:
        parser = PDB.MMCIFParser(QUIET=True)
        
    structure = parser.get_structure("Model", model_file)
    chain_A = structure[0][alpha_chain]
    chain_B = structure[0][beta_chain]
    
    # Extract B-factors for each CDR region separately
    b_factors_cdr1_A = extract_b_factors(cdr1_atoms_A, chain_A)
    b_factors_cdr2_A = extract_b_factors(cdr2_atoms_A, chain_A)
    b_factors_cdr3_A = extract_b_factors(cdr3_atoms_A, chain_A)
    
    b_factors_cdr1_B = extract_b_factors(cdr1_atoms_B, chain_B)
    b_factors_cdr2_B = extract_b_factors(cdr2_atoms_B, chain_B)
    b_factors_cdr3_B = extract_b_factors(cdr3_atoms_B, chain_B)

    # Mean 
    mean_cdr1_A = np.mean(b_factors_cdr1_A)
    mean_cdr2_A = np.mean(b_factors_cdr2_A)
    mean_cdr3_A = np.mean(b_factors_cdr3_A)
    mean_cdr1_B = np.mean(b_factors_cdr1_B)
    mean_cdr2_B = np.mean(b_factors_cdr2_B)
    mean_cdr3_B = np.mean(b_factors_cdr3_B)

    # Return the B-factors for each CDR region separately
    return mean_cdr1_A, mean_cdr2_A, mean_cdr3_A, mean_cdr1_B, mean_cdr2_B, mean_cdr3_B

def calculate_iptms(json_file_path, length=5):
    """
    Calculates the mean of `chain_iptm` and the mean of interface TCR-pMHC iPTMs
    using fixed chain mappings.

    Args:
        json_file_path (str): Path to the JSON file.

    Returns:
        dict: A dictionary with the calculated means.
    """
    try:
        # Load the JSON data from the file
        with open(json_file_path, 'r') as file:
            data = json.load(file)
        
        # Calculate the mean of chain_iptm
        chain_iptm = data.get('chain_iptm', [])
        if not chain_iptm:
            print("No data found in 'chain_iptm'.")
            chain_iptm_mean = None
        else:
            chain_iptm_mean = sum(chain_iptm) / len(chain_iptm)
        
        # Calculate the mean for interface TCR-pMHC
        chain_pair_iptm = data.get('chain_pair_iptm', [])
        if not chain_pair_iptm:
            print("No data found in 'chain_pair_iptm'.")
            tcr_pmch_mean = None
        else:
            # Fixed indices for TCR-pMHC interactions
            # A (MHC) = 0, B (B2M) = 1, C (peptide) = 2, D (TCRa) = 3, E (TCRb) = 4
            if length == 5:
                tcr_pmch_pairs = [
                    chain_pair_iptm[0][3],  # MHC-TCRa 
                    chain_pair_iptm[0][4],  # MHC-TCRb 
                    chain_pair_iptm[2][3],  # pep-TCRa
                    chain_pair_iptm[2][4]]  # pep-TCRb 
                tcr_pmch_iptm = sum(tcr_pmch_pairs) / len(tcr_pmch_pairs)
            elif length == 4:
                # A (MHC) = 0, C (peptide) = 1, D (TCRa) = 2, E (TCRb) = 3
                tcr_pmch_pairs = [
                    chain_pair_iptm[0][2],  # MHC-TCRa 
                    chain_pair_iptm[0][3],  # MHC-TCRb 
                    chain_pair_iptm[1][2],  # pep-TCRa
                    chain_pair_iptm[1][3]]  # pep-TCRb 
                tcr_pmch_iptm = sum(tcr_pmch_pairs) / len(tcr_pmch_pairs)
        return chain_iptm_mean, tcr_pmch_iptm
    
    except FileNotFoundError:
        print(f"File not found: {json_file_path}")
        return None
    except KeyError as e:
        print(f"Missing key in JSON data: {e}")
        return None
    except Exception as e:
        print(f"An error occurred: {e}")
        return None
    
import json

def calculate_pdockq (model_file):
    command=f"python ./scripts_py/pdockq.py --pdbfile {model_file}"
    result = subprocess.run(command, shell=True, capture_output=True, text=True, check=True)
    # Output is displayed as pDockQ = 0.609 for ./pre/merged_models_AB/1ao7_0_merged.pdb This corresponds to a PPV of at least 0.9400192
    # Capture pDockq
    pdockq = float(result.stdout.split('=')[1].split(' ')[1])
    return result.stdout, pdockq

def calculate_pdockq2_json (model_file, json_file):
    command=f"python ./scripts_py/pdockq2_json.py -json {json_file} -pdb {model_file}"
    result=subprocess.run(command, shell=True, capture_output=True, text=True, check=True)
    ipae_A = float(result.stdout.split('\n')[1].split(' ')[1])
    ipae_B = float(result.stdout.split('\n')[2].split(' ')[1])
    pdockq2_A= float(result.stdout.split('\n')[4].split(' ')[1])
    pdockq2_B= float(result.stdout.split('\n')[5].split(' ')[1])
    return result.stdout, ipae_A, ipae_B, pdockq2_A, pdockq2_B

In [10]:
all_data='./af3_output/1ao7/1ao7_full_data_0.json'
summary_conf= './models/1ao7/1ao7_summary_confidences_0.json'
model_pdb='./af3_output/cleaned_models_pdb/1ao7_model_0.pdb'
model_merged='./af3_output/merged_models_AB/1ao7_0_merged.pdb'

mean= calculate_global_plddt(all_data)
b_factors_cdr1_A, b_factors_cdr2_A, b_factors_cdr3_A, b_factors_cdr1_B, b_factors_cdr2_B, b_factors_cdr3_B = cdr_plddts(model_pdb, "D", "E")
iptm_mean, iptm_tcrpmhc = calculate_iptms(summary_conf)
stdout, pdockq = calculate_pdockq(model_merged)
stdout2, pdockq2_A, pdockq2_B = calculate_pdockq2_json(model_merged, all_data)

# Print the results
print("Global pLDDT:", mean)
print("\nCDR pLDDTs:")
print("CDR1 A:", b_factors_cdr1_A, "CDR1 B:", b_factors_cdr1_B)
print("CDR2 A:", b_factors_cdr2_A, "CDR2 B:", b_factors_cdr2_B)
print("CDR3 A:", b_factors_cdr3_A, "CDR3 B:", b_factors_cdr3_B)

print("\nipTM complex: ", iptm_mean, "TCR-pMHC ipTM", iptm_tcrpmhc)

print("\nPDockQ: ", pdockq)
print("\nPDockQ2: ", pdockq2_A, pdockq2_B)

Error: The file ./af3_output/1ao7/1ao7_full_data_0.json was not found.


FileNotFoundError: [Errno 2] No such file or directory: './af3_output/cleaned_models_pdb/1ao7_model_0.pdb'

In [None]:
for folder in os.listdir("./af_predictions_all/af3_output"):
    dataframe = pd.DataFrame()
    if len(folder) == 4:
        pdb_id = folder
        for folder2 in os.listdir(f"./af_predictions_all/af3_output/{folder}"):
            if folder2.startswith("seed-"):
                model_number = folder2.split("-")[-1]
                model_path = f"./af_predictions_all/models/cleaned_models/{pdb_id}_{model_number}.pdb"
                seqs, _ = extract_sequences(model_path)
                length = len(seqs)
                print(length)
                merged_model_AB=f"./af_predictions_all/models/merged_models_AB/{pdb_id}_{model_number}_merged.pdb"
                merged_model_BA=f"./af_predictions_all/models/merged_models_BA/{pdb_id}_{model_number}_merged.pdb"
                pdb_model=f"./af_predictions_all/models/cleaned_models/{pdb_id}_{model_number}.pdb"
                folder2_path=f"./af_predictions_all/af3_output/{folder}/{folder2}"
                print(f"Processing model {model_number} for pdb_id: {pdb_id}...")
                for file in os.listdir(folder2_path):
                    if file.endswith(".cif"):
                        cif_file = os.path.join(folder2_path, file)
                    if file.endswith(".json") and file.startswith("summary"):
                        summary_json = os.path.join(folder2_path, file)
                    if file.endswith(".json") and not file.startswith("summary"):
                        all_data_json = os.path.join(folder2_path, file)

                        mean = calculate_global_plddt(all_data_json)
                        
                        if length == 5:
                            b_factors_cdr1_A, b_factors_cdr2_A, b_factors_cdr3_A, b_factors_cdr1_B, b_factors_cdr2_B, b_factors_cdr3_B = cdr_plddts(pdb_model, "D", "E")
                            iptm_mean, iptm_tcrpmhc = calculate_iptms(summary_json)
                            
                        if length == 4:
                            b_factors_cdr1_A, b_factors_cdr2_A, b_factors_cdr3_A, b_factors_cdr1_B, b_factors_cdr2_B, b_factors_cdr3_B = cdr_plddts(pdb_model, "C", "D")
                            iptm_mean, iptm_tcrpmhc = calculate_iptms(summary_json, length=4)
                        _, pdockq_AB = calculate_pdockq(merged_model_AB)
                        _, pdockq_BA = calculate_pdockq(merged_model_BA)
                        _, avgipae_A, avgipae_B, pdockq2_A, pdockq2_B = calculate_pdockq2_json(merged_model_AB, all_data_json)
                        _, avgipae_A2, avgipae_B2, pdockq2_A2, pdockq2_B2 = calculate_pdockq2_json(merged_model_BA, all_data_json)
                        row = {
                                "pdb_id": pdb_id,
                                "model_number": model_number,
                                "global_plddt": mean,
                                "cdr1_A": b_factors_cdr1_A,
                                "cdr2_A": b_factors_cdr2_A,
                                "cdr3_A": b_factors_cdr3_A,
                                "cdr1_B": b_factors_cdr1_B,
                                "cdr2_B": b_factors_cdr2_B,
                                "cdr3_B": b_factors_cdr3_B,
                                "iptm_mean": iptm_mean,
                                "iptm_tcrpmhc": iptm_tcrpmhc,
                                "pdockq_AB": pdockq_AB,
                                "pdockq_BA": pdockq_BA,
                                "avgipae_A_AB": avgipae_A,
                                "avgipae_B_AB": avgipae_B,
                                "pdockq2_A_AB": pdockq2_A,
                                "pdockq2_A_BA": pdockq2_B,
                                "avgipae_A_BA": avgipae_A2,
                                "avgipae_B_BA": avgipae_B2,
                                "pdockq2_A_BA": pdockq2_A2,
                                "pdockq2_B_BA": pdockq2_B2}
                        print(row)
                        dataframe = pd.concat([dataframe, pd.DataFrame([row])], ignore_index=True)
df.to_csv("./af3_predictions_all/model.csv", index=False)

In [17]:
import os
import pandas as pd

# Path to the CSV file
csv_path = "./af3_predictions_all/model.csv"

# Ensure that the directory exists, if not, create it
csv_dir = os.path.dirname(csv_path)
if not os.path.exists(csv_dir):
    os.makedirs(csv_dir)

# Initialize or read the existing CSV file
if os.path.exists(csv_path) and os.stat(csv_path).st_size > 0:
    df_existing = pd.read_csv(csv_path)
    # Ensure that the CSV has the necessary columns
    if 'pdb_id' in df_existing.columns and 'model_number' in df_existing.columns:
        # Create a set of existing "pdb_id_model_number" to avoid reprocessing
        existing_combinations = set(zip(df_existing['pdb_id'], df_existing['model_number']))
    else:
        # If the necessary columns are missing, initialize an empty set
        existing_combinations = set()
else:
    # If the CSV doesn't exist or is empty, initialize an empty set
    df_existing = pd.DataFrame()
    existing_combinations = set()

# Iterate through the folders and files
for folder in os.listdir("./af_predictions_all/af3_output"):
    if len(folder) == 4:  # Assuming pdb_id length is 4
        pdb_id = folder
        for folder2 in os.listdir(f"./af_predictions_all/af3_output/{folder}"):
            if folder2.startswith("seed-"):
                model_number = folder2.split("-")[-1]
                # Skip this model if it's already in the CSV
                if (pdb_id, model_number) in existing_combinations:
                    print(f"Skipping {pdb_id}_{model_number} as it has already been processed.")
                    continue
                
                model_path = f"./af_predictions_all/models/cleaned_models/{pdb_id}_{model_number}.pdb"
                seqs, _ = extract_sequences(model_path)
                length = len(seqs)
                print(f"Processing model {model_number} for pdb_id: {pdb_id}...")
                
                # Construct paths
                merged_model_AB = f"./af_predictions_all/models/merged_models_AB/{pdb_id}_{model_number}_merged.pdb"
                merged_model_BA = f"./af_predictions_all/models/merged_models_BA/{pdb_id}_{model_number}_merged.pdb"
                pdb_model = f"./af_predictions_all/models/cleaned_models/{pdb_id}_{model_number}.pdb"
                folder2_path = f"./af_predictions_all/af3_output/{folder}/{folder2}"
                
                for file in os.listdir(folder2_path):
                    if file.endswith(".cif"):
                        cif_file = os.path.join(folder2_path, file)
                    if file.endswith(".json") and file.startswith("summary"):
                        summary_json = os.path.join(folder2_path, file)
                    if file.endswith(".json") and not file.startswith("summary"):
                        all_data_json = os.path.join(folder2_path, file)

                        # Calculate global pLDDT and iptm
                        mean = calculate_global_plddt(all_data_json)
                        
                        if length == 5:
                            b_factors_cdr1_A, b_factors_cdr2_A, b_factors_cdr3_A, b_factors_cdr1_B, b_factors_cdr2_B, b_factors_cdr3_B = cdr_plddts(pdb_model, "D", "E")
                            iptm_mean, iptm_tcrpmhc = calculate_iptms(summary_json)
                            
                        if length == 4:
                            b_factors_cdr1_A, b_factors_cdr2_A, b_factors_cdr3_A, b_factors_cdr1_B, b_factors_cdr2_B, b_factors_cdr3_B = cdr_plddts(pdb_model, "C", "D")
                            iptm_mean, iptm_tcrpmhc = calculate_iptms(summary_json, length=4)
                        
                        # Calculate docking parameters
                        _, pdockq_AB = calculate_pdockq(merged_model_AB)
                        _, pdockq_BA = calculate_pdockq(merged_model_BA)
                        _, avgipae_A, avgipae_B, pdockq2_A, pdockq2_B = calculate_pdockq2_json(merged_model_AB, all_data_json)
                        _, avgipae_A2, avgipae_B2, pdockq2_A2, pdockq2_B2 = calculate_pdockq2_json(merged_model_BA, all_data_json)
                        
                        # Create row data
                        row = {
                            "pdb_id": pdb_id,
                            "model_number": model_number,
                            "global_plddt": mean,
                            "cdr1_A": b_factors_cdr1_A,
                            "cdr2_A": b_factors_cdr2_A,
                            "cdr3_A": b_factors_cdr3_A,
                            "cdr1_B": b_factors_cdr1_B,
                            "cdr2_B": b_factors_cdr2_B,
                            "cdr3_B": b_factors_cdr3_B,
                            "iptm_mean": iptm_mean,
                            "iptm_tcrpmhc": iptm_tcrpmhc,
                            "pdockq_AB": pdockq_AB,
                            "pdockq_BA": pdockq_BA,
                            "avgipae_A_AB": avgipae_A,
                            "avgipae_B_AB": avgipae_B,
                            "pdockq2_A_AB": pdockq2_A,
                            "pdockq2_A_BA": pdockq2_B,
                            "avgipae_A_BA": avgipae_A2,
                            "avgipae_B_BA": avgipae_B2,
                            "pdockq2_A_BA": pdockq2_A2,
                            "pdockq2_B_BA": pdockq2_B2
                        }
                        print(row)
                        
                        # Save the row to the CSV immediately
                        row_df = pd.DataFrame([row])
                        row_df.to_csv(csv_path, mode='a', header=not os.path.exists(csv_path), index=False)

                # Add the current combination to the set of processed models
                existing_combinations.add((pdb_id, model_number))

Processing model 0 for pdb_id: 5wlg...
{'pdb_id': '5wlg', 'model_number': '0', 'global_plddt': 89.97140990159902, 'cdr1_A': 83.16725, 'cdr2_A': 84.91209677419354, 'cdr3_A': 74.1575, 'cdr1_B': 86.90913043478261, 'cdr2_B': 86.6615909090909, 'cdr3_B': 81.1245054945055, 'iptm_mean': 0.792, 'iptm_tcrpmhc': 0.8, 'pdockq_AB': 0.401, 'pdockq_BA': 0.401, 'avgipae_A_AB': 0.3894232748161533, 'avgipae_B_AB': 0.4695454977934307, 'pdockq2_A_AB': 0.028794491384602106, 'pdockq2_A_BA': 0.669402399363523, 'avgipae_A_BA': 0.9221867984918537, 'avgipae_B_BA': 0.8459526847966568, 'pdockq2_B_BA': 0.29726644727714774}
Processing model 1 for pdb_id: 5wlg...
{'pdb_id': '5wlg', 'model_number': '1', 'global_plddt': 90.02234009840099, 'cdr1_A': 82.97375, 'cdr2_A': 85.01, 'cdr3_A': 73.95906249999999, 'cdr1_B': 86.555, 'cdr2_B': 86.10545454545455, 'cdr3_B': 80.59483516483516, 'iptm_mean': 0.796, 'iptm_tcrpmhc': 0.8025, 'pdockq_AB': 0.393, 'pdockq_BA': 0.393, 'avgipae_A_AB': 0.37492175745994655, 'avgipae_B_AB': 0.470