In [1]:
import os
import shutil

# For downloads
import requests
from concurrent.futures import ThreadPoolExecutor

In [2]:
# Create the directory if it doesn't exist
os.makedirs('training_datasets', exist_ok=True)

# List of file URLs
urls = [
    'https://zenodo.org/records/13831403/files/ids_train_casp12nr50_nr70Ig_nr40Others.fasta',
    'https://zenodo.org/records/13831403/files/sidechainnet_casp12_50.pkl',
    'https://zenodo.org/records/13831403/files/AbSCSAbDAb_trainnr90_bkandcbcoords_aug2022.h5',
    'https://zenodo.org/records/13831403/files/ppi_trainset_5032_noabag_aug2022.h5',
    'https://zenodo.org/records/13831403/files/train_af_paired_nr70.h5'
]

def download_file(url):
    local_filename = os.path.join('training_datasets', os.path.basename(url))
    # Skip file if it exists already
    if os.path.exists(local_filename):
        print(f'Skipping download: {local_filename} already exists.')
        return
    # Stream download to handle large files
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        # Opens a file in binary write mode
        with open(local_filename, 'wb') as f:
            # write in 8 kb chunks
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    print(f'Downloaded {local_filename}')

# Use ThreadPoolExecutor to download files concurrently
with ThreadPoolExecutor(max_workers=2) as executor:
    executor.map(download_file, urls)


Downloaded training_datasets/ids_train_casp12nr50_nr70Ig_nr40Others.fasta


KeyboardInterrupt: 

In [None]:
! mkdir -p training_datasets

! wget -P training_datasets https://zenodo.org/records/13831403/files/ids_train_casp12nr50_nr70Ig_nr40Others.fasta

! wget -P training_datasets https://zenodo.org/records/13831403/files/sidechainnet_casp12_50.pkl

! wget -P training_datasets https://zenodo.org/records/13831403/files/AbSCSAbDAb_trainnr90_bkandcbcoords_aug2022.h5
! wget -P training_datasets https://zenodo.org/records/13831403/files/ppi_trainset_5032_noabag_aug2022.h5
! wget -P training_datasets https://zenodo.org/records/13831403/files/train_af_paired_nr70.h5

In [2]:
### COPY SIDECHAINNET TO LOCAL TEMPORARY DIRECTORY

sidechainnet = './training_datasets/sidechainnet_casp12_50.pkl'
sidechainnet_temp = '/tmp/sidechainnet_casp12_50.pkl'

if not os.path.exists(sidechainnet_temp):
    shutil.copy(sidechainnet, sidechainnet_temp)

In [4]:
### MODEL AND TRAINING PARAMETERS
LAYERS=4
HEADS=8
DIM=256
OLD=0
BS=1
SS=50 # protein model was trained on 90ss
save_every=5
gmodel='egnn-trans-ma'
atom_types='backbone_and_cb'
NN=48

### Training seed and output directory
SEED=1
MODELS_DIR=f'models_out_dir_seed_{SEED}'
EPOCHS=10

### WANDB ENTITY
WANDB_ENTITY="fadh-johns-hopkins-university"


In [5]:
gd2_dataset_ids = os.path.join(os.getcwd(), 'training_datasets', 'ids_train_casp12nr50_nr70Ig_nr40Others.fasta')

In [None]:
### Get SLURM variables and list of gpus

slurm_ntasks = int(os.getenv('SLURM_NTASKS', 6))  # Default to 6 if not set
slurm_job_num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', 1))  # Default to 1 if not set

# Calculate n_procs
n_procs = slurm_ntasks // slurm_job_num_nodes

# Default value for num_gpus
num_gpus = 1

slurm_step_gpus = os.getenv('SLURM_STEP_GPUS')
gpus_list = slurm_step_gpus.split(',')
num_gpus = len(gpus_list)

print(f"n_proces = {n_procs}")
print(f"num_gpus = {num_gpus}")

In [None]:
# Run training using shell execution in Jupyter
# !python3 ./MaskedProteinEnT/train_masked_model.py \ ### Uncomment when final
!python3 ./train_masked_model.py \
  --save_every {save_every} --lr 0.00001 --batch_size {BS} \
  --heads {HEADS} --model_dim {DIM} --epochs {EPOCHS} --dropout 0.2 \
  --masking_rate_max 0.25 --topk_metrics 1 --layers {LAYERS} \
  --num_gpus {num_gpus} --crop_sequences \
  --scn_sequence_similarity {SS} --protein_gmodel {gmodel} \
  --lr_patience 350 --lr_cooldown 20 --max_ag_neighbors {NN} \
  --atom_types {atom_types} \
  --file_with_selected_scn_ids_for_training {gd2_dataset_ids} \
  --lightning_save_last_model --use_scn --num_procs {n_procs} \
  --output_dir {MODELS_DIR} --seed {SEED} --wandb_entity {WANDB_ENTITY}
