-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Olmo tiny scripts #628
Olmo tiny scripts #628
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed my queries offline with @ananyahjha93
- How were model shapes decided? Based on Pythia and then number of parameters.
- How about LR? Also ballpark from Pythia.
Other things to note:
- Global batch size may also require some ablation
@@ -248,7 +248,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: | |||
) | |||
cfg.save_interval_unsharded = cfg.save_interval | |||
|
|||
if cfg.save_num_unsharded_checkpoints_to_keep < 1: | |||
if cfg.save_num_unsharded_checkpoints_to_keep == 0: | |||
log.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if save_num_checkpoints_to_keep
is also 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it then assumes that you did not want to keep checkpoints at all!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1 assumes you want to save all checkpoints and so I made it ==0
instead of < 1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume the configs between the sizes are all the same, so I didn't look at all of them.
configs/tiny/OLMo-300M.yaml
Outdated
@@ -9,17 +9,15 @@ wandb: | |||
model: | |||
d_model: 1024 | |||
n_heads: 16 | |||
n_layers: 16 | |||
n_layers: 24 | |||
mlp_ratio: 8 | |||
weight_tying: false | |||
alibi: false | |||
rope: true | |||
flash_attention: true # not available on AMD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is now available on AMD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed the comment
- label: commonsense_qa | ||
type: downstream | ||
|
||
- label: social_iqa | ||
type: downstream | ||
|
||
- label: basic_arithmetic | ||
type: downstream | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's wrong with these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, basic_arithmetic
should be in, others don't provide any signal based on my experience
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah this was commented out saying
# Doesn't work from cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should work with cache v4
configs/tiny/OLMo-300M.yaml
Outdated
stop_at: 100_000 | ||
global_train_batch_size: 2048 | ||
device_train_microbatch_size: 8 | ||
max_duration: 2ep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means you'll run into this bug: #584
It might not matter. The problem is only that the second epoch will be shuffled the same way the first one is shuffled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a stop_at
400k steps!
@@ -9,17 +9,15 @@ wandb: | |||
model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No DDP section in this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is!
configs/tiny/OLMo-20M.yaml
Outdated
grad_clip_warmup_steps: null | ||
grad_clip_warmup_factor: 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't have these settings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
took these from @AkshitaB 's llamaish1-normal-weka.yaml.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed them for now!
paths: | ||
######### NON WEB DATA ######### | ||
# ~> GUTENBERG BOOKS (5.256 GT) | ||
- s3://ai2-llm/preprocessed/olmo-mix/v1_6-decontaminated/books/gpt-neox-olmo-dolma-v1_5/part-0-00000.npy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you read from weka instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was planning to run on pluto, now I can see free nodes on jupiter, making the change!
# Unsharded checkpoints (for ddp) | ||
save_interval_unsharded: 5000 | ||
save_num_unsharded_checkpoints_to_keep: 3 | ||
save_num_unsharded_checkpoints_to_keep: -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does -1
do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1 is for keeping all checkpoints, but I'll double check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved with a small comment about the long warmup.
units: tokens | ||
t_warmup: 4194304000 | ||
t_max: 3e12 | ||
t_warmup: 5000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For normal init, this is a lot of warmup? Not a big deal, but unusual?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
smaller models, higher LR, did not take a chance! never bad doing a longer warmup!
max_duration: 1ep | ||
stop_at: 406_934 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need both max_duration
and stop_at
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, from what I have observed and Dave mentioned the training goes past max_duration if stop_at is not set
# Doesn't work from cache. | ||
# - label: basic_arithmetic | ||
# type: downstream | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even with cache v4?
--run_name=$TASK_NAME \ | ||
--wandb.name=$TASK_NAME \ | ||
--wandb.group=$TASK_NAME \ | ||
--wandb.project=tiny_olmo \ | ||
--wandb.project=olmo-tiny \ | ||
--max_grad_norm=2.0 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to do this clipping value for all small models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, let me fix this, so the model with clipping value 2.0 does not show any downstream improvement!
Co-authored-by: Pete <epwalsh10@gmail.com>
olmo/train.py
Outdated
num_fwd_flops=self.model.num_fwd_flops, # this is per sequence | ||
num_bck_flops=self.model.num_bck_flops, # this is per sequence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"this is per sequence" ... it's per-token now, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
No description provided.