Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]"
- name: Build the documentation
run: mkdocs build

Expand Down
11 changes: 9 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1.7-labs
FROM nvcr.io/nvidia/pytorch:24.11-py3
FROM nvcr.io/nvidia/pytorch:25.05-py3

# Install dependencies.
RUN apt-get update \
Expand All @@ -24,13 +24,20 @@ RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/to
/usr/local/lib/python3.12/dist-packages \
/usr/local/lib/python3.12/dist-packages/__pycache__

# The base image enforces versions for things like pytest for no good reason.
ENV PIP_CONSTRAINT=""
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?)
RUN MAX_JOBS=4 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.0.post8"
RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@v2.2.4"
# Copy dependency files with universal write permissions for all users.
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/

# Install dependencies within the virtual environment.
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,DEV]"
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]"

# Copy the remaining source code with universal write permissions.
COPY --chmod=777 ./Megatron-LM Megatron-LM
Expand Down
25 changes: 12 additions & 13 deletions fast_llm/functional/triton/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
from fast_llm.functional.triton.sparse_linear import output_sparse_matmul
from fast_llm.tensor import param_get_and_unset_is_zero

# Triton requires global variables to be annotated with `constexpr`.
_TritonActivationType: tl_constexpr = ActivationType


@triton_jit()
def triton_mlp_activation_forward_kernel(
Expand All @@ -50,18 +47,19 @@ def triton_mlp_activation_forward_kernel(

input_ = tl.load(input_ptr, mask=mask).to(tl.float32)

if activation_type == _TritonActivationType.gelu:
# Triton doesn't like enums, so we use str instead of ActivationType.
if activation_type == "gelu":
tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_)
tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input))
out = input_ * 0.5 * (1.0 + tanh)
elif activation_type == _TritonActivationType.silu:
elif activation_type == "silu":
out = input_ / (1 + tl.exp(-input_))
elif activation_type == _TritonActivationType.relu:
elif activation_type == "relu":
out = tl.where(input_ > 0, input_, 0)
elif activation_type == _TritonActivationType.squared_relu:
elif activation_type == "squared_relu":
relu_out = tl.where(input_ > 0, input_, 0)
out = relu_out * relu_out
elif activation_type == _TritonActivationType.identity:
elif activation_type == "identity":
out = input_
else:
tl.static_assert(False, activation_type)
Expand Down Expand Up @@ -100,28 +98,29 @@ def triton_mlp_activation_backward_kernel(
input_ = tl.load(input_ptr, mask=mask).to(tl.float32)
output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32)

if activation_type == _TritonActivationType.gelu:
# Triton doesn't like enums, so we use str instead of ActivationType.
if activation_type == "gelu":
tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_)
tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input))
grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh)
if gated or recompute:
out = input_ * 0.5 * (1.0 + tanh)
elif activation_type == _TritonActivationType.silu:
elif activation_type == "silu":
exp = tl.exp(-input_)
sigma = 1 / (1 + exp)
grad = sigma * sigma + (1 + input_) / (2 + exp + 1 / exp)
if gated or recompute:
out = input_ * sigma
elif activation_type == _TritonActivationType.relu:
elif activation_type == "relu":
grad = tl.where(input_ > 0, 1, 0)
if gated or recompute:
out = tl.where(input_ > 0, input_, 0)
elif activation_type == _TritonActivationType.squared_relu:
elif activation_type == "squared_relu":
relu_out = tl.where(input_ > 0, input_, 0)
grad = 2 * relu_out
if gated or recompute:
out = relu_out * relu_out
elif activation_type == _TritonActivationType.identity:
elif activation_type == "identity":
grad = 1
if gated or recompute:
out = input_
Expand Down
41 changes: 26 additions & 15 deletions fast_llm/layers/ssm/discrete_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import math

import einops
import mamba_ssm.ops.triton.ssd_combined
import torch

from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
Expand All @@ -13,12 +12,22 @@

logger = logging.getLogger(__name__)


try:
import causal_conv1d
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as _mamba_chunk_scan_combined # noqa

_mamba_available = True
except ImportError:
# this is needed since we cannot use causal_conv1d on B200 GPUs for now
logger.warning("Note, causal_conv1d not found, will use torch.nn.functional.conv1d instead")
causal_conv1d = None
_mamba_available = False


try:
from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa

_causal_conv1d_available = True
except ImportError:
_causal_conv1d_available = False


"""
This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py
Expand Down Expand Up @@ -148,6 +157,8 @@ def forward(self, hidden_states, kwargs):
outputs["hidden_states"]: (B, L, D).
outputs["state"]: inference cache.
"""

assert _mamba_available
input_ = hidden_states
outputs = {}
# assert state is None
Expand Down Expand Up @@ -201,7 +212,7 @@ def forward(self, hidden_states, kwargs):
C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads)

# SSM forward
result = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined(
result = _mamba_chunk_scan_combined(
x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1),
dt=A_log,
dt_softplus=True,
Expand Down Expand Up @@ -234,11 +245,18 @@ def forward(self, hidden_states, kwargs):

def convolutional_forward(self, xBC, padded_len):
"""Convolutional layer forward pass for the full sequence."""
if causal_conv1d is None or self.activation_name not in [
if _causal_conv1d_available and self.activation_name in (
"silu",
"swish",
"identity",
]:
):
xBC = _causal_conv1d_fn(
xBC.transpose(1, 2),
einops.rearrange(self.conv1d_weight, "d 1 w -> d w"),
self.conv1d_bias,
activation=None if self.activation_name == "identity" else self.activation_name,
).transpose(1, 2)
else:
xBC = self.act(
torch.nn.functional.conv1d(
xBC.transpose(1, 2),
Expand All @@ -248,11 +266,4 @@ def convolutional_forward(self, xBC, padded_len):
padding=self.conv_kernel_size - 1,
)[..., :padded_len].transpose(1, 2)
)
else:
xBC = causal_conv1d.causal_conv1d_fn(
xBC.transpose(1, 2),
einops.rearrange(self.conv1d_weight, "d 1 w -> d w"),
self.conv1d_bias,
activation=None if self.activation_name == "identity" else self.activation_name,
).transpose(1, 2)
return xBC
11 changes: 9 additions & 2 deletions fast_llm/layers/ssm/mamba_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Callable

import einops
import mamba_ssm.ops.selective_scan_interface
import torch

from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
Expand All @@ -11,6 +10,13 @@
from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_
from fast_llm.utils import get_lr_scale

try:
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa

_mamba_available = True
except ImportError:
_mamba_available = False

"""
Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba.
For now it only supports training and not inference.
Expand Down Expand Up @@ -153,6 +159,7 @@ def __init__(
self._return_input = return_input

def forward(self, hidden_states, kwargs):
assert _mamba_available
batch, seqlen, dim = hidden_states.shape

# We do matmul and transpose BLH -> HBL at the same time
Expand All @@ -167,7 +174,7 @@ def forward(self, hidden_states, kwargs):
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# In the backward pass we write dx and dz next to each other to avoid torch.cat
# not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s
out = mamba_ssm.ops.selective_scan_interface.mamba_inner_fn(
out = _mamba_inner_fn(
xz,
self.conv1d_weight,
self.conv1d_bias,
Expand Down
52 changes: 28 additions & 24 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,58 @@ packages = find_namespace:
include_package_data = True
python_requires = >=3.12
install_requires =
requests>=2.32.3
PyYAML>=6.0.1
pybind11>=2.5.0
packaging>=24.1
requests>=2.32.4
PyYAML>=6.0.2
pybind11>=2.13.6
packaging>=25.0

[options.extras_require]
# Required to use the main functionality of Fast-LLM
# To install on cpu environment (ex. for IDE support):
# FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE]" --no-build-isolation
CORE =
# Available through the nvidia base image
torch>=2.5.0
torch>=2.7.0
# Numpy major needs to match torch
numpy>=1.24.4,<2.0.0
numpy>=1.26.4,<2.0.0
# Used for checkpoints
safetensors>=0.4.4
safetensors>=0.5.3
# Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation
flash-attn==2.7.2.post1
mamba_ssm==2.2.4
flash-attn==2.7.3


# Required for some optional features and tools.
# Small packages required for some optional features and tools.
OPTIONAL =
# Huggingface tools
transformers>=4.44.2
hf-transfer>=0.1.8
datasets>=3.1.0
huggingface-hub>=0.28.1
# Weights and biases
wandb>=0.17.7
wandb>=0.20.1
# Hydra
hydra-core>=1.3.2
omegaconf>=2.3.0
# Miscellaneous
requests>=2.32.3
tqdm>=4.66.3
# For causal_conv1d
causal_conv1d>=1.4.0
tqdm>=4.67.1

# Huggingface tools
HUGGINGFACE =
transformers>=4.52.4
hf-transfer>=0.1.9
datasets>=3.6.0
huggingface-hub>=0.32.6

# Required to run SSMs
# To install on cpu environment (ex. for IDE support):
# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation
SSM =
mamba_ssm[causal-conv1d]==2.2.4

DEV =
# Pre-commit git hook
pre-commit>=4.0.1
pre-commit>=4.2.0
# Required for testing
pytest>=8.3.2
pytest>=8.4.0
pytest-depends>=1.0.1
pytest-xdist>=3.6.1
pytest-xdist>=3.7.0
# Somehow needed for Megatron to work with base image 24.11
setuptools>=75.6.0
setuptools>=80.9.0

# Required for building the documentation
DOCS =
Expand Down
17 changes: 4 additions & 13 deletions tests/test_ssms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,15 @@
from fast_llm.engine.schedule.schedule import Schedule
from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames
from fast_llm.layers.ssm.config import SSMBlockType
from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2
from fast_llm.layers.ssm.llamba_block import LlambaBlock
from fast_llm.layers.ssm.mamba_layer import MambaLayer
from fast_llm.layers.transformer.config import TransformerKwargs
from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat
from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat
from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel
from tests.common import get_hybrid_config, materialize_meta_tensors

try:
from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2
from fast_llm.layers.ssm.llamba_block import LlambaBlock
from fast_llm.layers.ssm.mamba_layer import MambaLayer
from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel
except Exception:
MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = (
None,
None,
None,
None,
)

try:
from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel
except ImportError:
Expand Down
Loading