# **SaprotHub: Making Protein Modeling Accessible to All Biologists**

This is the Colab version of [SaProt](https://github.com/westlake-repl/SaProt), a pre-trained protein language model designed for various downstream protein tasks.

**ColabSaprot** is a platform where **Protein Language Models(PLMs)** are more accessible and user-friendly for biologists, enabling effortless model training and sharing within the scientific community.

We've established the [SaprotHub](https://huggingface.co/SaProtHub) for storing and sharing models and datasets, where you can explore extensive collections for specific protein prediction tasks.

We hope ColabSaprot and SaprotHub can contribute to advancing biological research, fostering collaboration, and accelerating discoveries in the field. You can access [our paper](https://www.biorxiv.org/content/10.1101/2024.05.24.595648v2) for further details.

For detailed steps of each section, please refer to the <a href="#manual">manual</a>.

Check this [video](https://www.youtube.com/watch?v=r42z1hvYKfw) to see how to train your model using ColabSaprot.









## SaprotHub

Find awesome models and datasets for specific protein task on [SaprotHub](https://huggingface.co/SaprotHub)!

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/SaProtHub.png?raw=true" height="500" width="800px" align="center">

## ColabSaprot Content

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/content.png
?raw=true" height="400" width="600px" align="center">

<font color=red>**To view the content, please click on the first option in the left sidebar.**</font>



# 0: Instruction

Before you begin training and utilizing your model, here are some important details about **Task**, **Dataset** and **Basic Colab knowledge** that you need to be aware of.

<br>





## 0.1: Task

Different models are designed for different tasks, so it's essential to understand **which type your task belongs to**.

You can recognize your task type based on your task description and objectives.

<br>

<!-- ### Task Type

- **Classification Task**: classify protein sequences.
- **Regression Task**: predict the value of some property of a protein sequence.
- **Amino Acid Classification Task**: classify the amino acids in a protein sequence.  -->


| Task Description                                                                        | Task Type                          | Use SaProt to predict                         |
| --------------------------------------------------------------------------------------- | ---------------------------------- | --------------------------------------------- |
| Classify protein sequences.                                                             | **Classification Task**            | <a href="#classification_regression">here</a> |
| Predict the value of some property of a protein sequence.                               | **Regression Task**                | <a href="#classification_regression">here</a> |
| Classify the amino acids in a protein sequence.                                         | **Amino Acid Classification Task** | <a href="#classification_regression">here</a> |
| Predict the mutational effect based on the wild type sequence and mutation information. | **Mutational Effect Prediction**   | <a href="#mutational_effect">here</a>         |
| Predict the residue sequence given the structure backbone.                              | **Inverse Folding Prediction**     | <a href="#inverse_folding">here</a>           |
| Predict if there is interaction between the two proteins.                               | **Pair Classification Task**       | <a href="#classification_regression">here</a> |
| Predict the ability of interaction between the two proteins.                            | **Pair Regression Task**           | <a href="#classification_regression">here</a> |

<br>

Here are some example tasks and their task type:

| Task Type | Example |
| --- | --- |
| **Classification Task** | **Subcellular Location Prediction**: predict which location category the protein belong to. |
| **Classification Task** | **Metal Ion Binding Detection**: predict whether there are metal ion–binding sites in the protein. |
| **Regression Task** | **Thermostability Prediction**: predict the thermostability value of a protein. |
| **Amino Acid Classification Task** | **Binding Site Detection**: predict whether the amino acid is a binding site or not. |

<!-- <br>

### Use your models or shared models on SaProtHub

You can use

- your trained model
- or shared models on SaProtHub
- pre-trained protein language model

to make some prediction -->


<br>

To view the full list of tasks supported by ColabSaprot, please refer to [task_list.md](https://github.com/westlake-repl/SaProtHub/blob/main/task_list.md).

## 0.2: Dataset <a name="data_format"></a>

You can use your private data to train and predict. Below are the various data formats corresponding to different **data types**.

<br>


### SA(Structure-aware) Sequence

We combine the residue and structure tokens at each residue site to create a **Structure-aware sequence** (SA sequence), merging both residue and structural information.

The structure tokens are generated by encoding the 3D structure of proteins using Foldseek.

<a href="#get_sa">Here</a> you can **convert your data into SA Sequence** format.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/SA_Sequence.png?raw=true" height="300" width="600px" align="center">


<br>
<br>

### Data Type

1. Single AA Sequence
2. Single SA Sequence
3. Single UniProt ID
4. Single PDB/CIF Structure
5. Multiple AA Sequences
6. Multiple SA Sequences
7. Multiple UniProt IDs
8. Multiple PDB/CIF Structures
9. SaprotHub Dataset

For tasks that require **two protein sequences as input** (pair classification & pair regression) :

10. A pair of AA Sequences
11. A pair of SA Sequences
12. A pair of UniProt IDs
13. A pair of PDB/CIF Structures
14. Multiple pairs of AA Sequences
15. Multiple pairs of SA Sequences
16. Multiple pairs of UniProt IDs
17. Multiple pairs of PDB/CIF Structures

<br>

### Data Format <a name='data_format'></a>

#### For `Single AA Sequence`, `Single SA Sequence`, and `Single UniProt ID` (first three data types)
An input box will appear after running the cell. Please enter the protein sequence in the required format.

<br>

####  For `Single PDB/CIF Structure` (fourth data type)
A file upload button will appear after running the cell. Please upload a .pdb or .cif file.



<br>

#### For `Multiple AA Sequences`, `Multiple SA Sequences`, `Multiple UniProt IDs` (fifth to seventh data types)
A file upload button will appear after running the cell. Please upload a .csv file and ensure that the column name in the .csv file is `Sequence`.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_Sequences_data_format.png?raw=true" height="200" width="800px" align="center">

<br>
<br>

#### For `Multiple PDB/CIF Structures`
A file upload button will appear after running the cell. Please upload a .csv file containing three columns: `Sqeuence`, `type` and `chain`;

- `type`: Indicate whether the structure file is a real PDB structure or an AlphaFold 2 predicted structure. For AF2 (AlphaFold 2) structures, we will apply pLDDT masking. The value must be either "PDB" or "AF2".
- `chain`: For real PDB structures, since multiple chains may exist in one .pdb file, it is necessary to specify which chain is used. For AF2 structures, the chain is assumed to be A by default.

After successfully uploading the .csv file, a second file upload button will appear. Please upload a zip file containing all corresponding pdb/cif files.


<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_PDB_CIF_Structures_data_format.png?raw=true" height="200" width="500px" align="center">

<br>
<br>

#### For `SaprotHub Dataset`
An input box will appear after running the cell. Please enter the the ID of the SaprotHub Dataset. Find some datasets in [Official SaProtHub Repository](https://huggingface.co/SaProtHub).

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Huggingface_ID.png?raw=true" height="200" width="700px" align="center">

<br>
<br>



## 0.3: Basic Colab Knowledge

<br>

### Cell Running status

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Cell_status.png?raw=true" height="300" width="500px" align="center">

# **1: Installation**


In [None]:
#@title 1.1: ⚠️ Switch your Runtime type to <font color=red>**GPU!!!**</font>

#@markdown You can check the current runtime type in <font color=red>**the upper right corner of the page**</font>. If the current runtime type is CPU, you need to <font color=red>**switch it to GPU (either the free T4 or the paid A100)**</font> for a better training experience.

#@markdown <img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Runtime.png?raw=true" height="300" width="400px" align="center">

#@markdown #### Please follow the steps below to switch the runtime to GPU:

#@markdown 1. Click the dropdown button
#@markdown 2. Select option "Change runtime type"
#@markdown 3. Select a GPU
#@markdown 4. Click "Save" button
#@markdown 5. <font color=red>Each time you switch the runtime, all code blocks need to be **re-executed**.</font>


#@markdown <img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Switch_Runtime.png?raw=true" height="400" width="800px" align="center">

In [None]:
#@title 1.2: ▶️ Click the run button to install SaProt

#@markdown (Please waiting for 2-8 minutes to install...)
################################################################################
########################### install saprot #####################################
################################################################################
%load_ext autoreload
%autoreload 2

import os
# Check whether the server is local or from google cloud
root_dir = os.getcwd()

from google.colab import output
output.enable_custom_widget_manager()

try:
  import sys
  sys.path.append(f"{root_dir}/SaprotHub")
  import saprot
  print("SaProt is installed successfully!")
  os.system(f"chmod +x {root_dir}/SaprotHub/bin/*")

except ImportError:
  print("Installing SaProt...")
  os.system(f"rm -rf {root_dir}/SaprotHub")
  # !rm -rf /content/SaprotHub/

  !git clone https://github.com/westlake-repl/SaprotHub.git

  # !pip install /content/SaprotHub/saprot-0.4.7-py3-none-any.whl
  os.system(f"pip install -r {root_dir}/SaprotHub/requirements.txt")
  # !pip install -r /content/SaprotHub/requirements.txt

  os.system(f"pip install {root_dir}/SaprotHub")


  os.system(f"mkdir -p {root_dir}/SaprotHub/LMDB")
  os.system(f"mkdir -p {root_dir}/SaprotHub/bin")
  os.system(f"mkdir -p {root_dir}/SaprotHub/output")
  os.system(f"mkdir -p {root_dir}/SaprotHub/datasets")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/classification/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/regression/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/token_classification/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/pair_classification/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/pair_regression/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/structures")
  # !mkdir -p /content/SaprotHub/LMDB
  # !mkdir -p /content/SaprotHub/bin
  # !mkdir -p /content/SaprotHub/output
  # !mkdir -p /content/SaprotHub/datasets
  # !mkdir -p /content/SaprotHub/adapters/classification/Local
  # !mkdir -p /content/SaprotHub/adapters/regression/Local
  # !mkdir -p /content/SaprotHub/adapters/token_classification/Local
  # !mkdir -p /content/SaprotHub/adapters/pair_classification/Local
  # !mkdir -p /content/SaprotHub/adapters/pair_regression/Local
  # !mkdir -p /content/SaprotHub/structures

  # !pip install gdown==v4.6.3 --force-reinstall --quiet
  # os.system(
  #   f"wget 'https://drive.usercontent.google.com/download?id=1B_9t3n_nlj8Y3Kpc_mMjtMdY0OPYa7Re&export=download&authuser=0' -O {root_dir}/SaprotHub/bin/foldseek"
  # )

  os.system(f"chmod +x {root_dir}/SaprotHub/bin/*")
  # !chmod +x /content/SaprotHub/bin/foldseek
  import sys
  sys.path.append(f"{root_dir}/SaprotHub")

  # !mv /content/SaprotHub/ColabSaprotSetup/foldseek /content/SaprotHub/bin/

################################################################################
################################################################################
################################## global ######################################
################################################################################
################################################################################

import ipywidgets
import pandas as pd
import torch
import numpy as np
import lmdb
import base64
import copy
import os
import json
import zipfile
import yaml
import argparse
import pprint
import subprocess
import py3Dmol
import matplotlib.pyplot as plt


from loguru import logger
from easydict import EasyDict
from colorama import init, Fore, Back, Style
from IPython.display import clear_output
from saprot.utils.mpr import MultipleProcessRunnerSimplifier
from huggingface_hub import snapshot_download
from ipywidgets import HTML
from IPython.display import display
from google.colab import widgets
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
from google.colab import files
from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from string import ascii_uppercase,ascii_lowercase

DATASET_HOME = Path(f'{root_dir}/SaprotHub/datasets')
ADAPTER_HOME = Path(f'{root_dir}/SaprotHub/adapters')
STRUCTURE_HOME = Path(f"{root_dir}/SaprotHub/structures")
LMDB_HOME = Path(f'{root_dir}/SaprotHub/LMDB')
OUTPUT_HOME = Path(f'{root_dir}/SaprotHub/output')
UPLOAD_FILE_HOME = Path(f'{root_dir}/SaprotHub/upload_files')
FOLDSEEK_PATH = Path(f"{root_dir}/SaprotHub/bin/foldseek")
aa_set = {"A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"}
foldseek_struc_vocab = "pynwrqhgdlvtmfsaeikc#"

data_type_list = ["Single AA Sequence",
                  "Single SA Sequence",
                  "Single UniProt ID",
                  "Single PDB/CIF Structure",
                  "Multiple AA Sequences",
                  "Multiple SA Sequences",
                  "Multiple UniProt IDs",
                  "Multiple PDB/CIF Structures",
                  "SaprotHub Dataset",
                  "A pair of AA Sequences",
                  "A pair of SA Sequences",
                  "A pair of UniProt IDs",
                  "A pair of PDB/CIF Structures",
                  "Multiple pairs of AA Sequences",
                  "Multiple pairs of SA Sequences",
                  "Multiple pairs of UniProt IDs",
                  "Multiple pairs of PDB/CIF Structures",]

task_type_dict = {
  "Classify protein sequences (classification)" : "classification",
  "Classify each Amino Acid (amino acid classification), e.g. Binding site detection" : "token_classification",
  "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein" : "regression",
  "Predict protein-protein interaction (pair classification)":"pair_classification",
  "Predict protein-protein interaction (pair regression)":"pair_regression",
}
model_type_dict = {
  "classification" : "saprot/saprot_classification_model",
  "token_classification" : "saprot/saprot_token_classification_model",
  "regression" : "saprot/saprot_regression_model",
  "pair_classification" : "saprot/saprot_pair_classification_model",
  "pair_regression" : "saprot/saprot_pair_regression_model",
}
dataset_type_dict = {
  "classification": "saprot/saprot_classification_dataset",
  "token_classification" : "saprot/saprot_token_classification_dataset",
  "regression": "saprot/saprot_regression_dataset",
  "pair_classification" : "saprot/saprot_pair_classification_dataset",
  "pair_regression" : "saprot/saprot_pair_regression_dataset",
}
training_data_type_dict = {
  "Single AA Sequence": "AA",
  "Single SA Sequence": "SA",
  "Single UniProt ID": "SA",
  "Single PDB/CIF Structure": "SA",
  "Multiple AA Sequences": "AA",
  "Multiple SA Sequences": "SA",
  "Multiple UniProt IDs": "SA",
  "Multiple PDB/CIF Structures": "SA",
  "SaprotHub Dataset": "SA",
  "A pair of AA Sequences": "AA",
  "A pair of SA Sequences": "SA",
  "A pair of UniProt IDs": "SA",
  "A pair of PDB/CIF Structures": "SA",
  "Multiple pairs of AA Sequences": "AA",
  "Multiple pairs of SA Sequences": "SA",
  "Multiple pairs of UniProt IDs": "SA",
  "Multiple pairs of PDB/CIF Structures": "SA",
}


class font:
    RED = '\033[91m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'

    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

    RESET = '\033[0m'


################################################################################
############################### adapters #######################################
################################################################################
def get_adapters_list(task_type=None):

    adapters_list = []

    if task_type:
      for file_path in (ADAPTER_HOME / task_type).glob('**/adapter_config.json'):
        adapters_list.append(file_path.parent)
    else:
      for file_path in ADAPTER_HOME.glob('**/adapter_config.json'):
        adapters_list.append(file_path.parent)

    return adapters_list


def show_adapters_info(adapters_list):
  grid = widgets.Grid(len(adapters_list)+1, 2, header_row=True, header_column=True)

  with grid.output_to(0, 0):
    print("ID")

  with grid.output_to(0, 1):
    print("Local Model")

  # with grid.output_to(0, 2):
  #   print("Adapter Path")

  for i in range(len(adapters_list)):
    with grid.output_to(i+1, 0):
      print(i)
    with grid.output_to(i+1, 1):
      print(adapters_list[i].stem)
    # with grid.output_to(i+1, 2):
    #   print(adapters_list[i])

def adapters_text(adapters_list):
  input = ipywidgets.Text(
    value=None,
    placeholder='Enter SaprotHub Model ID',
    # description='Selected:',
    disabled=False)
  input.layout.width = '500px'
  display(input)

  return input

def adapters_dropdown(adapters_list):
  dropdown = ipywidgets.Dropdown(
    options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
    value=None,
    placeholder='Select a Local Model here',
    # description='Selected:',
    disabled=False)
  dropdown.layout.width = '500px'
  display(dropdown)

  return dropdown

def adapters_combobox(adapters_list):
  combobox = ipywidgets.Combobox(
    options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
    value=None,
    placeholder='Enter SaprotHub Model repository id or select a Local Model here',
    # description='Selected:',
    disabled=False)
  combobox.layout.width = '500px'
  display(combobox)

  return combobox

def select_adapter():
  adapters_list = get_adapters_list()
  print(Fore.BLUE+"Existing Models:"+Style.RESET_ALL)
  # print("="*100)
  # show_adapters_info(adapters_list)
  # print("="*100)
  return adapters_combobox(adapters_list)

def adapters_selectmultiple(adapters_list):
  selectmulitiple = ipywidgets.SelectMultiple(
  options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
  value=[],
  #rows=10,
  placeholder='Select multiple models',
  # description='Fruits',
  disabled=False,
  layout={'width': '500px'})
  display(selectmulitiple)

  return selectmulitiple

def adapters_textmultiple(adapters_list):
  textmultiple = ipywidgets.Text(
  value=None,
  placeholder='Enter multiple SaprotHub Model IDs, separated by commas.',
  # description='Fruits',
  disabled=False,
  layout={'width': '500px'})
  display(textmultiple)

  return textmultiple

# def select_adapter_from(use_model_from):
#   adapters_list = get_adapters_list()

#   if use_model_from == 'Trained by yourself on ColabSaprot':
#     print(Fore.BLUE+"Local Model:"+Style.RESET_ALL)
#     return adapters_dropdown(adapters_list)
#   elif use_model_from == 'Shared by peers on SaprotHub':
#     print(Fore.BLUE+"SaprotHub Model:"+Style.RESET_ALL)
#     return adapters_text(adapters_list)



def select_adapter_from(task_type, use_model_from):
  adapters_list = get_adapters_list(task_type)

  if use_model_from == 'Trained by yourself on ColabSaprot':
    print(Fore.BLUE+f"Local Model ({task_type}):"+Style.RESET_ALL)
    return adapters_dropdown(adapters_list)

  elif use_model_from == 'Shared by peers on SaprotHub':
    print(Fore.BLUE+"SaprotHub Model:"+Style.RESET_ALL)
    return adapters_text(adapters_list)

  elif use_model_from == "Saved in your local computer":
    print(Fore.BLUE+"Click the button to upload the \"Model-<task_name>-<model_size>.zip\" file of your Model:"+Style.RESET_ALL)
    # 1. upload model.zip
    adapter_upload_path = ADAPTER_HOME / task_type / "Local"
    adapter_zip_path = upload_file(adapter_upload_path)
    adapter_path = adapter_upload_path / adapter_zip_path.stem
    # 2. unzip model.zip
    with zipfile.ZipFile(adapter_zip_path, 'r') as zip_ref:
        zip_ref.extractall(adapter_path)
    os.remove(adapter_zip_path)
    # 3. check adapter_config.json
    adapter_config_path = adapter_path / "adapter_config.json"
    assert adapter_config_path.exists(), f"Can't find {adapter_config_path}"

    return EasyDict({"value":  f"Local/{adapter_zip_path.stem}"})

  elif use_model_from == "Multi-models on ColabSaprot":
    # 1. select the list of adapters
    print(Fore.BLUE+f"Local Model ({task_type}):"+Style.RESET_ALL)
    print(Fore.BLUE+f"Multiple values can be selected with \"shift\" and/or \"ctrl\" (or \"command\") pressed and mouse clicks or arrow keys."+Style.RESET_ALL)
    return adapters_selectmultiple(adapters_list)

  elif use_model_from == "Multi-models on SaprotHub":
    # 1. enter the list of adapters
    print(Fore.BLUE+f"SaprotHub Model IDs, separated by commas ({task_type}):"+Style.RESET_ALL)
    return adapters_textmultiple(adapters_list)



################################################################################
########################### download dataset ###################################
################################################################################
def download_dataset(task_name):
  import gdown
  import tarfile

  filepath = LMDB_HOME / f"{task_name}.tar.gz"
  download_links = {
    "ClinVar" : "https://drive.google.com/uc?id=1Le6-v8ddXa1eLJZFo7HPij7NhaBmNUbo",
    "DeepLoc_cls2" : "https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf",
    "DeepLoc_cls10" : "https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf",
    "EC" : "https://drive.google.com/uc?id=1VFLFA-jK1tkTZBVbMw8YSsjZqAqlVQVQ",
    "GO_BP" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "GO_CC" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "GO_MF" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "HumanPPI" : "https://drive.google.com/uc?id=1ahgj-IQTtv3Ib5iaiXO_ASh2hskEsvoX",
    "MetalIonBinding" : "https://drive.google.com/uc?id=1rwknPWIHrXKQoiYvgQy4Jd-efspY16x3",
    "ProteinGym" : "https://drive.google.com/uc?id=1L-ODrhfeSjDom-kQ2JNDa2nDEpS8EGfD",
    "Thermostability" : "https://drive.google.com/uc?id=1I9GR1stFDHc8W3FCsiykyrkNprDyUzSz",
  }

  try:
    gdown.download(download_links[task_name], str(filepath), quiet=False)
    with tarfile.open(filepath, 'r:gz') as tar:
      tar.extractall(path=str(LMDB_HOME))
      print(f"Extracted: {filepath}")
  except Exception as e:
    raise RuntimeError("The dataset has not prepared.")

################################################################################
############################# upload file ######################################
################################################################################
def upload_file(upload_path):
  import shutil
  import os
  from pathlib import Path
  import sys

  upload_path = Path(upload_path)
  upload_path.mkdir(parents=True, exist_ok=True)
  basepath = Path().resolve()
  try:
    uploaded = files.upload()
    filenames = []
    for filename in uploaded.keys():
      filenames.append(filename)
      shutil.move(basepath / filename, upload_path / filename)
    if len(filenames) == 0:
      logger.info("The uploading process has been interrupted by the user.")
      raise RuntimeError("The uploading process has been interrupted by the user.")
  except Exception as e:
    logger.error("Upload file fail! Please click the button to run again.")
    raise(e)

  return upload_path / filenames[0]

################################################################################
############################ upload dataset ####################################
################################################################################


def input_raw_data_by_data_type(data_type):
  print(Fore.BLUE+"Dataset: "+Style.RESET_ALL, end='')

  # 0-2. 0. Single AA Sequence, 1. Single SA Sequence, 2. Single UniProt ID
  if data_type in data_type_list[:3]:
    input_seq = ipywidgets.Text(
      value=None,
      placeholder=f'Enter {data_type} here',
      disabled=False)
    input_seq.layout.width = '500px'
    print(Fore.BLUE+f"{data_type}"+Style.RESET_ALL)
    display(input_seq)
    return input_seq

  # 3. Single PDB/CIF Structure
  elif data_type == data_type_list[3]:
    print("Please provide the structure type, chain and your structure file.")

    dropdown_type = ipywidgets.Dropdown(
      value="PDB",
      options=["PDB", "AF2"],
      disabled=False)
    dropdown_type.layout.width = '500px'
    print(Fore.BLUE+"Structure type:"+Style.RESET_ALL)
    display(dropdown_type)

    input_chain = ipywidgets.Text(
      value="A",
      placeholder=f'Enter the name of chain here',
      disabled=False)
    input_chain.layout.width = '500px'
    print(Fore.BLUE+"Chain:"+Style.RESET_ALL)
    display(input_chain)

    print(Fore.BLUE+"Please upload a .pdb/.cif file"+Style.RESET_ALL)
    pdb_file_path = upload_file(STRUCTURE_HOME)
    return pdb_file_path.stem, dropdown_type, input_chain

  # 4-7 & 13-16. Multiple Sequences
  elif data_type in data_type_list[4:8] or data_type in data_type_list[13:17]:
    print(Fore.BLUE+f"Please upload the .csv file which contains {data_type}"+Style.RESET_ALL)
    uploaded_csv_path = upload_file(UPLOAD_FILE_HOME)
    print(Fore.BLUE+"Successfully upload your .csv file!"+Style.RESET_ALL)
    print("="*100)

    if data_type == data_type_list[7] or data_type == data_type_list[16]:
      # upload and unzip PDB files
      print(Fore.BLUE+f"Please upload your .zip file which contains {data_type} files"+Style.RESET_ALL)
      pdb_zip_path = upload_file(UPLOAD_FILE_HOME)
      if pdb_zip_path.suffix != ".zip":
        logger.error("The data type does not match. Please click the run button again to upload a .zip file!")
        raise RuntimeError("The data type does not match.")
      print(Fore.BLUE+"Successfully upload your .zip file!"+Style.RESET_ALL)
      print("="*100)

      import zipfile
      with zipfile.ZipFile(pdb_zip_path, 'r') as zip_ref:
        zip_ref.extractall(STRUCTURE_HOME)

    return uploaded_csv_path

  # 8. SaprotHub Dataset
  elif data_type == data_type_list[8]:
    input_repo_id = ipywidgets.Text(
      value=None,
      placeholder=f'Copy and paste the SaprotHub Dataset ID here',
      disabled=False)
    input_repo_id.layout.width = '500px'
    print(Fore.BLUE+f"{data_type}"+Style.RESET_ALL)
    display(input_repo_id)
    return input_repo_id

  # 9-11. A pair of seq
  elif data_type in ["A pair of AA Sequences", "A pair of SA Sequences", "A pair of UniProt IDs"]:
    print()

    seq_type = data_type[len("A pair of "):-1]

    input_seq1 = ipywidgets.Text(
      value=None,
      placeholder=f'Enter the {seq_type} of Sequence 1 here',
      disabled=False)
    input_seq1.layout.width = '500px'
    print(Fore.BLUE+f"Sequence 1:"+Style.RESET_ALL)
    display(input_seq1)

    input_seq2 = ipywidgets.Text(
      value=None,
      placeholder=f'Enter the {seq_type} of Sequence 2 here',
      disabled=False)
    input_seq2.layout.width = '500px'
    print(Fore.BLUE+f"Sequence 2:"+Style.RESET_ALL)
    display(input_seq2)

    return (input_seq1, input_seq2)

  # 12. Pair Single PDB/CIF Structure
  elif data_type == data_type_list[12]:
    print("Please provide the structure type, chain and your structure file.")

    dropdown_type1 = ipywidgets.Dropdown(
      value="PDB",
      options=["PDB", "AF2"],
      disabled=False)
    dropdown_type1.layout.width = '500px'
    print(Fore.BLUE+"The first structure type:"+Style.RESET_ALL)
    display(dropdown_type1)

    input_chain1 = ipywidgets.Text(
      value="A",
      placeholder=f'Enter the name of chain of the first structure here',
      disabled=False)
    input_chain1.layout.width = '500px'
    print(Fore.BLUE+"Chain of the first structure:"+Style.RESET_ALL)
    display(input_chain1)

    print(Fore.BLUE+"Please upload a .pdb/.cif file"+Style.RESET_ALL)
    pdb_file_path1 = upload_file(STRUCTURE_HOME)


    dropdown_type2 = ipywidgets.Dropdown(
      value="PDB",
      options=["PDB", "AF2"],
      disabled=False)
    dropdown_type2.layout.width = '500px'
    print(Fore.BLUE+"The second structure type:"+Style.RESET_ALL)
    display(dropdown_type2)

    input_chain2 = ipywidgets.Text(
      value="A",
      placeholder=f'Enter the name of chain of the second structure here',
      disabled=False)
    input_chain2.layout.width = '500px'
    print(Fore.BLUE+"Chain of the second structure:"+Style.RESET_ALL)
    display(input_chain2)

    print(Fore.BLUE+"Please upload a .pdb/.cif file"+Style.RESET_ALL)
    pdb_file_path2 = upload_file(STRUCTURE_HOME)
    return (pdb_file_path1.stem, dropdown_type1, input_chain1, pdb_file_path2.stem, dropdown_type2, input_chain2)


  # elif data_type == "Multiple pairs of PDB/CIF Structures":
  #   print(Fore.BLUE+f"Please upload the .csv file which contains {data_type}"+Style.RESET_ALL)
  #   uploaded_csv_path = upload_file(UPLOAD_FILE_HOME)
  #   print(Fore.BLUE+"Successfully upload your .csv file!"+Style.RESET_ALL)
  #   print("="*100)

  #   if data_type == data_type_list[7]:
  #     # upload and unzip PDB files
  #     print(Fore.BLUE+f"Please upload your .zip file which contains {data_type} files"+Style.RESET_ALL)
  #     pdb_zip_path = upload_file(UPLOAD_FILE_HOME)
  #     if pdb_zip_path.suffix != ".zip":
  #       logger.error("The data type does not match. Please click the run button again to upload a .zip file!")
  #       raise RuntimeError("The data type does not match.")
  #     print(Fore.BLUE+"Successfully upload your .zip file!"+Style.RESET_ALL)
  #     print("="*100)

  #     import zipfile
  #     with zipfile.ZipFile(pdb_zip_path, 'r') as zip_ref:
  #       zip_ref.extractall(STRUCTURE_HOME)

  #   return uploaded_csv_path




def get_SA_sequence_by_data_type(data_type, raw_data):

  # 0. Single AA Sequence
  if data_type == data_type_list[0]:
    input_seq = raw_data
    aa_seq = input_seq.value

    sa_seq = ''
    for aa in aa_seq:
        sa_seq += aa + '#'
    return sa_seq

  # 1. Single SA Sequence
  if data_type == data_type_list[1]:
    input_seq = raw_data
    sa_seq = input_seq.value

    return sa_seq

  # 2. Single UniProt ID
  if data_type == data_type_list[2]:
    input_seq = raw_data
    uniprot_id = input_seq.value


    protein_list = [(uniprot_id, "AF2", "A")]
    uniprot2pdb([protein_list[0][0]])
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
    seqs = mprs.run()
    sa_seq = seqs[0].split('\t')[1]
    return sa_seq

  # 3. Single PDB/CIF Structure
  if data_type == data_type_list[3]:
    uniprot_id = raw_data[0]
    struc_type = raw_data[1].value
    chain = raw_data[2].value

    protein_list = [(uniprot_id, struc_type, chain)]
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
    seqs = mprs.run()
    sa_seq = seqs[0].split('\t')[1]
    return sa_seq

  # Multiple sequences
  # raw_data = upload_files/xxx.csv
  if data_type in data_type_list[4:8] or data_type in data_type_list[13:17]:
    uploaded_csv_path = raw_data
    csv_dataset_path = DATASET_HOME / uploaded_csv_path.name

  # 4. Multiple AA Sequences
  if data_type == data_type_list[4]:
    protein_df = pd.read_csv(uploaded_csv_path)
    for index, value in protein_df['Sequence'].items():
      sa_seq = ''
      for aa in value:
        sa_seq += aa + '#'
      protein_df.at[index, 'Sequence'] = sa_seq

    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path

  # 5. Multiple SA Sequences
  if data_type == data_type_list[5]:
    protein_df = pd.read_csv(uploaded_csv_path)

    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path

  # 6. Multiple UniProt IDs
  if data_type == data_type_list[6]:
    protein_df = pd.read_csv(uploaded_csv_path)
    protein_list = protein_df.iloc[:, 0].tolist()
    uniprot2pdb(protein_list)
    protein_list = [(uniprot_id, "AF2", "A") for uniprot_id in protein_list]
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
    outputs = mprs.run()

    protein_df['Sequence'] = [output.split("\t")[1] for output in outputs]
    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path

  # 7. Multiple PDB/CIF Structures
  if data_type == data_type_list[7]:
    protein_df = pd.read_csv(uploaded_csv_path)
    # protein_list = [(uniprot_id, type, chain), ...]
    # protein_list = [item.split('.')[0] for item in protein_df.iloc[:, 0].tolist()]
    # uniprot2pdb(protein_list)
    protein_list = []
    for row_tuple in protein_df.itertuples(index=False):
      assert row_tuple.type in ['PDB', 'AF2'],  "The type of structure must be either \"PDB\" or \"AF2\"!"
      protein_list.append(row_tuple)
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
    outputs = mprs.run()

    protein_df['Sequence'] = [output.split("\t")[1] for output in outputs]
    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path

  # 8. SaprotHub Dataset
  elif data_type == data_type_list[8]:
    input_repo_id = raw_data
    REPO_ID = input_repo_id.value

    if REPO_ID.startswith('/'):
      return Path(REPO_ID)

    snapshot_download(repo_id=REPO_ID, repo_type="dataset", local_dir=LMDB_HOME/REPO_ID)

    return LMDB_HOME/REPO_ID

  # 9. Pair Single AA Sequences
  elif data_type == "A pair of AA Sequences":
    input_seq_1, input_seq_2 = raw_data
    sa_seq1 = get_SA_sequence_by_data_type(data_type_list[0], input_seq_1)
    sa_seq2 = get_SA_sequence_by_data_type(data_type_list[0], input_seq_2)

    return (sa_seq1, sa_seq2)

  # 10. Pair Single SA Sequences
  elif data_type ==  "A pair of SA Sequences":
    input_seq_1, input_seq_2 = raw_data
    sa_seq1 = get_SA_sequence_by_data_type(data_type_list[1], input_seq_1)
    sa_seq2 = get_SA_sequence_by_data_type(data_type_list[1], input_seq_2)

    return (sa_seq1, sa_seq2)

  # 11. Pair Single UniProt IDs
  elif data_type ==  "A pair of UniProt IDs":
    input_seq_1, input_seq_2 = raw_data
    sa_seq1 = get_SA_sequence_by_data_type(data_type_list[2], input_seq_1)
    sa_seq2 = get_SA_sequence_by_data_type(data_type_list[2], input_seq_2)

    return (sa_seq1, sa_seq2)

  # 12. Pair Single PDB/CIF Structure
  if data_type == "A pair of PDB/CIF Structures":
    uniprot_id1 = raw_data[0]
    struc_type1 = raw_data[1].value
    chain1 = raw_data[2].value

    protein_list1 = [(uniprot_id1, struc_type1, chain1)]
    mprs1 = MultipleProcessRunnerSimplifier(protein_list1, pdb2sequence, n_process=2, return_results=True)
    seqs1 = mprs1.run()
    sa_seq1 = seqs1[0].split('\t')[1]

    uniprot_id2 = raw_data[3]
    struc_type2 = raw_data[4].value
    chain2 = raw_data[5].value

    protein_list2 = [(uniprot_id2, struc_type2, chain2)]
    mprs2 = MultipleProcessRunnerSimplifier(protein_list2, pdb2sequence, n_process=2, return_results=True)
    seqs2 = mprs2.run()
    sa_seq2 = seqs2[0].split('\t')[1]
    return sa_seq1, sa_seq2

  # # Pair raw_data = upload_files/xxx.csv
  # if data_type in data_type_list[12:16]:
  #   uploaded_csv_path = raw_data
  #   csv_dataset_path = DATASET_HOME / uploaded_csv_path.name

  # 13. Pair Multiple AA Sequences
  if data_type == "Multiple pairs of AA Sequences":
    protein_df = pd.read_csv(uploaded_csv_path)
    for index, value in protein_df['seq_1'].items():
      sa_seq1 = ''
      for aa in value:
        sa_seq1 += aa + '#'
      protein_df.at[index, 'seq_1'] = sa_seq1

    protein_df['name_1'] = 'name_1'
    protein_df['chain_1'] = 'A'

    for index, value in protein_df['seq_2'].items():
      sa_seq2 = ''
      for aa in value:
        sa_seq2 += aa + '#'
      protein_df.at[index, 'seq_2'] = sa_seq2

    protein_df['name_2'] = 'name_2'
    protein_df['chain_2'] = 'A'

    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path

  # 14. Pair Multiple SA Sequences
  if data_type == "Multiple pairs of SA Sequences":
    protein_df = pd.read_csv(uploaded_csv_path)

    protein_df['name_1'] = 'name_1'
    protein_df['chain_1'] = 'A'


    protein_df['name_2'] = 'name_2'
    protein_df['chain_2'] = 'A'


    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path

  # 15. Pair Multiple UniProt IDs
  if data_type == "Multiple pairs of UniProt IDs":
    protein_df = pd.read_csv(uploaded_csv_path)
    protein_list1 = protein_df.loc[:, "seq_1"].tolist()
    uniprot2pdb(protein_list1)
    protein_df['name_1'] = protein_list1
    protein_list1 = [(uniprot_id, "AF2", "A") for uniprot_id in protein_list1]
    mprs1 = MultipleProcessRunnerSimplifier(protein_list1, pdb2sequence, n_process=2, return_results=True)
    outputs1 = mprs1.run()

    protein_df['seq_1'] = [output.split("\t")[1] for output in outputs1]
    protein_df['chain_1'] = 'A'

    protein_list2 = protein_df.loc[:, "seq_2"].tolist()
    uniprot2pdb(protein_list2)
    protein_df['name_2'] = protein_list2
    protein_list2 = [(uniprot_id, "AF2", "A") for uniprot_id in protein_list2]
    mprs2 = MultipleProcessRunnerSimplifier(protein_list2, pdb2sequence, n_process=2, return_results=True)
    outputs2 = mprs2.run()

    protein_df['seq_2'] = [output.split("\t")[1] for output in outputs2]
    protein_df['chain_2'] = 'A'

    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path


  # # 13-16. Pair Multiple Sequences
  # elif data_type in data_type_list[12:16]:
  #   print(Fore.BLUE+f"Please upload the .csv file which contains {data_type}"+Style.RESET_ALL)
  #   uploaded_csv_path = upload_file(UPLOAD_FILE_HOME)
  #   print(Fore.BLUE+"Successfully upload your .csv file!"+Style.RESET_ALL)
  #   print("="*100)

  elif data_type ==  "Multiple pairs of PDB/CIF Structures":
    protein_df = pd.read_csv(uploaded_csv_path)
    # columns: seq_1, seq_2, type_1, type_2, chain_1, chain_2, label, stage

    # protein_list = [(uniprot_id, type, chain), ...]
    # protein_list = [item.split('.')[0] for item in protein_df.iloc[:, 0].tolist()]
    # uniprot2pdb(protein_list)

    for i in range(1, 3):
      protein_list = []
      for index, row in protein_df.iterrows():
        assert row[f"type_{i}"] in ['PDB', 'AF2'],  "The type of structure must be either \"PDB\" or \"AF2\"!"
        row_tuple = (row[f"seq_{i}"], row[f"type_{i}"], row[f"chain_{i}"])
        protein_list.append(row_tuple)
      mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
      outputs = mprs.run()

      # add name column, del type column
      protein_df[f'name_{i}'] = protein_df[f'seq_{i}'].apply(lambda x: x.split('.')[0])
      protein_df.drop(f"type_{i}", axis=1, inplace=True)
      print(outputs)
      protein_df[f'seq_{i}'] = [output.split("\t")[1] for output in outputs]

    # columns: name_1, name_2, chain_1, chain_2, seq_1, seq_2, label, stage
    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path


'''
  elif data_type == "A pair of AA Sequences",
  elif data_type == "A pair of SA Sequences",
  elif data_type == "A pair of UniProt IDs",
  elif data_type == "A pair of PDB/CIF Structures",
  elif data_type == "Multiple pairs of AA Sequences",
  elif data_type == "Multiple pairs of SA Sequences",
  elif data_type == "Multiple pairs of UniProt IDs",
  elif data_type == "Multiple pairs of PDB/CIF Structures",
'''


# # return a SA Sequence or a csv dataset path
# def get_raw_dataset(data_type, raw_data):
#   if data_type in data_type_list[:3]:
#     raw_dataset = get_SA_sequence_by_data_type(data_type, raw_data.value)
#   elif data_type == data_type_list[3]:
#     raw_dataset = get_SA_sequence_by_data_type(data_type, raw_data)
#   elif data_type in data_type_list[4:8]:
#     raw_dataset = get_SA_sequence_by_data_type(data_type, raw_data)
#   elif data_type in data_type_list[8]:
#     raw_dataset = get_SA_sequence_by_data_type(data_type, raw_data.value)

#   return raw_dataset

# def upload_dataset(data_type):
#   print(Fore.BLUE+f"Please upload the .csv file which contains {data_type}"+Style.RESET_ALL)
#   uploaded_csv_path = upload_file(UPLOAD_FILE_HOME)
#   print(Fore.BLUE+"Successfully upload your .csv file!"+Style.RESET_ALL)
#   print("="*100)

#   # selected_csv_dataset = DATASET_HOME / f"[DATASET]{Path(uploaded_csv_path).stem}.csv"
#   # get_SASequence_by_data_type(data_type, uploaded_csv_path, selected_csv_dataset)
#   # get_SA_sequence_by_data_type(data_type, uploaded_csv_path)
#   # print()
#   # print("="*100)
#   # print(Fore.BLUE+"Successfully upload your dataset!"+Style.RESET_ALL)

#   return uploaded_csv_path



################################################################################
########################## Download predicted structures #######################
################################################################################
def uniprot2pdb(uniprot_ids, nprocess=20):
  from saprot.utils.downloader import AlphaDBDownloader

  os.makedirs(STRUCTURE_HOME, exist_ok=True)
  af2_downloader = AlphaDBDownloader(uniprot_ids, "pdb", save_dir=STRUCTURE_HOME, n_process=20)
  af2_downloader.run()



################################################################################
############### Form foldseek sequences by multiple processes ##################
################################################################################
# def pdb2sequence(process_id, idx, uniprot_id, writer):
#   from saprot.utils.foldseek_util import get_struc_seq

#   try:
#     pdb_path = f"{STRUCTURE_HOME}/{uniprot_id}.pdb"
#     cif_path = f"{STRUCTURE_HOME}/{uniprot_id}.cif"
#     if Path(pdb_path).exists():
#       seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, ["A"], process_id=process_id)["A"][-1]
#     if Path(cif_path).exists():
#       seq = get_struc_seq(FOLDSEEK_PATH, cif_path, ["A"], process_id=process_id)["A"][-1]

#     writer.write(f"{uniprot_id}\t{seq}\n")
#   except Exception as e:
#     print(f"Error: {uniprot_id}, {e}")

# clear_output(wait=True)
# print("Installation finished!")

def pdb2sequence(process_id, idx, row_tuple, writer):

  # print("="*100)
  # print(row_tuple)
  # print("="*100)
  uniprot_id = row_tuple[0].split('.')[0]     #
  struc_type = row_tuple[1]                   # PDB or AF2
  chain = row_tuple[2]

  if struc_type=="AF2":
    plddt_mask= True
    chain = 'A'
  else:
    plddt_mask= False

  from saprot.utils.foldseek_util import get_struc_seq

  try:
    pdb_path = f"{STRUCTURE_HOME}/{uniprot_id}.pdb"
    cif_path = f"{STRUCTURE_HOME}/{uniprot_id}.cif"
    if Path(pdb_path).exists():
      seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, [chain], process_id=process_id, plddt_mask=plddt_mask)[chain][-1]
    elif Path(cif_path).exists():
      seq = get_struc_seq(FOLDSEEK_PATH, cif_path, [chain], process_id=process_id, plddt_mask=plddt_mask)[chain][-1]
    else:
      raise BaseException(f"The {uniprot_id}.pdb/{uniprot_id}.cif file doesn't exists!")
    writer.write(f"{uniprot_id}\t{seq}\n")

  except Exception as e:
    print(f"Error: {uniprot_id}, {e}")


pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
          "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
          "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
          "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
          "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

alphabet_list = list(ascii_uppercase+ascii_lowercase)


def convert_outputs_to_pdb(outputs):
	final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
	outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
	final_atom_positions = final_atom_positions.cpu().numpy()
	final_atom_mask = outputs["atom37_atom_exists"]
	pdbs = []
	outputs["plddt"] *= 100

	for i in range(outputs["aatype"].shape[0]):
		aa = outputs["aatype"][i]
		pred_pos = final_atom_positions[i]
		mask = final_atom_mask[i]
		resid = outputs["residue_index"][i] + 1
		pred = OFProtein(
		    aatype=aa,
		    atom_positions=pred_pos,
		    atom_mask=mask,
		    residue_index=resid,
		    b_factors=outputs["plddt"][i],
		    chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
		)
		pdbs.append(to_pdb(pred))
	return pdbs


# This function is copied from ColabFold!
def show_pdb(path, show_sidechains=False, show_mainchains=False, color="lddt"):
  file_type = str(path).split(".")[-1]
  if file_type == "cif":
    file_type == "mmcif"

  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(path,'r').read(),file_type)

  if color == "lDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    chains = 1
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view


def plot_plddt_legend(dpi=100):
  thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)']
  plt.figure(figsize=(1,0.1),dpi=dpi)
  ########################################
  for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt


################################################################################
###############   Download file to local computer   ##################
################################################################################
def file_download(path: str):
  with open(path, "rb") as r:
    res = r.read()

  #FILE
  filename = os.path.basename(path)
  b64 = base64.b64encode(res)
  payload = b64.decode()

  #BUTTONS
  html_buttons = '''<html>
  <head>
  <meta name="viewport" content="width=device-width, initial-scale=1">
  </head>
  <body>
  <a download="{filename}" href="data:text/csv;base64,{payload}" download>
  <button class="p-Widget jupyter-widgets jupyter-button widget-button mod-warning">Download File</button>
  </a>
  </body>
  </html>
  '''

  html_button = html_buttons.format(payload=payload,filename=filename)
  display(HTML(html_button))


clear_output(wait=True)
print("Installation finished!")

# **2: Train and Share your Protein Model**

## Training Dataset

For the training dataset, **two additional columns** are required in the CSV file: `label` and `stage`.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_AA_Sequences_data_format_training.png
?raw=true" height="200" width="400px" align="center">

### Column `label`

The content of column `label` depends on your **task type**:

| Task Type                         | Content in the Column                          |
|-----------------------------------|------------------------------------------------|
| Classification tasks              | Category index starting from zero              |
| Amino Acid Classification tasks   | A list of category indices for each amino acid |
| Regression tasks                  | Numerical values                               |

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/label_format.png?raw=true" height="300" width="800px" align="center">
<br>


### Column `stage`

The column `stage` indicate whether the sample is used for training, validation, or testing. Ensure your dataset includes samples for all three stages. The values are: `train`, `valid`, `test`.

<br>

### **Note:**

1. **Examples are available** at /content/SaprotHub/upload_files (if you connect to your local server, then the path is /SaprotHub/upload_files). Download to review their format, and then upload them for a trial.

2.  <a href="#get_sa">Here</a> you can **convert your data into SA Sequence** format.

3. <a href="#fa2csv">Here</a> you can **convert your .fa/.fasta file to a .csv file**, which corresponds to the data format for Multiple AA Sequences.

4. <a href="#split_dataset">Here</a> you can **randomly split your .csv dataset**, which means to add a `stage` column, where the ratio of `train`:`valid`:`test` is 8:1:1.

<!-- 4. The maximum input length of the model is 1024, and protein sequences exceeding this length will only retain the first 1024 amino acids. -->


In [None]:
#@title 2.1: Task Config

################################################################################
################################## TASK CONFIG #################################
################################################################################
#@markdown # 1. Task
task_name = "demo" # @param {type:"string"}
task_objective = "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein" # @param ["Classify protein sequences (classification)", "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein", "Classify each Amino Acid (amino acid classification), e.g. Binding site detection", "Predict protein-protein interaction (pair classification)", "Predict protein-protein interaction (pair regression)"]
task_type = task_type_dict[task_objective]

if task_type in ["classification", 'token_classification', 'pair_classification']:

  print(Fore.BLUE+'Enter the number of category in your training dataset here:'+Style.RESET_ALL)
  num_of_categories = ipywidgets.BoundedIntText(
                                              # value=7,
                                              min=2,
                                              max=1000000,
                                              step=1,
                                              # description='num_of_category: \n',
                                              disabled=False)
  num_of_categories.layout.width = "100px"
  display(num_of_categories)

#@markdown <br>



################################################################################
#################################### MODEL #####################################
################################################################################
#@markdown # 2. Model

##@markdown We use Parameter-Efficient Fine-Tuning Technique for model training. It enables us to store model weights in a small **adapter** without changing the original model weights during training. After training, you can get an adapter specific to your task.
##@markdown As we use Parameter-Efficient Fine-Tuning Technique, which allows us to store model weights into an small adapter without adjusting the original model weights during training, it's necessary to specify both the original model and adapter for prediction.
##@markdown
##@markdown 1. Select a **base model** from the dropdown box `model_path` below.
##@markdown
##@markdown 2. If you want to **train on existing adapters**, check the box `continue_learning` below. By running this cell, you will see an **adapter combobox**. We provide two ways to select your adapter:
##@markdown  - Select a **Trained by yourself on ColabSaprot** from the combobox.
##@markdown   - Enter a **huggingface repository name** to the combobox. (e.g. "SaProtAdapters/DeepLoc_cls10_35M")
##@markdown
##@markdown You can also find some officical adapters in [here](https://huggingface.co/SaProtAdapters)
base_model = "Official pretrained SaProt (35M)" # @param ["Official pretrained SaProt (35M)", "Official pretrained SaProt (650M)", "Trained by yourself on ColabSaprot", "Shared by peers on SaprotHub", "Saved in your local computer"]
# base_model = "westlake-repl/SaProt_35M_AF2" # @param ["westlake-repl/SaProt_35M_AF2", "westlake-repl/SaProt_650M_AF2", "Trained by yourself on ColabSaprot", "Shared by peers on SaprotHub"]
# print(Fore.BLUE+f"Model: {base_model}"+Style.RESET_ALL)

# continue_learning = True # @param {type:"boolean"}

# base_model
if base_model == "Official pretrained SaProt (35M)":
  base_model = "westlake-repl/SaProt_35M_AF2"
if base_model == "Official pretrained SaProt (650M)":
  base_model = "westlake-repl/SaProt_650M_AF2"

# continue learning
if base_model in ["Trained by yourself on ColabSaprot", "Shared by peers on SaprotHub", "Saved in your local computer"]:
  continue_learning = True
else:
  continue_learning = False

def upload_local_adapter(task_type):
    print(Fore.BLUE+"Click the button to upload the \"Model-<task_name>-<model_size>.zip\" file of your Model:"+Style.RESET_ALL)
    # 1. upload model.zip
    adapter_upload_path = ADAPTER_HOME / task_type / "Local"
    adapter_zip_path = upload_file(adapter_upload_path)
    adapter_path = adapter_upload_path / adapter_zip_path.stem
    # 2. unzip model.zip
    with zipfile.ZipFile(adapter_zip_path, 'r') as zip_ref:
        zip_ref.extractall(adapter_path)
    os.remove(adapter_zip_path)
    # 3. check adapter_config.json
    adapter_config_path = adapter_path / "adapter_config.json"
    assert adapter_config_path.exists(), f"Can't find {adapter_config_path}"
    adapter_combobox = {"value":  f"Local/{adapter_zip_path.stem}"}

    return adapter_combobox

if continue_learning:
  if base_model == "Trained by yourself on ColabSaprot":
    adapter_combobox = select_adapter_from(task_type, use_model_from='Trained by yourself on ColabSaprot')

  elif base_model == "Shared by peers on SaprotHub":
    adapter_combobox = select_adapter_from(task_type, use_model_from='Shared by peers on SaprotHub')

  elif base_model == 'Saved in your local computer':
    adapter_combobox = select_adapter_from(task_type, use_model_from='Saved in your local computer')

#@markdown <br>

################################################################################
################################### DATASET ####################################
################################################################################
#@markdown # 3. Dataset

data_type = "Multiple AA Sequences" # @param ["Multiple AA Sequences", "Multiple SA Sequences", "Multiple UniProt IDs", "Multiple PDB/CIF Structures", "SaprotHub Dataset", "Multiple pairs of AA Sequences", "Multiple pairs of SA Sequences", "Multiple pairs of UniProt IDs", "Multiple pairs of PDB/CIF Structures"]
mode = "Multiple Sequences" if data_type in data_type_list[4:8] else "Single Sequence"

raw_data = input_raw_data_by_data_type(data_type)
# lmdb_dataset_path=''

# if mode == "Multiple Sequences":
#   csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)

#   from saprot.utils.construct_lmdb import construct_lmdb
#   construct_lmdb(csv_dataset_path, LMDB_HOME, task_name, task_type)
#   lmdb_dataset_path = LMDB_HOME / task_name

# Hub Dataset
if data_type == data_type_list[8]:
  def apply(button):
    global lmdb_dataset_path
    button.disabled = True
    button.description = 'Clicked'
    button.button_style = ''
    lmdb_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)

  button_apply = ipywidgets.Button(
      description='Click this button to download SaprotHub Dataset',
      disabled=False,
      button_style='success', # 'success', 'info', 'warning', 'danger' or ''
      tooltip='Apply',
      icon='check' # (FontAwesome names without the `fa-` prefix)
      )
  button_apply.on_click(apply)
  button_apply.layout.width = '500px'
  display(button_apply)
else:
  csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)
  from saprot.utils.construct_lmdb import construct_lmdb
  construct_lmdb(csv_dataset_path, LMDB_HOME, task_name, task_type)

  lmdb_dataset_path = LMDB_HOME / task_name



#@markdown <br>

################################################################################
################################################################################
################################################################################
################################################################################
################################################################################





##@markdown Complete some task configs and run this cell to Finetune SaProt on your dataset. <br>

# def get_num_of_labels(selected_csv_dataset):
#   df = pd.read_csv(selected_csv_dataset)
#   num_of_labels = len(df['label'].unique())

#   return num_of_labels


##@markdown <br>

################################################################################

################################################################################
############################### custom config ##################################
################################################################################

##@markdown ---
##@markdown # <center>Training Task Config</center>


# num_of_categories = 10 # @param {type:"number"}
# #@markdown <font face="Consolas" size=2 color='gray'>(Ignoring `num_of_categories` if predicting a value)


  # print(Fore.BLUE+'It\'s normal not to receive feedback once inputting is finished. Let\'s move on to the next step.'+Style.RESET_ALL)


################################################################################
#@title 2.3: Select Model
################################################################################

# #@markdown We utilize **LoRA** (A Parameter-Efficient Fine-Tuning Technique), which allows us to store model weights into an small adapter without adjusting the original model weights during training.
# #@markdown

# #@markdown After training, you can obtain an adapter for your task.

##@markdown ---
##@markdown # <center>Model</center>



# if continue_learning:
#   print(Fore.BLUE+f"Loaded Adapter: {adapter_combobox.value}"+Style.RESET_ALL)




In [None]:
#@title 2.2: Train your Model

batch_size = "Adaptive" # @param ["Adaptive", "1", "2", "4", "8", "16", "32", "64", "128", "256"]
max_epochs = 2 # @param ["10", "20", "50"] {type:"raw", allow-input: true}
learning_rate = 1.0e-3 # @param ["1.0e-3", "5.0e-4", "1.0e-4"] {type:"raw", allow-input: true}

################################################################################
############################## advance config ##################################
################################################################################

limit_train_batches=1.0
limit_val_batches=1.0
limit_test_batches=1.0

val_check_interval=0.5
seed = 20000812

# use_lora = True
num_workers = 2

mask_struc_ratio=None
# mask_struc_ratio=1.0

download_adapter_to_your_computer = True

################################################################################
################################# MARKDOWM #####################################
################################################################################

#@markdown - <font face="Consolas" size=2 color='gray'> `batch_size` depends on the number of training samples. "Adaptive" (default choice) refers to automatic batch size according to your data size. If your training data set is large enough, you can use 32, 64, 128, 256, ..., others can be set to 8, 4, 2 (Note that you can not use a larger batch size if you use the Colab default T4 GPU. Strongly suggest you subscribe to Colab Pro for an A100 GPU.).
# #@markdown |  Recommended batch size   | T4  |  A100   |
# #@markdown | ---                       | --- |  ---    |
# #@markdown | SaProt_35M_AF2            |  4  |    16   |
# #@markdown | SaProt_650M_AF2           |  -  |    8    |


#@markdown - <font face="Consolas" size=2 color='gray'>`max_epochs` refers to the maximum number of training iterations. A larger value needs more training time. The best model will be saved after each iteration.
#@markdown You can adjust `max_epochs` to control training duration. (Note that the max running time of colab is 12hrs for unsubscribed user or 24hrs for Colab Pro+ user) <br>
#@markdown

# download_adapter_to_your_computer = True #@param {type:"boolean"}
#@markdown - <font face="Consolas" size=2 color='gray'>`learning_rate` affects the convergence speed of the model.
#@markdown Through experimentation, we have found that `5.0e-4` is a good default value for base model `Official pretrained SaProt (650M)` and `1.0e-3` for `Official pretrained SaProt (35M)`.

################################################################################
################################# CONFIG #######################################
################################################################################

from saprot.config.config_dict import Default_config
config = copy.deepcopy(Default_config)

config.setting.run_mode = "train"
config.setting.seed = seed

################################################################################
################################# ADAPTER ######################################
################################################################################
# config.model.kwargs.use_lora = use_lora

# base model and lora path
if continue_learning:
  adapter_path = ADAPTER_HOME / task_type / adapter_combobox.value
  print(f"Training on an existing model: {adapter_path}")

  if base_model == "Shared by peers on SaprotHub":
    if not adapter_path.exists():
      snapshot_download(repo_id=adapter_combobox.value, repo_type="model", local_dir=adapter_path)

  adapter_config_path = Path(adapter_path) / "adapter_config.json"
  assert adapter_config_path.exists(), f"Can't find {adapter_config_path}"
  with open(adapter_config_path, 'r') as f:
    adapter_config = json.load(f)
    base_model = adapter_config['base_model_name_or_path']

  # config.model.kwargs.lora_config_path = adapter_path
  config.model.kwargs.lora_kwargs = EasyDict({
    "is_trainable": True,
    "num_lora": 1,
    "config_list": [{"lora_config_path": adapter_path}]})

else:
  # config.model.kwargs.lora_config_path = None
  config.model.kwargs.lora_kwargs = EasyDict({
    "num_lora": 1,
    "config_list": []})


################################################################################
################################# MODEL ########################################
################################################################################

if task_type in ["classification", "token_classification", "pair_classification"]:
  # config.model.kwargs.num_labels = get_num_of_labels(selected_csv_dataset)
  config.model.kwargs.num_labels = num_of_categories.value

config.model.model_py_path = model_type_dict[task_type]

config.model.kwargs.config_path = base_model
config.dataset.kwargs.tokenizer = base_model

if base_model == "westlake-repl/SaProt_650M_AF2":
  model_size = "650M"
  model_name = f"Model-{task_name}-{model_size}"
elif base_model == "westlake-repl/SaProt_35M_AF2":
  model_size = "35M"
  model_name = f"Model-{task_name}-{model_size}"

config.model.save_path = str(ADAPTER_HOME / f"{task_type}" / "Local" / model_name)

if task_type in ["regression", "pair_regression", "pair_classification"]:
  config.model.kwargs.extra_config = {}
  config.model.kwargs.extra_config.attention_probs_dropout_prob=0
  config.model.kwargs.extra_config.hidden_dropout_prob=0

# config.model.kwargs.lora_kwargs = {
#     "num_lora": 6,
#     "config_list": [
#         { "lora_config_path": "/content/subcellular/SaProt_650M_AF2_lora_splitNum5_rank0",},
#         { "lora_config_path": "/content/subcellular/SaProt_650M_AF2_lora_splitNum5_rank1",},
#         { "lora_config_path": "/content/subcellular/SaProt_650M_AF2_lora_splitNum5_rank2",},
#         { "lora_config_path": "/content/subcellular/SaProt_650M_AF2_lora_splitNum5_rank3",},
#         { "lora_config_path": "/content/subcellular/SaProt_650M_AF2_lora_splitNum5_rank4",},
#         { "lora_config_path": "/content/SaprotHub/adapters/classification/SaProtHub/Model-Subcellular_Localization-650M",},

#         ]
# }

################################################################################
################################# DATASET ######################################
################################################################################

config.dataset.dataset_py_path = dataset_type_dict[task_type]

config.dataset.train_lmdb = str(lmdb_dataset_path / "train")
config.dataset.valid_lmdb = str(lmdb_dataset_path / "valid")
config.dataset.test_lmdb = str(lmdb_dataset_path / "test")

# num_workers
config.dataset.dataloader_kwargs.num_workers = num_workers

# mask_struc
# config.dataset.kwargs.mask_struc_ratio= mask_struc_ratio

################################################################################
######################## batch size ############################################
################################################################################
def get_accumulate_grad_samples(num_samples):
    if num_samples > 3200:
        return 64
    elif 1600 < num_samples <= 3200:
        return 32
    elif 800 < num_samples <= 1600:
        return 16
    elif 400 < num_samples <= 800:
        return 8
    elif 200 < num_samples <= 400:
        return 4
    elif 100 < num_samples <= 200:
        return 2
    else:
        return 1

# batch_size
GPU_batch_size_dict = {
    "Tesla T4": 2,
    "NVIDIA L4": 2,
    "NVIDIA A100-SXM4-40GB": 4,
}
if torch.cuda.is_available():
  GPU_name = torch.cuda.get_device_name(0)
  if base_model == "westlake-repl/SaProt_650M_AF2" and root_dir == "/content":
    assert GPU_name == "NVIDIA A100-SXM4-40GB", "If you want to train on SaProt 650M, please refer to Section 1.1 to switch your Runtime to GPU A100."
  GPU_batch_size = GPU_batch_size_dict[GPU_name] if GPU_name in GPU_batch_size_dict else 2
  if task_type in ["pair_classification", "pair_regression"]:
    GPU_batch_size = int(max(GPU_batch_size / 2, 1))
else:
  raise BaseException("Please refer to Section 1.1 to switch your Runtime to a GPU!")
config.dataset.dataloader_kwargs.batch_size = GPU_batch_size

# accumulate_grad_batches
if batch_size == "Adaptive":

  env = lmdb.open(config.dataset.train_lmdb, readonly=True)

  with env.begin() as txn:
    stat = txn.stat()
    num_samples = stat['entries']

  accumulate_grad_samples = get_accumulate_grad_samples(num_samples)

else:
  accumulate_grad_samples = int(batch_size)

config.Trainer.accumulate_grad_batches= max(int(accumulate_grad_samples / GPU_batch_size), 1)

# config.dataset.dataloader_kwargs.batch_size = 2
# config.Trainer.accumulate_grad_batches= 1

################################################################################
############################## TRAINER #########################################
################################################################################

config.Trainer.accelerator = "gpu" if torch.cuda.is_available() else "cpu"

# epoch
config.Trainer.max_epochs = max_epochs
# test only: load the existing model
if config.Trainer.max_epochs == 0:
  config.model.save_path = config.model.kwargs.lora_kwargs.config_list[0].lora_config_path

# learning rate
config.model.lr_scheduler_kwargs.init_lr = learning_rate

# trainer
config.Trainer.limit_train_batches=limit_train_batches
config.Trainer.limit_val_batches=limit_val_batches
config.Trainer.limit_test_batches=limit_test_batches
config.Trainer.val_check_interval=val_check_interval

# strategy
strategy = {
    # - deepspeed
    # 'class': 'DeepSpeedStrategy',
    # 'stage': 2

    # - None
    # 'class': None,

    # - DP
    # 'class': 'DataParallelStrategy',

    # - DDP
    # 'class': 'DDPStrategy',
    # 'find_unused_parameter': True
}
config.Trainer.strategy = strategy

################################################################################
############################## CONFIG ##########################################
################################################################################


################################################################################
############################## Run the task ####################################
################################################################################

print('='*100)
print(Fore.BLUE+f"Training task type: {task_type}"+Style.RESET_ALL)
print(Fore.BLUE+f"Dataset: {lmdb_dataset_path}"+Style.RESET_ALL)
print(Fore.BLUE+f"Base Model: {config.model.kwargs.config_path}"+Style.RESET_ALL)
if continue_learning:
  print(Fore.BLUE+f"Existing model: {config.model.kwargs.lora_kwargs.config_list[0].lora_config_path}"+Style.RESET_ALL)
print('='*100)
pprint.pprint(config)
print('='*100)

from saprot.scripts.training import finetune
finetune(config)


################################################################################
############################## Save the adapter ################################
################################################################################

def add_training_data_type_to_config(metadata_path, training_data_type):
  if metadata_path.exists() is False:
    config_data = {
        'training_data_type': training_data_type
        }
    with open(metadata_path, 'w') as file:
        json.dump(config_data, file, indent=4)

  else:
    with open(metadata_path, 'r') as file:
        config_data = json.load(file)

    config_data['training_data_type'] = training_data_type

    with open(metadata_path, 'w') as file:
        json.dump(config_data, file, indent=4)


metadata_path = Path(config.model.save_path) / "metadata.json"
training_data_type = training_data_type_dict[data_type]

add_training_data_type_to_config(metadata_path, training_data_type)

print(Fore.BLUE)
print(f"Model is saved to \"{config.model.save_path}\" on Colab Server")
print(Style.RESET_ALL)

if download_adapter_to_your_computer:
  adapter_zip = Path(config.model.save_path) / f"{model_name}.zip"
  !cd $config.model.save_path && zip -r $adapter_zip "adapter_config.json" "adapter_model.safetensors" "adapter_model.bin" "README.md" "metadata.json"
  # with zipfile.ZipFile(adapter_zip, 'w') as zipf:
  #   zip_files = [str(file_path) for file_path in Path(config.model.save_path).glob("*")]
  #   print(zip_files)
  #   for file in zip_files:
  #     zipf.write(file, Path(file).name)

  print("Click to download the model to your local computer")
  if adapter_zip.exists():
    # files.download(adapter_zip)
    file_download(adapter_zip)

In [None]:
#@title **2.3: Login HuggingFace to upload your model (Optional)**
################################################################################
###################### Login HuggingFace #######################################
################################################################################

from huggingface_hub import notebook_login
notebook_login()


In [None]:
#@title **2.4: Upload Model (Optional)**

# #@markdown Your Huggingface adapter repository names follow the format `<username>/<task_name>`.

################################################################################
########################## Metadata  ###########################################
################################################################################
#@markdown You can add some description to your model.
name = "demo_cls" # @param {type:"string"}
description = "This model is used for a demo classification task" # @param {type:"string"}

#@markdown For the classification model, please provide detailed information about the meanings of all labels.

#@markdown For example, in a Subcellular Localization Classification Task with 10 categories, label=0 means the protein is located in the Nucleus, label=1 means the protein is located in the Cytoplasm, and so on. The information should be provided as follows:

#@markdown `Nucleus, Cytoplasm, Extracellular, Mitochondrion, Cell.membrane, Endoplasmic.reticulum, Plastid, Golgi.apparatus, Lysosome/Vacuole, Peroxisome`


# #@markdown > 0: Nucleus <br>
# #@markdown > 1: Cytoplasm <br>
# #@markdown > 2: Extracellular <br>
# #@markdown > ... <br>
# #@markdown > 9: Peroxisome <br>

label_meanings = "A, B" #@param {type:"string"}



################################################################################
########################### Move Files  ########################################
################################################################################

from huggingface_hub import HfApi, Repository, ModelFilter

api = HfApi()

user = api.whoami()

if name == "":
  name = model_name
repo_name = user['name'] + '/' + name
local_dir = Path("/content/SaprotHub/model_to_push") / repo_name
local_dir.mkdir(parents=True, exist_ok=True)

repo_list = [repo.id for repo in api.list_models(filter=ModelFilter(author=user['name']))]
if repo_name not in repo_list:
  api.create_repo(repo_name, private=False)

repo = Repository(local_dir=local_dir, clone_from=repo_name)

command = f"cp {config.model.save_path}/* {local_dir}/"
subprocess.run(command, shell=True)

################################################################################
########################## Modify README  ######################################
################################################################################
import json

md_path = local_dir / "README.md"

if task_type in ["classification", "token_classification", "pair_classification"]:
  label_meanings_md = ''
  for index, label in enumerate(label_meanings.split(', ')):
    label_meanings_md += f"{index}: {label} <br> "

  # print(label_meanings_md)
  description = description + "<br><br> The digital label means: <br>" + label_meanings_md

replace_data = {
    "<!-- Provide a quick summary of what the model is/does. -->": description
}

with open(md_path, "r") as file:
    content = file.read()

for key, value in replace_data.items():
    if value != "":
        content = content.replace(key, value)

# new_md_path = "README.md"
with open(md_path, "w") as file:
    file.write(content)

################################################################################
########################## Upload Model  #######################################
################################################################################


repo.push_to_hub(commit_message="Upload adapter model")

# **3: Use SaProt to Predict**

## 3.1: Classification&Regression Prediction <a name="classification_regression"></a>

<br>


### Dataset

For the prediction dataset, **only** `Sequence` column is required in the CSV file.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_Sequences_data_format.png?raw=true" height="200" width="800px" align="center">

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_PDB_CIF_Structures_data_format.png?raw=true" height="200" width="500px" align="center">

You can refer to the <a href='#data_format'>instruction</a> for detailed data formats.

<br>


In [None]:
#@title 3.1.1: Task Config

from transformers import EsmTokenizer
import torch
import copy

################################################################################
################################# TASK #########################################
################################################################################
#@markdown # 1. Task

task_objective = "Classify protein sequences (classification)" # @param ["Classify protein sequences (classification)", "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein", "Classify each Amino Acid (amino acid classification), e.g. Binding site detection", "Predict protein-protein interaction (pair classification)", "Predict protein-protein interaction (pair regression)"]
task_type = task_type_dict[task_objective]

if task_type in ["classification", 'token_classification', 'pair_classification']:

  print(Fore.BLUE+'The number of categories in your classification task:'+Style.RESET_ALL)
  num_of_categories = ipywidgets.BoundedIntText(
                                              # value=7,
                                              min=2,
                                              # max=10,
                                              step=1,
                                              # description='num_of_category: \n',
                                              disabled=False)
  num_of_categories.layout.width = "100px"
  display(num_of_categories)

#@markdown <br>


################################################################################
################################## MODEL #######################################
################################################################################
#@markdown # 2. Model

##@markdown As we use Parameter-Efficient Fine-Tuning Technique, which allows us to store model weights into an small adapter without adjusting the original model weights during training, it's necessary to specify both the original model and adapter for prediction.
##@markdown
##@markdown 1. Select a **base model**
##@markdown
##@markdown 2. By running this cell, you will see an **model combobox**. We provide two ways to select your adapter:
##@markdown  - Select a **local model** from the combobox.
##@markdown   - Enter a **huggingface repository name** to the combobox. (e.g. "SaProtHub/DeepLoc_cls10_35M")
##@markdown
##@markdown You can also find some officical adapters in [here](https://huggingface.co/SaProtHub)
# base_model = "westlake-repl/SaProt_35M_AF2" #@param ['westlake-repl/SaProt_35M_AF2', 'westlake-repl/SaProt_650M_AF2'] {allow-input:true}
use_model_from = "Shared by peers on SaprotHub" # @param ["Trained by yourself on ColabSaprot", "Shared by peers on SaprotHub", "Saved in your local computer", "Multi-models on SaprotHub"]
if use_model_from == "Multi-models on ColabSaprot":
  multi_lora = True
else:
  multi_lora = False

# use_existing_model = True # @param {type:"boolean"}
# use_existing_model = True
# if use_existing_model:
#   adapter_combobox = select_adapter()

adapter_input = select_adapter_from(task_type, use_model_from)
#@markdown <br>

################################################################################
################################################################################
################################################################################

# # @markdown Please ensure that the selected task type aligns with the training task type of the model you intend to utilize.

## @markdown If you are conducting inference on a classification task, please ensure that the `num_of_category` matches the number of categories in the training dataset. Otherwise, you do not need to assign `num_of_category`.


##@markdown You have two options to provide your protein sequences:
##@markdown - **Single Sequence: Enter a single SA sequence** into the input box, you can get a SA Sequence by clicking <a href="#get_SA_seq">here</a>
##@markdown - **Multiple Sequences: Select a dataset**, you can upload a dataset from <a href="#upload_dataset">here</a>




# print(Fore.BLUE+f"Data type: {data_type}"+Style.RESET_ALL)


################################################################################
################################ DATASET #######################################
################################################################################
#@markdown # 3. Dataset
data_type = "Single AA Sequence" # @param ["Single AA Sequence", "Single SA Sequence", "Single UniProt ID", "Single PDB/CIF Structure", "Multiple AA Sequences", "Multiple SA Sequences", "Multiple UniProt IDs", "Multiple PDB/CIF Structures", "A pair of AA Sequences", "A pair of SA Sequences", "A pair of UniProt IDs", "A pair of PDB/CIF Structures", "Multiple pairs of AA Sequences", "Multiple pairs of SA Sequences", "Multiple pairs of UniProt IDs", "Multiple pairs of PDB/CIF Structures"]

mode = "Multiple Sequences" if (data_type in data_type_list[4:8] or data_type in data_type_list[13:17]) else "Single Sequence"

raw_data = input_raw_data_by_data_type(data_type)

# ["Single AA Sequence","Single SA Sequence","Single UniProt ID","Single PDB/CIF Structure","Multiple AA Sequences","Multiple SA Sequences","Multiple UniProt IDs","Multiple PDB/CIF Structures","SaprotHub Dataset","A pair of AA Sequences","A pair of SA Sequences","A pair of UniProt IDs","A pair of PDB/CIF Structures","Multiple pairs of AA Sequences","Multiple pairs of SA Sequences","Multiple pairs of UniProt IDs","Multiple pairs of PDB/CIF Structures"]
if mode == "Multiple Sequences":
  csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)
