<a href="https://colab.research.google.com/github/PDNALab/AlphaFolding/blob/main/alphafolding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Iterative structure predictions with AlphaFold2 using only sequence
- running iterative predictions with AlphaFold2 (monomer model 1,2) 
- visualization of structure predictions. 
- for predictions that succesfully find the native state, the structure predictions before native state can possibly resemble protein folding intermediates.


In [1]:
#@title setup {"vertical-output":true,"form-width":"50%","display-mode":"form"}
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/ccccclw/ColabDesign.git")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from Bio.PDB import *
import os, re
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.shared.protein import _np_get_cb
from colabdesign.shared.plot import plot_pseudo_3D, make_animation, show_pdb
import pickle
import json
from colabdesign import af
from google.colab import files
import numpy as np
from IPython.display import HTML
import jax.numpy as jnp
import jax
from scipy.special import softmax
import sys
import tqdm.notebook
import argparse
import matplotlib.pyplot as plt
import os
import requests
import zipfile
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import shutil
import zipfile
import os
import threading
import plotly.io as pio
pio.renderers.default = 'colab'

from concurrent.futures import ThreadPoolExecutor, as_completed

!pip install mdtraj
import glob
import mdtraj as md
import matplotlib.cm as cm
import matplotlib.colors as mcolors

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

##util functions
def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"


def get_dgram(positions, num_bins=39, min_bin=3.25, max_bin=50.75):
  atom_idx = residue_constants.atom_order
  atoms = {k:positions[...,atom_idx[k],:] for k in ["N","CA","C"]}
  cb = _np_get_cb(**atoms, use_jax=False)
  dist2 = np.square(cb[None,:] - cb[:,None]).sum(-1,keepdims=True)
  lower_breaks = np.linspace(min_bin, max_bin, num_bins)
  lower_breaks = np.square(lower_breaks)
  upper_breaks = np.concatenate([lower_breaks[1:],np.array([1e8], dtype=jnp.float32)], axis=-1)
  return ((dist2 > lower_breaks) * (dist2 < upper_breaks)).astype(float)

def sample_gumbel(shape, eps=1e-10):
  """Sample from Gumbel(0, 1)"""
  U = np.random.uniform(size=shape)
  return -np.log(-np.log(U + eps) + eps)

def sample_uniform(shape, eps=1e-10):
  """Sample from Uniform(0, 1)"""
  U = np.random.uniform(size=shape)
  return U + eps

from colabdesign.af.alphafold.common import residue_constants
def xyz_atom37(pdb_file):
  """
  Convert atom coordinates [num_atom, 3] from xyz read from file such as pdb to atom37 format.
  """
  atom37_order = residue_constants.atom_order
  parser = PDBParser()
  structure = parser.get_structure("A", pdb_file)
  atoms = list(structure.get_atoms())
  length = len(list(structure.get_residues()))
  atom37_coord = np.zeros((length, 37, 3))

  for atom in atoms:
    atom37_index = atom37_order[atom.get_name()]
    residue_index = atom.get_parent().id[1]
    atom37_coord[residue_index-1][atom37_index] = atom.get_coord()
  return atom37_coord

def sequence_to_one_hot(sequence):
    """
    Convert a sequence string into a one-hot encoding matrix of shape (N, 20),
    where N is the number of residues, and 20 is the number of amino acids.

    Parameters:
    - sequence: str, the input sequence of amino acids (e.g., "ACDE").

    Returns:
    - one_hot_matrix: np.ndarray, one-hot encoding matrix of shape (N, 20).
    """
    # Convert the sequence to a list of integers using aa_order dictionary
    aa_dict = residue_constants.restype_order
    seq_indices = [aa_dict.get(aa, -1) for aa in sequence]  # -1 for unknown AA

    # Ensure no unknown amino acids (-1) are present in the sequence
    if any(idx == -1 for idx in seq_indices):
        raise ValueError("Sequence contains invalid amino acid(s) not present in aa_order.")

    # Create a one-hot encoding matrix
    N = len(sequence)
    one_hot_matrix = np.eye(20)[seq_indices]

    return one_hot_matrix



UsageError: Line magic function `%%time` not found.


In [None]:
#@title input preparation {"vertical-output":true,"form-width":"50%","display-mode":"form"}
starting_seq = "MTYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE" #@param {type:"string"}
starting_seq = re.sub("[^A-Z]", "", starting_seq.upper())
##default sequence is PDB:3GB1 if no sequence is provided
starting_seq = "MTYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE" if len(starting_seq) == 0 else starting_seq
length = len(starting_seq)
template = "None" #@param ["custom","None"]
if template == "custom":
  custom_template_path = os.path.join(template,f"template")
  os.makedirs(custom_template_path, exist_ok=True)
  uploaded = files.upload()
  for fn in uploaded.keys():
    os.rename(fn,os.path.join(custom_template_path,fn))
  template_path = os.path.join(custom_template_path,fn)


In [None]:
#@title initialize the model with parameters and run {"vertical-output":true,"form-width":"50%","display-mode":"form"}
clear_mem()
model_name = "model_1_ptm" #@param ["model_1_ptm", "model_2_ptm","both"]
use_multimer = False 
model_name = None if model_name == "both" else [model_name]
af_model = mk_afdesign_model(protocol="hallucination",
                             use_templates=True,
                             debug=True, 
                             model_names=model_name,
                             use_multimer=use_multimer)
af_model.prep_inputs(length=length)

mode = "dgram" #@param ["dgram","dgram_retrain"]
if "dgram" in mode:
  if "retrain" in mode and not use_multimer:
    # update distogram head to return all 39 bins
    af_model._cfg.model.heads.distogram.first_break = 3.25
    af_model._cfg.model.heads.distogram.last_break = 50.75
    af_model._cfg.model.heads.distogram.num_bins = 39
    af_model._model = af_model._get_model(af_model._cfg)
    from colabdesign.af.weights import __file__ as af_path
    template_dgram_head = np.load(os.path.join(os.path.dirname(af_path),'template_dgram_head.npy'))
    for k in range(len(af_model._model_params)):
      params = {"weights":jnp.array(template_dgram_head[k]),"bias":jnp.zeros(39)}
      af_model._model_params[k]["alphafold/alphafold_iteration/distogram_head/half_logits"] = params
  else:
    dgram_map = np.eye(39)[np.repeat(np.append(0,np.arange(15)),4)]
    dgram_map[-1,:] = 0 

iterations = 50 #@param [50, 100, 200] {type:"raw"}
use_dgram_noise = None #@param ["g","u","None"]
use_dropout = False #@param {type:"boolean"}
seqsep_mask =  0 #@param {type:"integer"}
num_recycles = 2 #@param {type:"integer"}

sample_models = True if model_name == "both" else False
dgram_noise_type = use_dgram_noise
use_dgram_noise = False if use_dgram_noise is None else True

L = sum(af_model._lengths)
af_model.restart(mode="gumbel")
af_model._inputs["rm_template_seq"] = False
# gather info about inputs
if "offset" in af_model._inputs:           
  offset = af_model._inputs
else:
  idx = af_model._inputs["residue_index"]
  offset = idx[:,None] - idx[None,:]

# initialize sequence
if len(starting_seq) > 1:
  af_model.set_seq(seq=starting_seq)
af_model._inputs["bias"] = np.zeros((L,20))

# initialize coordinates/dgram
af_model._inputs["batch"] = {"aatype":np.zeros(L).astype(int),
                             "all_atom_mask":np.zeros((L,37)),
                             "all_atom_positions":np.zeros((L,37,3)),
                             "dgram":np.zeros((L,L,39))}

if template == "custom":
  xyz = xyz_atom37(pdb_file=template_path)
  af_model._inputs["batch"]["all_atom_positions"] = xyz
  dgram = get_dgram(xyz)
  mask = np.abs(offset) > seqsep_mask
  af_model._inputs["batch"]["dgram"] = dgram * mask[:,:,None]
  if use_dgram_noise:
    if dgram_noise_type == "g":   
      noise = sample_gumbel(dgram.shape) * (1 - k/iterations)
      dgram = softmax(np.log(dgram + 1e-8) + noise, -1)
    elif dgram_noise_type == 'u':  
      noise = sample_uniform(dgram.shape) * (1 - k/iterations)
      dgram = softmax(np.log(dgram + 1e-8) + noise, -1)
