<a href="https://colab.research.google.com/github/PKGuo/CS433-ML_course-project-2/blob/master/mpnn/examples/proteinmpnn_in_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ProteinMPNN in Jax!

---

fixbb monomer design:
 - `pdb="6MRR" chains="A"`

fixbb homooligomer design:
 - `pdb="5XZK" chains="A,B,C" homooligomer=True`

binder design:
 - `pdb="1SSC" chains="A,B" fix_pos="A"`

---


In [2]:
#@title Install colabdesign
import os
try:
  import colabdesign
except:
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  os.system("ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign")

from colabdesign.mpnn import mk_mpnn_model, clear_mem
from colabdesign.shared.protein import pdb_to_string

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
import pandas as pd
import tqdm.notebook
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

from google.colab import files
from google.colab import data_table
data_table.enable_dataframe_formatter()

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

In [2]:
from platform import python_version

print(python_version())

3.10.12


In [3]:
%%time
#@title Run ProteinMPNN to design new sequences for given backbone

import warnings, os, re
warnings.simplefilter(action='ignore', category=FutureWarning)

os.system("mkdir -p output")

# USER OPTIONS
#@markdown #### ProteinMPNN options
model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
#@markdown #### Input Options
pdb='6MRR' #@param {type:"string"}
#@markdown - leave blank to get an upload prompt
chains = "A" #@param {type:"string"}
homooligomer = False #@param {type:"boolean"}
#@markdown #### Design constraints
fix_pos = "" #@param {type:"string"}
#@markdown - specify which positions to keep fixed in the sequence (example: `1,2-10`)
#@markdown - you can also specify chain specific constraints (example: `A1-10,B1-20`)
#@markdown - you can also specify to fix entire chain(s) (example: `A`)
inverse = False #@param {type:"boolean"}
#@markdown - inverse the `fix_pos` selection (define position to "free" [or design] instead of "fix")
rm_aa = "" #@param {type:"string"}
#@markdown - specify amino acid(s) to exclude (example: `C,A,T`)

#@markdown #### Design Options
num_seqs = 32 #@param ["32", "64", "128", "256", "512", "1024"] {type:"raw"}
sampling_temp = 0.1 #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5", "1.0"] {type:"raw"}
#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.

#@markdown Note: designed sequences are saved to `design.fasta`

# cleaning user options
chains = re.sub("[^A-Za-z]+",",", chains)
if fix_pos == "": fix_pos = None
rm_aa = ",".join(list(re.sub("[^A-Z]+","",rm_aa.upper())))
if rm_aa == "": rm_aa = None

pdb_path = get_pdb(pdb)
if "mpnn_model" not in dir() or model_name_ != model_name:
  mpnn_model = mk_mpnn_model(model_name)
  model_name_ = model_name

mpnn_model.prep_inputs(pdb_filename=pdb_path,
                       chain=chains, homooligomer=homooligomer,
                       fix_pos=fix_pos, inverse=inverse,
                       rm_aa=rm_aa, verbose=True)
out = mpnn_model.sample(num=num_seqs//32, batch=32,
                        temperature=sampling_temp,
                        rescore=homooligomer)

with open("design.fasta","w") as fasta:
  for n in range(num_seqs):
    line = f'>score:{out["score"][n]:.3f}_seqid:{out["seqid"][n]:.3f}\n{out["seq"][n]}'
    fasta.write(line+"\n")

labels = ["score","seqid","seq"]
data = [[out[k][n] for k in labels] for n in range(num_seqs)]

df = pd.DataFrame(data, columns=labels)
df.to_csv('output/mpnn_results.csv')
data_table.DataTable(df.round(3))

lengths [68]
CPU times: user 9.77 s, sys: 812 ms, total: 10.6 s
Wall time: 15.3 s


Unnamed: 0,score,seqid,seq
0,0.909,0.515,GIDEELEKVVKELKKFLKEKGINNVKIEIKDGVLKIKMKGASEEVK...
1,0.935,0.5,GIDEELEKYVKELKKFLKEKGINNVEIEIKDGELKIKMKGADKETK...
2,0.961,0.485,GSDPELEAVVKELEAFLKEQGITNVEIKVEDGTLTITTNGASEALK...
3,0.962,0.515,GKDEELEKVVKKINEFLKKKGINNVKIKVENGTLTIEVKGASEELK...
4,0.958,0.544,GMDEELEKYVKKLKEFLKKEGVNNVEIKIENGVLTIKMKGASKKVK...
5,0.948,0.471,GMDEELEKYVKKLKEFFKKKGINNIKIEIKDGELTIEMKGASKETI...
6,0.994,0.574,HMDEELEKYVEELKAFLEEKGIDNVEIKIEDGTLTITVKGASEELK...
7,0.904,0.559,GMDPELEKYVEELKAFLKEKGVTNVEIKIENGELTIKMKGASEEVK...
8,0.948,0.544,GKDPELEKYVKKLKEFLKKKGITNVKIEIKDGVLTITTNGASEELK...
9,0.951,0.559,GMDPELEKYVKELKKFLKEKGINNVKIKVENGKLTIETKGASEELK...


In [4]:
#@title ### Get amino acid probabilties from ProteinMPNN (optional)
mode = "unconditional" #@param ["unconditional", "conditional", "conditional_fix_pos"]
#@markdown - `unconditional` - P(sequence | structure)
#@markdown - `conditional` - P(sequence | structure, sequence)
#@markdown - `conditional_fix_pos` - P(sequence[not_fixed] | structure, sequence[fix_pos])
show = "all"
import plotly.express as px
from scipy.special import softmax
from colabdesign.mpnn.model import residue_constants
L = sum(mpnn_model._lengths)
fix_pos = mpnn_model._inputs.get("fix_pos",[])
free_pos = np.delete(np.arange(L),fix_pos)

if mode == "conditional":
  ar_mask = 1-np.eye(L)
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
  pdb_labels = None
elif mode == "conditional_fix_pos":
  assert "fix_pos" in mpnn_model._inputs, "no positions fixed"
  ar_mask = 1-np.eye(L)
  p = np.delete(np.arange(L),mpnn_model._inputs["fix_pos"])
  ar_mask[free_pos[:,None],free_pos[None,:]] = 0
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
  logits = logits[free_pos]
  pdb_labels = np.array([f"{i}_{c}" for c,i in zip(mpnn_model.pdb["idx"]["chain"], mpnn_model.pdb["idx"]["residue"])])
  pdb_labels = pdb_labels[free_pos]
else:
  ar_mask = np.zeros((L,L))
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
  pdb_labels = None

pssm = softmax(logits,-1)
np.savetxt("output/pssm.txt",pssm)

fig = px.imshow(np.array(pssm).T,
               labels=dict(x="positions", y="amino acids", color="probability"),
               y=residue_constants.restypes + ["X"],
               x=pdb_labels,
               zmin=0,
               zmax=1,
               template="simple_white",
              )
fig.update_xaxes(side="top")
fig.show()

In [5]:
# 生成固定的自回归顺序
L = mpnn_model._inputs["X"].shape[0]
decoding_order = np.arange(L)
np.random.seed(42)  # 固定随机种子
np.random.shuffle(decoding_order)
mpnn_model._inputs["decoding_order"] = decoding_order
mpnn_model._inputs["decoding_order"]

array([46, 16,  4,  9, 28, 41, 58,  5, 61, 12, 25, 65, 47,  0, 54, 55, 49,
        7, 42, 31, 36, 19, 45, 33, 48, 30, 13, 63, 40,  3, 17, 34,  8, 44,
        6, 56, 66, 15, 27, 26, 24, 67, 11, 32, 64, 50, 37, 29, 43, 53,  1,
       21,  2, 62, 39, 35, 52, 23, 59, 10, 22, 18, 57, 38, 20, 60, 14, 51])

In [6]:
mpnn_model._inputs

