# **SaProt Hub: Collaborative Protein Language Modeling with ColabSaProt**

This is the Colab version of [SaProt](https://github.com/westlake-repl/SaProt), a pre-trained protein language model designed for various downstream protein tasks. Our aim is to make SaProt more accessible and user-friendly for biologists, enabling effortless model training and knowledge sharing within the scientific community.

We hope this platform can contribute to advancing biological research, fostering collaboration, and accelerating discoveries in the field. You can access [our paper](https://www.biorxiv.org/content/10.1101/2023.10.01.560349v2) for further details.





## Content

<img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/Outline.png" height="256" align="center" style="height:256px">

<font color=red>**To view the content, please click on the first option in the left sidebar.**</font>


# **1: Installation**


In [None]:
#@title 1.1: Clickt the run button ▶️ to install SaProt

#@markdown (Please waiting for 2-8 minutes to install...)

!mkdir -p /content/saprot/LMDB
!mkdir -p /content/saprot/bin
# !mkdir -p /content/saprot/tmp/af2_structures/
# !mkdir -p /content/saprot/weights
!mkdir -p /content/saprot/output
!mkdir -p /content/saprot/adapters/classification
!mkdir -p /content/saprot/adapters/regression
!mkdir -p /content/saprot/adapters/token_classification
!mkdir -p /content/saprot/structures
# !mkdir -p /content/saprot/training_monitor

################################################################################
########################### install saprot #####################################
################################################################################
# !pip install lmdb
# !pip install transformers==4.28.0 --quiet
# !pip install loguru --quiet
# !pip install multiprocess --quiet

!pip install gdown==v4.6.3 --force-reinstall --quiet
!gdown https://drive.google.com/drive/folders/1ECKe5clJXs4POlScVggRQDrFo5HJpGBN?usp=drive_link -O /content/saprot/ --folder  --quiet && pip install /content/saprot/ColabSaProtSetup/saprot-0.4.3-py3-none-any.whl --quiet
!chmod +x /content/saprot/ColabSaProtSetup/foldseek

!rsync -a --remove-source-files /content/saprot/ColabSaProtSetup/upload_files /content/saprot
!rsync -a --remove-source-files /content/saprot/ColabSaProtSetup/datasets /content/saprot
!mv /content/saprot/ColabSaProtSetup/foldseek /content/saprot/bin/

################################################################################
################################################################################
################################## global ######################################
################################################################################
################################################################################

import ipywidgets
from google.colab import widgets
from pathlib import Path
import pandas as pd
import torch
import copy
import os
from tqdm import tqdm
from datetime import datetime
from google.colab import files
import zipfile
from loguru import logger

import yaml
import argparse

from easydict import EasyDict
from datetime import datetime

# from saprot.utils.others import setup_seed
# from saprot.utils.module_loader import *
# from saprot.model.esm.esm_classification_model import EsmClassificationModel
# from saprot.model.esm.esm_regression_model import EsmRegressionModel
# from saprot.dataset.esm.esm_classification_dataset import EsmClassificationDataset
# from saprot.dataset.esm.esm_regression_dataset import EsmRegressionDataset

DATASET_HOME = Path('/content/saprot/datasets')
ADAPTER_HOME = Path('/content/saprot/adapters')
STRUCTURE_HOME = Path("/content/saprot/structures")
LMDB_HOME = Path('/content/saprot/LMDB')
OUTPUT_HOME = Path('/content/saprot/output')
FOLDSEEK_PATH = Path("/content/saprot/bin/foldseek")
aa_set = {"A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"}
foldseek_struc_vocab = "pynwrqhgdlvtmfsaeikc#"

task_type_dict = {
  "Classify protein sequences (classification)" : "classification",
  "Classify each Amino Acid (token classification), e.g. Binding site detection" : "token_classification",
  "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein" : "regression",
}
model_type_dict = {
  "classification" : "esm/esm_classification_model",
  "token_classification" : "esm/esm_token_classification_model",
  "regression" : "esm/esm_regression_model",
}
dataset_type_dict = {
  "classification": "esm/esm_classification_dataset",
  "token_classification" : "esm/esm_token_classification_dataset",
  "regression": "esm/esm_regression_dataset",
}
class font:
    RED = '\033[91m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'

    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

    RESET = '\033[0m'

################################################################################
################################################################################
################################## DATASET #####################################
################################################################################
################################################################################

################################################################################
############################# dataset list #####################################
################################################################################
def get_datasets_list():
    file_list = []
    directory_path = DATASET_HOME

    for file_path in directory_path.iterdir():
        if file_path.is_file():
            file_list.append(file_path)

    return file_list

def show_datasets_info(datasets_list):
  grid = widgets.Grid(len(datasets_list)+1, 3, header_row=True, header_column=True)

  with grid.output_to(0, 0):
    print("ID")

  with grid.output_to(0, 1):
    print("Dataset")

  with grid.output_to(0, 2):
    print("Dataset Path")

  for i in range(len(datasets_list)):
    with grid.output_to(i+1, 0):
      print(i)
    with grid.output_to(i+1, 1):
      print(datasets_list[i].stem)
    with grid.output_to(i+1, 2):
      print(datasets_list[i])

def datasets_dropdown(datasets_list):
  dropdown = ipywidgets.Dropdown(
      options=[f"{index}. {file.stem}" for index, file in enumerate(datasets_list)],
      value=None,
      # description='Selected:',
      disabled=False,)
  dropdown.layout.width = "500px"
  display(dropdown)
  return dropdown

def select_dataset():
  datasets_list = get_datasets_list()
  print(font.RED+font.BOLD+"Existing Datasets:"+font.RESET)
  print("="*100)
  show_datasets_info(datasets_list)
  print("="*100)

  return datasets_dropdown(datasets_list)

################################################################################
############################# adapter list #####################################
################################################################################

def get_adapters_list():
    file_list = []
    directory_path = ADAPTER_HOME

    for file_path in (directory_path / "classification").iterdir():
        if file_path.is_dir():
            file_list.append(file_path)

    for file_path in (directory_path / "regression").iterdir():
        if file_path.is_dir():
            file_list.append(file_path)

    for file_path in (directory_path / "token_classification").iterdir():
        if file_path.is_dir():
            file_list.append(file_path)

    file_list = [filename for filename in file_list if not filename.stem.startswith('.')]

    return file_list

def show_adapters_info(adapters_list):
  grid = widgets.Grid(len(adapters_list)+1, 3, header_row=True, header_column=True)

  with grid.output_to(0, 0):
    print("ID")

  with grid.output_to(0, 1):
    print("Local Adapter")

  with grid.output_to(0, 2):
    print("Adapter Path")

  for i in range(len(adapters_list)):
    with grid.output_to(i+1, 0):
      print(i)
    with grid.output_to(i+1, 1):
      print(adapters_list[i].stem)
    with grid.output_to(i+1, 2):
      print(adapters_list[i])

# def adapters_dropdown(adapters_list):
#   dropdown = ipywidgets.Dropdown(
#       options=[f"{index}. {file.stem}" for index, file in enumerate(adapters_list)],
#       value=None,
#       description='Selected:',
#       disabled=False,)
#   dropdown.layout.width = "500px"
#   display(dropdown)

#   return dropdown

def adapters_combobox(adapters_list):
  combobox = ipywidgets.Combobox(
    options=[f"{index}. {adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
    value=None,
    placeholder='Enter a huggingface repo_id or select an local adapter here',
    # description='Selected:',
    disabled=False)
  combobox.layout.width = '500px'
  display(combobox)

  return combobox

def select_adapter():
  adapters_list = get_adapters_list()
  print(font.RED+font.BOLD+"Existing Adapters:"+font.RESET)
  print("="*100)
  show_adapters_info(adapters_list)
  print("="*100)
  return adapters_combobox(adapters_list)

################################################################################
########################### download dataset ###################################
################################################################################
def download_dataset(task_name):
  import gdown
  import tarfile

  filepath = LMDB_HOME / f"{task_name}.tar.gz"
  download_links = {
    "ClinVar" : "https://drive.google.com/uc?id=1Le6-v8ddXa1eLJZFo7HPij7NhaBmNUbo",
    "DeepLoc_cls2" : "https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf",
    "DeepLoc_cls10" : "https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf",
    "EC" : "https://drive.google.com/uc?id=1VFLFA-jK1tkTZBVbMw8YSsjZqAqlVQVQ",
    "GO_BP" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "GO_CC" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "GO_MF" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "HumanPPI" : "https://drive.google.com/uc?id=1ahgj-IQTtv3Ib5iaiXO_ASh2hskEsvoX",
    "MetalIonBinding" : "https://drive.google.com/uc?id=1rwknPWIHrXKQoiYvgQy4Jd-efspY16x3",
    "ProteinGym" : "https://drive.google.com/uc?id=1L-ODrhfeSjDom-kQ2JNDa2nDEpS8EGfD",
    "Thermostability" : "https://drive.google.com/uc?id=1I9GR1stFDHc8W3FCsiykyrkNprDyUzSz",
  }

  try:
    gdown.download(download_links[task_name], str(filepath), quiet=False)
    with tarfile.open(filepath, 'r:gz') as tar:
      tar.extractall(path=str(LMDB_HOME))
      print(f"Extracted: {filepath}")
  except Exception as e:
    raise RuntimeError("The dataset has not prepared.")

################################################################################
############################# upload file ######################################
################################################################################
def upload_file(upload_path):
  import shutil
  import os
  from pathlib import Path
  import sys

  upload_path = Path(upload_path)
  upload_path.mkdir(parents=True, exist_ok=True)
  basepath = Path().resolve()
  try:
    uploaded = files.upload()
    filenames = []
    for filename in uploaded.keys():
      filenames.append(filename)
      shutil.move(basepath / filename, upload_path / filename)
    if len(filenames) == 0:
      logger.info("The uploading process has been interrupted by the user.")
      raise RuntimeError("The uploading process has been interrupted by the user.")
  except Exception as e:
    logger.error("Upload file fail! Please click the button to run again.")
    raise(e)

  return upload_path / filenames[0]

################################################################################
############################ upload dataset ####################################
################################################################################
def upload_dataset(data_type):
  print(font.RED+font.BOLD+"Please upload the .csv file"+font.RESET)

  upload_path = Path().resolve() / "saprot" / "upload_files"
  dataset_csv_path = upload_file(upload_path)
  print(font.RED+font.BOLD+"Successfully upload your .csv file!"+font.RESET)
  print("="*100)

  saseq_csv_path = DATASET_HOME / f"[DATASET]{Path(dataset_csv_path).stem}.csv"
  get_SASequence_by_data_type(data_type, dataset_csv_path, saseq_csv_path)
  print()
  print("="*100)
  print(font.RED+font.BOLD+"Successfully upload your dataset!"+font.RESET)

  return saseq_csv_path

################################################################################
########################## Download predicted structures #######################
################################################################################
def uniprot2pdb(uniprot_ids, nprocess=20):
  from saprot.utils.downloader import AlphaDBDownloader

  os.makedirs(STRUCTURE_HOME, exist_ok=True)
  af2_downloader = AlphaDBDownloader(uniprot_ids, "pdb", save_dir=STRUCTURE_HOME, n_process=20)
  af2_downloader.run()


################################################################################
############### Form foldseek sequences by multiple processes ##################
################################################################################
def pdb2sequence(process_id, idx, uniprot_id, writer):
  from saprot.utils.foldseek_util import get_struc_seq

  try:
    pdb_path = f"{STRUCTURE_HOME}/{uniprot_id}.pdb"
    if Path(pdb_path).exists:
      seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, ["A"], process_id=process_id)["A"][-1]
    else:
      pdb_path = f"{STRUCTURE_HOME}/{uniprot_id}.cif"
      seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, ["A"], process_id=process_id)["A"][-1]

    writer.write(f"{uniprot_id}\t{seq}\n")
  except Exception as e:
    print(f"Error: {uniprot_id}, {e}")


################################################################################
################## Form SA sequences by uniprotis/pdb/cif ######################
################################################################################
def get_SASequence_by_data_type(data_type, csv_file_path, seq_file_path):
  protein_df = pd.read_csv(csv_file_path)

  if data_type == "Structure Aware Sequence":
    protein_df.to_csv(seq_file_path, index=None)
    return

  if data_type == "Amino Acid Sequence":

    for index, value in protein_df['Sequence'].items():
      sa_seq = ''
      for aa in value:
        sa_seq += aa + '#'
      protein_df.at[index, 'Sequence'] = sa_seq

    protein_df.to_csv(seq_file_path, index=None)
    return

  from saprot.utils.mpr import MultipleProcessRunnerSimplifier

  if data_type == "UniProt ID":
    protein_list = protein_df.iloc[:, 0].tolist()
    uniprot2pdb(protein_list)
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=1, return_results=True)
    outputs = mprs.run()

    protein_df['Sequence'] = [output.split("\t")[1] for output in outputs]
    protein_df.to_csv(seq_file_path, index=None)

  elif data_type == "PDB/CIF file":
    # upload and unzip PDB file
    print(font.RED+font.BOLD+"Please upload your .zip file that contains .pdb/.cif files"+font.RESET)
    pdb_zip_path = upload_file(Path("/content/saprot/upload_files"))
    if pdb_zip_path.suffix != ".zip":
      logger.error("The data type does not match. Please click the run button again to upload a .zip file!")
      raise RuntimeError("The data type does not match.")
    print(font.RED+font.BOLD+"Successfully upload your .zip file!"+font.RESET)
    print("="*100)

    import zipfile
    with zipfile.ZipFile(pdb_zip_path, 'r') as zip_ref:
      zip_ref.extractall(STRUCTURE_HOME)
      protein_list = [Path(file).stem for file in zip_ref.namelist()]

    from saprot.utils.mpr import MultipleProcessRunnerSimplifier
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=1, return_results=True)
    seqs = mprs.run()

    seqs_dict = {}
    for seq in seqs:
      key, value = seq.split('\t')
      seqs_dict[key] = value

    for index, value in protein_df['Sequence'].items():
      protein_df.at[index, 'Sequence'] = seqs_dict[value.split(".")[0]]

    protein_df.to_csv(seq_file_path, index=None)

  else:
    raise RuntimeError("Wrong data type!")

################################################################################
############### Form Single SA sequences by uniprotis/pdb/cif ##################
################################################################################
def get_single_SASequence_by_data_type(data_type, raw_data):
  from saprot.utils.mpr import MultipleProcessRunnerSimplifier
  if data_type == "Amino Acid Sequence":
    sa_seq = ''
    for aa in raw_data:
        sa_seq += aa + '#'
  elif data_type == "UniProt ID":
    protein_list = [raw_data]
    uniprot2pdb(protein_list)
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=1, return_results=True)
    seqs = mprs.run()
    sa_seq = seqs[0].split('\t')[1]

  elif data_type == "PDB/CIF file":
    # # upload and unzip PDB file
    # print(font.RED+font.BOLD+"Please upload a .pdb/.cif file"+font.RESET)
    # pdb_file_path = upload_file(Path("/content/saprot/tmp/af2_structures/"))
    # print(font.RED+font.BOLD+"Successfully upload your .pdb/.cif file!"+font.RESET)
    # print("="*100)

    # protein_list = [Path(pdb_file_path).stem]

    mprs = MultipleProcessRunnerSimplifier(raw_data, pdb2sequence, n_process=1, return_results=True)
    seqs = mprs.run()
    sa_seq = seqs[0].split('\t')[1]

  else:
    raise RuntimeError("wrong data type!")

  return sa_seq

  # print()
  # print("="*100)
  # print(font.RED + font.BOLD + "The Structure-Aware Sequence is here, double click to select and copy it:" + font.RESET)
  # print(seqs[0].split('\t')[1])

# ################################################################################
# ################################################################################
# ################################ FINETUNE ######################################
# ################################################################################
# ################################################################################


# ################################################################################
# ################################ load model ####################################
# ################################################################################
# def load_model(config):
#     model_config = copy.deepcopy(config)
#     model_type = model_config.pop("model_py_path")
#     kwargs = model_config.pop('kwargs')
#     model_config.update(kwargs)
#     print(model_config)
#     if model_type == "esm/esm_classification_model":
#       return EsmClassificationModel(**model_config)
#     if model_type == "esm/esm_regression_model":
#       if 'num_labels' in model_config.keys():
#         model_config.pop("num_labels")
#       return EsmRegressionModel(**model_config)

# ################################################################################
# ################################ load dataset ##################################
# ################################################################################
# def load_dataset(config):
#     dataset_config = copy.deepcopy(config)
#     dataset_type = dataset_config.pop("dataset_py_path")
#     kwargs = dataset_config.pop('kwargs')
#     dataset_config.update(kwargs)

#     if dataset_type == "esm/esm_classification_dataset":
#       return EsmClassificationDataset(**dataset_config)
#     if dataset_type == "esm/esm_regression_dataset":
#       return EsmRegressionDataset(**dataset_config)

# ################################################################################
# ################################## finetune ####################################
# ################################################################################
# def finetune(config, run_mode):
#     if config.setting.seed:
#         setup_seed(config.setting.seed)

#     for k, v in config.setting.os_environ.items():
#         if v is not None and k not in os.environ:
#             os.environ[k] = str(v)

#         elif k in os.environ:
#             config.setting.os_environ[k] = os.environ[k]

#     if config.setting.os_environ.NODE_RANK != 0:
#         config.Trainer.logger = False

#     ############################################################################
#     model = load_model(config.model)
#     data_module = load_dataset(config.dataset)
#     trainer = load_trainer(config)


#     trainer.fit(model=model, datamodule=data_module)

#     ############################################################################
#     if model.save_path is not None:
#         if config.model.kwargs.get("use_lora", False):
#             # Load LoRA model
#             config.model.kwargs.lora_config_path = model.save_path
#             config.model.kwargs.lora_inference = True
#             model = load_model(config.model)
#         else:
#             model.load_checkpoint(model.save_path, load_prev_scheduler=model.load_prev_scheduler)


#     trainer.test(model=model, datamodule=data_module)


################################################################################
################################################################################
################################ INFERENCE #####################################
################################################################################
################################################################################

################################################################################
################################ zeroshot func #################################
################################################################################
# def zeroshot(mutation_task, seq, mut_info):
#   with torch.no_grad():
#     # single-site / multi-site
#     if mutation_task == "Single-site or Multi-site mutagenesis":
#       tokens = tokenizer.tokenize(seq)
#       for single in mut_info.split(":"):
#           pos = int(single[1:-1])
#           tokens[pos - 1] = "#" + tokens[pos - 1][-1]

#       mask_seq = " ".join(tokens)
#       inputs = tokenizer(mask_seq, return_tensors="pt")
#       inputs = {k: v.to(device) for k, v in inputs.items()}

#       outputs = model(**inputs)
#       logits = outputs.logits
#       probs = logits.softmax(dim=-1)

#       score = 0
#       for single in mut_info.split(":"):
#           ori_aa, pos, mut_aa = single[0], int(single[1:-1]), single[-1]
#           ori_st = tokenizer.get_vocab()[ori_aa + foldseek_struc_vocab[0]]
#           mut_st = tokenizer.get_vocab()[mut_aa + foldseek_struc_vocab[0]]

#           ori_prob = probs[0, pos, ori_st: ori_st + len(foldseek_struc_vocab)].sum()
#           mut_prob = probs[0, pos, mut_st: mut_st + len(foldseek_struc_vocab)].sum()

#           score += torch.log(mut_prob / ori_prob)
#       # print(f"The score of mutation {mut_info} is {font.RED}{score.item()}{font.RESET}")

#       return score.item()

#     # Saturation
#     if mutation_task == "Saturation mutagenesis":
#       scores = []

#       ori_seq = [seq[i:i+2] for i in range(0, len(seq), 2)]

#       for pos in tqdm(range(1, len(ori_seq)+1)):
#         mask_seq = ori_seq.copy()
#         mask_seq[pos-1] = "#" + ori_seq[pos-1][-1]
#         mask_seqs = []
#         mask_seqs.append(" ".join(mask_seq))

#         mask_inputs = tokenizer.batch_encode_plus(mask_seqs, return_tensors="pt", padding=True)
#         mask_inputs = {k: v.to(device) for k, v in mask_inputs.items()}
#         mask_outputs = model(**mask_inputs)
#         mask_probs = mask_outputs['logits'].softmax(dim=-1)

#         ori_aa = ori_seq[pos-1][0]
#         ori_st = tokenizer.get_vocab()[ori_aa + foldseek_struc_vocab[0]]
#         ori_prob = mask_probs[0, pos, ori_st: ori_st + len(foldseek_struc_vocab)].sum()

