# Evaluation of predicted structures

In [9]:
import pandas as pd
from typing import List, Dict, Optional, Set, Tuple
import pickle

In [10]:
PROTEIN_DIRECTORY = "data/proteins"

chains = pd.read_csv("data/chains.csv")
chains

Unnamed: 0,pdb_id,label,chain_id
0,8P0E,monomer,8P0E:A
1,8PX8,monomer,8PX8:A
2,8B2E,monomer,8B2E:A
3,8HOE,monomer,8HOE:A
4,8TCE,monomer,8TCE:A
...,...,...,...
1458,8G9J,synthetic,8G9J:A
1459,8OYV,synthetic,8OYV:A
1460,8TNO,synthetic,8TNO:A
1461,8FJE,synthetic,8FJE:A


Definition of Structure object and related functions:

In [11]:
from Bio.Data.PDBData import protein_letters_3to1_extended

class Atom:
    def __init__(self, x: float, y: float, z: float):
        self.x = x
        self.y = y
        self.z = z


class Residue:
    def __init__(self, amino_acid: str, position: int, alpha: str, order: int):
        self.amino_acid = amino_acid
        self.position = position
        self.alpha = alpha
        self.order = order
        self.ca = None
        self.is_hetatm = False
        self.is_terminal = False

    def add_alpha_carbon(self, x: float, y: float, z: float) -> None:
        self.ca = Atom(x, y, z)
        

class Chain:
    def __init__(self, letter: str, expected_sequence: str):
        self.letter = letter
        self.expected_sequence = expected_sequence
        self.residues = {}
        self.residue_counter = 0
        self.sequence = None
        self.mask = None

    def add_residue(self, amino_acid: str, position_string: str) -> Residue:
        position, alpha = process_position(position_string)
        residue = Residue(amino_acid, position, alpha, self.residue_counter)
        self.residues[position_string] = residue
        self.residue_counter += 1
        return residue

    def get_residue(self, amino_acid: str, position_string: str) -> Residue:
        residue = self.residues.get(position_string)
        if residue is None:
            residue = self.add_residue(amino_acid, position_string)
        return residue

    def save_sequence_and_mask(self) -> bool:
        sequence, mask = [], []
        met_terminal = False
        sorted_residues = sorted(list(self.residues.items()), key = lambda x: (x[1].position, x[1].alpha, x[1].order))
        for position, residue in sorted_residues:
            if residue.is_hetatm and met_terminal:
                continue
            sequence.append(residue.amino_acid)
            if residue.ca is not None:
                mask.append("1")
            else:
                mask.append("0")
            if residue.is_terminal:
                met_terminal = True
        self.sequence = "".join(sequence)
        self.mask = "".join(mask)
        return self.expected_sequence.strip("X") == self.sequence.strip("X")


class Structure:
    def __init__(self, pdb_id: str, expected_chains: Dict[str, str]):
        self.pdb_id = pdb_id
        self.chains = {chain_letter: Chain(chain_letter, expected_sequence) for chain_letter, expected_sequence in expected_chains.items()}

    def parse_ATOM_and_HETATM(self, line: str) -> None:
        is_hetatm = line.startswith("HETATM")
        atom_name = line[12:16].strip()
        amino_acid = protein_letters_3to1_extended.get(line[17:20])
        chain = self.chains.get(line[21])
        position_string = line[22:27].strip()
        if amino_acid is not None and chain is not None:
            residue = chain.get_residue(amino_acid, position_string)
            if is_hetatm:
                residue.is_hetatm = True
            if atom_name == "CA":
                residue.add_alpha_carbon(float(line[30:38]), float(line[38:46]), float(line[46:54]))

    def parse_TER(self, line: str) -> None:
        amino_acid = protein_letters_3to1_extended.get(line[17:20], "X")
        chain = self.chains.get(line[21])
        position_string = line[22:27].strip()
        if chain is not None:
            chain.get_residue(amino_acid, position_string).is_terminal = True

    def parse_REMARK_465_and_MODRES(self, line: str, is_REMARK_465_line: bool) -> None:
        amino_acid_position = 2 if is_REMARK_465_line else 5
        attributes = line.split()
        if len(attributes) >= 5:
            amino_acid = protein_letters_3to1_extended.get(attributes[amino_acid_position], "X")
            chain = self.chains.get(attributes[3])
            position_string = attributes[4]
            if amino_acid is not None and chain is not None:
                chain.add_residue(amino_acid, position_string)

    def save_sequences_and_masks(self) -> bool:
        for chain in self.chains.values():
            if not chain.save_sequence_and_mask():
                return False
        return True

    def write_to_files(self, directory: str, wanted_chain_letters: Set[str]) -> None:
        sequences = {chain_letter: chain.sequence for chain_letter, chain in self.chains.items()}
        with open(f"{directory}/{self.pdb_id}_inferred.fasta", "w") as fasta_file:
            for chain_letter, sequence in sequences.items():
                if chain_letter in wanted_chain_letters:
                    fasta_file.write(f">{self.pdb_id}:{chain_letter}\n{sequence}\n")
                    with open(f"{directory}/{self.pdb_id}:{chain_letter}.fasta", "w") as chain_fasta_file:
                        chain_fasta_file.write(f">{self.pdb_id}:{chain_letter}\n{sequence}\n")