{'X': array([[[-15.11299992,   4.64099979,  12.53299999],
         [-15.45499992,   3.35299993,  11.85400009],
         [-14.52499962,   3.06800008,  10.6960001 ],
         [-14.89700031,   2.51900005,   9.6619997 ]],
 
        [[-13.27499962,   3.42000008,  10.93000031],
         [-12.23900032,   3.52200007,   9.92399979],
         [-11.12800026,   2.58100009,  10.33699989],
         [-10.68000031,   2.63400006,  11.48499966]],
 
        [[-10.74100018,   1.67499995,   9.44499969],
         [ -9.73499966,   0.662     ,   9.73999977],
         [ -8.42300034,   1.05700004,   9.07400036],
         [ -8.3039999 ,   0.991     ,   7.85500002]],
 
        [[ -7.43200016,   1.44799995,   9.87100029],
         [ -6.12799978,   1.79999995,   9.32199955],
         [ -5.59399986,   0.70499998,   8.40100002],
         [ -5.14300013,   0.977     ,   7.27899981]],
 
        [[ -5.64400005,  -0.54799998,   8.85400009],
         [ -5.07399988,  -1.62399995,   8.0539999 ],
         [ -5.88399982,  -1.8

In [12]:
# def custom_sample(mpnn_model, sampling_method='greedy', top_k=None, top_p=None, temperature=1.0):
#     I = mpnn_model._inputs.copy()
#     key = mpnn_model.key()
#     L = I["X"].shape[0]

#     # 使用固定的自回归顺序
#     if "decoding_order" not in I:
#         I["decoding_order"] = np.arange(L)
#         np.random.shuffle(I["decoding_order"])

#     # 调用自定义的 mpnn_sample 函数
#     O = custom_mpnn_sample(mpnn_model, I, key, sampling_method, top_k, top_p, temperature)

#     # 处理输出
#     O.update(mpnn_model._get_seq(O))
#     O.update(mpnn_model._get_score(I, O))
#     return O

def custom_sample(mpnn_model, sampling_method='greedy', temperature=1.0, top_k=None, top_p=None):
    # 复制模型输入
    I = copy.deepcopy(mpnn_model._inputs)

    # 生成随机数密钥
    key = mpnn_model.key()

    # 如果没有指定解码顺序，则生成固定的解码顺序
    if "decoding_order" not in I:
        L = I["X"].shape[0]
        np.random.seed(42)  # 固定随机种子以确保可重复性
        I["decoding_order"] = np.random.permutation(L)

    # 调用自定义的 mpnn_sample 函数
    output = custom_mpnn_sample(mpnn_model, I, key, sampling_method, temperature, top_k, top_p)

    # 处理输出
    output.update(mpnn_model._get_seq(output))
    output.update(mpnn_model._get_score(I, output))

    return output


import jax
import jax.numpy as jnp
import numpy as np
import copy

import jax
import jax.numpy as jnp
import haiku as hk
import numpy as np

from colabdesign.mpnn.utils import cat_neighbors_nodes, get_ar_mask

def custom_mpnn_sample(mpnn_model, I, key, sampling_method='greedy', temperature=1.0, top_k=None, top_p=None):
    """
    自定义的采样函数，支持不同的采样方法。

    参数：
    - mpnn_model: mk_mpnn_model 的实例
    - I: 输入字典，包括 'X', 'mask', 'residue_idx', 'chain_idx', 'decoding_order' 等
    - key: 随机数生成器的密钥
    - sampling_method: 采样方法，'greedy', 'temperature', 'top_k', 'top_p'
    - temperature: 温度参数，默认为 1.0
    - top_k: 用于 Top-K 采样的方法
    - top_p: 用于 Top-P 采样的方法
    """
    # 从 mpnn_model 中获取必要的参数
    model_params = mpnn_model._model.model.params
    encoder_layers = mpnn_model._model.model.encoder_layers
    decoder_layers = mpnn_model._model.model.decoder_layers
    W_e = mpnn_model._model.model.W_e
    W_s = mpnn_model._model.model.W_s
    W_out = mpnn_model._model.model.W_out
    features = mpnn_model._model.model.features

    # 初始化随机数生成器
    safe_key = mpnn_model._model.features.safe_key

    L = I["X"].shape[0]

    # 准备节点和边的嵌入
    E, E_idx = features(I)
    h_V = jnp.zeros((E.shape[0], E.shape[-1]))
    h_E = W_e(E)

    # 编码器部分
    mask_attend = jnp.take_along_axis(I["mask"][:, None] * I["mask"][None, :], E_idx, 1)
    for layer in encoder_layers:
        h_V, h_E = layer(h_V, h_E, E_idx, I["mask"], mask_attend)

    # 获取自回归掩码
    ar_mask = I.get("ar_mask", get_ar_mask(I["decoding_order"]))

    mask_attend = jnp.take_along_axis(ar_mask, E_idx, 1)
    mask_1D = I["mask"][:, None]
    mask_bw = mask_1D * mask_attend
    mask_fw = mask_1D * (1 - mask_attend)

    h_EX_encoder = cat_neighbors_nodes(jnp.zeros_like(h_V), h_E, E_idx)
    h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
    h_EXV_encoder = mask_fw[..., None] * h_EXV_encoder

    # 定义解码器的前向函数
    def fwd(x, t, key):
        h_EXV_encoder_t = h_EXV_encoder[t]
        E_idx_t = E_idx[t]
        mask_t = I["mask"][t]
        mask_bw_t = mask_bw[t]
        h_ES_t = cat_neighbors_nodes(x["h_S"], h_E[t], E_idx_t)

        # 解码器循环
        for l, layer in enumerate(decoder_layers):
            h_V = x["h_V"][l]
            h_ESV_decoder_t = cat_neighbors_nodes(h_V, h_ES_t, E_idx_t)
            h_ESV_t = mask_bw_t[..., None] * h_ESV_decoder_t + h_EXV_encoder_t
            h_V_t = layer(h_V[t], h_ESV_t, mask_V=mask_t)
            # 更新解码器状态
            x["h_V"] = x["h_V"].at[l + 1, t].set(h_V_t)

        # 计算当前步的 logits
        logits_t = W_out(h_V_t)
        x["logits"] = x["logits"].at[t].set(logits_t)

        # 采样步骤
        # 添加偏置
        if "bias" in I:
            logits_t += I["bias"][t]

        # 应用温度
        logits_t = logits_t / temperature

        # 处理同源寡聚体的情况
        if mpnn_model._tied_lengths:
            logits_t = logits_t.mean(0, keepdims=True)

        # 根据采样方法进行采样
        if sampling_method == 'greedy':
            # 贪心搜索：选择概率最大的氨基酸
            S_t = jax.nn.one_hot(jnp.argmax(logits_t[..., :20], axis=-1), 21)
        elif sampling_method == 'temperature':
            # 温度采样：按照概率分布采样
            probs = jax.nn.softmax(logits_t[..., :20], axis=-1)
            S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t[..., :20]), 21)
        elif sampling_method == 'top_k':
            # Top-K 采样
            logits_t_filtered = apply_top_k(logits_t[..., :20], top_k)
            probs = jax.nn.softmax(logits_t_filtered, axis=-1)
            S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t_filtered), 21)
        elif sampling_method == 'top_p':
            # Top-P 采样
            logits_t_filtered = apply_top_p(key, logits_t[..., :20], top_p)
            probs = jax.nn.softmax(logits_t_filtered, axis=-1)
            S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t_filtered), 21)
        else:
            # 默认采样（Gumbel-Max Trick）
            logits_t += jax.random.gumbel(key, logits_t.shape)
            logits_t = logits_t.mean(0, keepdims=True)
            S_t = jax.nn.one_hot(jnp.argmax(logits_t[..., :20], axis=-1), 21)

        # 更新状态
        x["h_S"] = x["h_S"].at[t].set(W_s(S_t))
        x["S"] = x["S"].at[t].set(S_t)
        return x, None

    # 初始化状态
    X = {
        "h_S": jnp.zeros_like(h_V),
        "h_V": jnp.array([h_V] + [jnp.zeros_like(h_V)] * len(decoder_layers)),
        "S": jnp.zeros((L, 21)),
        "logits": jnp.zeros((L, 21)),
    }

    # 扫描解码顺序，逐步生成序列
    t = I["decoding_order"]
    if t.ndim == 1:
        t = t[:, None]
    XS = {"t": t, "key": jax.random.split(key, t.shape[0])}
    X = hk.scan(lambda x, xs: fwd(x, xs["t"], xs["key"]), X, XS)[0]

    return {"S": X["S"], "logits": X["logits"], "decoding_order": t}


def apply_top_k(logits, k):
    """
    对 logits 应用 Top-K 策略，只保留前 K 大的 logits，其余设为负无穷。
    """
    values, indices = jax.lax.top_k(logits, k)
    min_values = values[..., -1, None]
    logits_filtered = jnp.where(logits >= min_values, logits, -1e10)
    return logits_filtered

def apply_top_p(key, logits, p):
    """
    对 logits 应用 Top-P（核）采样策略。
    """
    probs = jax.nn.softmax(logits, axis=-1)
    sorted_probs = jnp.sort(probs, axis=-1)[..., ::-1]
    cumulative_probs = jnp.cumsum(sorted_probs, axis=-1)
    cutoff = cumulative_probs >= p
    indices_to_remove = jnp.roll(cutoff, shift=1, axis=-1)
    indices_to_remove = indices_to_remove.at[..., 0].set(False)
    sorted_indices = jnp.argsort(-probs, axis=-1)
    logits_filtered = logits.at[sorted_indices].set(jnp.where(indices_to_remove, -1e10, logits))
    return logits_filtered



In [15]:
from colabdesign.mpnn.modules import ProteinMPNN