#         for mut_aa in aa_set:
#           pred = 0
#           mut_st = tokenizer.get_vocab()[mut_aa + foldseek_struc_vocab[0]]
#           mut_prob = mask_probs[0, pos, mut_st: mut_st + len(foldseek_struc_vocab)].sum()
#           pred += torch.log(mut_prob / ori_prob)
#           # print(f"The score of mutation {ori_aa}{pos}{mut_aa} is {font.RED}{pred.item()}{font.RESET}")
#           scores.append(pred.item())

#       return scores

# ################################################################################
# ###################### load model for inference ################################
# ################################################################################
# def load_model_inference(model_path):
#   from transformers import EsmTokenizer, EsmForMaskedLM
#   # from saprot.utils.constants import aa_set， foldseek_struc_vocab

#   # model_path = "westlake-repl/SaProt_35M_AF2"
#   tokenizer = EsmTokenizer.from_pretrained(model_path)
#   model = EsmForMaskedLM.from_pretrained(model_path)

#   device = "cuda" if torch.cuda.is_available() else "cpu"
#   model.to(device)

print("Installation finished!")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.6/62.6 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.3/78.3 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m147.9/147.9 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.1/142.1 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.8/66.8 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.1/121.1 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of th

# **2: Dataset**

In [None]:
#@title 2.1: Upload your protein sequences dataset

# #@markdown <font color="red">Before clicking the run button to upload your dataset, please carefully review this section to learn about the format of data you should prepare.</font>

# #@markdown | Sequence | label | stage|
# #@markdown | --- | --- | --- |
# #@markdown | (depends on data type) | (a number) | (train/valid/test) |
#
# #@markdown <br>
#
# #@markdown #### Here are the meanings of these three columns:
# #@markdown - `Sequence`: The content of this column dependes on the type of data you have:
# #@markdown  - For `Structure Aware Sequence`: The "Sequence" column should contain **SA(Structure-Aware) sequence**
# #@markdown  - For `AA Sequence`: The "Sequence" column should contain **Amino Acid Sequence**
# #@markdown  - For `UniProt ID`: The "Sequence" column should contain **UniProt ID**
# #@markdown  - For `PDB/CIF file`: The "Sequence" column should contain **the filenames of your .pdb/.cif file** (<font color = "red">**Note that**</font>: If your data type is 'PDB/CIF file', you need to upload an additional .zip file containing your .pdb/.cif file after uploading the .csv file)
# #@markdown - `label`:
# #@markdown  - For **classification task**, the values in this column should represent the index of your categories (integers ranging from zero to the number of categories).
# #@markdown  - For **token classification task**, the values in this column should represent a list of category index for each amino acid.
# #@markdown  - For **regression task**, the values in this column should be numerical.
#
#
# #@markdown <img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/LabelFormat.png" height="400" align="center" style="height:256px">
#
# #@markdown - `stage`: The values in this column(**train/valid/test**) determine whether the sample is used for training, validation, or testing. (<font color = "red">**Note that**</font>: Ensure that your dataset includes samples of all three types(train/valid/test).)
# # The dataset uploaded by the user must include samples for validation purposes. The model saves checkpoints based on validation metrics. If the validation set is empty, the checkpoints will not be saved.
# #@markdown <br>
#
# # #@markdown ###  The following illustration displays the specific file format and its contents for three data types：
#
# #@markdown <img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/DatasetFormat.png" height="256" align="center" style="height:256px">
# #@markdown  <font color="red">**Note that**</font>: You can find some examples at /content/saprot/upload_files. Download them to review the format and upload them for a try.
#
# #@markdown <br>


#@markdown **Before uploading your dataset, please carefully review this section to understand the required data format.**
#@markdown
#@markdown ### You need to upload **a .csv file** as your dataset, where the columns should be named as `Sequence`, `label`, and `stage`.
##@markdown ### You need to upload a dataset file in the .csv format. Ensure that the columns are named as `Sequence`, `label`, and `stage`.
#@markdown
##@markdown | Sequence | label | stage|
##@markdown | --- | --- | --- |
##@markdown | (depends on data type) | (a number) | (train/valid/test) |
##@markdown
##@markdown #### Explanation of Columns:
#@markdown - `Sequence`: The content of this column depends on your **data type**:
#@markdown   - For `Structure Aware Sequence`: Provide the "Sequence" column with SA (Structure-Aware) sequences.
#@markdown   - For `AA Sequence`: The "Sequence" column should contain Amino Acid Sequences.
#@markdown   - For `UniProt ID`: Input UniProt IDs into the "Sequence" column.
#@markdown   - For `PDB/CIF file`: Enter filenames of your .pdb/.cif files into the "Sequence" column. (Note: For this data type, upload an additional .zip file containing your .pdb/.cif files after uploading the .csv file)
#@markdown - `label`: The content of this column depends on your **task type**:
#@markdown   - For `classification tasks`: Use integers ranging from zero to the number of categories to represent the categories in this column.
#@markdown   - For `token classification tasks`: Provide a list of category indices for each amino acid in this column.
#@markdown   - For `regression tasks`: Input numerical values into this column.
#@markdown
#@markdown
#@markdown - `stage`: This column should indicate whether the sample is for training, validation, or testing (train/valid/test). Ensure your dataset includes samples for all three stages.
#@markdown
# #@markdown ### The illustration below shows the specific file format and contents for three data types：
#@markdown

#@markdown <img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/DatasetFormat.png" height="400" width="700px" align="left">
#@markdown <img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/LabelFormat.png" height="400" width="550px" align="left">

#@markdown
#@markdown **Note:** Examples are available at /content/saprot/upload_files. Download to review their format, and then upload them for a trial.
#@markdown



################################################################################
################################ input #########################################
################################################################################
data_type = "Amino Acid Sequence" #@param ["Structure Aware Sequence", "Amino Acid Sequence", "UniProt ID", "PDB/CIF file"]
# use_for = "Training (with Label)" # @param ["Training (with Label)", "Prediction (without Lable)"]
# task_objective = "Classify protein sequences (classification)" # @param ["Classify protein sequences (classification)", "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein", "Classify each Amino Acid (token classification), e.g. Binding site detection"]
# automatically_split_dataset = False # @param {type:"boolean"}
# check the format of .csv file
# print(font.RED+font.BOLD+f"Data Type: {data_type}"+font.RESET)
# dataset_type = "Train for classification task (with label)" # @param ["Train for classification task (with label)", "Train for regression task (with label)", "Train for token_classification task (with label)", "Prediction for all three tasks (without label)"]

################################################################################
############################ upload dataset ####################################
################################################################################

saseq_csv_path = upload_dataset(data_type)
print(font.RED+font.BOLD +f"Dataset: \"{saseq_csv_path}\" has been saved to your local computer." + font.RESET )

# files.download(saseq_csv_path)

################################################################################
############################## existing dataset ################################
################################################################################

datasets_list = get_datasets_list()
show_datasets_info(datasets_list)

# **3: Train and Share your PLM**

In [None]:
#@title 3.1: Task Config
##@markdown Complete some task configs and run this cell to Finetune SaProt on your dataset. <br>


##@markdown <br>

################################################################################
from pathlib import Path
from easydict import EasyDict
import copy
import os
import subprocess
import torch

################################################################################
############################### custom config ##################################
################################################################################


##@markdown ### 1. Enter your task name
task_name = "demo_cls2" # @param {type:"string"}
# model_save_path = Path(f"/content/saprot/weights/{task_name}")

##@markdown ### 2. Select your task type
task_objective = "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein" # @param ["Classify protein sequences (classification)", "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein", "Classify each Amino Acid (token classification), e.g. Binding site detection"]
task_type = task_type_dict[task_objective]
# num_of_categories = 10 # @param {type:"number"}
# #@markdown <font face="Consolas" size=2 color='gray'>(Ignoring `num_of_categories` if predicting a value)

if task_type in ["classification", 'token_classification']:

  print(font.RED+font.BOLD+'Enter the number of category in your training dataset here:'+font.RESET)
  num_of_categories = ipywidgets.BoundedIntText(
                                              # value=7,
                                              min=2,
                                              max=1000000,
                                              step=1,
                                              # description='num_of_category: \n',
                                              disabled=False)
  num_of_categories.layout.width = "100px"
  display(num_of_categories)
  print(font.RED+font.BOLD+'It\'s normal not to receive feedback once inputting is finished. Let\'s move on to the next step.'+font.RESET)








In [None]:
#@title 3.2: Select Dataset


dataset_dropdown = select_dataset()

In [None]:
#@title 3.3: Select Model

# #@markdown We utilize **LoRA** (A Parameter-Efficient Fine-Tuning Technique), which allows us to store model weights into an small adapter without adjusting the original model weights during training.
# #@markdown

# #@markdown After training, you can obtain an adapter for your task.

#@markdown We use Parameter-Efficient Fine-Tuning Technique for model training. It enables us to store model weights in a small **adapter** without changing the original model weights during training. After training, you can get an adapter specific to your task.
# #@markdown As we use Parameter-Efficient Fine-Tuning Technique, which allows us to store model weights into an small adapter without adjusting the original model weights during training, it's necessary to specify both the original model and adapter for prediction.
#@markdown
#@markdown 1. Select a **base model** from the dropdown box `model_path` below.
#@markdown
#@markdown 2. If you want to **train on existing adapters**, check the box `use_adapter` below. By running this cell, you will see an **adapter combobox**. We provide two ways to select your adapter:
#@markdown  - Select a **local adapters** from the combobox.
#@markdown   - Enter a **huggingface repository name** to the combobox. (e.g. "SaProtAdapters/DeepLoc_cls10_35M")
#@markdown
#@markdown You can also find some officical adapters in [here](https://huggingface.co/SaProtAdapters)
model_path = "westlake-repl/SaProt_35M_AF2" # @param ["westlake-repl/SaProt_35M_AF2", "westlake-repl/SaProt_650M_AF2"]
print(font.RED+font.BOLD+f"Model: {model_path}"+font.RESET)
use_adapter = False # @param {type:"boolean"}

if use_adapter:
  adapter_combobox = select_adapter()

# if use_adapter:
#   print(font.RED+font.BOLD+f"Loaded Adapter: {adapter_combobox.value}"+font.RESET)


In [None]:
#@title 3.4: Train your Model

################################################################################
############################## advance config ##################################
################################################################################

batch_size = 4 # @param ["1", "2", "4", "8"] {type:"raw", allow-input: true}
max_epochs = 20 # @param ["10", "20", "50"] {type:"raw", allow-input: true}
learning_rate = 1.0e-3 # @param ["1.0e-3", "5.0e-4", "1.0e-4"] {type:"raw", allow-input: true}

limit_train_batches=1.0
limit_val_batches=1.0
limit_test_batches=1.0

val_check_interval=0.5

use_lora = True
num_workers = 2

################################################################################
############################## advance config ##################################
################################################################################

#@markdown - <font face="Consolas" size=2 color='gray'> `batch_size` depends on the number of training samples during model training. We recommend using the default value of 2 for GPU T4.

# #@markdown |  Recommended batch size   | T4  |  A100   |
# #@markdown | ---                       | --- |  ---    |
# #@markdown | SaProt_35M_AF2            |  4  |    16   |
# #@markdown | SaProt_650M_AF2           |  -  |    8    |


#@markdown - <font face="Consolas" size=2 color='gray'>`max_epochs` refers to the maximum number of complete passes through the entire dataset during the training process.
#@markdown You can adjust `max_epochs` to control training duration. (Note that the max running time of colab is 12hrs for unsubscribed user or 24hrs for colab pro+ user) <br>
#@markdown

# download_adapter_to_your_computer = True #@param {type:"boolean"}
download_adapter_to_your_computer = True
#@markdown - <font face="Consolas" size=2 color='gray'>`learning_rate` affects the convergence speed of the model.
#@markdown Through experimentation, we have found that `1.0e-3` is a good default value for model `SaProt_35M_AF2`.

################################################################################
################################# DATASET ######################################
################################################################################

from saprot.utils.construct_lmdb import construct_lmdb
dataset_path = DATASET_HOME / f"{dataset_dropdown.value.split('. ')[1]}.csv"
construct_lmdb(dataset_path, LMDB_HOME, task_name, task_type)
dataset_task = LMDB_HOME / task_name
# dataset_task = Path('/content/saprot/SaProtHub/Thermostability_StructureSimilarity_70')


################################################################################
############################## CONFIG ##########################################
################################################################################

from saprot.config.config_dict import Default_config
config = copy.deepcopy(Default_config)

################################################################################
config.setting.run_mode = "train"

if task_type in ["classification", "token_classification"]:
  config.model.kwargs.num_labels = num_of_categories.value

config.model.model_py_path = model_type_dict[task_type]
config.model.save_path = str(ADAPTER_HOME / f"{task_type}" / f"{task_name}")
config.model.kwargs.config_path = model_path

config.dataset.dataset_py_path = dataset_type_dict[task_type]
config.dataset.train_lmdb = str(dataset_task / "train")
config.dataset.valid_lmdb = str(dataset_task / "valid")
config.dataset.test_lmdb = str(dataset_task / "test")
config.dataset.kwargs.tokenizer = model_path

config.Trainer.accelerator = "gpu" if torch.cuda.is_available() else "cpu"

###############################################################################



# epoch, batch size, num_workers
config.Trainer.max_epochs = max_epochs

config.dataset.dataloader_kwargs.batch_size = batch_size
config.Trainer.accumulate_grad_batches= int(64 / batch_size)

config.dataset.dataloader_kwargs.num_workers = num_workers

# config.dataset.kwargs.mask_struc_ratio= 1.0

# learning rate
config.model.lr_scheduler_kwargs.init_lr = learning_rate

# lora
config.model.kwargs.use_lora = use_lora

if use_adapter:
  if ". " in adapter_combobox.value:
    adapter_path = ADAPTER_HOME / task_type / adapter_combobox.value.split('. ')[1]
  else:
    adapter_path = adapter_combobox.value
  config.model.kwargs.lora_config_path = adapter_path
else:
  config.model.kwargs.lora_config_path = None

if config.setting.run_mode == 'train':
  config.model.kwargs.lora_inference = False
if config.setting.run_mode == 'test':
  config.model.kwargs.lora_inference = True

# trainer
config.Trainer.limit_train_batches=limit_train_batches
config.Trainer.limit_val_batches=limit_val_batches
config.Trainer.limit_test_batches=limit_test_batches
config.Trainer.val_check_interval=val_check_interval

# strategy
strategy = {
    # - deepspeed
    # 'class': 'DeepSpeedStrategy',
    # 'stage': 2

    # - None
    # 'class': None,

    # - DP
    # 'class': 'DataParallelStrategy',

    # - DDP
    # 'class': 'DDPStrategy',
    # 'find_unused_parameter': True
}
config.Trainer.strategy = strategy



################################################################################
############################## Run the task ####################################
################################################################################

print(font.RED+font.BOLD+f"Training task type: {task_type}"+font.RESET)
print(font.RED+font.BOLD+f"Dataset: {dataset_task}: {dataset_path}"+font.RESET)
print(font.RED+font.BOLD+f"Model: {config.model.kwargs.config_path}"+font.RESET)
if use_adapter:
  print(font.RED+font.BOLD+f"Loaded Adapter: {config.model.kwargs.lora_config_path}"+font.RESET)

from saprot.scripts.training import finetune
print(config)
finetune(config)


################################################################################
############################## Save the adapter ################################
################################################################################

print(font.RED+font.BOLD)
print(f"Adapter is saved to \"{config.model.save_path}\" on colab Sever")
print(font.RESET)

if download_adapter_to_your_computer:
  adapter_zip = Path(config.model.save_path) / f"{task_name}.adapter.zip"
  !cd $config.model.save_path && zip -r $adapter_zip "adapter_config.json" "adapter_model.safetensors" "README.md"
  # with zipfile.ZipFile(adapter_zip, 'w') as zipf:
  #   zip_files = [str(file_path) for file_path in Path(config.model.save_path).glob("*")]
  #   print(zip_files)
  #   for file in zip_files:
  #     zipf.write(file, Path(file).name)

  print("Downloading adapter to your local computer")
  files.download(adapter_zip)

In [None]:
#@title **3.5: Login HuggingFace**
################################################################################
###################### Login HuggingFace #######################################
################################################################################

from huggingface_hub import notebook_login
notebook_login()


In [None]:
#@title **3.6: Upload Model**

#@markdown Your Huggingface adapter repository names follow the format `<username>/<task_name>`.
################################################################################
########################## Upload Model  #######################################
################################################################################
from huggingface_hub import HfApi, Repository, ModelFilter
from pathlib import Path
import subprocess

api = HfApi()

user = api.whoami()

repo_name = user['name'] + '/' + task_name
repo_list = []
for repo in api.list_models(filter=ModelFilter(author=user['name'])):
  repo_list.append(repo.id)
if repo_name not in repo_list:
  api.create_repo(repo_name, private=False)

local_dir = Path("/content/saprot/model_to_push") / repo_name
local_dir.mkdir(parents=True, exist_ok=True)
repo = Repository(local_dir=local_dir, clone_from=repo_name)
command = f"cp {config.model.save_path}/* {local_dir}/"
subprocess.run(command, shell=True)

repo.push_to_hub(commit_message="Upload adapter model")

# **4: Predict with PLM**

## 4.1: Classification&Regression

In [None]:
#@title 4.1.1: Task Config
from transformers import EsmTokenizer
import torch
import copy
import datetime

################################################################################
################################# task config ##################################
################################################################################

# # @markdown Please ensure that the selected task type aligns with the training task type of the model you intend to utilize.

## @markdown If you are conducting inference on a classification task, please ensure that the `num_of_category` matches the number of categories in the training dataset. Otherwise, you do not need to assign `num_of_category`.
task_objective = "Classify protein sequences (classification)" # @param ["Classify protein sequences (classification)", "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein", "Classify each Amino Acid (token classification), e.g. Binding site detection"]

task_type = task_type_dict[task_objective]
# num_of_categories = 10 #@param {type:"integer"}

if task_type in ["classification", 'token_classification']:

  print(font.RED+font.BOLD+'Enter the number of category in your training dataset here:'+font.RESET)
  num_of_categories = ipywidgets.BoundedIntText(
                                              # value=7,
                                              min=2,
                                              # max=10,
                                              step=1,
                                              # description='num_of_category: \n',
                                              disabled=False)
  num_of_categories.layout.width = "100px"
  display(num_of_categories)


In [None]:
#@title 4.1.2: Select Dataset

#@title 2： Prepare your protein sequences for inference

#@markdown You have two options to provide your protein sequences:
#@markdown - **Single Sequence: Enter a single SA sequence** into the input box, you can get a SA Sequence by clicking <a href="#get_SA_seq">here</a>
#@markdown - **Multiple Sequences: Select a dataset**

mode = "Single Sequence" #@param ['Single Sequence', 'Multiple Sequences']

##@markdown <img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/InferenceFileFormat.png" height="256" align="center" style="height:256px">