## pLDDT

In [12]:
from Bio.PDB import PDBParser
from statistics import mean
import os

def extract_average_pLDDT(pdb_path, multiply_by_100 = False):
    if not os.path.exists(pdb_path):
        return None
    structure = PDBParser(QUIET = True).get_structure("X", pdb_path)
    b_factors = [residue["CA"].get_bfactor() for residue in structure.get_residues()]
    return mean(b_factors) * 100 if multiply_by_100 else mean(b_factors)


chains["AF_average_pLDDT"] = chains.apply(
    lambda row: extract_average_pLDDT(f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/alphafold/{row['chain_id']}.pdb"),
    axis = 1
)

chains["OF_average_pLDDT"] = chains.apply(
    lambda row: extract_average_pLDDT(f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/omegafold/{row['chain_id']}.pdb"),
    axis = 1
)

chains["EF_average_pLDDT"] = chains.apply(
    lambda row: extract_average_pLDDT(f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/esmfold/{row['chain_id']}.pdb",
                                      multiply_by_100 = True),
    axis = 1
)

chains

Unnamed: 0,pdb_id,label,chain_id,AF_average_pLDDT,OF_average_pLDDT,EF_average_pLDDT
0,8P0E,monomer,8P0E:A,93.108947,93.535368,89.263158
1,8PX8,monomer,8PX8:A,95.914087,95.609217,95.147826
2,8B2E,monomer,8B2E:A,97.636503,97.136084,94.762238
3,8HOE,monomer,8HOE:A,93.577196,79.554392,89.613757
4,8TCE,monomer,8TCE:A,90.252128,89.159681,86.180851
...,...,...,...,...,...,...
1458,8G9J,synthetic,8G9J:A,90.586771,92.542063,89.071749
1459,8OYV,synthetic,8OYV:A,86.604205,88.043795,82.605128
1460,8TNO,synthetic,8TNO:A,90.392837,90.149504,87.758865
1461,8FJE,synthetic,8FJE:A,93.426345,93.104138,88.586207


## TM score

In [13]:
from tmtools import tm_align
from prody import parsePDB, AtomGroup
from numpy import array

In [14]:
def get_original_coords_and_sequence(structure: Structure, chain_letter: str):
    ca_coords, sequence = [], []
    chain = structure.chains[chain_letter]
    for residue in chain.residues.values():
        if residue.ca is not None and residue.amino_acid != "X":
            sequence.append(residue.amino_acid)
            ca_coords.append([residue.ca.x, residue.ca.y, residue.ca.z])
        if residue.is_terminal:
            break
    return array(ca_coords), "".join(sequence)


def get_predicted_coords_and_sequence(pdb_path: str, mask: str):
    if not os.path.exists(pdb_path):
        return None, None
    chain = parsePDB(pdb_path, chain = "A", subset = 'calpha')
    assert(len(chain) == len(mask)), print(pdb_path)
    coords, sequence = [], []
    for i, atom in enumerate(chain):
        if mask[i] == "1":
            coords.append(atom.getCoords())
            sequence.append(protein_letters_3to1_extended[atom.getResname()])
    return array(coords), "".join(sequence)


def compute_tm_score(chain_letter: str, pickle_path: str, prediction_path: str) -> float:
    with open(pickle_path, "rb") as pickle_file:
        structure = pickle.load(pickle_file)

    original_coords, original_sequence = get_original_coords_and_sequence(structure, chain_letter)
    predicted_coords, predicted_sequence = get_predicted_coords_and_sequence(
        prediction_path, structure.chains[chain_letter].mask
    )
    if predicted_coords is None:
        return None
    assert(len(original_sequence) == len(predicted_sequence)), print(prediction_path)
    if len(predicted_sequence) <= 2:
        return None
    result = tm_align(original_coords, predicted_coords, original_sequence, predicted_sequence)
    return result.tm_norm_chain1



chains["AF_TM_score"] = chains.apply(lambda row: compute_tm_score(row["chain_id"].split(":")[1],
                                                                  f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/{row['pdb_id']}.pkl",
                                                                  f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/alphafold/{row['chain_id']}.pdb"),
                                     axis = 1)


chains["OF_TM_score"] = chains.apply(lambda row: compute_tm_score(row["chain_id"].split(":")[1],
                                                                  f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/{row['pdb_id']}.pkl",
                                                                  f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/omegafold/{row['chain_id']}.pdb"),
                                     axis = 1)

chains["EF_TM_score"] = chains.apply(lambda row: compute_tm_score(row["chain_id"].split(":")[1],
                                                                  f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/{row['pdb_id']}.pkl",
                                                                  f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/esmfold/{row['chain_id']}.pdb"),
                                     axis = 1)

chains

Unnamed: 0,pdb_id,label,chain_id,AF_average_pLDDT,OF_average_pLDDT,EF_average_pLDDT,AF_TM_score,OF_TM_score,EF_TM_score
0,8P0E,monomer,8P0E:A,93.108947,93.535368,89.263158,0.966687,0.967459,0.982825
1,8PX8,monomer,8PX8:A,95.914087,95.609217,95.147826,0.995011,0.995740,0.989534
2,8B2E,monomer,8B2E:A,97.636503,97.136084,94.762238,0.991946,0.979414,0.990990
3,8HOE,monomer,8HOE:A,93.577196,79.554392,89.613757,0.988360,0.863854,0.976008
4,8TCE,monomer,8TCE:A,90.252128,89.159681,86.180851,0.938542,0.920785,0.920808
...,...,...,...,...,...,...,...,...,...
1458,8G9J,synthetic,8G9J:A,90.586771,92.542063,89.071749,0.637726,0.984365,0.984592
1459,8OYV,synthetic,8OYV:A,86.604205,88.043795,82.605128,0.981296,0.979309,0.980517
1460,8TNO,synthetic,8TNO:A,90.392837,90.149504,87.758865,0.966437,0.966444,0.958189
1461,8FJE,synthetic,8FJE:A,93.426345,93.104138,88.586207,0.970529,0.965841,0.972681


## RMSD

In [15]:
from Bio.SVDSuperimposer import SVDSuperimposer
import Bio.PDB
from numpy import save


def compute_RMSD(chain_letter: str, pickle_path: str, prediction_path: str, rotran_path: str) -> float:
    with open(pickle_path, "rb") as pickle_file:
        structure = pickle.load(pickle_file)

    original_coords, _ = get_original_coords_and_sequence(structure, chain_letter)
    predicted_coords, _ = get_predicted_coords_and_sequence(prediction_path, structure.chains[chain_letter].mask)
    if predicted_coords is None:
        return None
    assert(len(original_coords) == len(predicted_coords))
    sup = SVDSuperimposer()
    sup.set(original_coords, predicted_coords)
    sup.run()
    rotation, translation = sup.get_rotran()
    save(f"{rotran_path}/rotation.npy", rotation)
    save(f"{rotran_path}/translation.npy", translation)
    return sup.get_rms()



chains["AF_RMSD"] = chains.apply(lambda row: compute_RMSD(row["chain_id"].split(":")[1],
                                                          f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/{row['pdb_id']}.pkl",
                                                          f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/alphafold/{row['chain_id']}.pdb",
                                                          f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/alphafold"),
                                     axis = 1)


chains["OF_RMSD"] = chains.apply(lambda row: compute_RMSD(row["chain_id"].split(":")[1],
                                                          f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/{row['pdb_id']}.pkl",
                                                          f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/omegafold/{row['chain_id']}.pdb",
                                                          f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/omegafold"),
                                     axis = 1)

chains["EF_RMSD"] = chains.apply(lambda row: compute_RMSD(row["chain_id"].split(":")[1],
                                                          f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/{row['pdb_id']}.pkl",
                                                          f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/esmfold/{row['chain_id']}.pdb",
                                                          f"{PROTEIN_DIRECTORY}/{row['pdb_id']}/esmfold"),
                                     axis = 1)

chains

Unnamed: 0,pdb_id,label,chain_id,AF_average_pLDDT,OF_average_pLDDT,EF_average_pLDDT,AF_TM_score,OF_TM_score,EF_TM_score,AF_RMSD,OF_RMSD,EF_RMSD
0,8P0E,monomer,8P0E:A,93.108947,93.535368,89.263158,0.966687,0.967459,0.982825,1.757680,1.738276,0.703625
1,8PX8,monomer,8PX8:A,95.914087,95.609217,95.147826,0.995011,0.995740,0.989534,0.275844,0.253193,0.706370
2,8B2E,monomer,8B2E:A,97.636503,97.136084,94.762238,0.991946,0.979414,0.990990,0.400699,0.668645,0.425292
3,8HOE,monomer,8HOE:A,93.577196,79.554392,89.613757,0.988360,0.863854,0.976008,0.564870,2.354328,0.834659
4,8TCE,monomer,8TCE:A,90.252128,89.159681,86.180851,0.938542,0.920785,0.920808,1.805852,2.571247,2.617893
...,...,...,...,...,...,...,...,...,...,...,...,...
1458,8G9J,synthetic,8G9J:A,90.586771,92.542063,89.071749,0.637726,0.984365,0.984592,6.550639,0.678841,0.674661
1459,8OYV,synthetic,8OYV:A,86.604205,88.043795,82.605128,0.981296,0.979309,0.980517,0.725727,0.761958,0.734274
1460,8TNO,synthetic,8TNO:A,90.392837,90.149504,87.758865,0.966437,0.966444,0.958189,1.621622,1.482032,1.525321
1461,8FJE,synthetic,8FJE:A,93.426345,93.104138,88.586207,0.970529,0.965841,0.972681,0.845441,0.909988,0.804459


## Save `chains` to a .csv file

In [16]:
chains.to_csv("data/chains_evaluation.csv", sep = ",", index = False)