In [17]:
def custom_sample(mpnn_model, model_instance, model_params, sampling_method='greedy', temperature=1.0, top_k=None, top_p=None):
    # 复制模型输入
    I = copy.deepcopy(mpnn_model._inputs)

    # 生成随机数密钥
    key = mpnn_model.key()

    # 如果没有指定解码顺序，则生成固定的解码顺序
    if "decoding_order" not in I:
        L = I["X"].shape[0]
        np.random.seed(42)  # 固定随机种子以确保可重复性
        I["decoding_order"] = np.random.permutation(L)

    # 调用自定义的 mpnn_sample 函数
    output = custom_mpnn_sample(model_instance, model_params, I, key, sampling_method, temperature, top_k, top_p)

    # 处理输出
    output.update(mpnn_model._get_seq(output))
    output.update(mpnn_model._get_score(I, output))

    return output

def custom_mpnn_sample(model_instance, model_params, I, key, sampling_method='greedy', temperature=1.0, top_k=None, top_p=None):
    """
    自定义的采样函数，支持不同的采样方法。

    参数：
    - model_instance: ProteinMPNN 的实例
    - model_params: 模型参数
    - I: 输入字典，包括 'X', 'mask', 'residue_idx', 'chain_idx', 'decoding_order' 等
    - key: 随机数生成器的密钥
    - sampling_method: 采样方法，'greedy', 'temperature', 'top_k', 'top_p'
    - temperature: 温度参数，默认为 1.0
    - top_k: 用于 Top-K 采样的方法
    - top_p: 用于 Top-P 采样的方法
    """
    # 从 model_instance 中获取必要的属性
    encoder_layers = model_instance.encoder_layers
    decoder_layers = model_instance.decoder_layers
    W_e = model_instance.W_e
    W_s = model_instance.W_s
    W_out = model_instance.W_out
    features = model_instance.features

    # 初始化随机数生成器
    safe_key = features.safe_key

    L = I["X"].shape[0]

    # 准备节点和边的嵌入
    E, E_idx = features(I)
    h_V = jnp.zeros((E.shape[0], E.shape[-1]))
    h_E = W_e(E)

    # 编码器部分
    mask_attend = jnp.take_along_axis(I["mask"][:, None] * I["mask"][None, :], E_idx, 1)
    for layer in encoder_layers:
        h_V, h_E = layer(h_V, h_E, E_idx, I["mask"], mask_attend)

    # 获取自回归掩码
    ar_mask = I.get("ar_mask", get_ar_mask(I["decoding_order"]))

    mask_attend = jnp.take_along_axis(ar_mask, E_idx, 1)
    mask_1D = I["mask"][:, None]
    mask_bw = mask_1D * mask_attend
    mask_fw = mask_1D * (1 - mask_attend)

    h_EX_encoder = cat_neighbors_nodes(jnp.zeros_like(h_V), h_E, E_idx)
    h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
    h_EXV_encoder = mask_fw[..., None] * h_EXV_encoder

    # 定义解码器的前向函数
    def fwd(x, t, key):
        h_EXV_encoder_t = h_EXV_encoder[t]
        E_idx_t = E_idx[t]
        mask_t = I["mask"][t]
        mask_bw_t = mask_bw[t]
        h_ES_t = cat_neighbors_nodes(x["h_S"], h_E[t], E_idx_t)

        # 解码器循环
        for l, layer in enumerate(decoder_layers):
            h_V = x["h_V"][l]
            h_ESV_decoder_t = cat_neighbors_nodes(h_V, h_ES_t, E_idx_t)
            h_ESV_t = mask_bw_t[..., None] * h_ESV_decoder_t + h_EXV_encoder_t
            h_V_t = layer(h_V[t], h_ESV_t, mask_V=mask_t)
            # 更新解码器状态
            x["h_V"] = x["h_V"].at[l + 1, t].set(h_V_t)

        # 计算当前步的 logits
        logits_t = W_out(h_V_t)
        x["logits"] = x["logits"].at[t].set(logits_t)

        # 采样步骤
        # 添加偏置
        if "bias" in I:
            logits_t += I["bias"][t]

        # 应用温度
        logits_t = logits_t / temperature

        # 根据采样方法进行采样
        if sampling_method == 'greedy':
            # 贪心搜索：选择概率最大的氨基酸
            S_t = jax.nn.one_hot(jnp.argmax(logits_t[..., :20], axis=-1), 21)
        elif sampling_method == 'temperature':
            # 温度采样：按照概率分布采样
            probs = jax.nn.softmax(logits_t[..., :20], axis=-1)
            S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t[..., :20]), 21)
        elif sampling_method == 'top_k':
            # Top-K 采样
            logits_t_filtered = apply_top_k(logits_t[..., :20], top_k)
            S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t_filtered), 21)
        elif sampling_method == 'top_p':
            # Top-P 采样
            logits_t_filtered = apply_top_p(logits_t[..., :20], top_p)
            S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t_filtered), 21)
        else:
            # 默认采样（Gumbel-Max Trick）
            logits_t += jax.random.gumbel(key, logits_t.shape)
            S_t = jax.nn.one_hot(jnp.argmax(logits_t[..., :20], axis=-1), 21)

        # 更新状态
        x["h_S"] = x["h_S"].at[t].set(W_s(S_t))
        x["S"] = x["S"].at[t].set(S_t)
        return x, None

    # 初始化状态
    X = {
        "h_S": jnp.zeros_like(h_V),
        "h_V": jnp.array([h_V] + [jnp.zeros_like(h_V)] * len(decoder_layers)),
        "S": jnp.zeros((L, 21)),
        "logits": jnp.zeros((L, 21)),
    }

    # 扫描解码顺序，逐步生成序列
    t = I["decoding_order"]
    if t.ndim == 1:
        t = t[:, None]
    XS = {"t": t, "key": jax.random.split(key, t.shape[0])}

    # 使用 hk.scan 扫描
    def scan_fn(x, xs):
        return fwd(x, xs["t"], xs["key"])[0], None

    X, _ = hk.scan(scan_fn, X, XS)

    return {"S": X["S"], "logits": X["logits"], "decoding_order": t}


In [18]:
# 假设您已经加载了 mpnn_model，并且准备好了 PDB 文件
# 如果还没有加载，请按照以下步骤加载模型和准备输入

# # 加载 PDB 文件（以 Top7 蛋白为例）
# pdb_path = get_pdb('1QYS')  # Top7 的 PDB 编号是 1QYS