# seq = "MdEpDvRvRvApEvKvSvCvEvQvAvCvElSvLvKvRlQvDvYlEvMvAlLlKvHvClTvEvAlLlLvSvLlGvQpYdSdMpAvDpFdTpGdPpCcPpLlEvIsEvRlIsKnIlElSsLlLlYsRnIlAlSsFcLlQsLvKvNvYvVvQvAnDvEvDsClRcHvVlLcGvEvGcLvAvKvGpEvDvAsFsRlAvVsLlClCvMcQvLvKvGvKnLlQvPsVsSlTvIsLlAvKvSsLvTpGdEpSdLpNdGpMdVdTcKpDsLsTvRvLvKnTvLsLnSvEvTsEnTvAvTsSvNvAvLvSvGpYdHpVdEdDpLpDpEpGvSvCvNpGpWdHdFdRaPaPaPfRfGqIdTqSdSlEvEpYaTfLaCpKvRcFcLvEvQpGvIdCdRlYfGpAsQqCfTnShAhHsSdQpEvEsLsAvEvWnQsKvRvYlAvSvRlLvIvKvLlKvQvQvNpEvNvKdQdLdSdGdSdYpMlEvTvLlIvEvKcWlMvNpSdLpScPnEqKqVaLeSdEqCdIdEpGqVkKdVkEdHkNpPdDdLlSeVaTeVfSfTaKqKwSdHkQdTkWmTkFiAkLiTkCgKaPqAaRfMdLfYfRkVkAaLkLnYqDlAqHqRlPvHfFkSaIwIdAwIkSwAkGdDdSpTpTdQiVdStQdEdVgPdEvNsChQrEmWdIgGdGdKdMdAdQpNdGpLdDrHiYmViYmKiViGmImAmFgNtTgEnIgFfGaTkFhRkQiTwIiVwFtDdFrGvLhEtPrVtLhMtQdRiVhMiIyDgAyAfSaTvEpDfLqElYqLqMpHpAaKvQpQvLvVvTpTlAdKdRdWdDdSpSvSqKfTaIeIdDeFaEpPpNdEpTdTdDpLvElKvSvLlLcIvRvYlQdIqPdLpSdApDvQvLlFpTdQpSlVcLvDdKlSdLqTdKqSvNsYpQlSsRnLvHsDsLlLlYsIlElEvIvAlQlYlKvElIqSsKvFfNfLdKwVwQkLkQwIkLaAqSkFdMfLdTdGdVqSfGtGdAmKdYgAdQdNdGqQkLiFkGiRkFtKwLtTpEaTfLqSaEpDqTaLpAsGsRvLsVcMvTfKfVfNfAkVkYkLkLfPfVdPdKpQdKpLdVpQdTdQpGpThKgEgKyViYyEiAwTtIwEpEdKaTwKrEtYmItFmLtRmLgShRnEvCrCcEvEvLvNvLdRdPhDgCdDiTtQiViEgLmQrFiQdLgNnRcLsPvLsCsElMlHsYvAlLsDvRlIdKpDdNcGcVlLlFvPpDdIlSvMdTqPfTpIqPpWqSpPpNvRnQpWqDdEpQlLqDdPpRpLqNdApKlQlKsElAlVlLsAlIqTqTrPdLlAvIgQqLdPhPaVfLeIeIaGePdYfGqTlGcKqTlFvTsLvAlQsAsVvKvHvIqLlQvQdQpEpTaRaIeLeIeCeTeHqSdNpSlAsAlDlLcYsIqKpDpYrLqHvPvYvVvEvAvGvNcPvQlAsRqPeLaReVdYdFdRfNpRpWqVlKqTlVhHdPpVsVvHvQvYrCfLqIaSdSpAvHpSrTtFgQdMqPdQeKlEvDnIsLvKvHhRrVyVyVgVyTyLlNvTvSlQlYsLvCvQvLrDpLdEaPqGcFsFhTlHeIyLeLyDeEqAqAlQaAdMfEqCsEsTvIsMsPnLvAsLsAyTdQsNsTrRhIhVyLyAyGhDhHvMlQaLhSfPdFdVgYsSnEpFsAsRvEsRsNvLrHrVgStLsLsDnRvLvYlEvHvYhPdApEpFrPsCsRhIrLyLrCaEeNgYqRfSaHdEqAvIlIqNcYlTcScEsLrFrYpEvGvKrLhMdAySpGhKpQfPdAaHqKpDpFaYaPqLqTfFeFaTeAqRpGfEaDwVdQaEdKpNtShTnAwFiYaNtNvAsElVlFvEvVvVlEvRvVlEvEvLcRvRvKrWpPdVcAsWvGpKpLdDdDlGlSlIaGeVeVeTePqYtAdDvQnVlFvRsIsRcAvEvLcRvKvKvRvLsSnDnVhNhVyEdRyVlLlNqVcQpGvKaQaFhRqVeLyFeLySeTqVrRdTdRpHvTvCqKdHdKpQpTaPvIdKdKpKpEdQsLnLdEpDcSdTnEpDrLgDcYsGqFqLlSaNdYsKnLsLvNsTsArIsTsRrAhQhShLhVyAyVyVyGyDhPlIlAsLsClSnIgGyRrCnRnKsFsWsElRvFsIqAvLvCcHqEvNvSvSrLyHpGdIdThFpEvQrIsKvAvQvLsEvAvLsEnLsKdKnTpYpVpLpNpPsLnAhPnEdFpIdPpRsAsLsRhLhQpHpSvGdSdTdNdKdQdQdQdSdPdPdKdGdKdSdLdHdHdTdQdNdDdHdFdQdNdDdGdIdVdQdPdNtPgShVdLdIdGhNdPdIdRdAdYdTdPdPdPdPdLdGdPdHdPdNdLdGdKdSdPdSdPpVpQdRdIdDdPpHdTdGdTdSdIdLdYdVdPpAdVdYdGdGdNdVdVdMdSdVdPdLaPdVdPdWdTdGfYdQdGdRdFdAdVdDdPdRdIdIdTdHdQdAdAdMdAtYdNdMdNdLdLdQdTdHdGdRdGdSdPdIdPdYdGdLdGdHdHdPdPdVdTdIdGdQdPdQdNdQdHdQdEdKdDdQdHdEdQdNdRdNdGdKdSdDdTdNdNdSdGdPdEdIdNdKdIdRdTdPdEdKdKdPdTdEdPdKdQdVdDdLdEdSdNdPdQdNdRdSdPdEdSdRdPdSdVdVdYdPdSdTdKdFdPdRdKyDdNdLdNdPwRfHdIdNdLdPdLdPpAdPpHpApQdYdAdIdPvNpRpHpFpHdPdLdPdQdLdPdRdPdPdFdPdIdPdQdQdHdTdLdLdNdQdQdQyNdNdLdPdEdQdPdNdQdIdPdPdQdPdNdQdVdVdQdQdQdSdQdLdNdQdQdPdQdQdPdPdPdQdLdSdPdAdYdQdAdGdPdNdNdAdFdFdNdSdAdVdAdHdRdPdQdSdPdPdAdEdAdVdIdPdEdQdQdPdPdPdMdLdQdEdGdHdSdPdLdRdAdIdAdQdPdGdPdIdLdPdSdHdLdNdSdFdIdDdEdNdPdSdGdLdPdIdGdEyAdLdDdRdIdHdGdSdVdAdLdEdTdLdRdQdQdQdAdRdFdQdQdWdSdEdHdHdAdFdLdSdQdGdSdAdPdYdPdHdHdHdHdPdHdLdQdHdLdPdQdPdPdLdGdLdHdQdPdPdVdRdAdDdWdKdLdTdSdSdAdEdDdEdVdEdTdTdYdSdRdFdQdDyLdIdRdEdLdSdHdRdDdQdSdEdTdRdEdLdAdEdMdPrPtPdQpSdRpLdLsQdYdRdQdVdQdSdRdSdPdPdAdVdPdSdPdPdSdSdTdDdHdSdSdHdFdSdNdFdNdDdNdSdRdDdIdEdVdAdSdNdPdAdFdPdQdRdLdPdPdQdIdFdNdSdPdFdSdLdPdSdEdHdLdAdPdPdPdLdKdYdLdAdPdDyGdAdWdTdFdAdNdLdQdQdNdHdLdMdGdPdGdFdPdYdGdLdPdPdLdPdHdRdPdPdQdNdPdFdVdQdIdQdNdHdQdHdAdIdGdQdEdPdFdHdPdLdSdSdRdTgVdSgSdSdSdLdPdSdLdEdEdYdEdPdRdGdPdGdRdPdLdYdQdRdRdIdSdSdSdSdVdQdPdCdSdEdEdVdSdTdPdQyDdSdLdAdQdCyKdEdLdQdDdHdSdNdQdSdSdFdNdFdSdSdPdEdSdWdVdNdTdTdSdSdTdPdYdQdNdIdPdCdNdGdSdSdRdTdAdQdPdRdEdLdIdAdPdPdKdTyVdKdPdPdEdDdQdLdKdSdEdNdLdEdVdSdSdSdFdNdYdSdVdLdQdHdLdGdQdFdPdPdLdMdPdNdKdQdIdAdEdSdAdNdSdSdSdPdQdSySdAdGdGdKdPdAdMdSdYdAdSdAdLdRdAdPdPdKdPdRdPdPdPdEdQdAdKdKdSdSdDdPpLpSpLpFpQpEpLpSdLdGdSdSdSdGdSdNdGdFdYdSsYpFgKd" # @param {type:"string"}

# print(font.RED+font.BOLD+f"Data type: {data_type}"+font.RESET)

if mode == "Multiple Sequences":
  dataset_dropdown = select_dataset()
else:
  input_seq = ipywidgets.Text(
    value=None,
    placeholder='Paste the Structure-Aware Sequence here',
    description='SA Sequence:',
    disabled=False)
  input_seq.layout.width = '500px'
  display(input_seq)





In [None]:
#@title 4.1.3: Select Model

#@markdown As we use Parameter-Efficient Fine-Tuning Technique, which allows us to store model weights into an small adapter without adjusting the original model weights during training, it's necessary to specify both the original model and adapter for prediction.
#@markdown
#@markdown 1. Select a **base model**
#@markdown
#@markdown 2. By running this cell, you will see an **adapter combobox**. We provide two ways to select your adapter:
#@markdown  - Select a **local adapters** from the combobox.
#@markdown   - Enter a **huggingface repository name** to the combobox. (e.g. "SaProtAdapters/DeepLoc_cls10_35M")
#@markdown
#@markdown You can also find some officical adapters in [here](https://huggingface.co/SaProtAdapters)
model_path = "westlake-repl/SaProt_35M_AF2" #@param ['westlake-repl/SaProt_35M_AF2', 'westlake-repl/SaProt_650M_AF2'] {allow-input:true}

use_adapter = True # @param {type:"boolean"}
if use_adapter:
  adapter_combobox = select_adapter()

# print(font.RED+font.BOLD+f"Adapter: {adapter_combobox.value}"+font.RESET)

In [None]:
#@title 4.1.4: Get your Result
from transformers import EsmTokenizer
import torch
import copy
import datetime
import sys
from saprot.model.esm.esm_classification_model import EsmClassificationModel
from saprot.model.esm.esm_regression_model import EsmRegressionModel

from saprot.scripts.training import my_load_model


################################################################################
################################# task config ##################################
################################################################################


# @markdown Click the run button to make prediction.

# @markdown <font color="red">**Note that:**</font> When predicting a category, the index of categories starts from zero.

if use_adapter:
  if adapter_combobox.value =='':
    print("Please select an adatper!")
    sys.exit()

  if ". " in adapter_combobox.value:
    adapter_path = ADAPTER_HOME / task_type / adapter_combobox.value.split('. ')[1]
  else:
    adapter_path = adapter_combobox.value
