-
Notifications
You must be signed in to change notification settings - Fork 3
SLURM, pytorch distributed and Multiple Nodes
Samuel Larkin edited this page Feb 24, 2021
·
2 revisions
Key thing to know is that srun is like a super-ssh which means that when running srun cmd
it actually does something like ssh node cmd
#!/bin/bash
#SBATCH --partition=TrixieMain
#SBATCH --account=dt-mtp
#SBATCH --time=00:20:00
#SBATCH --job-name=pytorch.distributed
#SBATCH --comment="Helping Harry with pytorch distributed on multiple nodes."
#SBATCH --gres=gpu:4
##SBATCH --ntasks=2
#SBATCH --wait-all-nodes=1
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=6
#SBATCH --exclusive
#SBATCH --output=%x-%j.out
# USEFUL Bookmarks
# [Run PyTorch Data Parallel training on ParallelCluster](https://www.hpcworkshops.com/08-ml-on-parallelcluster/03-distributed-data-parallel.html)
# [slurm SBATCH - Multiple Nodes, Same SLURMD_NODENAME](https://stackoverflow.com/a/51356947)
readonly MASTER_ADDR_JOB=$SLURMD_NODENAME
readonly MASTER_PORT_JOB="12234"
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
readonly srun='srun --output=%x-%j.%t.out'
env
$srun bash \
task.sh \
$MASTER_ADDR_JOB \
$MASTER_PORT_JOB &
wait
This script will be executed on each node.
Note that we are activating the conda
environment in this script so that each node/worker can have the proper environment.
#!/bin/bash
# USEFUL Bookmarks
# [Run PyTorch Data Parallel training on ParallelCluster](https://www.hpcworkshops.com/08-ml-on-parallelcluster/03-distributed-data-parallel.html)
# [slurm SBATCH - Multiple Nodes, Same SLURMD_NODENAME](https://stackoverflow.com/a/51356947)
#module load miniconda3-4.7.12.1-gcc-9.2.0-j2idqxp
#source activate molecule
source /gpfs/projects/DT/mtp/WMT20/opt/miniconda3/bin/activate
conda activate pytorch-1.7.1
readonly MASTER_ADDR_JOB=$1
readonly MASTER_PORT_JOB=$2
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
env
python \
-m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=$SLURM_NTASKS \
--node_rank=$SLURM_NODEID \
--master_addr=$MASTER_ADDR_JOB \
--master_port=$MASTER_PORT_JOB \
main.py \
--batch_size 128 \
--learning_rate 5e-5 &
wait