-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
the Multi-GPU training acutally duplicates data in each GPU ? #37
Comments
Another info is that in the multi-gpu sampling script (which is recently updated) "sample_seq2seq.py" Lines 120 to 121 in bea43e1
It has an operation to "split data per gpu". Howerver, the training scripts ("diffuseq/text_datasets.py" or "train_util.py") do not have such operation to spilt data per gpu, and thus I conject that the data is actually duplicated in each gpu in the existing multi-gpu training script. |
Hi, Good question! We follow the training script in Diffusion-LM's repo script/run_train.py. The train_run.py you mentioned uses run_clm.py to train the classifier instead of LM itself. It's true that we "split data per gpu" when we do sampling. That's because we only want to iterate each test case once and in order. However, when training, we set |
Thank you for your reply! But I am still confused.
To verfiy my point, we can turn But with the existing multi-gpu training script, the data is duplicated in each gpu, an it will still run 800 iterations for 4 GPU training |
Another information: We can observe that in OpenAI's improved-diffusion, the dataloader is constructed in
refer to openai/improved-diffusion/improved_diffusion/image_datasets.py In contrast, in Diffusion-LM's |
@summmeer |
@Dawn-LX I experienced the same problem, and solved it by implementing custom seeding function. def seed_all(seed: "Any", deterministic: "bool" = False) -> "None":
import random
import numpy as np
import torch
if deterministic: # False in training, True in sampling
seed = hash(seed)
torch.backends.cudnn.deterministic = True # NOQA
torch.backends.cudnn.benchmark = False # NOQA
else:
seed = hash(seed) + int(os.environ.get("LOCAL_RANK", "0")) # Make seed differ by node rank
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) # contains torch.cuda.manual_seed_all Even though each seed per node differs, model parameters would be synchronized in initialization process refer to this. (by this line) |
If you want to use same seed per nodes, you can consider alternative code like below: Step 1. Change this function to my code below. DiffuSeq/diffuseq/text_datasets.py Line 11 in bea43e1
def load_data_text(
batch_size,
seq_len,
deterministic=False,
data_args=None,
model_emb=None,
split='train',
loaded_vocab=None,
loop=True,
seed=None, # ADD THIS
):
training_data = get_corpus(data_args, seq_len, split=split, loaded_vocab=loaded_vocab)
dataset = TextDataset(
training_data,
data_args,
model_emb=model_emb
)
if seed is not None:
batch_generator = torch.Generator()
batch_generator.manual_seed(hash(seed) + int(os.environ.get("LOCAL_RANK", "0")))
else:
batch_generator = None
data_loader = DataLoader(
dataset,
batch_size=batch_size, # 20,
# drop_last=True,
shuffle=not deterministic,
num_workers=0,
generator=batch_generator, # ADDED
)
if loop:
return infinite_loader(data_loader)
else:
# print(data_loader)
return iter(data_loader) Step 2. Add Line 44 in bea43e1
line 44~63 data = load_data_text(
batch_size=args.batch_size,
seq_len=args.seq_len,
data_args = args,
loaded_vocab=tokenizer,
model_emb=model_weight, # use model's weights as init
seed=args.seed
)
next(data)
data_valid = load_data_text(
batch_size=args.batch_size,
seq_len=args.seq_len,
data_args=args,
split='valid',
deterministic=True,
loaded_vocab=tokenizer,
model_emb=model_weight, # using the same embedding wight with tranining data
seed=args.seed
) |
Hi, @Dawn-LX In our code, if each GPU loaded different batches, for current gradiant update, it is the same with larger batch size (benefit from multi-GPU).
If you want to train total 200 iters on 4 GPU, you can set iters as 50 so that 4 GPU will load 50*4 iters of the data. It maybe ture that this way is not strictly the same as @kdha0727 Thanks for your clarification, but in code, each node loads different batch of data is we wanted and this fuction is already implemented. I'm a little bit confused by your PR. |
I understood that you implemented different data loader with in My PR intended to fix this issues. By setting only generator's seed different, other processes will go on with same seed per each processes! |
I tested the outputs, and checked that current code returns same data outputs. python -m torch.distributed.launch --nproc_per_node=4 --master_port=12233 --use_env run_train.py --diff_steps 2000 --lr 0.0001 --learning_steps 140000 --save_interval 20000 --seed 102 --noise_schedule sqrt --hidden_dim 128 --bsz 2048 --microbatch 64 --dataset dialogue --data_dir datasets/CommonsenseConversation --vocab bert --seq_len 128 --schedule_sampler lossaware --notes dialogue
"""
Train a diffusion model on images.
"""
import argparse
import json, torch, os
import numpy as np
from diffuseq.utils import dist_util, logger
from diffuseq.text_datasets import load_data_text
from diffuseq.step_sample import create_named_schedule_sampler
from basic_utils import (
load_defaults_config,
create_model_and_diffusion,
args_to_dict,
add_dict_to_argparser,
load_model_emb,
load_tokenizer
)
from train_util import TrainLoop
from transformers import set_seed
import wandb
### custom your wandb setting here ###
# os.environ["WANDB_API_KEY"] = ""
os.environ["WANDB_MODE"] = "offline"
def create_argparser():
defaults = dict()
defaults.update(load_defaults_config())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults) # update latest args according to argparse
return parser
def main():
args = create_argparser().parse_args()
set_seed(args.seed)
dist_util.setup_dist()
logger.configure()
logger.log("### Creating data loader...")
tokenizer = load_tokenizer(args)
model_weight, tokenizer = load_model_emb(args, tokenizer)
data = load_data_text(
batch_size=args.batch_size,
seq_len=args.seq_len,
data_args = args,
loaded_vocab=tokenizer,
model_emb=model_weight # use model's weights as init
)
from torch.distributed import barrier, get_rank
barrier()
print(get_rank(), next(data)[1]['input_ids'][0])
import sys
sys.exit(0)
|
Hi @kdha0727 You're right. I just tested the case when |
Yes. my method is fine for infinite loops, however, considering more general cases, DistributedSampler would be more compact solution. Thank you for reviewing! |
@summmeer @kdha0727
Thank both of your contributions again!, from which I learned a lot about torch's DataLoader and random seed. |
Another small question (although might not necessary now), In |
@Dawn-LX |
Thank you very much ! |
wait, I still confused. a) I understand that we construct b) On the other hand, process 0 calls So why process 0 has different data batch in dataloader's interation (uses different random seed) ?, In my view, since all processes use the same seed according to the description in a). Why process-0 do things in b) makes it's dataloader uses another seed ? I mean, why does the seed in |
@Dawn-LX It doesn't mean that process 0 use another seed. In original code all processes use same seed, however, process 0 has ONLY 1 MORE RANDOM OPERATION before |
@kdha0727 That means in contrast to process 1,2,3, process-0 has different results of this line |
Hello.
I find that the Dataloader constructed in
diffuseq/text_datasets.py
not used pytorch'sDistributedSampler
DiffuSeq/diffuseq/text_datasets.py
Line 47 in bea43e1
, which makes the data is actually duplicated in each GPU, e.g., in func:
forward_backward
intrain_util.py
DiffuSeq/train_util.py
Line 235 in bea43e1
i.e., each GPU is processing the same data, which makes distributed training pointless.
Is my conjecture correct?
just FYI, the training script in Diffusion-LM's repo train_run.py uses transformers's training script run_clm.py, in which
DistributedSampler
is used in theTrainer
The text was updated successfully, but these errors were encountered: