In [1]:
from data import CASPTestSet
from interfaces import ProteinPredictionTask
from models import ESM3Model
from utils import ProteinComparator


# Instantiate components
model = ESM3Model()
comparator = ProteinComparator()
test_set = CASPTestSet()

# Grab the first data point from the test set
first_sample = test_set[0]

fasta = first_sample["real_fasta"].iloc[0][:-4]
pdb = first_sample["real_pdb"].iloc[0]
print(f"Original protein FASTA: {fasta}")

# Run inference on the original protein FASTA
resultant_pdb = model(
    ProteinPredictionTask.STRUCTURE_PREDICTION,
    protein=fasta,
    generation_config_kwargs={ "num_steps": 1, "temperature": 0.0 }
)

print(f"Prediction PDB: {resultant_pdb}")

# Compute the optimal alignment and USalign/TMalign scores
resulting_alignment = comparator.compute_score_and_alignment(
    resultant_pdb,
    pdb,
)[0]

# Print and visualize the protein alignment and scores
print(f"TM-Score 1: {resulting_alignment.score1}")
print(f"TM-Score 2: {resulting_alignment.score2}")
print(f"Real TM-Score of Prediction Relative to Ground Truth: {resulting_alignment.final_score}")
comparator.visualize_alignment(resulting_alignment)

  from .autonotebook import tqdm as notebook_tqdm
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.
Fetching 22 files: 100%|██████████| 22/22 [00:00<00:00, 65863.45it/s]


Original protein FASTA: GADSIYVREQQIPILIDRIDNVLYEMRIPAQKGDVLNEITIQIGDNVDLSDIQAIRLFYSGVEAPSRKGEHFSPVTYISSHIPGNTRKALESYSVRQDEVTAPLSRTVKLTSKQPMLKGINYFWVSIQMKPETSLLAKVATTMPNAQINNKPINITWKGKVDERHVGIGVRQAGDDGSAAFRIPGLVTTNNGTLLGVYDIRYNSSVDLQEKIDIGVSRSTDKGQTWEPMRVAMTFKQTDGLPHGQNGVGDPSILVDEKTNTIWVVAAWTHGMGNERAWWNSMPGMTPDETAQLMLVKSEDDGKTWSEPINITSQVKDPSWYFLLQGPGRGITMQDGTLVFPIQFIDATRVPNAGIMYSKDRGKTWHLHNLARTNTTEAQVAEVEPGVLMLNMRDNRGGSRAVATTKDLGKTWTEHPSSRSALQESVCMASLIKVNAKDNITGKDLLLFSNPNTTKGRNHITIKASLDGGLTWPTEHQVLLDEAEGWGYSCLSMIDKETVGIFYESSVAHMTFQAVKLQDL


100%|██████████| 1/1 [00:05<00:00,  5.29s/it]


Prediction PDB: ATOM      1  N   GLY A   1     -43.070   8.011   4.910  1.00  0.51           N  
ATOM      2  CA  GLY A   1     -41.793   7.666   5.527  1.00  0.51           C  
ATOM      3  C   GLY A   1     -41.921   6.417   6.392  1.00  0.51           C  
ATOM      4  O   GLY A   1     -41.908   6.504   7.620  1.00  0.51           O  
ATOM      5  N   ALA A   2     -41.942   6.188   6.018  1.00  0.59           N  
ATOM      6  CA  ALA A   2     -41.817   5.030   6.896  1.00  0.59           C  
ATOM      7  C   ALA A   2     -40.861   5.315   8.050  1.00  0.59           C  
ATOM      8  O   ALA A   2     -39.829   5.959   7.860  1.00  0.59           O  
ATOM      9  N   ASP A   3     -41.545   4.856   8.279  1.00  0.71           N  
ATOM     10  CA  ASP A   3     -40.791   4.826   9.528  1.00  0.71           C  
ATOM     11  C   ASP A   3     -39.716   3.745   9.494  1.00  0.71           C  
ATOM     12  O   ASP A   3     -39.681   2.872  10.361  1.00  0.71           O  
ATOM     13 