plddts = []
print(f"running seq {starting_seq} with model: {'both' if model_name is None else model_name} for {iterations} steps")
for k in range(iterations):
  # noise
  if k > 0:
    dgram_xyz = get_dgram(xyz)
    dgram_prob = softmax(dgram_logits,-1)

    if mode == "xyz":
      dgram = dgram_xyz
    if mode == "dgram":
      dgram = dgram_prob @ dgram_map
      dgram[...,14:] = dgram_xyz[...,14:] * dgram_prob[...,-1:]
    if mode == "dgram_retrain":
      dgram = dgram_prob
    
    if use_dgram_noise:
      if dgram_noise_type == "g":   
        noise = sample_gumbel(dgram.shape) * (1 - k/iterations)
        dgram = softmax(np.log(dgram + 1e-8) + noise, -1)
      elif dgram_noise_type == 'u':  
        noise = sample_uniform(dgram.shape) * (1 - k/iterations)
        dgram = softmax(np.log(dgram + 1e-8) + noise, -1)

    # add mask to avoid local contacts being fixed (otherwise there is a bias toward helix)
    mask = np.abs(offset) > seqsep_mask
    af_model._inputs["batch"]["dgram"] = dgram * mask[:,:,None]

  # prediction
  aux = af_model.predict(return_aux=True, verbose=False,
                        sample_models=sample_models,
                        dropout=use_dropout, num_recycles=num_recycles)
  plddt = aux["plddt"]
  plddts.append(np.average(plddt))
  seq = aux["seq"]["hard"][0].argmax(-1)   
  xyz = aux["atom_positions"].copy()
  dgram_logits = aux["debug"]["outputs"]["distogram"]["logits"] 
  
  # update inputs    
  af_model._inputs["batch"]["aatype"] = seq
  af_model._inputs["batch"]["all_atom_mask"][:,:4] = np.sqrt(plddt)[:,None]
  af_model._inputs["batch"]["all_atom_positions"] = xyz
  
  # save results
  af_model._save_results(aux)
  af_model._k += 1
  af_model.save_pdb(f"iter_{k}.pdb")

In [None]:
#@title visualization {"vertical-output":true,"form-width":"50%","display-mode":"form"}
fig,ax=plt.subplots(1,1,figsize=(7.4,2))
ax.scatter(range(len(plddts)),np.array(plddts)*100,s=12, color='grey', zorder=1)
ax.plot(np.array(plddts)*100,'darkorange',zorder=0)
ax.set_xlabel("Prediction iteration")
ax.set_ylabel("pLDDT")
ax.text(ax.get_xlim()[0]+(ax.get_xlim()[1]-ax.get_xlim()[0])*0.85,\
        ax.get_ylim()[0]+(ax.get_ylim()[1]-ax.get_ylim()[0])*0.05,f"recycle# {num_recycles}")
HTML(af_model.animate(dpi=80, interval=300))

In [None]:
#@title Visualize precalculated iterative structure predictions from PDB {"vertical-output":true,"form-width":"50%","display-mode":"form"}
!pip install plotly
import plotly
import plotly.express as px
#import nglview as nv
import plotly.graph_objects as go
import pandas as pd

In [None]:
#@title visualize embeddings {"vertical-output":true,"form-width":"50%","display-mode":"form"}
visualize_embeddings = True #@param {type:"boolean"}
seq_length_min = 0 #@param {type:"integer"}
seq_length_max = 100 #@param {type:"integer"}
best_rmsd_min = 0 #@param {type:"integer"}
best_rmsd_max = 3 #@param {type:"integer"}

if visualize_embeddings:
    if 'all_ss_gap0.npy' not in os.listdir('zenodo_downloads'):
        download_file = ['all_ss_gap0.npy','all_ss_gap6.npy', 
                         'all_pdbs_gap0.npy','all_seq_length_gap0.npy',
                         'all_rmsd_model1_gap0.npy','all_rmsd_model2_gap0.npy']
        record_id = record_ids['18_and_embeddings']
        download_zenodo(record_id=record_id,download_file=download_file)

all_ss = np.load("./zenodo_downloads/all_ss_gap0.npy")
all_rmsd = np.load("./zenodo_downloads/all_rmsd_model1_gap0.npy")
all_rmsd2 = np.load("./zenodo_downloads/all_rmsd_model2_gap0.npy")
all_seq_length = np.load("./zenodo_downloads/all_seq_length_gap0.npy")
all_rmsd = np.array(all_rmsd)
all_rmsd2 = np.array(all_rmsd2)
all_tmfile_pd=np.load("./zenodo_downloads/all_pdbs_gap0.npy")
all_EH = np.array([[(np.array([*i])=='E').sum()/((np.array([*i])=='E').sum()+(np.array([*i])=='H').sum()),
                    (np.array([*i])=='H').sum()/((np.array([*i])=='E').sum()+(np.array([*i])=='H').sum()),
                    (np.array([*i])=='E').sum()/len(i),
                    (np.array([*i])=='H').sum()/len(i),
                    (np.array([*(i.strip('C'))])=='C').sum()/len(i)] for i in all_ss])
all_H = [(np.array([*i])=='H').sum()/((np.array([*i])=='E').sum()+(np.array([*i])=='H').sum()) for i in all_ss]
import plotly.io as pio
pio.renderers.default = 'colab'
# import plotly.offline as pyo
# pyo.init_notebook_mode(connected=True)
long_seq_index = np.where((all_seq_length>seq_length_min) & (all_seq_length<seq_length_max))[0]
all_X_embedded = np.load('./zenodo_downloads/rmsds_plddts_embeddings/TSE_embedding_gap0.npy')
X_embedded = all_X_embedded[9]
plot_embedding(X_embedded[long_seq_index],all_EH[long_seq_index],all_rmsd2[long_seq_index],np.array(all_tmfile_pd)[long_seq_index],selection='RMSD',selection_min=best_rmsd_min,selection_max=best_rmsd_max)


In [None]:
#@title visualize individual pdb {"vertical-output":true,"form-width":"50%","display-mode":"form"}
pdb_id = "3gb1" #@param {type:"string"}
pdb_id = pdb_id.lower()
sub_traj_record_ids = {'zenodo_id_subtrajs.json':'13857269'}

from concurrent.futures import ThreadPoolExecutor, as_completed

!pip install mdtraj
import glob
import mdtraj as md
import matplotlib.cm as cm
import matplotlib.colors as mcolors

# Function to extract pLDDT values from a PDB file
def get_plddt(pdb):
    n_res = md.load(pdb).n_residues
    run_residues = list(range(1, n_res+1))

    with open(pdb, 'r') as fh:
        atom_lines = []
        for line in fh.readlines():
            # atom_count = 0
            if re.search(r'^ATOM', line):
              # atom_lines.append(line)
              # atom_count += 1
                if int(line[22:26]) in run_residues:
                    # atom_count += 1
                    atom_lines.append(line)
                    run_residues = run_residues[1:]
        str_bfactors1 = [l[61:65] for l in atom_lines]
        flt_bfactors1 = np.array([float(i)/100 for i in str_bfactors1]).astype("float32")

    return flt_bfactors1

# Function to run get_plddt concurrently for multiple PDB files
def process_plddts_in_parallel(pdb_files, max_workers=4):
    plddts = {}

    # Use ThreadPoolExecutor for multi-threading
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit tasks to the executor
        futures = {executor.submit(get_plddt, pdb): pdb for pdb in pdb_files}

        # Collect the results as they are completed
        for future in as_completed(futures):
            pdb = futures[future]
            try:
                result = future.result()
                plddts[pdb] = result
            except Exception as e:
                print(f"Error processing {pdb}: {e}")
    return plddts

