<img src="http://bioinfo-pharma.u-strasbg.fr/labwebsite/images/logo_UMR7200.png" width="400" align="right" style="height:240px">

# LIT-AlphaFold-Colab
LIT-AlphaFold-Colab is an adapted version of LIT-AlphaFold for online use. <br>
The input file can be created and edited as with the local version of LIT-AlphaFold. <br>
Calculations are limited to a single monomer or multimer input.

In [None]:
# @title Install dependencies { display-mode: "form" }
import os
print("Installing LIT-AlphaFold")
if not os.path.isfile("CONDA_READY"):
    os.system("pip install --quiet condacolab")
    import condacolab
    condacolab.install()
    os.system("mamba install --quiet -c conda-forge -c bioconda python=3.10 openmm==7.7.0 pdbfixer kalign2=2.04 hhsuite=3.3.0 mmseqs2=14.7e284  polyleven") 
    os.system("touch CONDA_READY")
print("conda installed")

if not os.path.isfile("COLABFOLD_READY"):
    os.system("pip install --quiet alphapulldown==0.40.4 --no-deps")
    os.system('pip install --quiet --no-warn-conflicts "colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold@v1.5.3"')
    os.system("pip install --quiet https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl")
    os.system("pip install --quiet jax==0.3.25 chex==0.1.6 biopython==1.79")
    os.system("touch COLABFOLD_READY")
print("Environment ready")

if not os.path.isfile("LITAF_READY"):
    os.system("git clone -q https://github.com/LIT-CCM-lab/LIT-AlphaFold")
    os.system("pip install --quiet LIT-AlphaFold/litaf/ ")
    os.system("touch LITAF_READY")
print("LIT-AlphaFold installed")

In [None]:
# @title Input data { display-mode: "form" }
from pathlib import Path
from google.colab import files
import os
import logging
from absl import logging as absl_logging


#@markdown **Name of the project**
job_name = "LIT-AF-Colab-Test" # @param {type:"string"}
#@markdown **Templates**:
template_mode = "pdb100" # @param ["pdb100"]
max_template_date = "2050-01-01" # @param {type:"date"}

#@markdown `pdb100` = Templates are searched in the PDB100 dataset.<br>
#@markdown `custom` = Upload user selected templates (mmCIF or PDB format).

#@markdown Input type
input_type = 'fasta' # @param ["fasta", "pkl"]

#@markdown `fasta` = The input files are *.fasta* files containing the sequences of all the monomer or monomeric units.<br>
#@markdown `pkl` = The input files are *.pkl* files generated using LIT-AlphaFold containing MonomericObject or MultimericObject.<br>


# prepare output folder
output_dir = job_name
try:
        Path(output_dir).mkdir(parents=True, exist_ok=True)
except FileExistsError:
        print("Multiple processes are trying to create" \
                    " the same folder now.")

if template_mode == "pdb100":
  templates_path = 'mmseqs2'
  use_templates = True
elif template_mode == "custom":
  templates_path = os.path.join(f"custom_templates")
  os.makedirs(templates_path, exist_ok=True)
  up_templates = files.upload()
  use_templates = True
  for fn in up_templates.keys():
    os.rename(fn,os.path.join(custom_template_path,fn))
else:
  templates_path = None
  use_templates = False

for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
logging.basicConfig(format="%(levelname)s - %(message)s",
                    level = logging.INFO,
                    handlers=[
                    logging.FileHandler(f"{output_dir}/litaf.log"),
                    logging.StreamHandler(),
                    ])
absl_logging.set_verbosity('error')

In [None]:
# @title Upload input files { display-mode: "form" }
#@markdown Select *.fasta*, *.pkl* or *.pkl.bz2* files
monomers = {}
uploaded = files.upload()
for fn in uploaded.keys():
    if fn.split('.')[1] != input_type:
        raise Exception(f"Expected file of type {input_type} but uploaded file of type {fn.split('.')[1]}. Please check your inputs.")

In [None]:
# @title Prepare input files { display-mode: "form" }
import matplotlib.pyplot as plt
from colabfold.utils import DEFAULT_API_SERVER
from colabfold.batch import mk_hhsearch_db
from colabfold.plot import plot_msa_v2
from alphapulldown.utils import parse_fasta
from litaf.objects import MonomericObjectMmseqs2, load_monomer_objects
from litaf.filterpdb import load_template_filter

from absl import logging as absl_logging
absl_logging.set_verbosity(absl_logging.ERROR)

def iter_seqs(fasta_fns):
    for fasta_path in fasta_fns:
        with open(fasta_path, "r") as f:
            sequences, descriptions = parse_fasta(f.read())
            for seq, desc in zip(sequences, descriptions):
                yield seq, desc

show_msa = True # @param {type: 'boolean'}

template_description = ''

if input_type == 'fasta' and monomers == {}:

    if templates_path == "custom_templates":
        template_description = '_custom_templates'
        ff_appendix = ['_a3m.ffdata',
                        '_a3m.ffindex',
                        '_cs219.ffdata',
                        '_cs219.ffindex']
        ff_files = [os.path.join(templates_path,
                                    f'pdb70_{ff}') for ff in ff_appendix]
        for ff in ff_files:
            if not os.path.isfile(os.path.join(templates_path, ff)):
                mk_hhsearch_db(FLAGS.custom_template_path)
                break

    seqs = iter_seqs([fn for fn in uploaded.keys()])
    for seq_idx, (curr_seq, curr_desc) in enumerate(seqs, 1):
        curr_monomer_name = curr_desc.replace(' ', '_')+template_description
        monomers[curr_monomer_name] = MonomericObjectMmseqs2(curr_monomer_name, curr_seq)
        monomers[curr_monomer_name].make_features(DEFAULT_API_SERVER=DEFAULT_API_SERVER,
                                                    output_dir=output_dir,
                                                    templates_path=templates_path,
                                                    max_template_date=max_template_date,)
elif monomers == {}:
    for fn in uploaded.keys():
        m = load_monomer_objects({fn.split('.')[0]: '.'}, fn.split('.')[0])
        monomers[m.description] = m

for m_name, m in monomers.items():
    print(f"Monomer unit {m_name} created")
    if show_msa:
        %matplotlib inline
        plot_msa_v2(m.feature_dict)
        plt.show()
        plt.close()

filters = {}

In [None]:
# @title Load templates filters { display-mode: "form" }
#@markdown Filter structural templates based on a *.yaml* query file, please refer to the [Wiki](https://github.com/LIT-CCM-lab/LIT-AlphaFold/wiki) for more information.
#@markdown All instances of spcaes in the file names will be replaced by the '_' character
template_filter = True # @param {type:"boolean"}
if template_filter:
    up_filters = files.upload()
    filters.update({qf.replace(' ', '_'): load_template_filter(qf) for qf in up_filters.keys()})

In [None]:
# @title Filter templates { display-mode: "form" }
#@markdown Apply the structure filters on the monomeric units

import copy

monomer_and_filter = '' # @param{type: 'string'}
#@markdown Insert the pairing between a monomer unit and the uploaded filetr file, separate the different pairing with *;*<br>
#@markdown Do not add any space between *;* and the other characters

new_monomers = {}


for fp in monomer_and_filter.split(';'):
    if len(fp) == 0:
        break
    fk, ff = fp.split(' ')
    if not ff in filters or not fk in monomers:
        continue
    monomer = monomers.get(fk)

    new_m_name = f'{monomer.description}_{Path(ff).stem}'
    if new_m_name in monomers or new_m_name in new_monomers:
        continue
    logging.info(f"Filtering using query in file: {Path(ff).stem}")
    new_monomers[new_m_name] = copy.deepcopy(monomer)
    new_monomers[new_m_name].make_template_features(None,
                                                   filter_t = filters.get(ff, {}),
                                                   inplace = True)
    new_monomers[new_m_name].description = new_m_name
monomers.update(new_monomers)

In [None]:
# @title List and save monomer units { display-mode: "form" }
#@markdown List of the monomer objects that can be used for calculation
import pickle
save_monomers = False # @param {type:"boolean"}
print("Currently available monomer units:")
for mk, mv in monomers.items():
    print(mk)
    if save_monomers:
        pickle.dump(mv, open(f"{os.path.join(output_dir, mv.description)}.pkl", 'wb'))

In [None]:
from litaf.create_input import create_interactors_colab
from litaf.utils import obtain_options

# @title Calculations settings { display-mode: "form" }
#@markdown **Input**
#mode = "custom" # @param ['custom', 'homo-oligomer']
input_line = "" # @param {type: 'string'}
#@markdown For monomer prediction write the name of the monomer unit to use ex. "*monomer_1*" or "*monomer_1;1*"<br>
#@markdown For homo-oligomer multimer prediction write the name of the monomer units and the number of time the unit is present in the multimer separated by ';' ex. "*monomer_1;2*"<br>
#@markdown For multimer prediction write the name of the monomer units to use separated by *;* ex. "*monomer_1;monomer_2*"<br>
#@markdown For additional input otpions please visit the project [wiki](https://github.com/LIT-CCM-lab/LIT-AlphaFold/wiki)<br><br>
#@markdown **Prediction parameters**
run_description = ''
num_recycles_mono = 5 # @param {type: 'integer'}
num_recycles_multi = 20 # @param {type: 'integer'}
max_seqs = "none" # @param ["none", "8:16", "16:32"]

num_predictions_per_model = 1 # @param {type: 'integer'}
dropout = True # @param {type: 'boolean'}
if dropout:
    run_description = run_description+'_dropout'
disable_cluster_profile = False # @param {type: 'boolean'}
if disable_cluster_profile:
    run_description = run_description+'_noclusterprofile'
save_all = False # @param {type: 'boolean'}
model_to_relax = 'all' # @param ["all", "best", "none"]

if max_seqs == "none":
    max_seq = None
    max_extra_seq = None
else:
    max_seq, max_extra_seq = max_seqs.split(':')
    max_seq = int(max_seq)
    max_extra_seq = int(max_extra_seq)
    run_description = run_description + f'_MSA-subsampling-{max_seqs}'

only_template_models = True # @param {type: 'boolean'}
use_templates = True # @param {type: 'boolean'}

pair_msa = True # @param {type: 'boolean'}
remove_msa = False # @param {type: 'boolean'}
remove_template_msa = False # @param {type: 'boolean'}
shuffle_templates = False # @param {type: 'boolean'}

#@markdown **Show MSA plot**
show_msa = True # @param {type: 'boolean'}

from litaf.utils import read_custom
from litaf.objects import MultimericObject

input_interactors = input_line.rstrip().split(';')
if len(input_interactors) == 2:
    if input_interactors[1].isnumeric():
        num_units = int(input_interactors[1])
        data = [obtain_options(input_interactors[0])]
    else:
        custom = True
else:
    custom = True

if custom:
    num_units = 1
    data = [obtain_options(l) for l in input_interactors]
    

interactors = create_interactors_colab(data,
                                       monomers,
                                       remove_msa,
                                       remove_template_msa,
                                       not use_templates,
                                       shuffle_templates)
if len(interactors) > 1:
    multimer = MultimericObject(interactors, pair_msa)
elif not custom and num_units > 1:
    interactors = interactors * num_units
    multimer = MultimericObject(interactors, pair_msa)
else:
    multimer = interactors[0]

multimer.description = multimer.description+run_description

if show_msa:
    plot_msa_v2(multimer.feature_dict)
    plt.show()
    plt.close()

from litaf.utils import create_colabfold_runners
from colabfold.download import download_alphafold_params

if isinstance(multimer, MultimericObject):
        n = 5
        if not all([os.path.isfile(f"/params/params_model_{i}_multimer_v3.npz") for i in range(1,6)]):
          download_alphafold_params(f'alphafold2_multimer_v3', Path("."))
        model_runners = create_colabfold_runners(
                            '_multimer_v3',
                            n,
                            use_templates,
                            num_recycles_multi,
                            ".",
                            max_seq,
                            max_extra_seq,
                            num_predictions_per_model,
                            dropout,
                            disable_cluster_profile,
                            save_all)
else:
        n = 2 if only_template_models else 5
        if not all([os.path.isfile(f"/params/params_model_{i}_ptm.npz") for i in range(1,6)]):
            download_alphafold_params(f'alphafold2_ptm', Path("."))
        model_runners = create_colabfold_runners(
                            '_ptm',
                            n,
                            use_templates,
                            num_recycles_mono,
                            ".",
                            max_seq,
                            max_extra_seq,
                            num_predictions_per_model,
                            dropout,
                            disable_cluster_profile,
                            save_all)

import jax
try:
    # check if TPU is available
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
    logger.info('Running on TPU')
except:
    if jax.local_devices()[0].platform == 'cpu':
        logging.info("WARNING: no GPU detected, will be using CPU")
    else:
        import tensorflow as tf
        logging.info('Running on GPU')

In [None]:
# @title Run predictions { display-mode: "form" }
from litaf.predict_structure import predict, ModelsToRelax

output_path = os.path.join(output_dir, multimer.description)
Path(output_path).mkdir(parents=True, exist_ok=True)
if not isinstance(multimer, MultimericObject):
        multimer.input_seqs = [multimer.sequence]
predict(
        model_runners,
        output_path,
        multimer.feature_dict,
        0,
        False,
        fasta_name=multimer.description,
        models_to_relax=ModelsToRelax.NONE,
        seqs=multimer.input_seqs,
        allow_resume=True
    )

In [None]:
# @title Display 3D structure { display-mode: "form" }
import py3Dmol
import glob
import matplotlib.pyplot as plt
from colabfold.colabfold import plot_plddt_legend
from colabfold.colabfold import pymol_color_list, alphabet_list
rank_num = 1 #@param ["1", "2", "3", "4", "5"] {type:"raw"}
color = "lDDT" #@param ["chain", "lDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}

tag = rank_num - 1
pdb_filename = f"{output_dir}/{multimer.description}/ranked_{tag}.pdb"
pdb_file = glob.glob(pdb_filename)

is_complex = isinstance(multimer, MultimericObject)

#The function show_pdb is copied from ColabFold
def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color="lDDT"):
  model_name = f"rank_{rank_num}"
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(pdb_file[0],'r').read(),'pdb')

  if color == "lDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    chains = len(multimer.input_seqs)+1 if is_complex else 1
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view

show_pdb(rank_num, show_sidechains, show_mainchains, color).show()
if color == "lDDT":
  plot_plddt_legend().show()

In [None]:
# @title Plot plDDT & PAE { display-mode: "form" }
save_plots = True #@param {type:"boolean"}
from colabfold.colabfold import plot_plddts, plot_paes
import json
import pickle
from litaf.objects import MonomericObject, MultimericObject
with open(f"{output_dir}/{multimer.description}/ranking_debug.json", 'r') as f:
    data = json.load(f)

plddts = []
paes = []
for file in data['order']:
  with open(f'{output_dir}/{multimer.description}/result_{file}.pkl', 'rb') as pkl_f:
    model_data = pickle.load(pkl_f)
    plddts.append(model_data['plddt'])
    paes.append(model_data['predicted_aligned_error'])

if isinstance(multimer, MonomericObject):
  len_seqs = [len(multimer.sequence)]
elif isinstance(multimer, MultimericObject):
  len_seqs = [len(seq) for seq in multimer.sequence]

plddts_plot = plot_plddts(plddts, len_seqs)
pae_plot = plot_paes(paes, len_seqs)

if save_plots:
    plddts_plot.savefig(f"{output_dir}/{multimer.description}/plddts_plot.png")
    pae_plot.savefig(f"{output_dir}/{multimer.description}/pae_plots.png")

In [None]:
# @title Package and Download results { display-mode: "form" }
results_zip = f"{job_name}.result.zip"
os.system(f"zip -r {results_zip} {output_dir}")
files.download(f"{job_name}.result.zip")