# 准备模型输入
mpnn_model.prep_inputs(pdb_filename=pdb_path)

# 获取模型配置和参数
config = mpnn_model._model.config
model_params = mpnn_model._model.params

# 创建 ProteinMPNN 实例
from colabdesign.mpnn.modules import ProteinMPNN
model_instance = ProteinMPNN(**config)

# 测试不同的采样方法
sampling_methods = ['greedy', 'temperature', 'top_k', 'top_p']
temperature = 0.8
top_k = 5
top_p = 0.9

results = []

for method in sampling_methods:
    if method == 'greedy':
        output = custom_sample(mpnn_model, sampling_method='greedy')
    elif method == 'temperature':
        output = custom_sample(mpnn_model, sampling_method='temperature', temperature=temperature)
    elif method == 'top_k':
        output = custom_sample(mpnn_model, sampling_method='top_k', temperature=temperature, top_k=top_k)
    elif method == 'top_p':
        output = custom_sample(mpnn_model, sampling_method='top_p', temperature=temperature, top_p=top_p)

    seq = output['seq']
    score = output['score']
    seqid = output['seqid']

    print(f"Method: {method}")
    print(f"Sequence: {seq}")
    print(f"Score: {score}")
    print(f"SeqID: {seqid}")
    print("-" * 50)

    # 保存结果
    results.append({
        'method': method,
        'seq': seq,
        'score': score,
        'seqid': seqid
    })


ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.

In [19]:
from colabdesign.mpnn.modules import ProteinMPNN

class CustomProteinMPNN(ProteinMPNN):
    def custom_sample(self, I, sampling_method='greedy', temperature=1.0, top_k=None, top_p=None):
        # 与之前的 custom_mpnn_sample 函数类似，但需要适应 Haiku 的要求
        # 从 model_instance 中获取必要的属性
        encoder_layers = self.encoder_layers
        decoder_layers = self.decoder_layers
        W_e = self.W_e
        W_s = self.W_s
        W_out = self.W_out
        features = self.features

        L = I["X"].shape[0]

        # 准备节点和边的嵌入
        E, E_idx = features(I)
        h_V = jnp.zeros((E.shape[0], E.shape[-1]))
        h_E = W_e(E)

        # 编码器部分
        mask_attend = jnp.take_along_axis(I["mask"][:, None] * I["mask"][None, :], E_idx, 1)
        for layer in encoder_layers:
            h_V, h_E = layer(h_V, h_E, E_idx, I["mask"], mask_attend)

        # 获取自回归掩码
        ar_mask = I.get("ar_mask", get_ar_mask(I["decoding_order"]))

        mask_attend = jnp.take_along_axis(ar_mask, E_idx, 1)
        mask_1D = I["mask"][:, None]
        mask_bw = mask_1D * mask_attend
        mask_fw = mask_1D * (1 - mask_attend)

        h_EX_encoder = cat_neighbors_nodes(jnp.zeros_like(h_V), h_E, E_idx)
        h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
        h_EXV_encoder = mask_fw[..., None] * h_EXV_encoder

        # 定义解码器的前向函数
        def fwd(x, t):
            key = hk.next_rng_key()
            h_EXV_encoder_t = h_EXV_encoder[t]
            E_idx_t = E_idx[t]
            mask_t = I["mask"][t]
            mask_bw_t = mask_bw[t]
            h_ES_t = cat_neighbors_nodes(x["h_S"], h_E[t], E_idx_t)

            # 解码器循环
            for l, layer in enumerate(decoder_layers):
                h_V = x["h_V"][l]
                h_ESV_decoder_t = cat_neighbors_nodes(h_V, h_ES_t, E_idx_t)
                h_ESV_t = mask_bw_t[..., None] * h_ESV_decoder_t + h_EXV_encoder_t
                h_V_t = layer(h_V[t], h_ESV_t, mask_V=mask_t)
                # 更新解码器状态
                x["h_V"] = x["h_V"].at[l + 1, t].set(h_V_t)

            # 计算当前步的 logits
            logits_t = W_out(h_V_t)
            x["logits"] = x["logits"].at[t].set(logits_t)

            # 采样步骤
            # 添加偏置
            if "bias" in I:
                logits_t += I["bias"][t]

            # 应用温度
            logits_t = logits_t / temperature

            # 根据采样方法进行采样
            if sampling_method == 'greedy':
                # 贪心搜索：选择概率最大的氨基酸
                S_t = jax.nn.one_hot(jnp.argmax(logits_t[..., :20], axis=-1), 21)
            elif sampling_method == 'temperature':
                # 温度采样：按照概率分布采样
                probs = jax.nn.softmax(logits_t[..., :20], axis=-1)
                S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t[..., :20]), 21)
            elif sampling_method == 'top_k':
                # Top-K 采样
                logits_t_filtered = apply_top_k(logits_t[..., :20], top_k)
                S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t_filtered), 21)
            elif sampling_method == 'top_p':
                # Top-P 采样
                logits_t_filtered = apply_top_p(logits_t[..., :20], top_p)
                S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t_filtered), 21)
            else:
                # 默认采样（Gumbel-Max Trick）
                logits_t += jax.random.gumbel(key, logits_t.shape)
                S_t = jax.nn.one_hot(jnp.argmax(logits_t[..., :20], axis=-1), 21)

            # 更新状态
            x["h_S"] = x["h_S"].at[t].set(W_s(S_t))
            x["S"] = x["S"].at[t].set(S_t)
            return x

        # 初始化状态
        X = {
            "h_S": jnp.zeros_like(h_V),
            "h_V": jnp.array([h_V] + [jnp.zeros_like(h_V)] * len(decoder_layers)),
            "S": jnp.zeros((L, 21)),
            "logits": jnp.zeros((L, 21)),
        }

        # 扫描解码顺序，逐步生成序列
        t_list = I["decoding_order"]
        if t_list.ndim == 1:
            t_list = t_list[:, None]
        for t in t_list:
            X = fwd(X, t[0])

        return {"S": X["S"], "logits": X["logits"], "decoding_order": I["decoding_order"]}

