Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UL2 data sampling and pretraining #268

Open
wants to merge 103 commits into
base: main
Choose a base branch
from
Open

Conversation

janEbert
Copy link
Contributor

This adds pretraining using UL2 for both encoder-decoder, non-causal decoder-only, and causal decoder-only models.
I have not yet run large-scale tests to see if it yields the desired training improvements, but I wanted to give others the option to take a look at the code already.

I'm also not super sure about the non-causal GPT model, but I can disable (or even remove) that part if desired.

@janEbert
Copy link
Contributor Author

janEbert commented Dec 14, 2022

Previously, I truncated sequences so the maximum amount of duplicated extra_id tokens would fit in and still be accepted by the model, losing a bit of data most of the time. I now changed it so the program just errors out and asks the user to put in a longer sequence length for the model.

This is probably a worse/undesired solution, so I kept the other code in for now (but commented).

Note that erroring out is also how the T5Dataset does it.

@janEbert
Copy link
Contributor Author

janEbert commented Jan 3, 2023

There were several issues still remaining in the UL2 implementation, most notably that I only tested for micro batch sizes of 1, which when increased made the decoder-only models fail.
Also most notably in terms of the UL2 sampling, there was an issue regarding the S-denoisers, in which the mean was not correctly positioned, leading to shorter masks than desired.

The implementation also more closely follows the seqio implementation in the UL2 paper now, which omits the single extra_id token for the Prefix-LM task, which we previously added.

megatron/tokenizer/tokenizer.py Show resolved Hide resolved
megatron/data/ul2_dataset.py Outdated Show resolved Hide resolved
megatron/data/dataset_utils.py Outdated Show resolved Hide resolved
megatron/data/ul2_dataset.py Show resolved Hide resolved
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng,
max_ngrams=max_ngrams, masking_style="t5",
sampling_style=sampling_style, prefix_lm=prefix_lm,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we pass do_whole_word_mask=False here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once again I trusted the T5Dataset with this decision, since the T5 paper also does not mention specifically masking on word boundaries but it is still used as the default. Neither does the UL2 paper mention word splitting, so I thought the same could also be applied here. In short tests with a small dataset, I also did not find seqio to split between words. However, this is likely due to the tokenizer.

I personally think your suggestion makes sense and the T5Dataset should also have supplied this, although it would make sense to discuss about this. Since real-world prompts will also only split on whole words, this provides an (IMO sensible) inductive bias in the pre-training task.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was not the tokenizer – the seqio implementation apparently handles word-splitting, only placing boundaries between words. I couldn't figure out from a quick look where this is implemented, but the tokenizer in my tests did sub-tokenize words, but words were never split.

megatron/data/t5_dataset.py Show resolved Hide resolved
megatron/data/dataset_utils.py Outdated Show resolved Hide resolved
@janEbert
Copy link
Contributor Author

Thank you so much @RaymondLi0 for the detailed review! I agree with most of what you said here. Some of the issues addressed are sadly part of the original T5Dataset implementation and seemed to not have been a problem in the past. However, I need to correctly handle e.g. unbounded probability distributions better.

# Pad and convert to NumPy.
padding_length = max_seq_length - len(tokens)
if padding_length < 0:
raise LengthExceededError()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing something, but I have some trouble understanding how this exception does not occur almost always.
Each masked span adds an extra token in tokens_enc and in tokens_dec_in.
So the concatenation tokens = ( [bos_id] + tokens_enc + [sep_id] + tokens_dec_in ) should be longer than max_seq_length ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my case decoder-only I uncommented

    # if is_decoder_only(model_type):
    #     # Keep space for repeated `extra_id` tokens; not the most data
    #     # efficient since we calculate this based on the maximum number
    #     # of possible `extra_id` tokens.
    #     safe_max_seq_len = math.floor(max_num_tokens / (1 + masked_lm_prob))
    #     truncated = len(tokens) > safe_max_seq_len
    #     tokens = tokens[:safe_max_seq_len]
    # else:

& then it worked; Just for ref the code I'm using is https://github.com/TurkuNLP/Megatron-DeepSpeed/blob/main/megatron/data/ul2_dataset.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each masked span adds an extra token in tokens_dec_in. tokens_enc is actually reduced in size since multiple tokens are replaced by a single mask.
However, you're right, len(tokens_enc + tokens_dec_in) increases by 1 per span mask.

For the decoder-only case, this is indeed too strict and should be handled better than asking the user to specify a sequence length larger than any text in the data set.

While I still dislike my hacky solution that Niklas cited for the reasons I wrote in the comment, it's probably best to keep it (with a few adjustments due to the recent changes, which can still cause the seq length to overflow).

Ideally, there would be a solution that fixes the length excess problems even for the original T5Dataset, since it can still happen in the encoder-decoder case as well.

