In [None]:
import os
import sys

import numpy as np
import torch

sys.path.append("../")

from score_ligandmpnn import LigandMPNNBatch, score_complex

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(
    "../model_params/ligandmpnn_v_32_020_25.pt", map_location=device
)
ligand_mpnn = LigandMPNNBatch(
    model_type="ligand_mpnn",
    k_neighbors=32,
    atom_context_num=25,
    ligand_mpnn_use_side_chain_context=True,
    device=device,
)
ligand_mpnn.load_state_dict(checkpoint["model_state_dict"])
ligand_mpnn.to(device)
ligand_mpnn.eval()

In [None]:
# pdbfile = "../results/sample/backbones/NbALFA_ALFAtag_AF3_0.pdb"
pdbfile = "../pdbs/NbALFA_ALFAtag_AF3.pdb"
chains_to_design = "A"
redesigned_residues = "A1 A3 A4 A5 A7 A8 A9 A13 A14 A15 A19 A20 A21 A23 A24 A25 A26 A27 A39 A41 A44 A45 A46 A48 A50 A52 A53 A67 A68 A69 A72 A73 A74 A75 A76 A77 A78 A79 A80 A81 A82 A83 A84 A85 A86 A88 A89 A91 A92 A93 A95 A97 A99 A100 A102 A114 A116 A118 A119 A120 A121 A123 A124"
target_seqs_list = [
    "SGEVQLQESGGGLVQPGGSLRLSCTASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS:PSRLEEELRRRLTEP",
    "GGTVVLTESGGGTVAPGGSATLTATASGVTISALNAMAWGWYRQRPGERPVAVAAVSERGNAMYREDVRGRWTVTADRANKTVSLEMRDLQPEDTATYYPHVLEDRVDSFHDYWGAGVPLTVVP:PSRLEEELRRRLTEP",
    "GQVQLQQSAELARPGASVKMSCKASGYTFTSQAPGKGLEWVSAITWNELARPGASVKMSGHIDYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKYPYYGSHWYFDVWGAGTTVTVS:PSRLEEELRRRLTEP",
    "PGLRAEDTAVYYCAKYPYELARPGYTFTSQAPGKGLGSHWYFDVWWYFDLYQMNSLRATIRDNSKNTWVSEVWGAGTASKMSCKASGGSVKMEDTAVYYCAKYPYYGSHGAGTDNSKNAVVTVS:PSRLEEELRRRLTEP",
]

entropy, loss, perplexity = score_complex(
    ligand_mpnn,
    pdbfile,
    # chains_to_design=chains_to_design,
    redesigned_residues=redesigned_residues,
    seqs_list=target_seqs_list,
    use_side_chain_context=True,
)
entropy.shape, loss, perplexity

In [None]:
def extract_from_score(output_path):
    with open(output_path, "rb") as f:
        output = torch.load(f)

    entropy = -(
        torch.tensor(output["logits"][:, :, :20]).softmax(dim=-1).mean(dim=0).log()
    )  # (L, 20)
    target = torch.tensor(output["native_sequence"], dtype=torch.long)  # (L,)
    loss = torch.gather(entropy, 1, target.unsqueeze(1)).squeeze()  # (L,)
    perplexity = torch.exp(loss.mean()).item()  # scalar

    return entropy, loss, perplexity


def extract_from_sample(output_path):
    with open(output_path, "rb") as f:
        output = torch.load(f)

    entropy = -output["log_probs"]  # (B, L, 20)
    target = output["generated_sequences"]  # (B, L)
    loss = torch.gather(entropy, 2, target.unsqueeze(2)).squeeze(2)  # (B, L)
    perplexity = torch.exp(loss.mean(dim=-1))  # (B,)
    # redesigned = output["chain_mask"] == 1
    # confidence = torch.exp(-loss[:, redesigned].mean(dim=-1))

    return entropy, loss, perplexity

In [None]:
# !sh "./score_complex.sh" "../pdbs/NbALFA_ALFAtag_AF3.pdb" "../results_/score/"
# extract_from_score("../results_/score/NbALFA_ALFAtag_AF3.pt")

In [None]:
# !sh "./score_wt.sh" "../pdbs/NbALFA_ALFAtag_AF3.pdb" "../results/score"
# extract_from_score("../results/score/NbALFA_ALFAtag_AF3.pt")

In [None]:
# !sh "./sample_complex.sh" "../pdbs/NbALFA_AF3.pdb" "../results/sample"
# extract_from_sample("../results/sample/stats/NbALFA_AF3.pt")