else:
  def apply(button):
    global single_sa_seq
    # button.disabled = True
    # button.description = 'Clicked'
    # button.button_style = ''

    # print(Fore.BLUE+'Construct dataset...'+Style.RESET_ALL)
    single_sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)
    print()
    print('='*100)
    print(Fore.BLUE+f'Current Model ({use_model_from}): {adapter_input.value}'+Style.RESET_ALL)
    if data_type == "A pair of PDB/CIF Structures":
      print(Fore.BLUE+f'Current Dataset ({data_type}): Sequence 1: {raw_data[0]}, Sequence 2: {raw_data[1]}'+Style.RESET_ALL)
    elif data_type in ["A pair of AA Sequences","A pair of SA Sequences","A pair of UniProt IDs"]:
      print(Fore.BLUE+f'Current Dataset ({data_type}): Sequence 1: {raw_data[0].value}, Sequence 2: {raw_data[1].value}'+Style.RESET_ALL)
    elif data_type == "Single PDB/CIF Structure":
      print(Fore.BLUE+f'Current Dataset ({data_type}): {raw_data[0]}' +Style.RESET_ALL)
    else:
      print(Fore.BLUE+f'Current Dataset ({data_type}): {raw_data.value}'+Style.RESET_ALL)

  button_apply = ipywidgets.Button(
      description='Apply',
      disabled=False,
      button_style='success', # 'success', 'info', 'warning', 'danger' or ''
      tooltip='Apply',
      icon='check' # (FontAwesome names without the `fa-` prefix)
      )
  button_apply.on_click(apply)
  # button_apply.layout.width = '500px'
  display(button_apply)


# def apply(button):
#     global single_sa_seq
#     # button.disabled = True
#     # button.description = 'Clicked'
#     # button.button_style = ''

#     # print(Fore.BLUE+'Construct dataset...'+Style.RESET_ALL)
#     if mode == "Multiple Sequences":
#       raw_data = input_raw_data_by_data_type(data_type)
#       csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)

#     elif mode == "Single Sequence":
#       single_sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)


#     print()
#     print('='*100)
#     print(Fore.BLUE+f'Current Model ({use_model_from}): {adapter_input.value}'+Style.RESET_ALL)
#     print(Fore.BLUE+f'Current Dataset ({data_type}): {raw_data.value}'+Style.RESET_ALL)

# button_apply = ipywidgets.Button(
#     description='Apply',
#     disabled=False,
#     button_style='success', # 'success', 'info', 'warning', 'danger' or ''
#     tooltip='Apply',
#     icon='check' # (FontAwesome names without the `fa-` prefix)
#     )
# button_apply.on_click(apply)
# # button_apply.layout.width = '500px'
# if
# display(button_apply)


#@markdown <br>

#@markdown  <font color="red"> **Note that:** </font> If `use_model_from` is set to `Multi-models on SaprotHub`, each sample will be predicted using multiple models. For classification tasks, voting will be used to determine the final predicted category; for regression tasks, the predicted values from each model will be averaged.

In [None]:
#@title 3.1.2: Get your Result
from transformers import EsmTokenizer
import torch
import copy
import sys
from saprot.scripts.training import my_load_model


################################################################################
################################# 0. MARKDOWN ##################################
################################################################################


# @markdown Click the run button to make prediction.

# @markdown <font color="red">**Note that:**</font> When predicting a category, the index of categories starts from zero.

################################################################################
################################# 1. DATASET ##################################
################################################################################

# if mode == "Multiple Sequences":
#   csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)
# else:
#   single_sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)


################################################################################
################################# 2. MODEL ##################################
################################################################################
def get_base_model(adapter_path):
  adapter_config = Path(adapter_path) / "adapter_config.json"
  with open(adapter_config, 'r') as f:
    adapter_config_dict = json.load(f)
    base_model = adapter_config_dict['base_model_name_or_path']
    if 'SaProt_650M_AF2' in base_model:
      base_model = "westlake-repl/SaProt_650M_AF2"
    elif 'SaProt_35M_AF2' in base_model:
      base_model = "westlake-repl/SaProt_35M_AF2"
    else:
      raise RuntimeError("Please ensure the base model is \"SaProt_650M_AF2\" or \"SaProt_35M_AF2\"")
  return base_model

def check_training_data_type(adapter_path, data_type):
  metadata_path = Path(adapter_path) / "metadata.json"
  if metadata_path.exists():
    with open(metadata_path, 'r') as f:
      metadata = json.load(f)
      required_training_data_type = metadata['training_data_type']
  else:
    required_training_data_type = "SA"
  assert required_training_data_type == training_data_type_dict[data_type], f"This model ({base_model}) is trained on {required_training_data_type} sequences. Please ensure your data type is also {required_training_data_type} sequences for accurate predictions."


# base_model = "westlake-repl/SaProt_35M_AF2"
if multi_lora:
  if use_model_from == "Multi-models on ColabSaprot":
    config_list = [EasyDict({'lora_config_path': ADAPTER_HOME / task_type / lora_config_path}) for lora_config_path in list(adapter_input.value)]
  elif use_model_from == "Multi-models on SaprotHub":
    #1. get adapter_list
    repo_id_list = adapter_input.value.replace(" ", "").split(',')
    #2. download adapters
    for repo_id in repo_id_list:
      snapshot_download(repo_id=repo_id, repo_type="model", local_dir=ADAPTER_HOME / task_type / repo_id)
    config_list = [EasyDict({'lora_config_path': ADAPTER_HOME / task_type / repo_id}) for repo_id in repo_id_list]

  assert len(config_list) > 0, "Please select your models from the dropdown menu on the output of 3.1.1!"
  base_model = get_base_model(ADAPTER_HOME / task_type / config_list[0].lora_config_path)

  for lora_config in config_list:
    check_training_data_type(lora_config.lora_config_path, data_type)

  lora_kwargs = EasyDict({
    "num_lora": len(config_list),
    "config_list": config_list
  })


else:
  if use_model_from == "Shared by peers on SaprotHub":
    snapshot_download(repo_id=adapter_input.value, repo_type="model", local_dir=ADAPTER_HOME / task_type / adapter_input.value)

  adapter_path = ADAPTER_HOME / task_type / adapter_input.value
  base_model = get_base_model(adapter_path)
  check_training_data_type(adapter_path, data_type)
  lora_kwargs = {
      "num_lora": 1,
      "config_list": [{"lora_config_path": adapter_path}]
  }



# if use_existing_model:
#   if adapter_combobox.value =='':
#     print("Please select a model!")
#     sys.exit()

#   if ". " in adapter_combobox.value:
#     adapter_path = ADAPTER_HOME / task_type / adapter_combobox.value
#   else:
#     adapter_path = adapter_combobox.value

################################################################################
##################################### config ###################################
################################################################################
from saprot.config.config_dict import Default_config
config = copy.deepcopy(Default_config)

# task
if task_type in [ "classification", "token_classification"]:
  # config.model.kwargs.num_labels = num_of_categories.value
  config.model.kwargs.num_labels = num_of_categories.value

# base model
config.model.model_py_path = model_type_dict[task_type]
# config.model.save_path = model_save_path
config.model.kwargs.config_path = base_model

# lora
# config.model.kwargs.lora_config_path = adapter_path
# config.model.kwargs.use_lora = True
# config.model.kwargs.lora_inference = True
config.model.kwargs.lora_kwargs = lora_kwargs

################################################################################
################################### inference ##################################
################################################################################
from peft import PeftModelForSequenceClassification

model = my_load_model(config.model)
tokenizer = EsmTokenizer.from_pretrained(config.model.kwargs.config_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


clear_output(wait=True)

# print("#"*100)
print(Fore.BLUE+f"Inference task type: {task_type}"+Style.RESET_ALL)
if mode == "Multiple Sequences":
  print(Fore.BLUE+f"Dataset: {csv_dataset_path}"+Style.RESET_ALL)
else:
  if data_type == "Single PDB/CIF Structure":
    print(Fore.BLUE+f'Dataset ({data_type}): {raw_data[0]}' +Style.RESET_ALL)
  else:
    print(Fore.BLUE+f"Dataset: {raw_data.value}"+Style.RESET_ALL)

if multi_lora:
  print(Fore.BLUE+f"Model: {base_model} - {[str(lora_config.lora_config_path) for lora_config in lora_kwargs.config_list]}"+Style.RESET_ALL)
else:
  print(Fore.BLUE+f"Model: {base_model} - {adapter_path}"+Style.RESET_ALL)
# if use_existing_model:
#   print(Fore.BLUE+f"Adapter: {adapter_path}"+Style.RESET_ALL)

outputs_list=[]

if mode == "Multiple Sequences":
  timestamp = str(datetime.now().strftime("%Y%m%d%H%M%S"))
  output_file = OUTPUT_HOME / f'output_{timestamp}.csv'
  df = pd.read_csv(csv_dataset_path)

  if task_type in ["pair_classification", "pair_regression"]:
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    # for index in tqdm(range(len(df))):
      # seq = df['Sequence'].iloc[index]

      input_1 = tokenizer(row["seq_1"], return_tensors="pt")
      input_1 = {k: v.to(device) for k, v in input_1.items()}

      input_2 = tokenizer(row["seq_2"], return_tensors="pt")
      input_2 = {k: v.to(device) for k, v in input_2.items()}

      with torch.no_grad(): outputs = model(input_1, input_2)
      outputs_list.append(outputs)

    df['score'] = [output.cpu().tolist() for output in outputs_list]
    df.to_csv(output_file, index=False)
    # files.download(output_file)
    file_download(output_file)

    print(Fore.BLUE+f"\nThe prediction result is saved to {output_file} and your local computer."+Style.RESET_ALL)

  else:
    for index in tqdm(range(len(df))):
      seq = df['Sequence'].iloc[index]
      inputs = tokenizer(seq, return_tensors="pt")
      inputs = {k: v.to(device) for k, v in inputs.items()}
      with torch.no_grad(): outputs = model(inputs)
      outputs_list.append(outputs)

    df['score'] = [output.cpu().tolist() for output in outputs_list]
    df.to_csv(output_file, index=False)
    # files.download(output_file)
    file_download(output_file)

    print(Fore.BLUE+f"\nThe prediction result is saved to {output_file} and your local computer."+Style.RESET_ALL)

else:
  if task_type in ["pair_classification", "pair_regression"]:
      # print("You are making inference based on a sequence that you entered")
    input_1 = tokenizer(single_sa_seq[0], return_tensors="pt")
    input_1 = {k: v.to(device) for k, v in input_1.items()}

    input_2 = tokenizer(single_sa_seq[1], return_tensors="pt")
    input_2 = {k: v.to(device) for k, v in input_2.items()}

    outputs = model(input_1, input_2)
    outputs_list.append(outputs)
  else:
    # print("You are making inference based on a sequence that you entered")
    inputs = tokenizer(single_sa_seq, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad(): outputs = model(inputs)
    outputs_list.append(outputs)

################################################################################
##################################### output ###################################
################################################################################

print()
print('='*100)
print(Fore.BLUE+"outputs:"+Style.RESET_ALL)

if task_type == "classification":
  import torch.nn.functional as F
  softmax_output_list = [F.softmax(output, dim=1).squeeze().tolist() for output in outputs_list]
  for index, output in enumerate(softmax_output_list):
    print(f"For Sequence {index}, Prediction: Category {output.index(max(output))}, Probability: {output}")
elif task_type == "regression":
  output_list = [output.squeeze().tolist() for output in outputs_list]
  for index, output in enumerate(outputs_list):
    print(f"For Sequence {index}, Prediction: Value {output.item()}")
elif task_type == "token_classification":
  import torch.nn.functional as F
  softmax_output_list = [F.softmax(output, dim=-1).squeeze().tolist() for output in outputs_list]
  # print(softmax_output_list)
  print("The probability of each category:")
  for seq_index, seq in enumerate(softmax_output_list):
    seq_prob_df = pd.DataFrame(seq)
    print('='*100)
    print(f'Sequence {seq_index + 1}:')
    print(seq_prob_df[1:-1])
elif task_type == "pair_classification":
  import torch.nn.functional as F
  softmax_output_list = [F.softmax(output, dim=-1).squeeze().tolist() for output in outputs_list]
  # print(softmax_output_list)
  print("The probability of each category:")
  for seq_index, seq in enumerate(softmax_output_list):
    seq_prob_df = pd.DataFrame(seq)
    print('='*100)
    print(f'Sequence {seq_index + 1}:')
    print(seq_prob_df[1:-1])
elif task_type == "pair_regression":
  output_list = [output.squeeze().tolist() for output in outputs_list]
  for index, output in enumerate(outputs_list):
    print(f"For Sequence {index}, Prediction: Value {output.item()}")




## 3.2: Mutational Effect Prediction <a name="mutational_effect"></a>

<br>

### Mutation Task
- Single-site or Multi-site mutagenesis
- Saturation mutagenesis

<br>

### Mutation Dataset

For `Single-site or Multi-site mutagenesis`, **one additional column** are required in the CSV file: `mutation`.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_SA_Sequences_data_format_mutation.png
?raw=true" height="200" width="500px" align="center">

- `mutation` column contains the **mutation information**.

<br>

### Mutation Information

Here is the detail about the representation of **mutation information**: <a name="mutation info"></a>

| mode | mutation information|
| --- | --- |
| Single-site mutagenesis | H87Y |
| Multi-site mutagenesis | H87Y:V162M:P179L:P179R |

- For `Single-site mutagenesis`, we use a term like "H87Y" to denote the mutation, where the first letter represents the **original amino acid**, the number in the middle represents the **mutation site** (indexed starting from 1), and the last letter represents the **mutated amino acid**,
- For `Multi-site mutagenesis`, we use a colon ":" to connect each single-site mutations, such as "H87Y:V162M:P179L:P179R".

<!-- ### Prediction Result -->


<!-- ### How to use your model for Mutational Effect Prediction -->

<!--## 1. Input and Output

 You have four different combinations of **mutation task** and **mode** to choose from: -->

<!--
 |Combination| Input | Output |
 | --- | --- | --- |
 |`Single-site or Multi-site mutagenesis` + `Single Sequence`| Enter **a SA sequence** and **a mutation information**| a score of the mutation |
 |`Single-site or Multi-site mutagenesis` + `Multiple Sequences`| Select **a dataset** and upload **a .csv file containing mutation information**| a .csv file containing the scores of mutations |
 |`Saturation mutagenesis` + `Single Sequence`| Enter **a SA sequence**| a .csv file containing the scores of all mutation on every position of the sequence |
 |`Saturation mutagenesis` + `Multiple Sequences`| Select **a dataset**| a .zip file containing the .csv files of the Saturation mutagenesis on every sequence |
  -->


<!-- ### 2. Format of the uploaded .csv file containing mutation information -->

For Multiple Sequences, you are required to **upload an additional .csv file** as your mutation information.
<font color=red>Please ensure that each mutation in the mutation CSV file corresponds to each Sequence in the dataset CSV file.</font>

In [None]:
#@title 3.2.1: Task Config

mutation_task = "Saturation mutagenesis" #@param ["Single-site or Multi-site mutagenesis", "Saturation mutagenesis"]

# data_type = "Single AA Sequence" # @param ["Single AA Sequence", "Single SA Sequence", "Single UniProt ID", "Single PDB/CIF Structure", "Multiple AA Sequences", "Multiple SA Sequences", "Multiple UniProt IDs", "Multiple PDB/CIF Structures"]
data_type = "Single SA Sequence" # @param ["Single SA Sequence", "Single UniProt ID", "Single PDB/CIF Structure", "Multiple SA Sequences", "Multiple UniProt IDs", "Multiple PDB/CIF Structures"]
raw_data = input_raw_data_by_data_type(data_type)

mode = "Multiple Sequences" if data_type in data_type_list[4:8] else "Single Sequence"

if mutation_task == "Single-site or Multi-site mutagenesis":
  if mode == "Single Sequence":
    input_mut = ipywidgets.Text(
      value=None,
      placeholder='Enter Single Mutation Information here',
      # description='SA Sequence:',
      disabled=False)
    print(Fore.BLUE+"Mutation:"+Style.RESET_ALL)
    input_mut.layout.width = '500px'
    display(input_mut)


In [None]:
#@title 3.2.2: Get your Result

################################################################################
################################# DATASET ###################################
################################################################################
if mode == "Single Sequence":
  seq = get_SA_sequence_by_data_type(data_type, raw_data)
else:
  dataset_csv_path = get_SA_sequence_by_data_type(data_type, raw_data)

################################################################################
################################# Task Info ####################################
################################################################################
base_model = "westlake-repl/SaProt_650M_AF2"

clear_output(wait=True)

print(Fore.BLUE)
print(f"Mutation task: {mutation_task}")
print(f"Mode: {mode}")
print(f"Model: {base_model}")
if mode == "Multiple Sequences":
  print(Fore.BLUE+f"Dataset: {dataset_csv_path}"+Style.RESET_ALL)
else:
  print(Fore.BLUE+f"Dataset: {seq}"+Style.RESET_ALL)

print(Style.RESET_ALL)

print(f"Predicting...")
timestamp = datetime.now().strftime("%y%m%d%H%M%S")

################################################################################
################################# load model ###################################
################################################################################

from saprot.model.saprot.saprot_foldseek_mutation_model import SaprotFoldseekMutationModel

config = {
    "foldseek_path": None,
    "config_path": base_model,
    "load_pretrained": True,
}

try:
  zero_shot_model
except Exception:
  zero_shot_model = SaprotFoldseekMutationModel(**config)
  device = "cuda" if torch.cuda.is_available() else "cpu"
  zero_shot_model.to(device)

################################################################################
########################### Single Sequence ####################################
################################################################################
if mode == "Single Sequence":

  if mutation_task == "Single-site or Multi-site mutagenesis":
    mut = input_mut.value
    # validate mut
    aa_seq = seq[0::2]
    for m in mut.split(':'):
      ori_aa = m[0]
      pos = int(m[1:-1])
      mut_aa = m[-1]
      assert aa_seq[pos-1] == ori_aa, f"The provided mutation information contains an error ({m}): the original amino acid at position {pos} ({ori_aa}) does not match your sequence ({aa_seq[pos-1]})."

    score = zero_shot_model.predict_mut(seq, mut)

    print()
    print("="*100)
    print(Fore.BLUE+"Output:"+Style.RESET_ALL)
    print(f"The score of mutation {mut} is {Fore.BLUE}{score}{Style.RESET_ALL}")

  if mutation_task=="Saturation mutagenesis":
    timestamp = datetime.now().strftime("%y%m%d%H%M%S")
    output_path = OUTPUT_HOME / f'{timestamp}_prediction_output.csv'

    mut_dicts = []
    for pos in range(1, int(len(seq) / 2)+1):
      mut_dict = zero_shot_model.predict_pos_mut(seq, pos)
      mut_dicts.append(mut_dict)

    mut_list = [{'mutation': key, 'score': value} for d in mut_dicts for key, value in d.items()]
    df = pd.DataFrame(mut_list)
    df.to_csv(output_path, index=None)

    print()
    print("="*100)
    print(Fore.BLUE+"Output:"+Style.RESET_ALL)
    # files.download(output_path)
    file_download(output_path)
    print(f"\n{Fore.BLUE}The result has been saved to {output_path} and your local computer.{Style.RESET_ALL}")

################################################################################
########################### Multiple Sequences #################################
################################################################################
if mode == "Multiple Sequences":

  dataset_df = pd.read_csv(dataset_csv_path)
  results = []

  if mutation_task=="Single-site or Multi-site mutagenesis":
    for index, row in tqdm(dataset_df.iterrows(), total=len(dataset_df), leave=False, desc=f"Predicting"):
     seq = row['Sequence']
     mut_info = row['mutation']
     results.append(zero_shot_model.predict_mut(seq, mut_info).cpu().item())

    print()
    print("="*100)
    print(Fore.BLUE+"Output:"+Style.RESET_ALL)

    # result_df = pd.DataFrame()
    # result_df['Sequence'] = dataset_df['Sequence']
    # result_df['mutation'] = dataset_df['mutation']
    dataset_df['score'] = results

    output_path = OUTPUT_HOME / f"{timestamp}_prediction_output_{Path(dataset_csv_path).stem}.csv"
    dataset_df.to_csv(output_path, index=None)
    # files.download(output_path)
    file_download(output_path)
    print(f"{Fore.BLUE}The result has been saved to {output_path} and your local computer {Style.RESET_ALL}")

  else:
    for index, row in tqdm(dataset_df.iterrows(), total=len(dataset_df), leave=False, desc=f"Predicting"):
      seq = row['Sequence']
      mut_dicts = []
      for pos in range(1, int(len(seq) / 2)+1):
        mut_dict = zero_shot_model.predict_pos_mut(seq, pos)
        mut_dicts.append(mut_dict)
      mut_list = [{'mutation': key, 'score': value} for d in mut_dicts for key, value in d.items()]
      result_df = pd.DataFrame(mut_list)
      results.append(result_df)

    print()
    print("="*100)
    print(Fore.BLUE+"Output:"+Style.RESET_ALL)

    zip_files = []
    for i in range(len(results)):
      output_path = OUTPUT_HOME / f"{timestamp}_prediction_output_{Path(dataset_csv_path).stem}_Sequence{i+1}.csv"
      results[i].to_csv(output_path, index=None)
      zip_files.append(output_path)

    # zip and download zip to local computer
    zip_path = OUTPUT_HOME / f"{timestamp}_{Path(dataset_csv_path).stem}.zip"
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for file in zip_files:
            zipf.write(file, os.path.basename(file))
    # files.download(zip_path)
    file_download(zip_path)
    print(f"{Fore.BLUE}The result has been saved to {zip_path} and your local computer{Style.RESET_ALL}")

## 3.3: Inverse Folding Prediction <a name="inverse_folding"></a>

Predict the amino acid sequence from protein backbone structure.

<br>

### Dataset

The protein backbone structure should be provided in .pdb/.cif file format.

<br>

<!-- Predict the residue sequence of a structure-aware sequence with masked amino acids (which could be all masked or partially masked).

<br>

### Dataset

Enter a **SA sequence with masked amino acids** into the `sa_seq` input box.

<br>

For example,
**input** is a SA Sequence with masked amino acids:

`#d#v#v#v#p#p#p#p#a#p#a#q#k#k#k#k#w`

and the **output** predicted by model is an AA Sequence:

`MEELGLPDLPPGGVVVV`.

<br> -->


In [None]:
#@title 3.3.1: Upload .pdb/.cif structure file

#@markdown After clicking the run button, an upload button will appear for you to upload your .pdb/.cif structure file.

#@markdown <font face="Consolas" size=2 color='gray'>Note：since you may not know the AA type, you can simply populate your .pdb/.cif file with any random AA. If you want to predit partial positions given some accurate AA information in other positions, just input the accurate AA in these positions and any random AA in unknown positions.</fonte>

#@markdown After uploading is finished, the .pdb/.cif structure will be transformed into the corresponding AA Sequence and Structure (3Di) Sequence.

#@markdown You can **mask partial or all amino acids** in the AA sequence with '#' at certain positions, allowing the model to make predictions for those masked amino acids.

data_type = "Single PDB/CIF Structure"
# raw_data = input_raw_data_by_data_type(data_type)

def get_structure_file():
  print("Please provide the structure type, chain and your structure file.")

  dropdown_type = ipywidgets.Dropdown(
    value="PDB",
    options=["PDB", "AF2"],
    disabled=False)
  dropdown_type.layout.width = '500px'
  print(Fore.BLUE+"Structure type:"+Style.RESET_ALL)
  display(dropdown_type)

  input_chain = ipywidgets.Text(
    value="A",
    placeholder=f'Enter the name of chain here',
    disabled=False)
  input_chain.layout.width = '500px'
  print(Fore.BLUE+"Chain:"+Style.RESET_ALL)
  display(input_chain)

  print(Fore.BLUE+"Please upload a .pdb/.cif file"+Style.RESET_ALL)
  pdb_file_path = upload_file(STRUCTURE_HOME)
  return pdb_file_path, pdb_file_path.stem, dropdown_type, input_chain


backbone_path, stem, dropdown_type, input_chain = get_structure_file()
raw_data = (stem, dropdown_type, input_chain)

sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)

aa_seq = sa_seq[0::2]
struc_seq = sa_seq[1::2]

# masked_sa_seq = ''
# for s in sa_seq[1::2]:
#   masked_sa_seq += '#' + s

clear_output(wait=True)

################################################################################
################################################################################
################################################################################

input_aa_seq = ipywidgets.Text(
      value=aa_seq,
      placeholder='Enter Amino Acid Sequence here',
      disabled=False)
print(Fore.BLUE+"Amino Acid Sequence:"+Style.RESET_ALL)
input_aa_seq.layout.width = '500px'
display(input_aa_seq)

input_struc_seq = ipywidgets.Text(
  value=struc_seq,
  placeholder='Enter Structure Sequence here',
  disabled=False)
print(Fore.BLUE+"Structure Sequence:"+Style.RESET_ALL)
input_struc_seq.layout.width = '500px'
display(input_struc_seq)

print(Fore.RED+"If you want to mask all amino acids and make prediction, simply clear the 'Amino Acid Sequence' box.")

backbone_name = os.path.basename(backbone_path)
show_pdb(backbone_path, color="rainbow").show()
print(f"Backbone visualization of {backbone_name} ({len(struc_seq)} amino acids)")

In [None]:
#@title 3.3.2: Predict Amino Acid Sequence

#@markdown Click the run button to get the predicted Amino Acid Sequence

method = "multinomial" # @param ["argmax", "multinomial"]
num_samples = 1 # @param {type:"integer"}

#@markdown - `method` refers to the prediction method. It could be either "argmax" or "multinomial".
#@markdown   - `argmax` selects the amino acid with the highest probability.
#@markdown   - `multinomial` samples an amino acid from the multinomial distribution.


#@markdown - `num_samples` refers to the number of output amino acid sequences.

save_name = "predicted_seq" # @param {type:"string"}



################################################################################
############################### Dataset ########################################
################################################################################

masked_aa_seq = input_aa_seq.value
if masked_aa_seq.strip() == "":
  print(1111)
  masked_aa_seq = "#" * len(input_struc_seq.value)

masked_struc_seq = input_struc_seq.value

# assert len(masked_aa_seq) == len(masked_struc_seq), f"Please make sure that the amino acid sequence ({len(masked_aa_seq)}) and the structure sequence ({len(masked_struc_seq)}) have the same length."
# masked_sa_seq = ''.join(a + b for a, b in zip(masked_aa_seq, masked_struc_seq))


# if num_samples == 1:
#   method = "argmax"
# elif num_samples > 1:
#   method = "multinomial"
# else:
#   raise BaseException("\"num_samples\" should be an integer greater than or equal to 1.")

################################################################################
############################### Model ##########################################
################################################################################
# base_model = "westlake-repl/SaProt_650M_AF2"
base_model = "westlake-repl/SaProt_650M_AF2_inverse_folding"

config = {
    "config_path": base_model,
    "load_pretrained": True,
}
from saprot.model.saprot.saprot_if_model import SaProtIFModel
try:
  saprot_if_model
except Exception:
  saprot_if_model = SaProtIFModel(**config)
  tokenizer = saprot_if_model.tokenizer
  device = "cuda" if torch.cuda.is_available() else "cpu"
  saprot_if_model.to(device)

################################################################################
############################### Predict ########################################
################################################################################

pred_aa_seqs = saprot_if_model.predict(masked_aa_seq, masked_struc_seq, method=method, num_samples=num_samples)

print("#"*100)
print(Fore.BLUE+"outputs:"+Style.RESET_ALL)
save_path = f"{root_dir}/SaprotHub/output/{save_name}.fasta"
with open(save_path, "w") as w:
  for i, aa_seq in enumerate(pred_aa_seqs):
    print(aa_seq)
    w.write(f">predicted_seq_{i}\n{aa_seq}")

file_download(save_path)


In [None]:
#@title 3.3.3: Predict the structure of generated sequence

#@markdown  <font color="red"> **Warning: Please do not run this cell if you only have 12GB RAM!!! This will cause
#@markdown the out of memory error and you will have to restart the notebook. We recommend you connect to a runtime
#@markdown with more RAM to run the cell properly.** </font>

#@markdown Click the run button to predict the structure of generated sequence using ESMFold

protein_sequence = "" # @param {type:"string"}
save_name = "predicted_structure" # @param {type:"string"}

#@markdown Visualization settings
color = "lDDT" #@param ["chain", "lDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}


################################################################################
############################### LOAD ESMFOLD ################################
################################################################################
try:
  esmfold
except Exception:
  tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
  esmfold = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
  esmfold.esm = esmfold.esm.half()
  esmfold.trunk.set_chunk_size(64)

  device = "cuda" if torch.cuda.is_available() else "cpu"
  esmfold.to(device)

################################################################################
################################## PREDICT ###################################
################################################################################
tokenized_input = tokenizer(
    [protein_sequence],
    return_tensors="pt",
    add_special_tokens=False,
    max_length=1024,
    truncation=True,
    )['input_ids']

tokenized_input = tokenized_input.to(esmfold.device)
with torch.no_grad():
  output = esmfold(tokenized_input)

################################################################################
#################################### SAVE ####################################
################################################################################
save_path = f"{root_dir}/SaprotHub/output/{save_name}.pdb"
pdb = convert_outputs_to_pdb(output)
with open(save_path, "w") as f:
  f.write("".join(pdb))

################################################################################
################################# VISUALIZE ##################################
################################################################################
show_pdb(save_path, show_sidechains, show_mainchains, color).show()
if color == "lDDT":
  plot_plddt_legend().show()

print("Predicted structure")
file_download(save_path)

In [None]:
#@title 3.3.4: Align proteins using TMalign

#@markdown You can find the **uploaded proteins** from /content/SaprotHub/structures (if you connect to your local server, then the path is /SaprotHub/structures).

#@markdown You can find the **predicted proteins** from /content/SaprotHub/output (if you connect to your local server, then the path is /SaprotHub/output).

#@markdown Right click the pdb file to copy the path and then paste it into the box:
pdb_path_1 = "" # @param {type:"string"}
pdb_path_2 = "" # @param {type:"string"}

pdb_path_1 = f"{root_dir}{pdb_path_1}"
pdb_path_2 = f"{root_dir}{pdb_path_2}"

assert os.path.exists(pdb_path_1) and os.path.exists(pdb_path_2), "Input proteins do not exist!"

cmd = f"{root_dir}/SaprotHub/bin/TMalign {pdb_path_1} {pdb_path_2}"
print(os.popen(cmd).read())

# **4: (Optional) Data Preparation**

## 4.1: Get Structure-Aware Sequence <a name="get_sa"></a>

In [None]:
# @title 4.1.1: AA Sequence, UniProt ID, PDB/CIF file -> SA Sequence

################################################################################
################################ input #########################################
################################################################################

data_type = "Single PDB/CIF Structure"  # @param ["Single AA Sequence", "Single UniProt ID", "Single PDB/CIF Structure", "Multiple AA Sequences", "Multiple UniProt IDs", "Multiple PDB/CIF Structures"]

if data_type == data_type_list[7]:
    # upload and unzip PDB files
    print(Fore.BLUE + f"Please upload your .zip file which contains {data_type} files" + Style.RESET_ALL)
    pdb_zip_path = upload_file(UPLOAD_FILE_HOME)
    if pdb_zip_path.suffix != ".zip":
        logger.error("The data type does not match. Please click the run button again to upload a .zip file!")
        raise RuntimeError("The data type does not match.")
    print(Fore.BLUE + "Successfully upload your .zip file!" + Style.RESET_ALL)
    print("="*100)

    import zipfile
    with zipfile.ZipFile(pdb_zip_path, 'r') as zip_ref:
        file_names = zip_ref.namelist()
        zip_ref.extractall(STRUCTURE_HOME)

    uploaded_csv_path = UPLOAD_FILE_HOME / f"{pdb_zip_path.stem}.csv"
    df = pd.DataFrame(file_names, columns=['Sequence'])
    df.to_csv(uploaded_csv_path, index=False)
    raw_data = uploaded_csv_path

else:
    raw_data = input_raw_data_by_data_type(data_type)

################################################################################
############################### output #########################################
################################################################################

if data_type in ["Single AA Sequence", "Single UniProt ID", "Single PDB/CIF Structure"]:
    def apply(button):
        button.disabled = True
        button.description = 'Clicked'
        button.button_style = ''
        sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)
        
        print("="*100)
        print(f"Amino Acid Sequence: {sa_seq[0::2]}")
        print(f"Structure Sequence: {sa_seq[1::2]}")
        print("="*100)
        print(Fore.BLUE + "The Structure-Aware Sequence is here, double click to select and copy it:" + Style.RESET_ALL)
        print(sa_seq)

    button_apply = ipywidgets.Button(
        description='Apply',
        disabled=False,
        button_style='success',  # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Apply',
        icon='check'  # (FontAwesome names without the `fa-` prefix)
    )
    button_apply.on_click(apply)
    button_apply.layout.width = '500px'
    display(button_apply)