def custom_sample_fn(I, sampling_method='greedy', temperature=1.0, top_k=None, top_p=None):
    # 定义模型
    model = CustomProteinMPNN(**config)

    # 执行自定义采样
    return model.custom_sample(I, sampling_method, temperature, top_k, top_p)

# 使用 hk.transform 转换函数
custom_sample_transformed = hk.transform(custom_sample_fn)


In [23]:
I

{'X': array([[[-4.52199984, 18.30599976, 17.4090004 ],
         [-3.06100011, 18.22800064, 17.12199974],
         [-2.66400003, 16.99300003, 16.3239994 ],
         [-3.5150001 , 16.30599976, 15.75399971]],
 
        [[-1.36000001, 16.7140007 , 16.29700089],
         [-0.82300001, 15.56200027, 15.56799984],
         [-0.72100002, 14.30900002, 16.43300056],
         [ 0.091     , 14.22200012, 17.35499954]],
 
        [[-1.56700003, 13.34200001, 16.10499954],
         [-1.63199997, 12.07800007, 16.81399918],
         [-0.90700001, 10.94999981, 16.06599998],
         [-1.30599999, 10.56700039, 14.96500015]],
 
        ...,
 
        [[ 3.36899996, -1.46899998,  6.89699984],
         [ 4.34499979, -2.18099999,  6.10599995],
         [ 4.66499996, -3.56699991,  6.62900019],
         [ 4.43400002, -3.88599992,  7.79899979]],
 
        [[ 5.18400002, -4.40199995,  5.7329998 ],
         [ 5.60900021, -5.75600004,  6.06899977],
         [ 7.0710001 , -5.89799976,  5.65899992],
         [ 7.49800

In [20]:
# 准备输入数据 I 和随机数密钥 key
I = copy.deepcopy(mpnn_model._inputs)
key = mpnn_model.key()

# 获取模型参数
params = mpnn_model._model.params

# 执行采样
output = custom_sample_transformed.apply(params, key, I, sampling_method='greedy')

# 处理输出
output.update(mpnn_model._get_seq(output))
output.update(mpnn_model._get_score(I, output))

# 显示结果
print("Sequence:", output['seq'])
print("Score:", output['score'])


ValueError: Unable to retrieve parameter 'w' for module 'custom_protein_mpnn/~/protein_features/~/positional_encodings/~/embedding_linear' All parameters must be created as part of `init`.

In [41]:
import jax
import jax.numpy as jnp
import haiku as hk
import numpy as np
import copy
import pandas as pd
# from colabdesign.mpnn.utils import cat_neighbors_nodes, get_ar_mask
# 定义辅助函数
def apply_top_k(logits, k):
    values, indices = jax.lax.top_k(logits, k)
    min_values = values[..., -1, None]
    logits_filtered = jnp.where(logits >= min_values, logits, -1e10)
    return logits_filtered

def apply_top_p(logits, p):
    sorted_logits = jnp.sort(logits, axis=-1)[..., ::-1]
    sorted_indices = jnp.argsort(-logits, axis=-1)
    cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits, axis=-1), axis=-1)
    cutoff = cumulative_probs > p
    indices_to_remove = cutoff & ~jnp.pad(cutoff[..., :-1], ((0,0),(1,0)), constant_values=False)
    indices_to_remove = jnp.take_along_axis(sorted_indices, indices_to_remove, axis=-1)
    logits_filtered = logits.at[indices_to_remove].set(-1e10)
    return logits_filtered

# def get_ar_mask(decoding_order):
#     L = len(decoding_order)
#     ar_mask = np.zeros((L, L))
#     for i in range(L):
#         ar_mask[decoding_order[i], decoding_order[:i+1]] = 1
#     return ar_mask

# def cat_neighbors_nodes(h_nodes, h_edges, neighbor_idx):
#     h_neighbors = jnp.take(h_nodes, neighbor_idx, axis=0)
#     h = jnp.concatenate([h_neighbors, h_edges], axis=-1)
#     return h

def gather_nodes(nodes, neighbor_idx):
  # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
  # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
  neighbors_flat = neighbor_idx.reshape([neighbor_idx[None].shape[0], -1])
  neighbors_flat = jnp.tile(jnp.expand_dims(neighbors_flat, -1),[1, 1, nodes[None].shape[2]])
  # Gather and re-pack
  neighbor_features = jnp.take_along_axis(nodes[None], neighbors_flat, 1)
  neighbor_features = neighbor_features.reshape(list(neighbor_idx[None].shape[:3]) + [-1])
  return neighbor_features[0]

def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
  h_nodes = gather_nodes(h_nodes, E_idx)[None]
  h_nn = jnp.concatenate([h_neighbors[None], h_nodes], -1)
  return h_nn[0]

def get_ar_mask(order):
  '''compute autoregressive mask, given order of positions'''
  order = order.flatten()
  L = order.shape[-1]
  tri = jnp.tri(L, k=-1)
  idx = order.argsort()
  ar_mask = tri[idx,:][:,idx]
  return ar_mask

# 导入模型
from colabdesign.mpnn.modules import ProteinMPNN

# 定义自定义的 ProteinMPNN 类
class CustomProteinMPNN(ProteinMPNN):
    def custom_sample(self, I, sampling_method='greedy', temperature=1.0, top_k=None, top_p=None):
        # 与之前的 custom_mpnn_sample 函数类似，但需要适应 Haiku 的要求
        encoder_layers = self.encoder_layers
        decoder_layers = self.decoder_layers
        W_e = self.W_e
        W_s = self.W_s
        W_out = self.W_out
        features = self.features

        L = I["X"].shape[0]

        if "decoding_order" not in I:
            L = I["X"].shape[0]
            I["decoding_order"] = np.random.permutation(L)
        # print(I)
        # 准备节点和边的嵌入
        E, E_idx = features(I)
        h_V = jnp.zeros((E.shape[0], E.shape[-1]))
        h_E = W_e(E)

        # 编码器部分
        mask_attend = jnp.take_along_axis(I["mask"][:, None] * I["mask"][None, :], E_idx, 1)
        for layer in encoder_layers:
            h_V, h_E = layer(h_V, h_E, E_idx, I["mask"], mask_attend)

        # 获取自回归掩码
        ar_mask = I.get("ar_mask", get_ar_mask(I["decoding_order"]))

        mask_attend = jnp.take_along_axis(ar_mask, E_idx, 1)
        mask_1D = I["mask"][:, None]
        mask_bw = mask_1D * mask_attend
        mask_fw = mask_1D * (1 - mask_attend)

        h_EX_encoder = cat_neighbors_nodes(jnp.zeros_like(h_V), h_E, E_idx)
        h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
        h_EXV_encoder = mask_fw[..., None] * h_EXV_encoder

        # print the shapes of X, h_V, h_E,  h_EX_encoder, h_EXV_encoder

        # 定义解码器的前向函数
        # def fwd(x, t):
        #     key = hk.next_rng_key()
        #     h_EXV_encoder_t = h_EXV_encoder[t]
        #     E_idx_t = E_idx[t]
        #     mask_t = I["mask"][t]
        #     mask_bw_t = mask_bw[t]
        #     h_ES_t = cat_neighbors_nodes(x["h_S"], h_E[t], E_idx_t)

        #     # 解码器循环
        #     for l, layer in enumerate(decoder_layers):
        #         h_V = x["h_V"][l]
        #         h_ESV_decoder_t = cat_neighbors_nodes(h_V, h_ES_t, E_idx_t)
        #         h_ESV_t = mask_bw_t[..., None] * h_ESV_decoder_t + h_EXV_encoder_t
        #         h_V_t = layer(h_V[t], h_ESV_t, mask_V=mask_t)
        #         # 更新解码器状态
        #         x["h_V"] = x["h_V"].at[l + 1, t].set(h_V_t)

        #     # 计算当前步的 logits
        #     logits_t = W_out(h_V_t)
        #     x["logits"] = x["logits"].at[t].set(logits_t)
        def fwd(x, t):
            key = hk.next_rng_key()
            h_EXV_encoder_t = h_EXV_encoder[t]
            E_idx_t = E_idx[t]
            mask_t = I["mask"][t:t+1]  # 扩展维度
            mask_bw_t = mask_bw[t:t+1]  # 扩展维度
            h_E_t = h_E[t:t+1]  # 扩展维度
            # print the shape of input of cat_neighbors_nodes
            print("h_S shape:", x["h_S"].shape)
            print("h_E_t shape:", h_E_t.shape)
            print("E_idx_t shape:", E_idx_t.shape)

            h_ES_t = cat_neighbors_nodes(x["h_S"], h_E_t, E_idx_t)

            # 解码器循环
            for l, layer in enumerate(decoder_layers):
                h_V = x["h_V"][l]
                h_V_t = h_V[t:t+1]  # 扩展维度
                h_ESV_decoder_t = cat_neighbors_nodes(h_V, h_ES_t, E_idx_t)
                h_ESV_t = mask_bw_t[..., None] * h_ESV_decoder_t + h_EXV_encoder_t
                h_V_t = layer(h_V_t, h_ESV_t, mask_V=mask_t)
                # 更新解码器状态
                x["h_V"] = x["h_V"].at[l + 1, t].set(h_V_t[0])
            # 计算当前步的 logits
            logits_t = W_out(h_V_t[0])
            x["logits"] = x["logits"].at[t].set(logits_t)
            # 采样步骤
            # 添加偏置
            if "bias" in I:
                logits_t += I["bias"][t]

            # 应用温度
            logits_t = logits_t / temperature

            # 根据采样方法进行采样
            if sampling_method == 'greedy':
                # 贪心搜索：选择概率最大的氨基酸
                S_t = jax.nn.one_hot(jnp.argmax(logits_t[..., :20], axis=-1), 21)
            elif sampling_method == 'temperature':
                # 温度采样：按照概率分布采样
                probs = jax.nn.softmax(logits_t[..., :20], axis=-1)
                S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t[..., :20]), 21)
            elif sampling_method == 'top_k':
                # Top-K 采样
                logits_t_filtered = apply_top_k(logits_t[..., :20], top_k)
                S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t_filtered), 21)
            elif sampling_method == 'top_p':
                # Top-P 采样
                logits_t_filtered = apply_top_p(logits_t[..., :20], top_p)
                S_t = jax.nn.one_hot(jax.random.categorical(key, logits=logits_t_filtered), 21)
            else:
                # 默认采样（Gumbel-Max Trick）
                logits_t += jax.random.gumbel(key, logits_t.shape)
                S_t = jax.nn.one_hot(jnp.argmax(logits_t[..., :20], axis=-1), 21)

            # 更新状态
            x["h_S"] = x["h_S"].at[t].set(W_s(S_t))
            x["S"] = x["S"].at[t].set(S_t)
            return x

        # 初始化状态
        X = {
            "h_S": jnp.zeros_like(h_V),
            "h_V": jnp.array([h_V] + [jnp.zeros_like(h_V)] * len(decoder_layers)),
            "S": jnp.zeros((L, 21)),
            "logits": jnp.zeros((L, 21)),
        }

        # 扫描解码顺序，逐步生成序列
        t_list = I["decoding_order"]
        if t_list.ndim == 1:
            t_list = t_list[:, None]
        for t in t_list:
            X = fwd(X, t[0])

        return {"S": X["S"], "logits": X["logits"], "decoding_order": I["decoding_order"]}

