Skip to content

Commit

Permalink
updated the multinode code to use torch ccl
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-misra committed Jun 17, 2024
1 parent dbd44ee commit a791a5f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 41 deletions.
50 changes: 30 additions & 20 deletions tutorials/examples/ddp_gfn.batch.small.4.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#SBATCH --partition=spr
#SBATCH --ntasks=2
#SBATCH --cpus-per-task=112
#SBATCH --time=00:60:00
#SBATCH --time=08:00:00

source /swtools/intel/2024.0/oneapi-vars.sh
export I_MPI_HYDRA_BOOTSTRAP=slurm
Expand All @@ -19,22 +19,32 @@ echo $MASTER_ADDR
echo $SLURM_JOB_NUM_NODES
echo $SLURM_NODELIST

./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 32 &> batch.out.4.4.256000.32
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 64 &> batch.out.4.4.256000.64
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 128 &> batch.out.4.4.256000.128
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 256 &> batch.out.4.4.256000.256
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 512 &> batch.out.4.4.256000.512
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 1000 &> batch.out.4.4.256000.1000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 2000 &> batch.out.4.4.256000.2000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 4000 &> batch.out.4.4.256000.4000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 8000 &> batch.out.4.4.256000.8000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 16000 &> batch.out.4.4.256000.16000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 32000 &> batch.out.4.4.256000.32000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 64000 &> batch.out.4.4.256000.64000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 128000 &> batch.out.4.4.256000.128000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 256000 &> batch.out.4.4.256000.256000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 32000 &> batch.out.4.4.512000.32000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 64000 &> batch.out.4.4.512000.64000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 128000 &> batch.out.4.4.512000.128000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256000 &> batch.out.4.4.512000.256000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 512000 &> batch.out.4.4.512000.512000
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 1024 --batch_size 128 &> batch.out.4.4.1024.128
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 512000 --validation_interval 1 &> batch.acc.out.4.4.512000.512000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256000 --validation_interval 1 &> batch.acc.out.4.4.512000.256000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 128000 --validation_interval 1 &> batch.acc.out.4.4.512000.128000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 64000 --validation_interval 1 &> batch.acc.out.4.4.512000.64000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 32000 --validation_interval 1 &> batch.acc.out.4.4.512000.32000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 16000 --validation_interval 1 &> batch.acc.out.4.4.512000.16000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 8000 --validation_interval 1 &> batch.acc.out.4.4.512000.8000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 4000 --validation_interval 1 &> batch.acc.out.4.4.512000.4000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 2000 --validation_interval 1 &> batch.acc.out.4.4.512000.2000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 1024 --validation_interval 1 &> batch.acc.out.4.4.512000.1024
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 512 &> batch.acc.out.4.4.512000.512
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256 &> batch.acc.out.4.4.512000.256
./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 128 &> batch.acc.out.4.4.512000.128
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 32000 --validation_interval 1 &> batch.acc.out.4.4.512000.32000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 32 &> batch.out.4.4.256000.32
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 64 &> batch.out.4.4.256000.64
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 128 &> batch.out.4.4.256000.128
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 256 &> batch.out.4.4.256000.256
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 512 &> batch.out.4.4.256000.512
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 1000 &> batch.out.4.4.256000.1000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 2000 &> batch.out.4.4.256000.2000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 4000 &> batch.out.4.4.256000.4000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 8000 &> batch.out.4.4.256000.8000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 16000 &> batch.out.4.4.256000.16000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 32000 &> batch.out.4.4.256000.32000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 64000 &> batch.out.4.4.256000.64000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 128000 &> batch.out.4.4.256000.128000
#./run_dist_ht.sh -np 4 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 256000 --batch_size 256000 &> batch.out.4.4.256000.256000
10 changes: 5 additions & 5 deletions tutorials/examples/ddp_gfn.small.16.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ echo $MASTER_ADDR
echo $SLURM_JOB_NUM_NODES
echo $SLURM_NODELIST

./run_dist_ht.sh -np 16 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 256000 &> scaling.out.16.4.512000.256000
./run_dist_ht.sh -np 16 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 1024000 --batch_size 256000 &> scaling.out.16.4.1024000.256000
./run_dist_ht.sh -np 16 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 2048000 --batch_size 256000 &> scaling.out.16.4.2048000.256000
./run_dist_ht.sh -np 16 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 4096000 --batch_size 256000 &> scaling.out.16.4.4096000.256000
./run_dist_ht.sh -np 16 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 8192000 --batch_size 256000 &> scaling.out.16.4.8192000.256000
./run_dist_ht.sh -np 16 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 512000 --batch_size 128 &> scaling.out.16.4.512000.128
#./run_dist_ht.sh -np 16 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 1024000 --batch_size 256000 &> scaling.out.16.4.1024000.256000
#./run_dist_ht.sh -np 16 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 2048000 --batch_size 256000 &> scaling.out.16.4.2048000.256000
#./run_dist_ht.sh -np 16 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 4096000 --batch_size 256000 &> scaling.out.16.4.4096000.256000
#./run_dist_ht.sh -np 16 -ppn 4 python -u train_hypergrid_multinode.py --ndim 4 --height 64 --R0 0.01 --tied --loss TB --n_trajectories 8192000 --batch_size 256000 &> scaling.out.16.4.8192000.256000
4 changes: 2 additions & 2 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from argparse import ArgumentParser

import torch
import wandb
from tqdm import tqdm, trange

from gfn.containers import ReplayBuffer
Expand Down Expand Up @@ -43,6 +42,7 @@ def main(args): # noqa: C901

use_wandb = len(args.wandb_project) > 0
if use_wandb:
import wandb
wandb.init(project=args.wandb_project)
wandb.config.update(args)

Expand Down Expand Up @@ -247,7 +247,7 @@ def main(args): # noqa: C901
to_log = {"loss": loss.item(), "states_visited": states_visited}
if use_wandb:
wandb.log(to_log, step=iteration)
if iteration % args.validation_interval == 0:
if (iteration % args.validation_interval == 0) or (iteration == n_iterations - 1):
validation_info = validate(
env,
gflownet,
Expand Down
30 changes: 16 additions & 14 deletions tutorials/examples/train_hypergrid_multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
import time
import wandb
from tqdm import tqdm, trange
from math import ceil
import torch.distributed as dist
Expand Down Expand Up @@ -83,6 +82,7 @@ def main(args): # noqa: C901

use_wandb = len(args.wandb_project) > 0
if use_wandb:
import wandb
wandb.init(project=args.wandb_project)
wandb.config.update(args)

Expand Down Expand Up @@ -317,20 +317,21 @@ def main(args): # noqa: C901

states_visited += len(trajectories)

to_log = {"loss": loss.item(), "states_visited": states_visited}
if use_wandb:
wandb.log(to_log, step=iteration)
if iteration % args.validation_interval == 0:
validation_info = validate(
env,
gflownet,
args.validation_samples,
visited_terminating_states,
)
if my_rank == 0:
to_log = {"loss": loss.item(), "states_visited": states_visited}
if use_wandb:
wandb.log(validation_info, step=iteration)
to_log.update(validation_info)
tqdm.write(f"{iteration}: {to_log}")
wandb.log(to_log, step=iteration)
if (iteration % args.validation_interval == 0) or (iteration == n_iterations - 1):
validation_info = validate(
env,
gflownet,
args.validation_samples,
visited_terminating_states,
)
if use_wandb:
wandb.log(validation_info, step=iteration)
to_log.update(validation_info)
tqdm.write(f"{iteration}: {to_log}")

time_end = time.time()
total_time = time_end - time_start
Expand All @@ -339,6 +340,7 @@ def main(args): # noqa: C901
if (my_rank == 0):
print ("total_time, sample_time, to_train_samples_time, loss_time, loss_backward_time, opt_time, rest_time")
print (total_time, sample_time, to_train_samples_time, loss_time, loss_backward_time, opt_time, rest_time)

return validation_info["l1_dist"]


Expand Down

0 comments on commit a791a5f

Please sign in to comment.