else:
    csv_dataset = get_SA_sequence_by_data_type(data_type, raw_data)
    print(Fore.BLUE + "The Structure-Aware Sequences are saved in a .csv file here:" + Style.RESET_ALL)
    print(csv_dataset)
    file_download(csv_dataset)


## 4.2: Convert `.fa/.fasta` file to `.csv` file in the data format of "Multiple AA Sequences"

In [None]:
#@title 4.2.1: `.fa/.fasta` -> Multiple AA Sequences `.csv` <a name="fa2csv"></a>
from Bio import SeqIO
import numpy as np

aa_seq_dict = { "Sequence": [],
                # "label": [],
                # "stage":[]
                }

fa_file_path = upload_file(UPLOAD_FILE_HOME)
assert Path(fa_file_path).name.split('.')[1] in ['fa', 'fasta'], "Please upload a .fa or .fasta file."
with fa_file_path.open("r") as fa:
  for record in tqdm(SeqIO.parse(fa, 'fasta'), leave=True):
      aa_seq_dict["Sequence"].append(str(record.seq))

fa_df = pd.DataFrame(aa_seq_dict)
print(fa_df[5:])

csv_file_path = UPLOAD_FILE_HOME / f'{fa_file_path.stem}.csv'
fa_df.to_csv(csv_file_path, index=None)
# files.download(csv_file_path)
file_download(csv_file_path)

