Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM On Galore Axolotl #1448

Open
6 of 8 tasks
m626zNq opened this issue Mar 27, 2024 · 12 comments
Open
6 of 8 tasks

OOM On Galore Axolotl #1448

m626zNq opened this issue Mar 27, 2024 · 12 comments
Labels
bug Something isn't working

Comments

@m626zNq
Copy link

m626zNq commented Mar 27, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

Should start training without OOM, like Llama factory.

Current behaviour

Causing OOM issue on axolotl with my config. LLaMA Factory acted fine but axolotl is hating on me. On llama factory i was able to do 16bit, and 1024 rank, and 8k context, worked fine on same gpu. axolotl wont even work with 8bit and 128 rank, at 4k context,(out of mem)

I have tried:

  • Enabling gradient checkpointing
  • Disabling sample packing
  • Lowering rank
  • Enabling use_reentrant

Steps to reproduce

install galore: pip install galore-torch
run the config posted below

  • Make sure the GPU is A6000-48GB

Config yaml

base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
  - path: Walmart-the-bag/alpaca-ingen
    type:
      field_instruction: instruction
      field_output: output
      format: "\n### Instruction:\n{instruction}\n### Response:\n"
      no_input_format: "\n### Instruction:\n{instruction}\n### Response:\n"
dataset_prepared_path:
val_set_size: 0.05
output_dir: /notebooks/output

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
optim_args:
  rank: 128
  update_proj_gap: 200
  scale: 0.25
  proj_type: std
optim_target_modules:
  - q_proj
  - v_proj
  - linear
  - mlp
  - attn
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: galore_adafactor
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false

warmup_steps: 10
evals_per_epoch: 0
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.11

axolotl branch-commit

main

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@m626zNq m626zNq added the bug Something isn't working label Mar 27, 2024
@jaredquekjz
Copy link

jaredquekjz commented Mar 28, 2024

yes same issue here. Even a 7B model takes a LOT of memory - much higher than the <24gb promised in the original repo. Is "activation checkpointing" of the repo equivalent to the "gradient checkpointing" in axolotl? My yaml for Yi (also adapted for Mistral-Hermes) :

base_model: NousResearch/Nous-Hermes-2-Yi-34B
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
  - path: /workspace/axolotl/runpod/psychicOlierYidup.jsonl
    type: completion
dataset_prepared_path: 
val_set_size: 0.0
output_dir: /workspace/axolotl/model

sequence_len: 1200
sample_packing: true
pad_to_sequence_len: true

adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:

wandb_project: huggingface
wandb_entity: singaporespprtsschool
wandb_watch:
wandb_run_id: PsychicYiGalore26Mar
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 7
optimizer: galore_adamw_8bit 
optim_args:
  rank: 128
  update_proj_gap: 200
  scale: 0.25
  proj_type: std
optim_target_modules:
  - mlp
  - attn
lr_scheduler: cosine
learning_rate: 1e-4


train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 30
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true

warmup_steps: 1000
evals_per_epoch: 
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: 
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<|startoftext|>"
  eos_token: "<|im_end|>"
  pad_token: "<unk>"
  unk_token: "<unk>"

@winglian
Copy link
Collaborator

Did you try the galore 8bit variants?

@jaredquekjz
Copy link

jaredquekjz commented Mar 28, 2024

I used the 8bit optimiser as seen above. Hermes 7b takes a shocking 36gb or so at seqlen 1200. And in theory - Yi is supposed to fit on a H100 with Galore - but it will OOM. How can the above yml be optimised further?

@jaredquekjz
Copy link

jaredquekjz commented Mar 29, 2024

@m626zNq @winglian I think I found the problem. In the axolotl Readme, I note that there are a number of layerwise optimisers:

# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise

According to a remark I saw in the original HF PR - these layerwise optimisers are essential to achieve much higher memory savings - but they come with limitations like perhaps not being able to work with multi-GPUs (see huggingface/transformers#29588 and original galore github). In any case when I used galore_adamw_8bit_layerwise I can train Hermes-M 7B in 20gb with batch size of 2800 tokens. So @m626zNq can try and probably close this "bug".

But I do find the loss seems to fall much slower (if at all) for layerwise opt compared to the normal galore op. Guess things are quite unstable still.

@winglian
Copy link
Collaborator

@jaredquekjz the 24GB from the paper is for a 7B parameter model. You're using yi-34B. That's still going to require much more VRAM, probably at least an 80GB A100 to full finetune with Galore.

@winglian
Copy link
Collaborator

@m626zNq set flash_attention: true, without flash attention it's going to OOM almost every time no matter what.

@jaredquekjz
Copy link

jaredquekjz commented Mar 30, 2024

@winglian - thanks for input. I used both the Yi and the Hermes-Mistral 7B for trial. For both - we need the layer wise optimiser for the memory savings to be maximum (20gb for Hermes as reported). But as shared - the layerwise opt may not be fully working stably yet (no grad norm and no loss decrease) - at least when I last trialed. Yi can load in one H100 with layerwise but v slow..

{'loss': 3.0955, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.099, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.0705, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.0608, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.1196, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}
{'loss': 3.0307, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.0}

@winglian
Copy link
Collaborator

Yeah. Layerwise also requires that you use a gradient accumulation steps value of 1

@m626zNq
Copy link
Author

m626zNq commented Apr 2, 2024

@winglian I have tried all of those, still get OOM. I have tried every optimizer of galore, flash attention, deepspeed, etc.. *sorry for late response

@m626zNq
Copy link
Author

m626zNq commented Apr 6, 2024

@m626zNq @winglian I think I found the problem. In the axolotl Readme, I note that there are a number of layerwise optimisers:

# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise

According to a remark I saw in the original HF PR - these layerwise optimisers are essential to achieve much higher memory savings - but they come with limitations like perhaps not being able to work with multi-GPUs (see huggingface/transformers#29588 and original galore github). In any case when I used galore_adamw_8bit_layerwise I can train Hermes-M 7B in 20gb with batch size of 2800 tokens. So @m626zNq can try and probably close this "bug".

But I do find the loss seems to fall much slower (if at all) for layerwise opt compared to the normal galore op. Guess things are quite unstable still.

Still OOM. No idea what is going wrong.

@nelaturuharsha
Copy link

nelaturuharsha commented May 10, 2024

Just a thought [could be wrong] here due to a similar discussion I had: as far as I understand -- GaLore is run completely in BFloat16 precision without any automatic mixed precision. My sense is that using accelerate under the hood, is AMP being used which obviously requires more memory? [the SVD is done in float32 IIRC] -- not sure exactly though. Reference here

@e-p-armstrong
Copy link

e-p-armstrong commented May 24, 2024

I am encountering similar issues -- way too much VRAM is being used for GaLore tuning Llama 8b, for me (280 GB on 8x A6000s!) Something definitely seems wrong here. If the paper gives 24GB for a 7b model, presumably it should not take 280 GB for an 8b, even with a larger tokenizer?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants