# MaskedProteinEnT: Sample Training Notebook

This notebook demonstrates how to train a **MaskedProteinEnT** using the PyTorch framework. The model leverages masked language modeling for protein sequences.

**Repository**: [MaskedProteinEnT](https://github.com/Graylab/MaskedProteinEnT)

# Installation
- Requires access to a GPU with **CUDA 11.1**

In [None]:
!git clone https://github.com/Graylab/MaskedProteinEnT.git
!pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
!pip install -r requirements_torch191.txt

# Basic imports

In [2]:
import os
import shutil

# Downloading training datasets
- Run this cell to download the datasets required for either training or fine-tuning
- Datasets to the required urls for initial training are already uncommented

In [None]:
# Create the directory if it doesn't exist
! mkdir -p training_datasets

urls_content = """\
## Uncomment the URLs you want to download and comment out those you want to skip
## Required for initial training
https://zenodo.org/record/13831403/files/ids_train_casp12nr50_nr70Ig_nr40Others.fasta
https://zenodo.org/record/13831403/files/sidechainnet_casp12_50.pkl
## Fine-tuning datasets
# https://zenodo.org/record/13831403/files/AbSCSAbDAb_trainnr90_bkandcbcoords_aug2022.h5
# https://zenodo.org/record/13831403/files/ppi_trainset_5032_noabag_aug2022.h5
# https://zenodo.org/record/13831403/files/train_af_paired_nr70.h5
"""

# Write the content to urls.txt
with open('urls.txt', 'w') as f:
    f.write(urls_content)

# Download files in parallel using wget and xargs, ignoring commented lines
! grep -v '^#' urls.txt | xargs -n 1 -P 5 wget -nc -P training_datasets

# Clean up
! rm urls.txt

# Adjust model and training parameters

In [4]:
LAYERS=4
HEADS=8
DIM=256
OLD=0
BS=1
SS=50 # protein model was trained on 90ss
save_every=5
gmodel='egnn-trans-ma'
atom_types='backbone_and_cb'
NN=48
EPOCHS=10

# Set training seed, output directory and wandb entity

In [None]:
SEED=1
MODELS_DIR=f'models_out_dir_seed_{SEED}'
WANDB_ENTITY="YOUR_WANDB_ENTITY"

# SLURM Job and GPU Setup

The following code retrieves environment variables set by SLURM, a workload manager commonly used for scheduling jobs on HPC systems. It performs the following steps:

- **SLURM Environment Variables**: Retrieves the total number of tasks (`SLURM_NTASKS`) and the number of nodes (`SLURM_JOB_NUM_NODES`) for the current job, using default values of 6 and 1 if these are not set.
- **Process Calculation**: Calculates the number of processes per node (`n_procs`) by dividing the total number of tasks by the number of nodes.
- **GPU Information**: Retrieves the list of GPUs assigned to the job (`SLURM_STEP_GPUS`) and counts the number of GPUs being used.

In [None]:
# Get the number of tasks (SLURM_NTASKS) or default to 6
slurm_ntasks = int(os.getenv('SLURM_NTASKS', 6))  # Default to 6 if not set

# Get the number of nodes (SLURM_JOB_NUM_NODES) or default to 1
slurm_job_num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', 1))  # Default to 1 if not set

# Calculate n_procs (number of processes per node)
n_procs = slurm_ntasks // slurm_job_num_nodes

# Default value for num_gpus
num_gpus = 1

# Get the list of GPUs (SLURM_STEP_GPUS) or default to a single GPU
slurm_step_gpus = os.getenv('SLURM_STEP_GPUS', "0")  # Default to GPU 0 if not set
gpus_list = slurm_step_gpus.split(',')
num_gpus = len(gpus_list)

print(f"n_procs = {n_procs}")
print(f"num_gpus = {num_gpus}")

# Final preparation before running training

In [2]:
# COPY SIDECHAINNET TO LOCAL TEMPORARY DIRECTORY
sidechainnet = './training_datasets/sidechainnet_casp12_50.pkl'
sidechainnet_temp = '/tmp/sidechainnet_casp12_50.pkl'

if not os.path.exists(sidechainnet_temp):
    shutil.copy(sidechainnet, sidechainnet_temp)

gd2_dataset_ids = os.path.join(os.getcwd(), 'training_datasets', 'ids_train_casp12nr50_nr70Ig_nr40Others.fasta')

# Run training

In [None]:
# Run training using shell execution in Jupyter
# !python3 ./MaskedProteinEnT/train_masked_model.py \ ### Uncomment when final
!python3 ./train_masked_model.py \
  --save_every {save_every} --lr 0.00001 --batch_size {BS} \
  --heads {HEADS} --model_dim {DIM} --epochs {EPOCHS} --dropout 0.2 \
  --masking_rate_max 0.25 --topk_metrics 1 --layers {LAYERS} \
  --num_gpus {num_gpus} --crop_sequences \
  --scn_sequence_similarity {SS} --protein_gmodel {gmodel} \
  --lr_patience 350 --lr_cooldown 20 --max_ag_neighbors {NN} \
  --atom_types {atom_types} \
  --file_with_selected_scn_ids_for_training {gd2_dataset_ids} \
  --lightning_save_last_model --use_scn --num_procs {n_procs} \
  --output_dir {MODELS_DIR} --seed {SEED} --wandb_entity {WANDB_ENTITY}