################################################################################
############################ .fa 2 .csv and split ##############################
################################################################################

# automatically_split_dataset = False # @param {type:"boolean"}
# split = ['train', 'valid', 'test']

# aa_seq_dict = { "Sequence": [],
#                 "label": [],
#                 "stage":[]}



# if automatically_split_dataset:

#   fa_file_path = upload_file(UPLOAD_FILE_HOME)
#   label = fa_file_path.stem

#   with fa_file_path.open("r") as fa:
#       for record in tqdm(SeqIO.parse(fa, 'fasta'), leave=True):
#           aa_seq_dict["Sequence"].append(str(record.seq))
#           aa_seq_dict["label"].append(label)
#   weights = [0.8, 0.1, 0.1]
#   aa_seq_dict["stage"] = np.random.choice(split, size=len(aa_seq_dict["Sequence"]), p=weights).tolist()

# else:
#   for i in range(3):
#     print(Fore.BLUE+f"Please upload a .fa file as your {split[i]} dataset")
#     fa_file_path = upload_file(UPLOAD_FILE_HOME)
#     label = fa_file_path.stem

#     with fa_file_path.open("r") as fa:
#         for record in tqdm(SeqIO.parse(fa, 'fasta')):
#             aa_seq_dict["Sequence"].append(str(record.seq))
#             aa_seq_dict["label"].append(label)
#             aa_seq_dict["stage"].append(split[i])

#     print()
#     print("="*100)

# fa_df = pd.DataFrame(aa_seq_dict)
# timestamp = datetime.now().strftime("%y%m%d%H%M%S")
# fa_df.to_csv(f'/content/SaprotHub/upload_files/{timestamp}.csv', index=None)
# files.download(f'/content/SaprotHub/upload_files/{timestamp}.csv')
# print(fa_df[5:])

## 4.3: Dataset Split

Please click the run button to upload your .csv dataset

In [None]:
#@title 4.3.1: Randomly split your .csv dataset <a name="split_dataset"></a>

csv_dataset_path = upload_file(UPLOAD_FILE_HOME)
dataset_df = pd.read_csv(csv_dataset_path)

split = ['train', 'valid', 'test']
split_ratio = [0.8, 0.1, 0.1]

if ('stage' not in dataset_df.columns) or (dataset_df["stage"].nunique()<3):
  dataset_df["stage"] = np.random.choice(split, size=len(dataset_df), p=split_ratio).tolist()

dataset_df.to_csv(csv_dataset_path, index=None)


# Manual <a name="manual"></a>





## How to train and share your model





### Train your Model

#### Step 1

Complete the input and selection of Task Configs

- `task_name` is the name of the training task you're working on.
- `task_objective` describes the goal of your task, like sorting protein sequences into categories or predicting the values of some protein properties.
- `base_model` is the base model you use for training. By default, it's set to the officially pretrained SaProt, but you can use models either retrained (by yourself) by ColabSaprot or shared on [SaprotHub](https://huggingface.co/SaProtHub). For example, you can choose `Trained-by-peers` with your own data if you want to retrain on SaProt models shared by others.  There are a wide range of retrained models available on [SaprotHub](https://huggingface.co/SaProtHub).
- `data_type` indicates the kind of data you're using, which is determined by the dataset file you upload. You can find more details about the formats for different types of data in the provided <a href="#data_format">instruction</a>.


#### Step 2

Click the run button to apply the configs.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/train-1.png?raw=true" height="300" width="600px" align="center">

#### Step 3

After clicking the "Run" button, additional input boxes will appear.

Complete the input of additional information and upload files.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/train-2.png?raw=true" height="300" width="400px" align="center">


<!-- If you want to train on an existing model, choose "Existing Models with your data" as `base_model` at step 1, and then "Existing model" input box will appear. Enter a huggingface model id or select a local model. -->

(Note: Do not click the "Run" button of the next cell before completing the input and upload.)

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/train-3.png?raw=true" height="300" width="300px" align="center">

#### Step 4

Complete the input of training configs

- `batch_size` depends on the number of training samples. If your training data set is large enough, we recommend using 32, 64,128,256, ..., others can be set to 8, 4, 2. (Note that you can not use a larger batch size if you the Colab default T4 GPU. Strongly suggest you subscribe to Colab Pro for an A100 GPU.)
- `max_epochs` refers to the maximum number of training iterations. A larger value needs more training time. The best model will be saved after each iteration. You can adjust `max_epochs` to control training duration. (Note that the max running time of Colab is 12hrs for unsubscribed user or 24hrs for Colab Pro+ user)
- `learning_rate` affects the convergence speed of the model. Through experimentation, we have found that `5.0e-4` is a good default value for base model `Official pretrained SaProt (650M)` and `1.0e-3` for `Official pretrained SaProt (35M)`.

### Step 5

Click the "Run" button to start training.


<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/train-4.png?raw=true" height="300" width="400px" align="center">

You can monitor the training process by these plots. After training, check the training results and the saved model.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/train-5.png?raw=true" height="300" width="400px" align="center">


### Advanced usage: Continual learning

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/train-6.png?raw=true" height="300" width="400px" align="center">





### (Optional) Upload your Model to Huggingface:

#### Step 1

Click the "Run" button and the Hugging Face login interface will appear.

#### Step 2

Find your token by clicking the link.

#### Step 3

Enter the token and click the "Login" button.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/upload-1.png?raw=true" height="300" width="500px" align="center">

<!-- ![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/upload-1-2.png?raw=true) -->

#### Step 4

Enter the model name, model description and other information.

#### Step 5

Click the button to upload the model. You can check your model by the link.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/upload-2.png?raw=true" height="300" width="700px" align="center">

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/upload-3.png?raw=true" height="300" width="500px" align="center">

## How to use your model for prediction





### Classification&Regression prediction task

#### Step 1

Complete the input and selection of Task Configs, and then

- `task_objective` describes the goal of your task, like sorting protein sequences into categories or predicting the values of some protein properties.
- `use_model_from` depends on whether you want to use a local model or a Huggingface model. If you choose `Shared by peers on SaprotHub`, please enter the Hugging Face model ID in the input box. If you choose `Local Model`, simply select your local model from the options. Additionally, there's a wide range of models available on SaprotHub.
- `data_type` indicates the kind of data you're using, which determines the dataset file you should upload. You can find more details about the formats for different types of data in the provided <a href="#data_format">instruction</a>.

<!-- ![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/cls_regr-1-1.png?raw=true) -->

#### Step 2

Click the run button to apply the configs.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/cls_regr-1.png?raw=true" height="300" width="500px" align="center">

#### Step 3

After clicking the "Run" button, additional input boxes and upload button will appear.

Complete the input of additional information and upload files.

(Note: Do not click the "Run" button of the next cell before completing the input and upload.)

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/cls_regr-2.png?raw=true" height="300" width="400px" align="center">

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/cls_regr-3.png?raw=true" height="300" width="400px" align="center">

#### Step 4

Click the run button to start predicting. Check your results after finishing prediction.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/cls_regr-4.png?raw=true" height="300" width="500px" align="center">

### Mutational effect prediction task

#### Step 1

Complete the selection of Task Configs.

- `mutation_task` indicates the type of mutation task. You can choose from `Single-site or Multi-site mutagenesis` and `Saturation mutagenesis`.
- `data_type` indicates the kind of data you're using, which determines the dataset file you should upload. You can find more details about the formats for different types of data in the provided <a href="#data_format">instruction</a>.


#### Step 2

Click the run button to apply the configs.



<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/mep-1.png?raw=true" height="300" width="800px" align="center">



#### Step 3

After clicking the "Run" button, additional input boxes and upload button will appear.

For a single sequence, enter the sequence and the mutation information into the corresponding input fields. (Note that for Saturation mutagenesis, you won't see the Mutation input box.)

For multiple sequences, click the upload button to upload your dataset. (Note that for Saturation mutagenesis, you don’t need to provide mutation information in your dataset, which means only `sequence` column is required in the .csv dataset.)

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/mep-2.png?raw=true" height="300" width="800px" align="center">



#### Step 4

Click the run button to start predicting. Check your results after finishing prediction.

- For a single sequence, the predicted score will be show in the output.

- For multiple sequences, the predicted score will be saved in a .csv file.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/mep-3.png?raw=true" height="300" width="600px" align="center">






<!-- ![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/mutation-3-2.png?raw=true) -->


<!-- ![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/mutation-3-3.png?raw=true)

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/mutation-3-4.png?raw=true) -->

### Inverse folding task

#### Step 1

Click the run button to upload the structure file, which could be in the format of .pdb or .cif file.



#### Step 2

After clicking the "Run" button, additional input boxes and upload button will appear.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/if-1.png?raw=true" height="300" width="500px" align="center">



#### Step 3

After uploading the structure file, it will be transformed into AA sequence and structure sequence.

Use '#' to mask some amino acids for prediction.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/if-2.png?raw=true" height="300" width="800px" align="center">

#### Step 4

Choose the prediction method.

#### Step 5

Click the run button to make prediction.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/if-3.png?raw=true" height="300" width="1000px" align="center">

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/v1/if-4.png?raw=true" height="300" width="600px" align="center">
