Skip to content

Commit

Permalink
changed backend of distributed computing from mpi to ccl
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-misra committed May 28, 2024
1 parent bd7bed0 commit 19aff1b
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tutorials/examples/train_hypergrid_multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import time
import wandb
from tqdm import tqdm, trange
from math import ceil
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

Expand All @@ -43,7 +44,7 @@ def dist_init():

global my_rank
global my_size
dist_backend = "mpi"
dist_backend = "ccl"
if int(os.environ.get("PMI_SIZE", "0")) > 1:
if dist_backend == "ccl":
try:
Expand Down Expand Up @@ -268,7 +269,8 @@ def main(args): # noqa: C901
visited_terminating_states = env.states_from_batch_shape((0,))

states_visited = 0
n_iterations = args.n_trajectories // args.batch_size // world_size
n_iterations = ceil(args.n_trajectories / args.batch_size)
my_batch_size = args.batch_size // world_size
validation_info = {"l1_dist": float("inf")}
sample_time = 0
to_train_samples_time = 0
Expand All @@ -277,11 +279,12 @@ def main(args): # noqa: C901
opt_time = 0
rest_time = 0
print ("n_iterations = ", n_iterations)
print ("my_batch_size = ", my_batch_size)
time_start = time.time()
for iteration in trange(n_iterations):
sample_start = time.time()
trajectories = gflownet.sample_trajectories(
env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling
env, n_samples=my_batch_size, sample_off_policy=off_policy_sampling
)
sample_end = time.time()
sample_time += (sample_end - sample_start)
Expand All @@ -292,7 +295,7 @@ def main(args): # noqa: C901
if replay_buffer is not None:
with torch.no_grad():
replay_buffer.add(training_samples)
training_objects = replay_buffer.sample(n_trajectories=args.batch_size)
training_objects = replay_buffer.sample(n_trajectories=my_batch_size)
else:
training_objects = training_samples

Expand Down

0 comments on commit 19aff1b

Please sign in to comment.