In [None]:
import numpy as np
import pandas as pd

from data import CASPTestSet
from interfaces import ProteinPredictionTask, ProteinPredictionReturnType
from models import ESM3Model
from utils import ProteinComparator, ProteinComparatorMethod, ProteinAlignment

In [None]:
# Instantiate components
model = ESM3Model()
comparator = ProteinComparator()
test_set = CASPTestSet()

In [None]:
test_set_results = []

print(f"Running through {len(test_set)} samples...")

# Run through all samples in the test set
for idx, sample in enumerate(test_set):
    try:
        original_protein_fasta = sample["real_fasta"].iloc[0][:-4]
        real_protein_pdb = sample["real_pdb"].iloc[0]

        # Run inference on the original protein FASTA
        resultant_protein = model(
            ProteinPredictionTask.STRUCTURE_PREDICTION,
            protein=original_protein_fasta,
            generation_config_kwargs={ "num_steps": 1, "temperature": 0.0 },
            return_type=ProteinPredictionReturnType.DEFAULT,
        )
        resultant_pdb = resultant_protein.to_pdb_string()

        # Compute the optimal alignment and USalign/TMalign scores
        resulting_alignment = comparator.compute_score_and_alignment(
            resultant_pdb,
            real_protein_pdb,
        )

        # Print and visualize the protein alignment and scores
        for alignment_score in resulting_alignment:
            print(f"{alignment_score.method.value}: {alignment_score.final_score}")
        
        # Append the current alignment to the results
        test_set_results.append((idx, resulting_alignment + [
            ProteinAlignment(
                method=ProteinComparatorMethod.LDDT,
                pdb1=resultant_pdb,
                pdb2=None,
                superimposed_pdb=None,
                score1=resultant_protein.plddt,
                score2=None,
                final_score=np.average(resultant_protein.plddt),
                auxiliary=None,
            )
        ]))

        # Visualize the alignment
        comparator.visualize_alignment(resulting_alignment[0])
    except Exception as e:
        # Skip the sample due to the exception
        print(f"Skipping sample at index {idx} due to `{e}`.")
        test_set_results.append((idx, e))

In [None]:
print(test_set_results)

In [None]:
EXTRA_FIELDS = ["subset", "target_id", "pdb_id"]

data =  { field_name: [] for field_name in EXTRA_FIELDS }
data.update({
    "US-Align": [],
    "avg_pLDDT": [],
})

for i, sample in enumerate(test_set):
    results = list(filter(lambda res: i == res[0], test_set_results))
    if len(results) > 0:
        _, alignment_list_or_err = results[0]
        if isinstance(alignment_list_or_err, Exception):
            continue
        us_alignment, plddt_alignment = tuple(alignment_list_or_err)
        for field in EXTRA_FIELDS:
            data[field] = data[field] + [sample[field].iloc[0]]
        data["US-Align"] = data["US-Align"] + [us_alignment.final_score]
        data["avg_pLDDT"] = data["avg_pLDDT"] + [plddt_alignment.final_score]

pd.DataFrame(data).to_csv("casp_structure_prediction_results.csv", index=False)