Skip to content

Commit

Permalink
Fix multi-gpu sampling (#33)
Browse files Browse the repository at this point in the history
* Fix multi-gpu sampling for efficient

* Fix initialization of model_embedding

* Remove model_emb_copy
  • Loading branch information
kdha0727 committed Mar 17, 2023
1 parent bdc8f0a commit bea43e1
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 36 deletions.
4 changes: 4 additions & 0 deletions diffuseq/utils/dist_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def setup_dist():

dist.init_process_group(backend=backend, init_method="env://")

if th.cuda.is_available(): # This clears remaining caches in GPU 0
th.cuda.set_device(dev())
th.cuda.empty_cache()


def dev():
"""
Expand Down
86 changes: 50 additions & 36 deletions sample_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch as th
import torch.distributed as dist
from transformers import set_seed
from diffuseq.rounding import denoised_fn_round, get_weights
from diffuseq.rounding import denoised_fn_round
from diffuseq.text_datasets import load_data_text

# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
Expand All @@ -24,7 +24,6 @@
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
load_model_emb,
load_tokenizer
)

Expand All @@ -38,12 +37,16 @@ def create_argparser():
return parser


@th.no_grad()
def main():
args = create_argparser().parse_args()

dist_util.setup_dist()
logger.configure()

world_size = dist.get_world_size() or 1
rank = dist.get_rank() or 0

# load configurations.
config_path = os.path.join(os.path.split(args.model_path)[0], "training_args.json")
print(config_path)
Expand All @@ -65,14 +68,14 @@ def main():
pytorch_total_params = sum(p.numel() for p in model.parameters())
logger.log(f'### The parameter count is {pytorch_total_params}')

model.to(dist_util.dev())
model.eval()
model.eval().requires_grad_(False).to(dist_util.dev())

tokenizer = load_tokenizer(args)
model_emb, tokenizer = load_model_emb(args, tokenizer)

model_emb.weight = th.nn.Parameter(model.word_embedding.weight.clone().cpu())
model_emb_copy = get_weights(model_emb, args)
model_emb = th.nn.Embedding(
num_embeddings=tokenizer.vocab_size,
embedding_dim=args.hidden_dim,
_weight=model.word_embedding.weight.clone().cpu()
).eval().requires_grad_(False)

set_seed(args.seed2)

Expand All @@ -86,12 +89,12 @@ def main():
data_args=args,
split=args.split,
loaded_vocab=tokenizer,
model_emb=model_emb.cpu(), # using the same embedding wight with tranining data
model_emb=model_emb.cpu(), # using the same embedding wight with tranining data
loop=False
)

start_t = time.time()

# batch, cond = next(data_valid)
# print(batch.shape)

Expand All @@ -108,18 +111,36 @@ def main():

all_test_data = []

idx = 0

try:
while True:
batch, cond = next(data_valid)
# print(batch.shape)
all_test_data.append(cond)
if idx % world_size == rank: # Split data per nodes
all_test_data.append(cond)
idx += 1

except StopIteration:
print('### End of reading iteration...')

from tqdm import tqdm

for cond in tqdm(all_test_data):
model_emb.to(dist_util.dev())

if idx % world_size and rank >= idx % world_size:
all_test_data.append({}) # Dummy data for Remainder : for dist.barrier()

if rank == 0:
from tqdm import tqdm
iterator = tqdm(all_test_data)
else:
iterator = iter(all_test_data)

for cond in iterator:

if not cond: # Barrier for Remainder
for i in range(world_size):
dist.barrier()
continue

input_ids_x = cond.pop('input_ids').to(dist_util.dev())
x_start = model.get_embeds(input_ids_x)
Expand All @@ -128,7 +149,7 @@ def main():

noise = th.randn_like(x_start)
input_ids_mask = th.broadcast_to(input_ids_mask.unsqueeze(dim=-1), x_start.shape).to(dist_util.dev())
x_noised = th.where(input_ids_mask==0, x_start, noise)
x_noised = th.where(input_ids_mask == 0, x_start, noise)

model_kwargs = {}

Expand All @@ -150,7 +171,7 @@ def main():
sample_shape,
noise=x_noised,
clip_denoised=args.clip_denoised,
denoised_fn=partial(denoised_fn_round, args, model_emb_copy.cuda()),
denoised_fn=partial(denoised_fn_round, args, model_emb),
model_kwargs=model_kwargs,
top_p=args.top_p,
clamp_step=args.clamp_step,
Expand All @@ -160,31 +181,20 @@ def main():
gap=step_gap
)

model_emb_copy.cpu()
# print(samples[0].shape) # samples for each step

sample = samples[-1]
gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_samples, sample)
all_sentence = [sample.cpu().numpy() for sample in gathered_samples]

# print('sampling takes {:.2f}s .....'.format(time.time() - start_t))
# print('decoding for seq2seq', )
# print(sample.shape)

logits = model.get_logits(sample) # bsz, seqlen, vocab
cands = th.topk(logits, k=1, dim=-1)

word_lst_recover = []
word_lst_ref = []
word_lst_source = []


arr = np.concatenate(all_sentence, axis=0)
x_t = th.tensor(arr).cuda()
# print('decoding for seq2seq', )
# print(arr.shape)

reshaped_x_t = x_t
logits = model.get_logits(reshaped_x_t) # bsz, seqlen, vocab

cands = th.topk(logits, k=1, dim=-1)
sample = cands.indices
# tokenizer = load_tokenizer(args)

for seq, input_mask in zip(cands.indices, input_ids_mask_ori):
Expand All @@ -198,13 +208,17 @@ def main():
word_lst_source.append(tokenizer.decode_token(seq[:len_x]))
word_lst_ref.append(tokenizer.decode_token(seq[len_x:]))

fout = open(out_path, 'a')
for (recov, ref, src) in zip(word_lst_recover, word_lst_ref, word_lst_source):
print(json.dumps({"recover": recov, "reference": ref, "source": src}), file=fout)
fout.close()
for i in range(world_size):
if i == rank: # Write files sequentially
fout = open(out_path, 'a')
for (recov, ref, src) in zip(word_lst_recover, word_lst_ref, word_lst_source):
print(json.dumps({"recover": recov, "reference": ref, "source": src}), file=fout)
fout.close()
dist.barrier()

print('### Total takes {:.2f}s .....'.format(time.time() - start_t))
print(f'### Written the decoded output to {out_path}')


if __name__ == "__main__":
main()

0 comments on commit bea43e1

Please sign in to comment.