Skip to content

Commit

Permalink
fix: warn user to install mamba_ssm package (#1019)
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Jan 10, 2024
1 parent 9e3f0cb commit d69ba2b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl

# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
else \
pip install -e .[deepspeed,flash-attn]; \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
fi

# So we can test the Docker image
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging
packaging==23.2
peft==0.7.0
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
tokenizers==0.15.0
Expand Down Expand Up @@ -34,6 +34,8 @@ fschat==0.2.34
gradio==3.50.2
tensorboard

mamba-ssm==1.1.1

# remote filesystems
s3fs
gcsfs
Expand Down
14 changes: 7 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ def parse_requirements():
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif (
"flash-attn" not in line
and "flash-attention" not in line
and "deepspeed" not in line
and line
and line[0] != "#"
):
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)

Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/models/mamba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,20 @@
Modeling module for Mamba models
"""

import importlib


def check_mamba_ssm_installed():
mamba_ssm_spec = importlib.util.find_spec("mamba_ssm")
if mamba_ssm_spec is None:
raise ImportError(
"MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`"
)


def fix_mamba_attn_for_loss():
check_mamba_ssm_installed()

from mamba_ssm.models import mixer_seq_simple

from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
Expand Down

0 comments on commit d69ba2b

Please sign in to comment.