# 定义并转换采样函数
def custom_sample_fn(I, sampling_method='greedy', temperature=1.0, top_k=None, top_p=None):
    # 定义模型
    model = CustomProteinMPNN(**config)
    # 执行自定义采样
    return model.custom_sample(I, sampling_method, temperature, top_k, top_p)

custom_sample_transformed = hk.transform(custom_sample_fn)

# 初始化参数
init_key = jax.random.PRNGKey(42)

# 生成固定的自回归顺序
mpnn_model.prep_inputs(pdb_filename=pdb_path)
L = mpnn_model._inputs["X"].shape[0]
I = copy.deepcopy(mpnn_model._inputs)
params = custom_sample_transformed.init(init_key, I)

# 获取原始模型的参数
original_params = mpnn_model._model.params

# 更新参数
def update_params(target, source):
    for key in target:
        if key in source:
            if isinstance(target[key], dict):
                update_params(target[key], source[key])
            else:
                target[key] = source[key]
        else:
            print(f"Key {key} not found in source parameters.")

update_params(params, original_params)

# 执行采样
apply_key = jax.random.PRNGKey(0)
output = custom_sample_transformed.apply(params, apply_key, I, sampling_method='greedy')

# 处理输出
output.update(mpnn_model._get_seq(output))
output.update(mpnn_model._get_score(I, output))