n = np_rng.choice(ngrams[:len(cand_index_set)])
elif sampling_style is SamplingStyle.NORMAL:
n = round(np.clip(
np_rng.normal(loc=normal_mean),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This normal distribution has a standard deviation of 1 by default.
I didn't find any information about that in the UL2 paper, but would it make sense to scale that with the mean instead? For example scale=np.sqrt(normal_mean)

Copy link
Contributor Author

@janEbert janEbert Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's desired for the spans to vary in length that heavily. (Disregard, I skipped over the sqrt.) You did make me look up what's happening in the seqio implementation, though. At first glance, it seems like they're actually using a uniform distribution after all. This was really only a cursory glance, take this with a grain of salt.
See https://github.com/google-research/text-to-text-transfer-transformer/blob/f0cf9e8c51bd48699265763d01c2f8b24ae1098b/t5/data/preprocessors.py#L2958-L2960
and https://github.com/google/seqio/blob/cd5bb07f22d3b36c7c10e906ebed97ab3efd780f/seqio/utils.py#L613-L620.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh that seqio code looks indeed very different from a normal distribution.
In Nemo they are using scale=np.sqrt(normal_mean) https://github.com/NVIDIA/NeMo/blob/989b07a0e87979a73572d4c37d2f19538eac9d16/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py#L465

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to say how much/if this matters in benefitting training. I do think it could be an improvement in general since the model sees more variance in the predictions. However, since the paper does not specifically mention modifications to the normal distribution's stdev, I feel it's hard to argue for changes here, since the focus of this PR was reproducing results of the seqio implementation (but favoring the paper's prose to resolve disagreements, such as the mean noise ratio for the PrefixLM task).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It completely eluded me that UL2R uses very different denoiser settings and also describes their implementation differently from UL2 (e.g. uniform sampling of span lengths for R- and X-denoisers). I'll adjust the implementation to handle these.

@janEbert
Copy link
Contributor Author

janEbert commented Apr 6, 2023

I can finally report results... Comparing standard T5 training vs training with UL2 or UL2R, results in lm-eval-harness were almost always better with UL2/UL2R. Which should mean this code does improve evaluation results. :)

@lvcc2018
Copy link

lvcc2018 commented May 2, 2023

I wonder if you have compared the standard GPT vs training with UL2 or UL2R. If so, could you show more about the experiment settings and results? If not, perhaps I can help to carry out experiments.

@janEbert
Copy link
Contributor Author

janEbert commented May 2, 2023

Hey @lvcc2018, I have sadly not gotten to that yet. Since this month is abnormally busy, I assume that it will take some time for me to create and compile results for standard GPT training.

T5 Experiments

For the encoder-decoder experiments, I strongly leaned on the T5 paper.
I used xPos embeddings (to have something comparable to the T5 relative attention), SwiGLU layers without bias, Adam, and the same inverse sqrt learning rate schedule as in the T5 paper. Other parameters were all the same as for the T5-base model.

xPos embeddings, SwiGLU, and the LR schedule aren't implemented here.

GPT Experiments

If you want to try your own experiments, I would suggest taking one of the GPT papers and just comparing a run with UL2 vs. without. UL2 always used non-causal decoders, so the GPT should probably also be non-causal for fairness. However, while you can create a non-causal GPT model, the GPTDataset doesn't take this into account; you would have to implement non-causal attention masks for the GPTDataset as well.

Focusing on the part that is already implemented, for the appropriate UL2 training you would keep all parameters the same, except:

python pretrain_ul2.py \
    [...] \
    --pack-samples \
    --ul2-like-ul2r \
    --ul2-model-type ND
  • --ul2-model-type ND selects a >nd<ecoder as the model (for a >cd<ecoder, use CD),
  • ---ul2-like-ul2r means to use more sensible random sampling (uniform instead of normal, as specified in the UL2R paper), and
  • --pack-samples packs different sequences into one sample so the amount of tokens seen is similar to the GPT model (because the GPT dataset always packs samples).

@janEbert janEbert force-pushed the ul2 branch 2 times, most recently from 6862a5c to a9736ff Compare July 4, 2023 13:05
@janEbert
Copy link
Contributor Author

janEbert commented Jul 4, 2023

I rebased the code on the current master. This entailed some backward-incompatible changes to T5Dataset and BertDataset so that I could support the new caching of splits_string, as done in the updated GPTDataset.

@janEbert
Copy link
Contributor Author

janEbert commented Jul 4, 2023

One thing that could be changed is to include the prefix_lm argument I added to GPTModel in the TransformerConfig. But I wasn't sure if that's the place for it.

@pluiez
Copy link

pluiez commented Aug 10, 2023

Hey @lvcc2018, I have sadly not gotten to that yet. Since this month is abnormally busy, I assume that it will take some time for me to create and compile results for standard GPT training.

T5 Experiments

For the encoder-decoder experiments, I strongly leaned on the T5 paper. I used xPos embeddings (to have something comparable to the T5 relative attention), SwiGLU layers without bias, Adam, and the same inverse sqrt learning rate schedule as in the T5 paper. Other parameters were all the same as for the T5-base model.

xPos embeddings, SwiGLU, and the LR schedule aren't implemented here.

GPT Experiments

If you want to try your own experiments, I would suggest taking one of the GPT papers and just comparing a run with UL2 vs. without. UL2 always used non-causal decoders, so the GPT should probably also be non-causal for fairness. However, while you can create a non-causal GPT model, the GPTDataset doesn't take this into account; you would have to implement non-causal attention masks for the GPTDataset as well.

Focusing on the part that is already implemented, for the appropriate UL2 training you would keep all parameters the same, except:

python pretrain_ul2.py \
    [...] \
    --pack-samples \
    --ul2-pack-any \
    --ul2-like-ul2r \
    --ul2-model-type ND
  • --ul2-model-type ND selects a >nd<ecoder as the model (for a >cd<ecoder, use CD),
  • ---ul2-like-ul2r means to use more sensible random sampling (uniform instead of normal, as specified in the UL2R paper), and
  • --pack-samples packs different sequences into one sample so the amount of tokens seen is similar to the GPT model (because the GPT dataset always packs samples).

Hi @janEbert, I would like to benchmark on causal GPT and non-causal GPT(prefixlm) using this UL2 branch. The model is already under training (1.5B parameters, 50B training tokens, global_bsz=1024, max_length=2048).

Could you please help me with the evaluation procedure? Specifically, including:

  1. Convert megatron checkpoint to huggingface transformers (swiglu, rotary embeddings, support PrefixLM-style attention);
  2. When evaluating PrefixLM with lm-evaluation-harness, does the format of the input model need to be changed? Because many extra_id are added during training, do I need to add bos, extra_id, sep, etc. special tokens in the appropriate positions during inference, or do I need to do any special processing?

Thank you!

@janEbert
Copy link
Contributor Author

For 1., you can use the existing Megatron-LM-GPT2 ↔ HF conversion script and add missing features like SwiGLU, RoPE, etc. to it.
For 2., I adapted an eval harness adaptor so that you can supply a mode switching token. It would make way more sense to implement this into the eval harness itself, though. Other than the mode switching token, I don't think you need to add anything else.

@pluiez
Copy link

pluiez commented Aug 10, 2023

@janEbert Let me confirm, do you mean that when using lm-eval-harness for inference, we don't need to use PrefixLM's bidirectional attention, but only CausalLM's unidirectional attention anywhere in the input?

@janEbert
Copy link
Contributor Author

Ah you're absolutely right, sorry! I had also changed the same adapter to supply a bi-directional mask for the context; similarly, this would have to also be implemented for the harness.

@pluiez
Copy link

pluiez commented Aug 15, 2023

@janEbert Hi, I have trained a baseline causal decoder model without UL2, but when training a causal/non-causal decoder that uses UL2, I encountered an out-of-bound error when creating blendable dataset from ul2 datasets:
企业微信截图_20230815191900

This problem might change with different datasets used. For example, using certain subsets of the whole datasets (Actually I have 20 datasets) does not observe this error. However, pretrain_gpt.py can train on the same datasets without any error.

Setting ul2_pack_any=False and --ul2-model-type ND/CD can not eliminate this error.

The detailed training arguments are given below:

torchrun \
--nproc_per_node 1 \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 9999 pretrain_ul2.py \
--num-layers 24 \
--hidden-size 2048 \
--num-attention-heads 16 \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--seq-length 2048 \
--no-position-embedding \
--use-rotary-position-embeddings \
--rotary-percent 1.0 \
--swiglu \
--max-position-embeddings 2048 \
--disable-bias-linear \
--transformer-impl local \
--train-iters 25000 \
--lr 3e-4 \
--lr-warmup-iters 2000 \
--lr-decay-iters 25000 \
--lr-decay-style cosine \
--min-lr 3.00e-05 \
--weight-decay 0.1 \
--clip-grad 1.0 \
--optimizer adam \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--adam-eps 1e-08 \
--micro-batch-size 8 \
--global-batch-size 128 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--sequence-parallel \
--use-distributed-optimizer \
--bf16 \
--attention-softmax-in-fp32 \
--use-flash-attn \
--data-path 0.09160467 /path/to/pile_train \ # the other datasets are now shown here for simplicity
--tokenizer-type SentencePieceTokenizer \
--vocab-file custom_tokenizer.model \
--tokenizer-model custom_tokenizer.model \
--data-impl mmap \
--log-timers-to-tensorboard \
--log-memory-to-tensorboard \
--log-world-size-to-tensorboard \
--log-validation-ppl-to-tensorboard \
--tensorboard-dir checkpoint/tensorboard \
--log-interval 1 \
--timing-log-level 0 \
--ul2-model-type ND \
--ul2-denoisers R R S X X X X \
--ul2-mean-span-lengths 3 8 0.25 3 8 64 64 \
--ul2-mask-ratios 0.15 0.15 0.25 0.5 0.5 0.15 0.5 \
--ul2-r-denoiser-token '[R]' \
--ul2-s-denoiser-token '[S]' \
--ul2-x-denoiser-token '[X]' \
--ul2-pack-any \
--pack-samples \
--ul2-like-ul2r \
--vocab-extra-ids 100 \
--distributed-backend nccl \
--use-checkpoint-opt_param-scheduler \
--save-interval 1000 \
--eval-interval 1000 \
--save checkpoint 

@janEbert
Copy link
Contributor Author

Thanks for the notification about BlendableDataset! I'll have to see why it doesn't work.

@pluiez
Copy link

pluiez commented Aug 15, 2023

@janEbert Hi, I have managed to start the training jobs by making a temp fix in BlendableDataset.__getitem__, but I'm not sure if this could bring side-effect (Possibily duplicate samples in data-parallel group?). Looking forward to your fix anyway.
image

@janEbert
Copy link
Contributor Author

Just a showerthought but it could be the following:
UL2Dataset takes max_num_samples literally because I expected from the variable name that I could return a dataset with fewer samples than this specifies.
BlendableDataset uses the values passed to max_num_samples for each dataset and sums them up for its datasets' sizes, but doesn't check whether the datasets actually have that size. Accordingly, when datasets have fewer entries, it can index into non-existent indices.

I guess your fix is fine for now. Ideally, the lines of code I referenced would be executed after the sub-dataset have been created and sum up len({train,valid,test}_dataset) instead.

I'll fix that next week, thank you so much for posting about the error. :)

@pluiez
Copy link

pluiez commented Aug 19, 2023

Ok, then I will use this fix for the time being to continue the follow-up experiment, looking forward to your good news!

`_initalize` → `_initialize`
Handles backward-compatibility, so the rest of the code base does not
need to change.
Namely sampling from uniform and normal distributions.
That is, datasets that use `get_samples_mapping`. File names now include
the splits string (for caching purposes), similar to `GPTDataset`.
Previously dataset sizes would be calculated by `max_num_samples`, which
is correct if datasets always produce exactly `max_num_samples` samples.

However, as the name suggests, datasets can return fewer elements, which
is now handled by instead summing the actual lengths of the created
datasets.
@janEbert
Copy link
Contributor Author

janEbert commented Sep 1, 2023

Hey, sorry I took longer than I suggested. I hope this fixes the issue you're facing! I also rebased the branch on top of main.

@janEbert
Copy link
Contributor Author

I think there are issues after all; with a decoder-only, I get much worse evaluation results compared to a GPT-style pretraining.
One issue that already came to my mind is that spans could possibly be next to each other, which means that spans can become much longer than desired. Other than that I think I really need to implement some tests to make sure the code does what it should in all cases.

@janEbert
Copy link
Contributor Author

While I didn't address the neighboring spans issue, it seems at some point I broke the causal targets. These are now fixed.

This fixes an upstream mistake.
Of course, if we want to hide everything, the masking probability should
not be zero.
Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label Nov 27, 2023
@eagle705
Copy link

eagle705 commented Dec 6, 2023

Are there any plans for additional commits or merges?
I am interested in the UL2 decoder-only part

@janEbert
Copy link
Contributor Author

janEbert commented Dec 6, 2023

I haven't had time to continue my UL2 experiments, so I cannot verify the correctness of the current code. There are no specific feature plans, but if there are bugs I'll definitely fix them.
I don't think this will be merged, since NVIDIA have their own UL2 implementation in NeMo (albeit with differences from the paper) and this PR is a quite major change to the Megatron-LM code base.

@github-actions github-actions bot removed the stale No activity in 60 days on issue or PR label Dec 6, 2023
rraminen pushed a commit to rraminen/Megatron-LM that referenced this pull request Dec 12, 2023
We have a runtime checker, and it report: 
```
ResourceWarning: unclosed file <_io.TextIOWrapper name='deepspeed_config_13B.json' mode='r' encoding='utf-8'>
  open(args.deepspeed_config, 'r', encoding='utf-8'))
```
Because when json.load(open()), it never close.
That is, the samples dict was created with
`max_seq_length_dec=max_seq_length`, which is obviously wrong.
It's possible to obtain empty sequences during packing. This handles
this edge case by breaking the packing loop; the reason we break is that
it's almost certain that we will not obtain another empty sequence
afterwards.
Copy link

github-actions bot commented Mar 5, 2024

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label Mar 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale No activity in 60 days on issue or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants