Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nanoT5 initializes lm_head weights with 768x too much variance, probably #25

Open
Birch-san opened this issue Dec 26, 2023 · 19 comments
Open

Comments

@Birch-san
Copy link

Birch-san commented Dec 26, 2023

it looks like your lm_head weight init is the same as HF's, which has a known problem meaning it doesn't match the original mesh tensorflow weight init:
huggingface/transformers#26441

I believe that when training T5 with an untied lm_head, you would want to initialize the lm_head weights with std=hidden_dim**-.5. so about std=0.036.

currently the lm_head is inited with std=1, which is 27.7x too much std or 768x too much variance.
https://github.com/PiotrNawrot/nanoT5/blob/1c82d67bf8dea635be68a3b2a68a43b68b665193/nanoT5/utils/t5_model.py#L505C26-L505C26

@PiotrNawrot
Copy link
Owner

I'll take a closer look at it and the effects of it soon, thanks for pointing this out!

@Birch-san
Copy link
Author

thanks! yeah, I wonder if this will have an impact on the hyperparam search. maybe fixing the init will mean a different lr schedule can be used.

@w11wo
Copy link

w11wo commented Feb 1, 2024

Hi. I have been trying to pre-train a T5 model on Indonesian via nanoT5, and I saw this issue. Specifically, I am training on the Indonesian subset of CulturaX and trained my own tokenizer using this script. I wanted to share my experiment results for further investigation.

  • I ran my first experiment without any changes to the model lm_head initialization (std=1.0) and I noticed that the training loss is converging slower than the English results reported here (I aim to also get < 10 ppl on this corpus).
  • In my second experiment, I implemented the changes proposed and discussed here (lm_head init with std=dim**-0.5). As expected, the loss starts much lower and converges faster (closer to the reported results), which is great! But somewhere after 20k steps, the loss increases, and the final loss is still rather high (~2.60). It was also around this step where the grad_l2 became rather large again (~80-120 near the end).

These are my loss graphs:
output

In both experiments, I ran:

python -m nanoT5.main \
    optim.name=adamwscale \
    optim.lr_scheduler=cosine \
    model.compile=true

I would be interested to know:

  • whether this high loss issue is specific to my dataset, or
  • there are potential fixes, related to the model weight init

Thanks!

@PiotrNawrot
Copy link
Owner

@w11wo Thanks a lot for this comment and extra observations. I've been very busy lately because of the ICML submission - deadline is tomorrow. Right after ICML and some rest over the weekend I plan to come back and investigate what's been going on here.

Also if you @w11wo or @Birch-san or anyone else are interested in researching this issue with initialisation and its effect on the training as a small project with potential of submitting it together to some workshop at top-tier conference please reach out :)

@w11wo
Copy link

w11wo commented Feb 1, 2024

No worries @PiotrNawrot. I am running some more experiments at the moment, and I'll be sharing them once they're done.

I'm very much interested in researching this issue and potentially submitting it to a workshop!

@PiotrNawrot
Copy link
Owner

Take a look also at #15 - we've already had some findings with regards to this. Generally the whole init bug could be a cause behind this RMS scaling trick I had to incorporate from Adafactor to make this thing work.

Also, when it comes to your training curves I think that the LR could be too high. I ran extensive HPs searches to find the best (largest) LR which doesn't crash the training for C4 and this could be simply too high for your dataset.

Trying smaller LR would be one of the first things I'd try if you care about having your training asap.
For the better analysis we would have to run some more experiments with

{AdamW, AdamWScaled} x {Different inits} x {Different LRs}

We can set up a call and some working space next week to discuss it further :)

@w11wo
Copy link

w11wo commented Feb 1, 2024

Hi @PiotrNawrot. Yeah, you're right about the LR being too high. I decreased the LR to 5e-3 and the model converged nicely to about ~2.3 loss.

I am now looking into replacing the optimizer with regular AdamW and am trying to find the right LR for it. My current training losses are as follows:

output

So far I have not been able to get AdamW to converge just yet. Looks like the LR is still too high for regular AdamW. I have another experiment currently running with lr=5e-5. I'll continue to report the results.

I'm cool with a call next week. Cheers.

@Birch-san
Copy link
Author

Birch-san commented Feb 1, 2024

happy to join in investigating this. can do a call next week.

there are a few things I was wondering:

  • is NAdamW a better choice than AdamW? it's built into torch now.
  • does 8-bit AdamW perform just as well as 32-bit AdamW?
  • how about Shampoo and its successor, CASPR? they seem to be popular
  • if we really want to reproduce the T5 paper, shouldn't we use the same training objective? HF's run_mlm script (which I assume you based nanoT5 on because it's the only example they provided for torch) doesn't implement MLM anything like how the paper describes. they have a more faithful version run_t5_mlm_flax, which I think has the same bugs as the original Google code. I've fixed Google's bugs in my implementation, and ported run_t5_mlm_flax's T5-style MLM objective to torch, but I never tried a long training run on a standard dataset or model.
  • fixing our weight init may enable us to consider different (simpler) lr schedules. but cosine should already be pretty good. I asked Stella about this, and it sounds like it's pretty standard to do "Linear warm-up to a peak value, then cosine decay to 10% of that peak". so maybe we don't need to change much here.

@w11wo
Copy link

w11wo commented Feb 2, 2024

Interesting ideas @Birch-san!

Regarding the training objective used in nanoT5, I think it already uses HF's T5 objective from Flax, as found here. I believe it's a direct port of HF's FlaxDataCollatorForT5MLM which, as you mentioned, still has bugs nonetheless.

In the meantime, I'm still struggling to get AdamW to converge even with the fixed lm_head initialization. I've tried different magnitudes of LRs as follows:

output

AdamW with lr=5e-3 seems to be too high, while lr=5e-4 starts off very promising and similar to AdamWScaled (orange), but failed to do a double-descent.

@PiotrNawrot do you happen to recall what LR you used when you tried regular Adam (I'm assuming it's not AdamW?) in #15 (comment)?

Otherwise, any suggestions as to what to try out next? I'll be very much curious!

Thanks!

w11wo added a commit to LazarusNLP/nanoT5 that referenced this issue Feb 2, 2024
@w11wo
Copy link

w11wo commented Feb 5, 2024

Hey everyone. We (@DavidSamuell, @stevenlimcorn, and I) did some more experiments and we eventually found out that our model config is actually incorrect (there's a mismatch between model config vocab size and tokenizer vocab size leading to very poor accuracies). That means all of our loss graphs above are likely irrelevant, so please ignore them.

Because of that, we reverted back to AdamWScaled --with fixed init -- just to get a baseline result first. Like the previous incorrect runs, lr=2e-2 diverges while lr=5e-3 converges pretty nicely to an eval loss of about 2.0, which is roughly similar to the English results!

We then took that model and fine-tuned it on a summarization task (IndoNLG/indosum) and to our surprise, this baseline model achieved state-of-the-art results! It outperformed mT5, mBART, and IndoBART. We have yet to evaluate on other downstream tasks, but we think this result is pretty strong for a baseline experiment. The pre-trained Indonesian NanoT5 is here and the fine-tuned IndoSum model is here. We compared against the benchmark results reported here.

Since we resolved the core issue i.e. our model config, we will re-launch experiments with regular AdamW (and fixed init), which hopefully leads to a lower eval loss. We will continue to report the results here.

Cheers!

@PiotrNawrot
Copy link
Owner

@Birch-san

fixing our weight init may enable us to consider different (simpler) lr schedules

As you write later I think that Linear increase + Cosine annealing is pretty standard and "simple" already.

if we really want to reproduce the T5 paper, shouldn't we use the same training objective?

I based my code on the HF's Flax implementation. I'm actually impressed that you found this bug in their implementation, but as you say it's more crucial for short sequences + when you use padding, which isn't the case for nanoT5. Anyways it's definitely something I'd double check and re-run the training as it could boost sample efficiency even better. If you'd be interested in running this experiment and providing the MR that'd be amazing.

In the long run, ideally I would aspire to implement the entire UL2 objective in PyTorch and re-run the nanoT5 experiments. It seems that UL2 is actually the superior way to pre-train T5 models. I think that once we have this we could look for some compute sponsorship on Twitter and start a project like TinyLlama #29.

Re other optimizers like NAdamW, 8-bit AdamW, and others

I would need to double check how these differ from regular Adam, but for now I'm not very convinced of running experiments with other optimizers, for example I'm pretty convinced that 8-bit Adam would be no different from regular Adam. Please note that I ran experiment with Lion and Sophia for this paper and both of them diverged without the LR-scaling trick and both were worse than AdamWScale after adding LR-Scaling.

TLDR: I would love to incorporate the fixed version of the MLM objective and check if it helps pre-training. Then I would try to fix the initialisation and find the reason why regular AdamW works - imo we're very close to finding the cause.

@PiotrNawrot
Copy link
Owner

@w11wo

It's amazing and very impressive. Congrats to you and to the entire team working on it!

I would love to look at the lm-loss graphs and discuss them in detail. Also, tomorrow I will take a look at the loss curves mentioned in #15 (comment) and I will be able to run some experiments with AdamW + Different inits on the regular C4 dataset. Also please take a look at my comment above.

@PiotrNawrot
Copy link
Owner

@w11wo @Birch-san @DavidSamuell @stevenlimcorn

I think that it'd be more convenient to continue further discussion somewhere else than GitHub.
I created this Slack Workspace. If Slack doesn't work for you then please provide your suggestion - I'm flexible.

@w11wo
Copy link

w11wo commented Feb 6, 2024

@PiotrNawrot

Thanks! I have run a few more experiments with AdamW and NAdamW, but none of them seem to converge nicely unlike AdamWScaled.

To not clutter this thread any longer, let's move to Slack as you suggested. I have also documented the loss graphs in a separate repo here.

@Taytay
Copy link

Taytay commented Feb 7, 2024

@PiotrNawrot I am having a hard time opening that Slack link. Is it available for anyone to join?

@PiotrNawrot
Copy link
Owner

PiotrNawrot commented Feb 7, 2024

@Taytay It should be available for anyone, maybe it expired. Try a new one - it says that it will last for 29 days.

@PiotrNawrot
Copy link
Owner

@Taytay Let me know if it works!

@fuxiang-chen
Copy link

fuxiang-chen commented Apr 18, 2024

@PiotrNawrot Could you share a new slack link? I would like to join too.

@PiotrNawrot
Copy link
Owner

@fuxiang-chen New invite link. We discussed different ideas including specific experiments in the #fixing-t5-init channel, but afaik noone has really tested it because people were busy / didn't have compute :/. So this problem is still yet to be solved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants