In [None]:
%pip install py3dmol

In [None]:
from abodybuilder3.utils import string_to_input, output_to_pdb, add_atom37_to_output
from abodybuilder3.lightning_module import LitABB3
from abodybuilder3.language.model import ProtT5
import py3Dmol
import torch

# ABodyBuilder3 Example

We demonstrate our model using structure 6yio_H0-L0. The sequences are given below

In [None]:
heavy = "QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSLAISWVRQAPGQGLEWMGGIIPIFGTANYAQKFQGRVTITADESTSTAYMELSSLRSEDTAVYYCARGGSVSGTLVDFDIWGQGTMVTVSS"
light = "DIQMTQSPSTLSASVGDRVTITCRASQSISSWLAWYQQKPGKAPKLLIYKASSLESGVPSRFSGSGSGTEFTLTISSLQPDDFATYYCQQYNIYPITFGGGTKVEIK"

In [None]:
module = LitABB3.load_from_checkpoint("../output/plddt-loss/best_second_stage.ckpt")
model = module.model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ab_input = string_to_input(heavy=heavy, light=light)
ab_input_batch = {
    key: (value.unsqueeze(0).to(device) if key not in ["single", "pair"] else value.to(device))
    for key, value in ab_input.items()
}

model.to(device)

output = model(ab_input_batch, ab_input_batch["aatype"])
output = add_atom37_to_output(output, ab_input["aatype"].to(device))
pdb_string = output_to_pdb(output, ab_input)

In [None]:
view = py3Dmol.view()
view.addModelsAsFrames(pdb_string, viewer=(0, 0))
view.setStyle(
    {"model": -1}, 
    {"cartoon": {"colorscheme": {"prop": "b", "gradient": "roygb", "min": 50, "max": 100}}}, 
    viewer=(0, 0)
)
view.zoomTo(viewer=(0, 0))
view.render()

# ABodyBuilder3-LM Example

The T5 model is a large model that may not fit into memory. We give an option to use a
pre-computed embedding (the `bash download.sh` script needs to be run first.)

In [None]:
module = LitABB3.load_from_checkpoint("../output/language-loss/best_second_stage.ckpt")
model = module.model

In [None]:
use_precomputed = True

if use_precomputed:
    embedding = torch.load("../data/structures/structures_plm/6yio_H0-L0.pt")[
        "plm_embedding"
    ]
else:
    plm = ProtT5()
    embedding = plm.get_embeddings(
        [
            heavy,
        ],
        [
            light,
        ],
    )

print(f"{embedding.shape=}")

In [None]:
ab_input = string_to_input(heavy=heavy, light=light)
ab_input["single"] = embedding.unsqueeze(0)  # use plm for residue feature
ab_input_batch = {
    key: (value.unsqueeze(0) if key not in ["single", "pair"] else value)
    for key, value in ab_input.items()
}  # add batch dim
output = model(ab_input_batch, ab_input_batch["aatype"])
output = add_atom37_to_output(output, ab_input["aatype"])
pdb_string = output_to_pdb(output, ab_input)

In [None]:
view = py3Dmol.view()
view.addModelsAsFrames(pdb_string, viewer=(0, 0))
view.setStyle(
    {"model": -1}, 
    {"cartoon": {"colorscheme": {"prop": "b", "gradient": "roygb", "min": 50, "max": 100}}}, 
    viewer=(0, 0)
)
view.zoomTo(viewer=(0, 0))
view.render()