# 显示结果
print("Sequence:", output['seq'])
print("Score:", output['score'])


h_S shape: (92, 128)
h_E_t shape: (1, 48, 128)
E_idx_t shape: (48,)


TypeError: Cannot concatenate arrays with different numbers of dimensions: got (1, 1, 48, 128), (1, 48, 128).

In [36]:
I.keys()

dict_keys(['X', 'mask', 'S', 'residue_idx', 'chain_idx', 'lengths', 'bias', 'decoding_order', 'offset'])

In [37]:
I['decoding_order'].shape

(92,)

In [None]:
#@title Run AlphaFold Prediction on ProteinMPNN sequences (optional)
#@markdown ###AlphaFold Options
num_models = 1 #@param ["1","2","3","4","5"] {type:"raw"}
num_recycles = 1 #@param ["0","1","2","3"] {type:"raw"}
use_multimer = False #@param {type:"boolean"}
use_templates = False #@param {type:"boolean"}
rm_template_interchain = False #@param {type:"boolean"}
if not os.path.isdir("params"):
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

# where pdb files will be save:
if not os.path.isdir("output/all_pdb"):
  os.system("mkdir output/all_pdb")
else:
  os.system("rm output/all_pdb/*")

from colabdesign.af import mk_af_model
af_args = [pdb_path, chains, homooligomer,
           use_multimer, use_templates]
if "af_arg_current" not in dir() or af_args != af_arg_current:
  af_model = mk_af_model(use_multimer=use_multimer,
                         use_templates=use_templates,
                         best_metric="dgram_cce")
  af_model.prep_inputs(pdb_path,chains,homooligomer=homooligomer)
  af_arg_current = [x for x in af_args]

af_model.restart()
af_model.set_opt("template", rm_ic=rm_template_interchain)

with tqdm.notebook.tqdm(total=out["S"].shape[0], bar_format=TQDM_BAR_FORMAT) as pbar:
  for n,S in enumerate(out["S"]):
    seq = S[:af_model._len].argmax(-1)
    af_model.predict(seq=seq,
                    num_recycles=num_recycles,
                    num_models=num_models,
                    verbose=False)
    (rmsd, ptm, plddt) = (af_model.aux["log"][k] for k in ["rmsd","ptm","plddt"])
    af_model.aux["log"]["composite"] = ptm * plddt
    af_model._save_results(save_best=True, verbose=False)
    af_model.save_current_pdb(f"output/all_pdb/n{n}.pdb")
    af_model._k += 1
    pbar.update(1)

af_model.save_pdb(f"output/best.pdb")

data = []
labels = ["dgram_cce","plddt","ptm","i_ptm","rmsd","composite","mpnn","seqid","seq"]
for n,af in enumerate(af_model._tmp["log"]):
  data.append([af["dgram_cce"],
               af["plddt"],
               af["ptm"],
               af["i_ptm"],
               af["rmsd"],
               af["composite"],
               out["score"][n],
               out["seqid"][n],
               out["seq"][n]])

df = pd.DataFrame(data, columns=labels)
df.to_csv('output/alphafold_results.csv')
data_table.DataTable(df.sort_values("dgram_cce").round(3))
#@markdown Note: designed pdbs are saved to `output/all_pdb/`

In [None]:
#@title download predictions (optional)
from google.colab import files
os.system(f"zip -r output.zip output/")
files.download(f'output.zip')

In [None]:
#@title display protein (optional) {run: "auto"}
show_best = True #@param {type:"boolean"}
show_idx = 0 #@param {type:"integer"}
#@markdown - Enter index of protein to show, if `show_best` is disabled.
#@markdown - Note: these are NOT sorted and correspond to
#@markdown the index in pandas dataframe above.
color = "pLDDT" #@param ["chain", "pLDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
color_HP = False #@param {type:"boolean"}
animate = True #@param {type:"boolean"}
#@markdown - if `num_models` > 1, will iterate through the models when `animate` is enabled.
if not show_best:
  pdb_str = pdb_to_string(f"output/all_pdb/n{show_idx}.pdb")
else:
  pdb_str = None
af_model.plot_pdb(show_sidechains=show_sidechains,
                  show_mainchains=show_mainchains,
                  color=color, color_HP=color_HP,
                  animate=animate, pdb_str=pdb_str)

In [None]:
#@title animate (optional)
#@markdown Note: animation frames are sorted worst to best design
def sort_traj(self, metric="dgram_cce"):
  if metric in ["plddt","ptm","i_ptm","seqid","composite"]:
    metric_higher_better = True
  else:
    metric_higher_better = False
  num = len(self._tmp["traj"]["seq"])
  log = self._tmp["log"][-num:]
  if metric in log[0]:
    n = np.array([x[metric] for x in log]).argsort()
    if metric_higher_better: n = n[::-1]
    sub_traj = {k:[v[m] for m in n] for k,v in self._tmp["traj"].items()}
    return sub_traj
  else:
    return None

sub_traj= sort_traj(af_model)

color_by = "plddt" #@param ["chain", "plddt", "rainbow"]
dpi = 100 #@param {type:"integer"}
HTML(af_model.animate(traj={k:v[::-1] for k,v in sub_traj.items()}, color_by=color_by, dpi=dpi))