if pdb_id is not None:
    if pdb_id in fig_1_6_pdbs:
        download_file = f"{pdb_id.lower()}.zip"
        record_id = record_ids['18_and_embeddings']
        download_zenodo(record_id,download_file=download_file)
        download_file = os.path.join('zenodo_downloads', download_file)
        extract_zip_multithreaded(download_file, 'zenodo_downloads', num_threads=4)
    else:
        download_zenodo('13857269',download_file='zenodo_id_subtrajs.json')
        sub_traj_zenodo_id_index = json.load(open('zenodo_downloads/zenodo_id_subtrajs.json'))
        sub_traj_zenodo_id_index_zip = {'1':['13841237','13836304'],'2':['13836288','13836290'],
                                        '3':['13836306','13836308'],'4':['13836319','13836313'],
                                        '5':['13836321'],'6':['13836295'],'7':['13841213'],'8':['13836304']}
        sub_traj_zenodo_id = sub_traj_zenodo_id_index_zip[pdb_id[0]]
        download_sub_traj_zenodo_id = []
        for id in sub_traj_zenodo_id:
            sub_traj_zips = sub_traj_zenodo_id_index[id].keys()
            for zip_key in sub_traj_zips:
                sub_trajs = sub_traj_zenodo_id_index[id][zip_key]
                for sub_traj in sub_trajs:
                    if sub_traj[-4:] == '.xtc' and pdb_id in sub_traj:
                        download_sub_traj_zenodo_id.append((id, zip_key, sub_traj))
        for (id, zip_key, sub_traj) in download_sub_traj_zenodo_id:
            if not os.path.exists(os.path.join('zenodo_downloads', sub_traj)):
              download_zenodo(id, download_file=zip_key)
              # sub_traj = os.path.join('zenodo_downloads', sub_traj)
            zip_key = os.path.join('zenodo_downloads', zip_key)
            extract_zip_multithreaded(zip_key, 'zenodo_downloads', file_name=[pdb_id], num_threads=4)
        extract_zip_multithreaded('zenodo_downloads/model_1_2_gap_0_6_per_residue_plddts.zip', 'zenodo_downloads', file_name=[pdb_id], num_threads=4)
        extract_zip_multithreaded('zenodo_downloads/0_0_iter_0_pdb.zip', 'zenodo_downloads', file_name=[pdb_id], num_threads=4)


# fig_1_6_trajs = {'3gb1':'run_2nd_nosamplemodel_model_1_gap0/2_0/','1mi0':'run_2nd_nosamplemodel_model_1_gap0/1_0/',
#                  '1hz5':'run_2nd_nosamplemodel_model_1_gap0/0_0/','1kh0':'run_2nd_nosamplemodel_model_2_gap0/0_0/',
#                  '1ubq':'run_2nd_nosamplemodel_model_1_gap0/2_0/','2hda':'run_2nd_nosamplemodel_model_1_gap0/2_0/',
#                  '1fnf':'run_2nd_nosamplemodel_model_2_gap0/1_0/'}
fig_1_6_trajs = {'3gb1':'run_2nd_nosamplemodel_model_1_gap0','1mi0':'run_2nd_nosamplemodel_model_1_gap0',
                 '1hz5':'run_2nd_nosamplemodel_model_1_gap0','1kh0':'run_2nd_nosamplemodel_model_2_gap0',
                 '1ubq':'run_2nd_nosamplemodel_model_1_gap0','2hda':'run_2nd_nosamplemodel_model_1_gap0',
                 '1fnf':'run_2nd_nosamplemodel_model_2_gap0'}

models = ['run_2nd_nosamplemodel_model_1_gap0','run_2nd_nosamplemodel_model_1_gap6','run_2nd_nosamplemodel_model_2_gap0','run_2nd_nosamplemodel_model_2_gap0']
if pdb_id in fig_1_6_pdbs:
  pdbs = [f"zenodo_downloads/{pdb_id}/{fig_1_6_trajs[pdb_id]}/{recycle}_0/iter_{step}.pdb" for recycle in range(11) for step in range(50)]
  plddts = process_plddts_in_parallel(pdbs)
  plddts = [plddts[pdb] for pdb in pdbs]
  traj = md.load(pdbs)
else:
  traj = [glob.glob(f"zenodo_downloads/*/{pdb_id}/{model}/*xtc")[0] for model in models]

  top = glob.glob(f"zenodo_downloads/*/{pdb_id}/{pdb_id}*pdb")[0]
  total_frames = 2500
  all_frames = list(range(total_frames))
  traj_list = [x for i in range(0, total_frames, 500) for x in all_frames[i:i+100]]
  traj_all = md.load(traj[0],top=top)[traj_list]
  for i in range(1,4):
    traj_all += md.load(traj[i],top=top)[traj_list]
  plddts = []
  for model in models:
    plddts_json = glob.glob(f"zenodo_downloads/*seq*/{pdb_id}/{model}/*plddt*json")[0]
    plddts_json = json.load(open(plddts_json))
    plddts.append([plddts_json[f'{recycle}_0/iter_{step}.pdb'] for recycle in [0,1,3,5,8] for step in range(100)])
  plddts = np.concatenate(plddts)/100
