In [None]:
import numpy as np
import pandas as pd

from data import ThermoMutDB
from interfaces import ProteinPredictionTask, ProteinPredictionReturnType
from models import ESM3Model
from utils import ProteinComparator, ProteinComparatorMethod, ProteinAlignment

In [None]:
# Instantiate components
model = ESM3Model()
comparator = ProteinComparator()
test_set = ThermoMutDB(filepath="edited_thermomutdb_subset_esm3.json", mutation_manifest_path="../data/thermomutdb_alphafold_investigation.csv")

In [None]:
SUFFIX = "mutant" #@param ['mutant','unaltered'] {type:"raw"}

In [None]:
test_set_results = []

print(f"Running through {len(test_set)} samples...")

# Run through all samples in the test set
for idx, sample in enumerate(test_set):
    try:
        # Obtain necessary information
        original_protein_fasta = sample["real_fasta"].iloc[0][:-4]
        mutation_code = sample["mutation_code"].iloc[0]
        target_mutation_id = sample["target_mutation_id"].iloc[0]
        real_protein_pdb = sample["real_pdb_wild"].iloc[0]

        # Capture the mutation index and amino acid change
        original_amino_acid = mutation_code[0]
        variant_position = int(mutation_code[1:-1]) - 1
        variant_amino_acid = mutation_code[-1]

        # If the variant is in the last 4 amino acids, we cannot predict for it with ESM3.
        # Skip it.
        if len(original_protein_fasta) <= variant_position:
            continue
        else:
            # Make sure the amino acid in the FASTA sequence matches the original amino acid in the variant
            sequence_position_matches_variant_original_amino_acid = original_amino_acid == original_protein_fasta[variant_position]

            print(f"Mutation ID: {target_mutation_id}")
            print(f"Assert: {original_amino_acid} matches {original_protein_fasta[variant_position]} is {sequence_position_matches_variant_original_amino_acid}.", end=" ")
            
            # Filter if amino acid in the sequence does not match original amino acid in the variant
            if not sequence_position_matches_variant_original_amino_acid:
                print("Skipping...")
                continue
            else:
                print()

        if SUFFIX == "mutant":
            # Modify the sequence to be the variant sequence
            original_protein_fasta = original_protein_fasta[:variant_position] + variant_amino_acid + original_protein_fasta[variant_position + 1:]

        # Run inference on the original protein FASTA
        resultant_protein = model(
            ProteinPredictionTask.STRUCTURE_PREDICTION,
            protein=original_protein_fasta,
            generation_config_kwargs={ "num_steps": 1, "temperature": 0.0 },
            return_type=ProteinPredictionReturnType.DEFAULT,
        )
        resultant_pdb = resultant_protein.to_pdb_string()

        # Compute the optimal alignment and USalign/TMalign scores
        resulting_alignment = comparator.compute_score_and_alignment(
            resultant_pdb,
            real_protein_pdb,
        )

        # Print and visualize the protein alignment and scores
        for alignment_score in resulting_alignment:
            print(f"{alignment_score.method.value}: {alignment_score.final_score}")
        
        # Append the current alignment to the results
        test_set_results.append((idx, resulting_alignment + [
            ProteinAlignment(
                method=ProteinComparatorMethod.LDDT,
                pdb1=resultant_pdb,
                pdb2=None,
                superimposed_pdb=None,
                score1=resultant_protein.plddt,
                score2=resultant_protein.plddt[variant_position].item(), # The confidence of the atomic coordinates of the amino-acid/residue at the variant position
                final_score=np.average(resultant_protein.plddt),
                auxiliary=None,
            )
        ]))

        # Visualize the alignment
        # comparator.visualize_alignment(resulting_alignment[0])
    except Exception as e:
        # Skip the sample due to the exception
        print(f"Skipping sample at index {idx} due to `{e}`.")
        test_set_results.append((idx, e))

In [None]:
# print(test_set_results)

In [None]:
EXTRA_FIELDS = ["target_id", "pdb_id", "mutation_code"]

data =  { field_name: [] for field_name in EXTRA_FIELDS }
data.update({
    "US-Align": [],
    "pLDDT_of_variant_residue": [],
    "avg_pLDDT": [],
})

for i, sample in enumerate(test_set):
    results = list(filter(lambda res: i == res[0], test_set_results))
    if len(results) > 0:
        _, alignment_list_or_err = results[0]
        if isinstance(alignment_list_or_err, Exception):
            continue
        us_alignment, plddt_alignment = tuple(alignment_list_or_err)
        for field in EXTRA_FIELDS:
            data[field] = data[field] + [sample[field].iloc[0]]

        data["US-Align"] = data["US-Align"] + [us_alignment.final_score]
        data["pLDDT_of_variant_residue"] = data["pLDDT_of_variant_residue"] + [plddt_alignment.score2]
        data["avg_pLDDT"] = data["avg_pLDDT"] + [plddt_alignment.final_score]

pd.DataFrame(data).to_csv(f"thermomutdb_prediction_results_{SUFFIX}.csv", index=False)