################################################################################
##################################### config ###################################
################################################################################
from saprot.config.config_dict import Default_config
config = copy.deepcopy(Default_config)

if task_type in [ "classification", "token_classification"]:
  config.model.kwargs.num_labels = num_of_categories.value

config.model.model_py_path = model_type_dict[task_type]
# config.model.save_path = model_save_path
config.model.kwargs.config_path = model_path
if use_adapter:
  config.model.kwargs.lora_config_path = adapter_path
else:
  config.model.kwargs.lora_config_path = None

config.model.kwargs.use_lora = True
config.model.kwargs.lora_inference = True

################################################################################
################################### inference ##################################
################################################################################
from peft import PeftModelForSequenceClassification

model = my_load_model(config.model)
tokenizer = EsmTokenizer.from_pretrained(config.model.kwargs.config_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print("#"*60)
print(font.RED+font.BOLD+f"Inference task type: {task_type}"+font.RESET)
if mode == "Multiple Sequences":
  dataset_path = DATASET_HOME / f"{dataset_dropdown.value.split('. ')[1]}.csv"
  print(font.RED+font.BOLD+f"Dataset: {dataset_path}"+font.RESET)
else:
  print(font.RED+font.BOLD+f"Dataset: {input_seq.value}"+font.RESET)
print(font.RED+font.BOLD+f"Model: {model_path}"+font.RESET)
if use_adapter:
  print(font.RED+font.BOLD+f"Adapter: {adapter_path}"+font.RESET)

outputs_list=[]

if mode == "Multiple Sequences":
  timestamp = str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
  output_file = OUTPUT_HOME / f'output_{timestamp}.txt'
  df = pd.read_csv(saseq_csv_path)
  for index in tqdm(range(len(df))):
    seq = df['Sequence'].iloc[index]
    inputs = tokenizer(seq, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model(inputs)
    outputs_list.append(outputs)
else:
  print("You are making inference based on a sequence that you entered")
  seq = input_seq.value
  inputs = tokenizer(seq, return_tensors="pt")
  inputs = {k: v.to(device) for k, v in inputs.items()}
  outputs = model(inputs)
  outputs_list.append(outputs)

################################################################################
##################################### output ###################################
################################################################################

print()
print("#"*60)
print(font.RED+font.BOLD+"outputs:"+font.RESET)

if task_type == "classification":
  import torch.nn.functional as F
  softmax_output_list = [F.softmax(output, dim=1).squeeze().tolist() for output in outputs_list]
  for index, output in enumerate(softmax_output_list):
    print(f"For Sequence {index}, Prediction: Category {output.index(max(output))}, Probability: {output}")
elif task_type == "regression":
  output_list = [output.squeeze().tolist() for output in outputs_list]
  for index, output in enumerate(outputs_list):
    print(f"For Sequence {index}, Prediction: Value {output.item()}")
elif task_type == "token_classification":
  import torch.nn.functional as F
  softmax_output_list = [F.softmax(output, dim=-1).squeeze().tolist() for output in outputs_list]
  # print(softmax_output_list)
  print("The probability of each category:")
  for seq_index, seq in enumerate(softmax_output_list):
    seq_prob_df = pd.DataFrame(seq)
    print('='*100)
    print(f'Sequence {seq_index + 1}:')
    print(seq_prob_df[1:-1])




## 4.2: Mutational Effect

In [None]:
#@title 4.2.1: Sequences and Mutations

################################################################################
################################################################################
################################################################################


#@markdown ### 1. Input and Output

#@markdown You have four different combinations of **mutation task** and **mode** to choose from:

#@markdown |Combination| Input | Output |
#@markdown | --- | --- | --- |
#@markdown |`Single-site or Multi-site mutagenesis` + `Single Sequence`| Enter **a SA sequence** and **a mutation information**| a score of the mutation |
#@markdown |`Single-site or Multi-site mutagenesis` + `Multiple Sequences`| Select **a dataset** and upload **a .csv file containing mutation information**| a .csv file containing the scores of mutations |
#@markdown |`Saturation mutagenesis` + `Single Sequence`| Enter **a SA sequence**| a .csv file containing the scores of all mutation on every position of the sequence |
#@markdown |`Saturation mutagenesis` + `Multiple Sequences`| Select **a dataset**| a .zip file containing the .csv files of the Saturation mutagenesis on every sequence |

#@markdown <img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/mutation_input_output.png" height="500" width="800px" align="center">

mutation_task = "Saturation mutagenesis" #@param ["Single-site or Multi-site mutagenesis", "Saturation mutagenesis"]
mode = "Multiple Sequences" #@param ['Single Sequence', 'Multiple Sequences']

#@markdown You can obtain a single SA sequence from <a href="#get_SA_seq">here</a>

#@markdown Click the run button to provide your **Sequences and Mutations**


##@markdown - **Single Sequence: Enter a single SA sequence** into the input box, you can get a SA Sequence by clicking <a href="#get_SA_seq">here</a>
##@markdown - **Multiple Sequences: Select a dataset** and then **upload a .csv file as mutations**

################################################################################
################################################################################
################################################################################

#@markdown <br>

#@markdown ### 2. Mutation Information of `Single-site or Multi-site mutagenesis`
#@markdown Here is the detail about the representation of **mutation information**: <a name="mutation info"></a>

#@markdown | mode | mutation information|
#@markdown | --- | --- |
#@markdown | Single-site mutagenesis | H87Y |
#@markdown | Multi-site mutagenesis | H87Y:V162M:P179L:P179R |

#@markdown - For `Single-site mutagenesis`, we use a term like "H87Y" to denote the mutation, where the first letter represents the **original amino acid**, the number in the middle represents the **mutation site**, and the last letter represents the **mutated amino acid**,

#@markdown - For `Multi-site mutagenesis`, we use a colon ":" to connect each single-site mutations, such as "H87Y:V162M:P179L:P179R".

################################################################################
################################################################################
################################################################################

#@markdown <br>

#@markdown ### 3. Format of the uploaded .csv file containing mutation information

#@markdown For Multiple Sequences, you are required to **upload an additional .csv file** as your mutation information.
#@markdown <font color=red>Please ensure that each mutation in the mutation CSV file corresponds to each Sequence in the dataset CSV file.</font>

#@markdown <img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/MutationFormat.png" height="256" align="center" style="height:256px">


# seq = "MdEpDvRvRvApEvKvSvCvEvQvAvCvElSvLvKvRlQvDvYlEvMvAlLlKvHvClTvEvAlLlLvSvLlGvQpYdSdMpAvDpFdTpGdPpCcPpLlEvIsEvRlIsKnIlElSsLlLlYsRnIlAlSsFcLlQsLvKvNvYvVvQvAnDvEvDsClRcHvVlLcGvEvGcLvAvKvGpEvDvAsFsRlAvVsLlClCvMcQvLvKvGvKnLlQvPsVsSlTvIsLlAvKvSsLvTpGdEpSdLpNdGpMdVdTcKpDsLsTvRvLvKnTvLsLnSvEvTsEnTvAvTsSvNvAvLvSvGpYdHpVdEdDpLpDpEpGvSvCvNpGpWdHdFdRaPaPaPfRfGqIdTqSdSlEvEpYaTfLaCpKvRcFcLvEvQpGvIdCdRlYfGpAsQqCfTnShAhHsSdQpEvEsLsAvEvWnQsKvRvYlAvSvRlLvIvKvLlKvQvQvNpEvNvKdQdLdSdGdSdYpMlEvTvLlIvEvKcWlMvNpSdLpScPnEqKqVaLeSdEqCdIdEpGqVkKdVkEdHkNpPdDdLlSeVaTeVfSfTaKqKwSdHkQdTkWmTkFiAkLiTkCgKaPqAaRfMdLfYfRkVkAaLkLnYqDlAqHqRlPvHfFkSaIwIdAwIkSwAkGdDdSpTpTdQiVdStQdEdVgPdEvNsChQrEmWdIgGdGdKdMdAdQpNdGpLdDrHiYmViYmKiViGmImAmFgNtTgEnIgFfGaTkFhRkQiTwIiVwFtDdFrGvLhEtPrVtLhMtQdRiVhMiIyDgAyAfSaTvEpDfLqElYqLqMpHpAaKvQpQvLvVvTpTlAdKdRdWdDdSpSvSqKfTaIeIdDeFaEpPpNdEpTdTdDpLvElKvSvLlLcIvRvYlQdIqPdLpSdApDvQvLlFpTdQpSlVcLvDdKlSdLqTdKqSvNsYpQlSsRnLvHsDsLlLlYsIlElEvIvAlQlYlKvElIqSsKvFfNfLdKwVwQkLkQwIkLaAqSkFdMfLdTdGdVqSfGtGdAmKdYgAdQdNdGqQkLiFkGiRkFtKwLtTpEaTfLqSaEpDqTaLpAsGsRvLsVcMvTfKfVfNfAkVkYkLkLfPfVdPdKpQdKpLdVpQdTdQpGpThKgEgKyViYyEiAwTtIwEpEdKaTwKrEtYmItFmLtRmLgShRnEvCrCcEvEvLvNvLdRdPhDgCdDiTtQiViEgLmQrFiQdLgNnRcLsPvLsCsElMlHsYvAlLsDvRlIdKpDdNcGcVlLlFvPpDdIlSvMdTqPfTpIqPpWqSpPpNvRnQpWqDdEpQlLqDdPpRpLqNdApKlQlKsElAlVlLsAlIqTqTrPdLlAvIgQqLdPhPaVfLeIeIaGePdYfGqTlGcKqTlFvTsLvAlQsAsVvKvHvIqLlQvQdQpEpTaRaIeLeIeCeTeHqSdNpSlAsAlDlLcYsIqKpDpYrLqHvPvYvVvEvAvGvNcPvQlAsRqPeLaReVdYdFdRfNpRpWqVlKqTlVhHdPpVsVvHvQvYrCfLqIaSdSpAvHpSrTtFgQdMqPdQeKlEvDnIsLvKvHhRrVyVyVgVyTyLlNvTvSlQlYsLvCvQvLrDpLdEaPqGcFsFhTlHeIyLeLyDeEqAqAlQaAdMfEqCsEsTvIsMsPnLvAsLsAyTdQsNsTrRhIhVyLyAyGhDhHvMlQaLhSfPdFdVgYsSnEpFsAsRvEsRsNvLrHrVgStLsLsDnRvLvYlEvHvYhPdApEpFrPsCsRhIrLyLrCaEeNgYqRfSaHdEqAvIlIqNcYlTcScEsLrFrYpEvGvKrLhMdAySpGhKpQfPdAaHqKpDpFaYaPqLqTfFeFaTeAqRpGfEaDwVdQaEdKpNtShTnAwFiYaNtNvAsElVlFvEvVvVlEvRvVlEvEvLcRvRvKrWpPdVcAsWvGpKpLdDdDlGlSlIaGeVeVeTePqYtAdDvQnVlFvRsIsRcAvEvLcRvKvKvRvLsSnDnVhNhVyEdRyVlLlNqVcQpGvKaQaFhRqVeLyFeLySeTqVrRdTdRpHvTvCqKdHdKpQpTaPvIdKdKpKpEdQsLnLdEpDcSdTnEpDrLgDcYsGqFqLlSaNdYsKnLsLvNsTsArIsTsRrAhQhShLhVyAyVyVyGyDhPlIlAsLsClSnIgGyRrCnRnKsFsWsElRvFsIqAvLvCcHqEvNvSvSrLyHpGdIdThFpEvQrIsKvAvQvLsEvAvLsEnLsKdKnTpYpVpLpNpPsLnAhPnEdFpIdPpRsAsLsRhLhQpHpSvGdSdTdNdKdQdQdQdSdPdPdKdGdKdSdLdHdHdTdQdNdDdHdFdQdNdDdGdIdVdQdPdNtPgShVdLdIdGhNdPdIdRdAdYdTdPdPdPdPdLdGdPdHdPdNdLdGdKdSdPdSdPpVpQdRdIdDdPpHdTdGdTdSdIdLdYdVdPpAdVdYdGdGdNdVdVdMdSdVdPdLaPdVdPdWdTdGfYdQdGdRdFdAdVdDdPdRdIdIdTdHdQdAdAdMdAtYdNdMdNdLdLdQdTdHdGdRdGdSdPdIdPdYdGdLdGdHdHdPdPdVdTdIdGdQdPdQdNdQdHdQdEdKdDdQdHdEdQdNdRdNdGdKdSdDdTdNdNdSdGdPdEdIdNdKdIdRdTdPdEdKdKdPdTdEdPdKdQdVdDdLdEdSdNdPdQdNdRdSdPdEdSdRdPdSdVdVdYdPdSdTdKdFdPdRdKyDdNdLdNdPwRfHdIdNdLdPdLdPpAdPpHpApQdYdAdIdPvNpRpHpFpHdPdLdPdQdLdPdRdPdPdFdPdIdPdQdQdHdTdLdLdNdQdQdQyNdNdLdPdEdQdPdNdQdIdPdPdQdPdNdQdVdVdQdQdQdSdQdLdNdQdQdPdQdQdPdPdPdQdLdSdPdAdYdQdAdGdPdNdNdAdFdFdNdSdAdVdAdHdRdPdQdSdPdPdAdEdAdVdIdPdEdQdQdPdPdPdMdLdQdEdGdHdSdPdLdRdAdIdAdQdPdGdPdIdLdPdSdHdLdNdSdFdIdDdEdNdPdSdGdLdPdIdGdEyAdLdDdRdIdHdGdSdVdAdLdEdTdLdRdQdQdQdAdRdFdQdQdWdSdEdHdHdAdFdLdSdQdGdSdAdPdYdPdHdHdHdHdPdHdLdQdHdLdPdQdPdPdLdGdLdHdQdPdPdVdRdAdDdWdKdLdTdSdSdAdEdDdEdVdEdTdTdYdSdRdFdQdDyLdIdRdEdLdSdHdRdDdQdSdEdTdRdEdLdAdEdMdPrPtPdQpSdRpLdLsQdYdRdQdVdQdSdRdSdPdPdAdVdPdSdPdPdSdSdTdDdHdSdSdHdFdSdNdFdNdDdNdSdRdDdIdEdVdAdSdNdPdAdFdPdQdRdLdPdPdQdIdFdNdSdPdFdSdLdPdSdEdHdLdAdPdPdPdLdKdYdLdAdPdDyGdAdWdTdFdAdNdLdQdQdNdHdLdMdGdPdGdFdPdYdGdLdPdPdLdPdHdRdPdPdQdNdPdFdVdQdIdQdNdHdQdHdAdIdGdQdEdPdFdHdPdLdSdSdRdTgVdSgSdSdSdLdPdSdLdEdEdYdEdPdRdGdPdGdRdPdLdYdQdRdRdIdSdSdSdSdVdQdPdCdSdEdEdVdSdTdPdQyDdSdLdAdQdCyKdEdLdQdDdHdSdNdQdSdSdFdNdFdSdSdPdEdSdWdVdNdTdTdSdSdTdPdYdQdNdIdPdCdNdGdSdSdRdTdAdQdPdRdEdLdIdAdPdPdKdTyVdKdPdPdEdDdQdLdKdSdEdNdLdEdVdSdSdSdFdNdYdSdVdLdQdHdLdGdQdFdPdPdLdMdPdNdKdQdIdAdEdSdAdNdSdSdSdPdQdSySdAdGdGdKdPdAdMdSdYdAdSdAdLdRdAdPdPdKdPdRdPdPdPdEdQdAdKdKdSdSdDdPpLpSpLpFpQpEpLpSdLdGdSdSdSdGdSdNdGdFdYdSsYpFgKd" # @param {type:"string"}

# print(font.RED+font.BOLD+f"Data type: {data_type}"+font.RESET)

if mode == "Multiple Sequences":
  dataset_dropdown = select_dataset()
  if mutation_task == "Single-site or Multi-site mutagenesis":
    print("="*100)
    print(font.RED+font.BOLD+"please upload a .csv mutation file!"+font.RESET)
    upload_path = Path().resolve() / "saprot" / "upload_files"
    mutation_csv_path = upload_file(upload_path)

else:
  input_seq = ipywidgets.Text(
    value=None,
    placeholder='Paste the Structure-Aware Sequence here',
    # description='SA Sequence:',
    disabled=False)
  input_seq.layout.width = '500px'
  print(font.RED+font.BOLD+"Structure-Aware Sequence:"+font.RESET)
  display(input_seq)
  if mutation_task == "Single-site or Multi-site mutagenesis":
    input_mut = ipywidgets.Text(
      value=None,
      placeholder='Paste the Structure-Aware Sequence here',
      # description='SA Sequence:',
      disabled=False)
    print(font.RED+font.BOLD+"Mutation:"+font.RESET)
    input_mut.layout.width = '500px'
    display(input_mut)



In [None]:
#@title 4.2.2: Get your Result

################################################################################
################################# Task Info ####################################
################################################################################
model_path = "westlake-repl/SaProt_650M_AF2"

print(font.RED+font.BOLD)
print(f"Mutation task: {mutation_task}")
print(f"Mode: {mode}")
print(f"Model: {model_path}")
if mode == "Multiple Sequences":
  dataset_path = DATASET_HOME / f"{dataset_dropdown.value.split('. ')[1]}.csv"
  print(font.RED+font.BOLD+f"Dataset: {dataset_path}"+font.RESET)
else:
  print(font.RED+font.BOLD+f"Dataset: {input_seq.value}"+font.RESET)

print(font.RESET)

print(f"Predicting...")
timestamp = datetime.now().strftime("%y%m%d%H%M%S")

################################################################################
################################# load model ###################################
################################################################################

from saprot.model.esm.esm_foldseek_mutation_model import EsmFoldseekMutationModel

config = {
    "foldseek_path": None,
    "config_path": model_path,
    "load_pretrained": True,
}
model = EsmFoldseekMutationModel(**config)
tokenizer = model.tokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


################################################################################
########################### Single Sequence ####################################
################################################################################
if mode == "Single Sequence":

  seq = input_seq.value
  if mutation_task == "Single-site or Multi-site mutagenesis":
    mut = input_mut.value
    score = model.predict_mut(seq, mut)

    print()
    print("="*100)
    print(font.RED+font.BOLD+"Output:"+font.RESET)
    print(f"The score of mutation {mut} is {font.RED}{score}{font.RESET}")

  if mutation_task=="Saturation mutagenesis":
    timestamp = datetime.now().strftime("%y%m%d%H%M%S")
    output_path = OUTPUT_HOME / f'{timestamp}_prediction_output.csv'

    mut_dicts = []
    for pos in range(1, int(len(seq) / 2)+1):
      mut_dict = model.predict_pos_mut(seq, pos)
      mut_dicts.append(mut_dict)

    mut_list = [{'mutation': key, 'score': value} for d in mut_dicts for key, value in d.items()]
    df = pd.DataFrame(mut_list)
    df.to_csv(output_path, index=None)

    print()
    print("="*100)
    print(font.RED+font.BOLD+"Output:"+font.RESET)
    files.download(output_path)
    print(f"\n{font.RED+font.BOLD}The result has been saved to {output_path} and your local computer.{font.RESET}")

################################################################################
########################### Multiple Sequences #################################
################################################################################
if mode == "Multiple Sequences":

  dataset_path = DATASET_HOME / f"{dataset_dropdown.value.split('. ')[1]}.csv"
  dataset_df = pd.read_csv(dataset_path)
  results = []

  if mutation_task=="Single-site or Multi-site mutagenesis":
    # merge mutation info into dataset
    mutation_df = pd.read_csv(mutation_csv_path)
    assert(len(dataset_df) == len(mutation_df))
    merged_df = pd.concat([dataset_df, mutation_df], axis=1)

    for index, row in tqdm(merged_df.iterrows(), total=len(merged_df), leave=False, desc=f"Predicting"):
     seq = row['Sequence']
     mut_info = row['mutation']
     results.append(model.predict_mut(seq, mut_info))

    print()
    print("="*100)
    print(font.RED+font.BOLD+"Output:"+font.RESET)

    result_df = pd.DataFrame()
    result_df['Sequence'] = dataset_df['Sequence']
    result_df['mutation'] = mutation_df['mutation']
    result_df['score'] = results

    output_path = OUTPUT_HOME / f"{timestamp}_prediction_output_{Path(dataset_path).stem}.csv"
    result_df.to_csv(output_path, index=None)
    files.download(output_path)
    print(f"{font.RED+font.BOLD}The result has been saved to {output_path} and your local computer {font.RESET}")

  else:
    for index, row in tqdm(dataset_df.iterrows(), total=len(dataset_df), leave=False, desc=f"Predicting"):
      seq = row['Sequence']
      mut_dicts = []
      for pos in range(1, int(len(seq) / 2)+1):
        mut_dict = model.predict_pos_mut(seq, pos)
        mut_dicts.append(mut_dict)
      mut_list = [{'mutation': key, 'score': value} for d in mut_dicts for key, value in d.items()]
      result_df = pd.DataFrame(mut_list)
      results.append(result_df)

    print()
    print("="*100)
    print(font.RED+font.BOLD+"Output:"+font.RESET)

    zip_files = []
    for i in range(len(results)):
      output_path = OUTPUT_HOME / f"{timestamp}_prediction_output_{Path(dataset_path).stem}_Sequence{i+1}.csv"
      results[i].to_csv(output_path, index=None)
      zip_files.append(output_path)

    # zip and download zip to local computer
    zip_path = OUTPUT_HOME / f"{timestamp}_{Path(dataset_path).stem}.zip"
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for file in zip_files:
            zipf.write(file, os.path.basename(file))
    files.download(zip_path)
    print(f"{font.RED+font.BOLD}The result has been saved to {zip_path} and your local computer{font.RESET}")

#  5: (Optional) Get a Structure-Aware Sequence <a name="get_SA_seq"></a>


In [None]:
#@title 5.1: Input
# @markdown You can obtain a Structure-Aware Sequence by providing your sequence in one of three data types:
# @markdown - `Amino Acid Sequence` : Enter the Amino Acid Sequence into the box.
# @markdown - `UniProt ID` : Enter the UniProt ID into the box.
# @markdown - `PDB/CIF file` : Click the run button to upload a .pdb/.cif file.

data_type = "UniProt ID" # @param ["Amino Acid Sequence", "UniProt ID", "PDB/CIF file"]

if data_type == "Amino Acid Sequence":
  raw_data = ipywidgets.Text(
    value=None,
    placeholder='Paste the Amino Acid Sequence here',
    # description='AA Sequence:',
    disabled=False)
  raw_data.layout.width = '500px'
  print(font.RED + font.BOLD + 'Amino Acid Sequence:'+ font.RESET)
  display(raw_data)
  # sa_seq = get_single_SASequence_by_data_type(data_type, raw_data.value)

elif data_type == "UniProt ID":
  raw_data = ipywidgets.Text(
    value=None,
    placeholder='Paste the UniProt ID here',
    # description='UniProt ID:',
    disabled=False)
  raw_data.layout.width = '500px'
  print(font.RED + font.BOLD + 'UniProt ID:'+ font.RESET)
  display(raw_data)
  # sa_seq = get_single_SASequence_by_data_type(data_type, raw_data.value)

elif data_type == "PDB/CIF file":
  # print("Please Upload a .pdb/.cif file at 2.")
  # sa_seq = get_single_SASequence_by_data_type(data_type, None)

  # upload and unzip PDB file
  print(font.RED+font.BOLD+"Please upload a .pdb/.cif file"+font.RESET)
  pdb_file_path = upload_file(STRUCTURE_HOME)
  print("="*100)
  print(font.RED+font.BOLD+"Successfully upload your .pdb/.cif file!"+font.RESET)

  raw_data = EasyDict({})
  raw_data.value = [Path(pdb_file_path).stem]

# UniProt_ID = "P42694" # @param {type:"string"}

# print(sa_seq)





In [None]:
#@title 5.2: Output
#@markdown Click the run button to get the SA Sequence.
sa_seq = get_single_SASequence_by_data_type(data_type, raw_data.value)

print()
print("="*100)
print(f"Amino Acid Sequence: {sa_seq[0::2]}")
print(f"Structure Sequence: {sa_seq[1::2]}")
# print(f"Structure-Aware Sequence: {sa_seq}")
print("="*100)
print(font.RED + font.BOLD + "The Structure-Aware Sequence is here, double click to select and copy it:" + font.RESET)
print(sa_seq)


# 6: (Optional) Example tasks

In [None]:
#@title 6.1: Run example tasks directly
################################################################################
#               Advanced Config                #
################################################################################

batch_size = 2
num_workers = 2
max_epochs = 4

limit_train_batches=10
limit_val_batches=2
limit_test_batches=100
val_check_interval=0.2

use_lora = True
if use_lora == True:
  init_lr = 1.0e-3
else:
  init_lr = 1.0e-5

logger = False
WANDB_API_KEY = ""

model_path = "westlake-repl/SaProt_35M_AF2" # @param ["westlake-repl/SaProt_35M_AF2", "westlake-repl/SaProt_650M_AF2"]

################################################################################
#@title **Task Config**
from pathlib import Path
from easydict import EasyDict
import copy
import os
import subprocess
import torch
import time

#@markdown ### **Select an example task and Click the button**
task_name = "Thermostability"  #@param ['DeepLoc_cls2', 'DeepLoc_cls10', 'EC', 'GO_BP', 'GO_CC', 'GO_MF', 'MetalIonBinding', 'Thermostability']
example_tasks = ['DeepLoc', 'EC', 'GO', 'HumanPPI', 'MetalIonBinding', 'Thermostability', 'ClinVar', 'ProteinGym']
model_save_path = Path(f"/content/saprot/weights/{task_name}")

################################################################################
#                  Dataset                 #
################################################################################
download_dataset(task_name)

################################################################################
#                 Config                   #
################################################################################
cmd = f"from saprot.config.config_dict import {task_name}_config"
exec(cmd)
config = copy.deepcopy(locals()[f"{task_name}_config"])

config.model.save_path = model_save_path

config.Trainer.accelerator = "gpu" if torch.cuda.is_available() else "cpu"

################################################################################

config.dataset.kwargs.tokenizer = model_path
config.model.kwargs.config_path = model_path

config.Trainer.max_epochs = max_epochs
config.dataset.dataloader_kwargs.batch_size = batch_size
config.dataset.dataloader_kwargs.num_workers = num_workers

config.model.kwargs.use_lora = use_lora
config.model.lr_scheduler_kwargs.init_lr = init_lr

config.Trainer.limit_train_batches=limit_train_batches
config.Trainer.limit_val_batches=limit_val_batches
config.Trainer.limit_test_batches=limit_test_batches
config.Trainer.val_check_interval=val_check_interval

# wandb config
config.Trainer.logger = logger
config.setting.os_environ.WANDB_API_KEY = WANDB_API_KEY

################################################################################
#               Run the task                 #
################################################################################
start_time = time.time()

# if task_name in ['ClinVar', 'ProteinGym']:
#   config.model.kwargs.use_lora = None
#   if task_name != 'ClinVar':
#     from saprot.scripts.mutation_zeroshot import my_zeroshot
#     my_zeroshot(config)
#   end_time = time.time()
# else:
#   # download_adapter_to_your_computer = True # @param {type:"boolean"}
#   download_adapter_to_your_computer = True

#   from saprot.scripts.training import finetune

#   # config.model.kwargs.lora_config_path = "/content/saprot/weights/DeepLoc_cls2"
#   config.model.kwargs.lora_config_path = None

#   finetune(config)
#   end_time = time.time()

download_adapter_to_your_computer = True

# config.model.kwargs.lora_config_path = "/content/saprot/weights/DeepLoc_cls2
config.model.kwargs.lora_config_path = None

from saprot.scripts.training import finetune
finetune(config)
end_time = time.time()


################################################################################
#              save the adapter                #
################################################################################
if use_lora is True:
  print(f"Adapter is saved to \"{config.model.save_path}\"")
  if download_adapter_to_your_computer:
    adapter_zip = config.model.save_path / f"{task_name}.adapter.zip"
    !cd $config.model.save_path && zip -r $adapter_zip "adapter_config.json" "adapter_model.safetensors" "README.md"

    files.download(adapter_zip)

execution_time = end_time - start_time
print("Training time: ", execution_time, "seconds")

# 7: (Optional) Data Preparation

In [None]:
# @title 7.1: From `.fasta` to `.csv`
from Bio import SeqIO
import numpy as np

aa_seq_dict = { "Sequence": [],
                # "label": [],
                # "stage":[]
                }

fa_file_path = upload_file(Path("/content/saprot/upload_files"))
with fa_file_path.open("r") as fa:
  for record in tqdm(SeqIO.parse(fa, 'fasta'), leave=True):
      aa_seq_dict["Sequence"].append(str(record.seq))

fa_df = pd.DataFrame(aa_seq_dict)
print(fa_df[5:])

csv_file_path = f'/content/saprot/upload_files/{fa_file_path.stem}.csv'
fa_df.to_csv(csv_file_path, index=None)
files.download(csv_file_path)

################################################################################
############################ .fa 2 .csv and split ##############################
################################################################################

# automatically_split_dataset = False # @param {type:"boolean"}
# split = ['train', 'valid', 'test']

# aa_seq_dict = { "Sequence": [],
#                 "label": [],
#                 "stage":[]}



# if automatically_split_dataset:

#   fa_file_path = upload_file(Path("/content/saprot/upload_files"))
#   label = fa_file_path.stem

#   with fa_file_path.open("r") as fa:
#       for record in tqdm(SeqIO.parse(fa, 'fasta'), leave=True):
#           aa_seq_dict["Sequence"].append(str(record.seq))
#           aa_seq_dict["label"].append(label)
#   weights = [0.8, 0.1, 0.1]
#   aa_seq_dict["stage"] = np.random.choice(split, size=len(aa_seq_dict["Sequence"]), p=weights).tolist()

# else:
#   for i in range(3):
#     print(font.RED+font.BOLD+f"Please upload a .fa file as your {split[i]} dataset")
#     fa_file_path = upload_file(Path("/content/saprot/upload_files"))
#     label = fa_file_path.stem

#     with fa_file_path.open("r") as fa:
#         for record in tqdm(SeqIO.parse(fa, 'fasta')):
#             aa_seq_dict["Sequence"].append(str(record.seq))
#             aa_seq_dict["label"].append(label)
#             aa_seq_dict["stage"].append(split[i])

#     print()
#     print("="*100)

# fa_df = pd.DataFrame(aa_seq_dict)
# timestamp = datetime.now().strftime("%y%m%d%H%M%S")
# fa_df.to_csv(f'/content/saprot/upload_files/{timestamp}.csv', index=None)
# files.download(f'/content/saprot/upload_files/{timestamp}.csv')
# print(fa_df[5:])