CA_atoms=traj_all.top.select(f"name CA")
traj_all=traj_all.atom_slice(CA_atoms)
xyz = [traj_all.xyz[i]*10 for i in range(len(traj_all))]
seq = [sequence_to_one_hot(traj_all.top.to_fasta()[0])[None,:,:]]*len(xyz)
if pdb_id in fig_1_6_pdbs:
  #recycling plot
  first_max_hit = []
  colors = plt.cm.plasma_r(np.linspace(0, 1, 12))[-11:]
  fig,ax = plt.subplots(1,1,figsize=(10,2))
  plot_steps = 20
  plddts = np.array(plddts).reshape(11,50,len(plddts[-1]))
  average_plddts = np.average(plddts, axis=2)
  for i in range(11):
    # file = f'/orange/alberto.perezant/liweichang/dev/colabdesign/example/known_folders/3gb1/model{i}_rseed0{j}/'
    # print(f"this is {file.split('/')[-3]+', '+file.split('/')[-2]} {j}th traj")
    tmp_max_hit = []
    tmp_plddts = average_plddts[i]
    ax.plot(tmp_plddts[:plot_steps]*100,label=str(i),color=colors[i])
    ax.scatter(list(range(plot_steps)),tmp_plddts[:plot_steps]*100,s=20,color=colors[i])
  ax.set_yticks(list(range(60,100,10)))
  ax.set_xticks(list(range(0,plot_steps)))
  plt.legend()
  plt.show()
else:
  #plot plddts for four models
  fig,ax=plt.subplots(4,1,figsize=(8.4,12))
  average_plddts = plddts.mean(axis=1)

  num_recycles = 5
  recycles = [0,1,3,5,8]
  cmap = cm.get_cmap('viridis', num_recycles) 
  norm = mcolors.Normalize(vmin=0, vmax=num_recycles-1)
  models = ['model_1_gap_0','model_1_gap_6','model_2_gap_0','model_2_gap_6']
  plot_steps = 50
  # Plot the data using the colormap
  for model in range(4):
      for recycle in range(5):
          color = cmap(recycle)  # Get the color from the colormap
          start_step = model*500+recycle*100
          tmp_plddts = average_plddts[start_step:start_step+plot_steps]*100
          ax[model].plot(tmp_plddts,
                        color=color,
                        label=f'recycle {recycle}')
          ax[model].scatter(list(range(len(tmp_plddts))),tmp_plddts,
                        color=color,
                        label=f'recycle {recycle}',s=15)

      ax[model].set_xlabel("Prediction iteration")
      ax[model].set_ylabel("pLDDT")
      ax[model].set_title(f"Model {models[model]}",fontsize=9)
  plt.tight_layout()
  # Add a colorbar to indicate recycle steps
  sm = cm.ScalarMappable(cmap=cmap, norm=norm)
  sm.set_array([])  # Needed to avoid a warning
  cbar = fig.colorbar(sm, ax=ax,shrink=0.8, aspect=30) #, orientation='vertical', pad=0.02)
  cbar.set_label('Recycle steps', rotation=270, labelpad=20)
  cbar.set_ticks([0, 1, 2, 3, 4])
  cbar.set_ticklabels(['0', '1', '3', '5', '8'])



In [None]:
#@title structural visualization with selected model and recycle {"vertical-output":true,"form-width":"50%","display-mode":"form"}
model = 'model_1_gap_0' #@param ['model_1_gap_0','model_1_gap_6','model_2_gap_0','model_2_gap_6']
recycle = "0" #@param ['0','1','3','5','8']

models = ['model_1_gap_0','model_1_gap_6','model_2_gap_0','model_2_gap_6']
recycles = ['0','1','3','5','8']
models_index = models.index(model)
recycle_index = recycles.index(recycle)
start_index = models_index*500+recycle_index*100
end_index = start_index+plot_steps
sub_xyz = xyz[start_index:end_index]
sub_plddts = plddts[start_index:end_index]
sub_seq = seq[start_index:end_index]
HTML(make_animation(sub_seq, xyz=sub_xyz, pae=None, plddt=sub_plddts, dpi=80, interval=300))
