In [1]:
import proteinmpnn.run
import proteinmpnn.utils.misc

from proteinmpnn.run import load_protein_mpnn_model, set_seed, nll_score
from proteinmpnn.data import BackboneSample, untokenise_sequence
from proteinmpnn.utils.misc import find_files

import numpy as np
import torch

DEVICE = "cpu"

# Manual mode (more control)

In [2]:
model = load_protein_mpnn_model(model_type="ca", device=DEVICE)

2024-04-29 15:58:24,116 INFO:
	Number of edges: 48 [in load_protein_mpnn_model at run.py:57]
2024-04-29 15:58:24,117 INFO:
	Training noise level: 0.2A [in load_protein_mpnn_model at run.py:59]
2024-04-29 15:58:24,117 INFO:
	Training noise level: 0.2A [in load_protein_mpnn_model at run.py:59]


In [3]:
# ... add backbones from PDB files
pdb_files = find_files("./data", ".pdb")  # Or numpy or torch objects
backbones = [BackboneSample.load_any(f, ca_only=True) for f in pdb_files]

# ...add backbones from numpy files / torch files
backbones += [BackboneSample(bb_coords=np.random.rand(10, 3), 
                             ca_only=True)
            ]
                           
# ... specify a sequence motif to be fixed (via res_mask -- 0's are fixed)
backbones += [BackboneSample(bb_coords=np.random.rand(10, 3), 
                             ca_only=True, 
                             res_name="MXXXACXGXX", 
                             res_mask=np.array([0, 1, 1, 1, 0, 0, 1, 0, 1, 1]))
            ]

backbones

2024-04-29 15:58:24,192 INFO:
	Found 2 files with extension .pdb in ./data. [in find_files at misc.py:43]


[BackboneSample(n_atoms=106, n_residues=106, ca_only=True),
 BackboneSample(n_atoms=68, n_residues=68, ca_only=True),
 BackboneSample(n_atoms=10, n_residues=10, ca_only=True),
 BackboneSample(n_atoms=10, n_residues=10, ca_only=True)]

In [6]:
with torch.inference_mode():
    set_seed(39)

    samples = [model.sample(
        randn=torch.randn(1, backbone.n_residues), 
        **backbone.to_protein_mpnn_input("sampling", device=DEVICE)) 
        for backbone in backbones
    ]

    scores = []
    for sample, backbone in zip(samples, backbones):
        inpt = backbone.to_protein_mpnn_input("scoring", device=DEVICE)
        inpt["decoding_order"] = sample["decoding_order"]
        inpt["S"] = sample["S"]
        log_probs = model(randn=torch.randn(1, backbone.n_residues), 
                      use_input_decoding_order=True, 
                      **inpt)
        sample["nll_score"] = nll_score(sample["S"], log_probs, mask=inpt["mask"])
        sample["prob"] = torch.exp(-sample["nll_score"])


In [7]:
# NOTE the fixed motif
print(untokenise_sequence(samples[0]["S"]))
print(samples[0]["prob"])

TLTLKQTIANQYIKAFERQRSDQCKKCVHPLTIWTVQGWERKREEMVQFVEDMMAKGISWEFQAYERIGVIYDYDAKRQADGVVSFDLYKIEVIEDVIPIIYGNHK
tensor([0.0949])


In [8]:

# NOTE the fixed motif
print(untokenise_sequence(samples[-2]["S"]))
print(samples[-2]["prob"])

SSPWRKKQSS
tensor([0.0677])


In [9]:
# NOTE the fixed motif
print(untokenise_sequence(samples[-1]["S"]))
print(samples[-1]["prob"])
samples[-1]

MQSSACAGYG
tensor([0.0589])


{'S': tensor([[10, 13, 15, 15,  0,  1,  0,  5, 19,  5]]),
 'probs': tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0853, 0.0120, 0.0729, 0.0656, 0.0258, 0.0992, 0.0292, 0.0227,
           0.0495, 0.0447, 0.0139, 0.0509, 0.0720, 0.0418, 0.0461, 0.1256,
           0.0719, 0.0351, 0.0120, 0.0238, 0.0000],
          [0.1103, 0.0138, 0.0790, 0.0694, 0.0352, 0.0955, 0.0332, 0.0242,
           0.0458, 0.0536, 0.0212, 0.0503, 0.0472, 0.0475, 0.0474, 0.0885,
           0.0527, 0.0356, 0.0167, 0.0329, 0.0000],
          [0.0868, 0.0227, 0.0710, 0.0560, 0.0419, 0.0959, 0.0282, 0.0391,
           0.0396, 0.0710, 0.0227, 0.0460, 0.0622, 0.0374, 0.0341, 0.0865,
           0.0605, 0.0487, 0.0171, 0.0327, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

# One-function call

In [None]:
# TODO (wrap the above into a single function)