Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config for training Mamba breaks #975

Closed
5 of 8 tasks
itsanderz opened this issue Dec 18, 2023 · 2 comments · Fixed by #1019
Closed
5 of 8 tasks

Config for training Mamba breaks #975

itsanderz opened this issue Dec 18, 2023 · 2 comments · Fixed by #1019
Labels
bug Something isn't working

Comments

@itsanderz
Copy link

itsanderz commented Dec 18, 2023

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

Completion of training

Current behaviour

[2023-12-17 21:35:56,680] [ERROR] [axolotl.load_model:461] [PID:676] [RANK:0] No module named 'mamba_ssm'
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
File "/workspace/axolotl/src/axolotl/models/mamba/init.py", line 7, in fix_mamba_attn_for_loss
from mamba_ssm.models import mixer_seq_simple
ModuleNotFoundError: No module named 'mamba_ssm'

Trace after installing mamba 1.0.1

<@208256080092856321> `config.json: 100%|_________________________________________________________________________________________________| 200/200 [00:00<00:00, 841kB/s]
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py", line 149, in forward
out = mamba_inner_fn(
return fwd(*args, **kwargs)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 181, in forward
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:
1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor

Invoked with: tensor([[[ 0.1060, 0.2676, -0.2178, ..., -0.6016, -0.5977, -0.4922],
[-0.8672, 0.0108, -0.1123, ..., 0.5312, -0.9609, -3.0469],
[-0.1621, -0.5586, -0.1602, ..., -1.1797, 0.7812, -0.9062],
...,
[ 1.5938, 0.5312, 0.5742, ..., -0.6094, -0.2490, -0.8867],
[-0.4668, 1.3203, 1.5234, ..., 1.3047, 0.9727, 0.9375],
[ 0.8711, 0.6719, 0.6562, ..., -0.7070, 1.4453, -3.0625]]],
device='cuda:0', dtype=torch.bfloat16, requires_grad=True), tensor([[ 0.0151, 0.0280, 0.1030, 0.0012],
[ 0.0133, 0.0221, 0.1216, -0.0250],
[-0.0023, -0.0020, -0.0801, 0.0032],
...,
[-0.0023, -0.0046, -0.0703, 0.0045],
[ 0.0050, -0.0001, -0.1050, -0.0410],
[ 0.0337, 0.0698, 0.0483, 0.0001]], device='cuda:0',
dtype=torch.bfloat16, requires_grad=True), Parameter containing:
tensor([ 0.0664, 0.0192, 0.0220, ..., -0.0058, 0.2334, -0.0757],
device='cuda:0', dtype=torch.bfloat16, requires_grad=True), True
0%| | 0/256766 [00:04<?, ?it/s]
Traceback (most recent call last):
File "/root/miniconda3/envs/py3.10/bin/accelerate", line 8, in
sys.exit(main())
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
args.func(args)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/launch.py", line 994, in launch_command
simple_launcher(args)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/commands/launch.py", line 636, in simple_launcher
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/root/miniconda3/envs/py3.10/bin/python3', 'scripts/finetune.py', 'SlimOrcaMamba.yaml']' returned non-zer
o exit status 1.

Steps to reproduce

I am using a runpod for instance

git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl

pip3 install -e .[flash-attn]
pip3 install -U git+https://github.com/huggingface/peft.git

wget [gist]

accelerate launch scripts/finetune.py [gist]

Config yaml

base_model: state-spaces/mamba-2.8b
model_type: MambaLMHeadModel
tokenizer_type: AutoTokenizer
tokenizer_config: EleutherAI/gpt-neox-20b

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:

  • path: mhenrichsen/alpaca_2k_test
    type: alpaca
    dataset_prepared_path:
    val_set_size: 0.0
    output_dir: ./out

sequence_len: 2048
sample_packing: false
pad_to_sequence_len: false

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 5e-5

train_on_inputs: false
group_by_length: true

bf16: true
fp16: false
tf32: true

gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
tokens:
save_safetensors: False

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

main/d25c34c

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@itsanderz itsanderz added the bug Something isn't working label Dec 18, 2023
@winglian
Copy link
Collaborator

install axolotl with pip install -e .[mamba-ssm,flash-attn]

@ncoop57
Copy link
Contributor

ncoop57 commented Jan 25, 2024

I'm still getting this error even after installing mamba_ssm package

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants