In [1]:
import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from einops import rearrange, repeat

from torch.cuda.amp import custom_bwd, custom_fwd

In [27]:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn

In [2]:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import causal_conv1d_cuda

In [14]:
import selective_scan_cuda
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn

In [15]:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update

In [5]:
# d_model = dimension of input (size of a sequence element)
# d_state = SSM state expansion factor
# d_conv = convolution width (filter size?)
# expand = block expansion factor

In [6]:
batch, length, dim = 1, 64, 16
x = torch.randn(batch, length, dim).to("cuda")

In [7]:
A = nn.Conv1d(10, 10, 1).weight

In [10]:
class MambaInnerFn(torch.autograd.Function):
  @staticmethod
  @custom_fwd
  def forward(
    ctx, xz,
    conv1d_weight, conv1d_bias, x_proj_weight,
    delta_proj_weight, out_proj_weight, out_proj_bias,
    A, B = None, C = None, D = None, delta_bias = None,
    B_proj_bias = None, C_proj_bias = None, 
    delta_softplus = True, checkpoint_lvl = 1
  ):
    """
    xz: (batch, dim, seqlen), (batch, 2*input_dim*expand, seqlen)
    """
    assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d"
    assert checkpoint_lvl in [0, 1]
    L = xz.shape[-1]
    # delta_proj_weight is d_inner x dt_rank
    delta_rank = delta_proj_weight.shape[1]
    
    # A is d_inner x d_state
    d_state = A.shape[-1] * (1 if not A.is_complex() else 2)

    if torch.is_autocast_enabled():
      x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
      delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
      out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
      out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None)

    # how could xz.stride(-1) possible be unequal to 1?
    if xz.stride(-1) != 1:
      xz = xz.contiguous()

    # d = d_inner and w = convolution filter size
    # second dimension is equal to 1 because groups = in_channels = out_channels
    # d = d_inner = d_model * expand
    conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")

    # split across sequence element dimension
    # possible since in_proj send d_model to d_inner * 2
    x, z = xz.chunk(2, dim = 1)

    if conv1d_bias is not None:
      conv1d_bias = conv1d_bias.contiguous()

    # parameters of causal_conv1d_fwd are:
    # x - input sequence
    # conv1dweight, conv1dbias - weight and bias
    # seq_idx - perhaps it can output only on the seq element?
    # initial_states - no idea
    # final_states_out = no idea
    # activation - boolean, presumably whether to apply an activation function or not
    # conv1d_out should be (batch, inner_dim, seq_len)
    conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
      x, conv1d_weight, conv1d_bias, None, None, None, True
    )

    # collapse batch and sequence dimensions into a single full sequence dimension
    # and apply the x_proj linear layer
    # result is of shape (seq, self.dt_rank + self.d_state * 2)
    x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)

    # take the first dt_rank columns from x_dbl
    # and project them to d_inner
    # obtaining delta of shape (seq, d_inner)
    # finally uncompresses the batch x seq_len dim
    # obtaining output of shape (batch, seq, d_inner)
    delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)

    ctx.is_variable_B = B is None
    ctx.is_variable_C = C is None
    ctx.B_proj_bias_is_None = B_proj_bias is None
    ctx.C_proj_bias_is_None = C_proj_bias is None

    # if B wasn't passed to forward, that means it is input dependent
    # so we pick it from x_dbl
    if B is None:
      B = x_dbl[:, delta_rank:delta_rank + d_state]
      if B_proj_bias is not None:
        B = B + B_proj_bias.to(dtype = B.dtype)

      # if A is complex, then B must also be complex
      # a complex number is represented via 2 real numbers
      # which results in splitting the dstate dimension into 2

      if not A.is_complex():
        # why 1?
        B = rearrange(B, "(b l) dstate -> b 1 dstate l", l = L).contiguous()
      else:
        B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)",
                      l = L, two = 2).contiguous()

    else:
      if B.stride(-1) != 1:
        B = B.contiguous()
   
    if C is None:
      C = x_dbl[:, -d_state:]
      if C_proj_bias is not None:
        C = C + C_proj_bias.to(dtype = C.dtype)
      if not A.is_complex():
        C = rearrange(C, "(b l) dstate -> b 1 dstate l", l = L).contiguous()
      else:
        C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l = L, two = 2).contiguous()

    else:
      if C.stride(-1) != 1:
        C = C.contiguous()

    if D is not None:
      D = D.contiguous()

    # mamba_block = linear(ssm(sigma(conv(first_half)))*sigma(second_half))
    # out is the output of SSM left side, and out_z is the output of before 
    # multiplication (?)

    out, scan_intermediates, out_z = selective_scan_cuda.fwd(
      conv1d_out,
      delta,
      A, B, C, D,
      z,
      delta_bias,
      delta_softplus
    )

    ctx.delta_softplus = delta_softplus
    ctx.out_proj_bias_is_None = out_proj_bias is None
    ctx.checkpoint_lvl = checkpoint_lvl

    # recompute conv1d_out and delta in backward pass
    if checkpoint_lvl >= 1:
      conv1d_out, delta = None, None

    ctx.save_for_backward(
      xz,
      conv1d_weight,
      conv1d_bias,
      x_dbl,
      x_proj_weight,
      delta_proj_weight,
      out_proj_weight,
      conv1d_out,
      delta,
      A, B, C, D,
      delta_bias,
      scan_intermediates,
      out
    )

    return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)

  @staticmethod
  @custom_bwd
  def backward(ctx, dout):
    # dout: (batch, seqlen, dim)
    # presumably the output of the forward pass
    assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
    (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
      conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
    L = xz.shape[-1]
    delta_rank = delta_proj_weight.shape[1]
    d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
    x, z = xz.chunk(2, dim=1)
    if dout.stride(-1) != 1:
        dout = dout.contiguous()
    if ctx.checkpoint_lvl == 1:
        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
            x, conv1d_weight, conv1d_bias, None, None, None, True
        )
        delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                          "d (b l) -> b d l", l = L)
    # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
    # backward of selective_scan_cuda with the backward of chunk).
    dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
    dx, dz = dxz.chunk(2, dim=1)
    dout = rearrange(dout, "b l e -> e (b l)")
    dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
    dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
        conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
        ctx.delta_softplus,
        True  # option to recompute out_z
    )
    dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
    dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
    dD = dD if D is not None else None
    dx_dbl = torch.empty_like(x_dbl)
    dB_proj_bias = None
    if ctx.is_variable_B:
        if not A.is_complex():
            dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
        else:
            dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
        dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
        dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
        dB = None
    dC_proj_bias = None
    if ctx.is_variable_C:
        if not A.is_complex():
            dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
        else:
            dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
        dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
        dx_dbl[:, -d_state:] = dC  # (bl d)
        dC = None
    ddelta = rearrange(ddelta, "b d l -> d (b l)")
    ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
    dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
    dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
    dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
    dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
    dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
    # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
    # backward of conv1d with the backward of chunk).
    dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
        x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
    )
    dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
    dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
    return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
            dout_proj_weight, dout_proj_bias,
            dA, dB, dC, dD,
            ddelta_bias if delta_bias is not None else None,
            dB_proj_bias, dC_proj_bias, None)


In [11]:
def mamba_inner_fn(
    xz,
    conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B = None, C = None, D = None,
    delta_bias = None, 
    B_proj_bias = None,
    C_proj_bias = None,
    delta_softplus = True
):
  return MambaInnerFn.apply(
    xz,
    conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B, C, D,
    delta_bias, B_proj_bias,
    C_proj_bias, delta_softplus = True
  )

In [17]:
class Mamba(nn.Module):
  def __init__(
      self,
      # input size will be (batch, seq_len, d_model)
      d_model,
      d_state = 16,
      d_conv = 4,
      expand = 2,
      dt_rank = "auto",
      dt_min = 0.001,
      dt_max = 0.1,
      dt_init = "random",
      dt_scale = 1.0,
      dt_init_floor = 1e-4,
      conv_bias = True,
      bias = False,
      use_fast_path = True,
      layer_idx = None,
      device = None,
      dtype = None
  ):
    factory_kwargs = {
      "device": device,
      "dtype": dtype,
    }
    super().__init__()
    self.d_model = d_model
    self.d_state = d_state
    self.d_conv = d_conv

    # initial sequence element length is multiplied by self.expand
    self.expand = expand
    self.d_inner = int(self.expand * self.d_model)

    # rank is d_model/16, which is 1 for the used dimensions
    # or the supplied value (presumably an integer) if one is supplied
    self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank

    self.use_fast_path = use_fast_path
    self.layer_idx = layer_idx

    # shape becomes batch x seq_len x (initial dim * expand * 2)
    self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias = bias, **factory_kwargs)

    # input channels equal to d_inner
    # d_inner may be the final dimension of 
    # our input, so will it be transposed?
    self.conv1d = nn.Conv1d(
      in_channels = self.d_inner,
      out_channels = self.d_inner,
      bias = conv_bias,
      kernel_size = d_conv,
      groups = self.d_inner,
      padding = d_conv -1,
      **factory_kwargs
    )

    self.activation = "silu"
    self.act = nn.SiLU()

    # from output of conv1d
    # to (dt_rank + d_state * 2)
    # presumably this is the actual input to the ssm layer
    self.x_proj = nn.Linear(
      self.d_inner, self.dt_rank + self.d_state * 2, bias = False,
      **factory_kwargs
    )

    # from something of size (*, dt_rank) to something of size (*, d_inner)
    # so far nothing appears to be of size (*, dt_rank)
    # will the output of x_proj be split to obtain an input for this layer?
    self.dt_proj = nn.Linear(
      self.dt_rank,
      self.d_inner,
      bias = True,
      **factory_kwargs
    )

    # initialze special dt projection to preserve variance at initialization ??
    # init_std = scale / sqrt(dt_rank)
    # if vectors a and b of size d have mean 0 and variance 1
    # then their dot product has variance equal to their size (d)
    # therefore std equal to sqrt(d)
    # is rank equal to the size of something?
    dt_init_std = self.dt_rank ** -0.5 * dt_scale

    if dt_init == "constant":
      nn.init.constant_(self.dt_proj.weight, dt_init_std)
    elif dt_init == "random":
      nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
    else:
      raise NotImplementedError

    # F.softplus(dt_bias) should be between dt_min and dt_max, apparently
    # softplus = 1/B*log(1+exp(B*x))
    # torch.rand is a uniform variable on [0, 1]
    # [0, math.log(dt_max) - math.log(dt_min)] + math.log(dt_min)
    # [log(min), log(max)]
    # dt is then uniform between min and max
    # this should be the output of softplus
    dt = torch.exp(
      torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
    ).clamp(min = dt_init_floor)

    # inverse of softplus
    # y=1/B*log(1+exp(B*x))
    # assuming B=1, y=log(1+exp(x))
    # 1+exp(x)=exp(y)
    # exp(x)=exp(y)-1
    # x=log(exp(y)-1)
    # x=log(exp(y)(1-1/exp(y)))
    # x=log(exp(y))+log(1-1/exp(y))
    # x=y+log(1-exp(-y))
    # x=y+log(-(exp(-y)-1))
    # x=y+log(-expm1(-y))
    inv_dt = dt + torch.log(-torch.expm1(-dt))

    # bias initialized to inv_dt
    with torch.no_grad():
      self.dt_proj.bias.copy_(inv_dt)

    # presumably .init.function will ignore this parameter
    self.dt_proj.bias._no_reinit = True

    # https://arxiv.org/pdf/2206.11893.pdf
    # each row of A is equal to [1, ... d_state]
    A = repeat(
      torch.arange(1, self.d_state + 1, dtype = torch.float32,
      device = device),
      "n -> d n",
      d = self.d_inner
    ).contiguous()

    A_log = torch.log(A)
    self.A_log = nn.Parameter(A_log)
    self.A_log._no_weight_decay = True

    self.D = nn.Parameter(torch.ones(self.d_inner, device = device))

    self.D._no_weight_decay = True

    self.out_proj = nn.Linear(self.d_inner, self.d_model, bias = bias, **factory_kwargs)

  def _get_state_from_cache(self, inference_params, batch_size, initialize_states = False):
    assert self.layer_idx is not None
    if self.layer_idx not in inference_params.key_value_memory_dict:
      batch_shape = (batch_size, )
      conv_state = torch.zeros(
        batch_size,
        self.d_model * self.expand,
        self.d_conv,
        device = self.conv1d.weight.device,
        dtype = self.conv1d.weight.dtype
      )
      ssm_state = torch.zeros(
        batch_size,
        self.d_model * self.expand,
        self.d_state,
        device = self.dt_proj.weight.device,
        dtype = self.dt_proj.weight.device,
      )
      inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
    else:
      conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]

      if initialize_states:
        conv_state.zero_()
        ssm_state.zero_()
    
    return conv_state, ssm_state

  def step(self, hidden_states, conv_state, ssm_state):
    dtype = hidden_states.dtype
    assert hidden_states.shape[1] == 1, "only support decoding with 1 token at a time for now"
    xz = self.in_proj(hidden_states.squeeze(1))
    x, z = xz.chunk(2, dim = -1)

    x = causal_conv1d_update(
      x,
      conv_state,
      rearrange(self.conv1d.weight, "d 1 w -> d w"),
      self.conv1d.bias,
      self.activation
    )

    x_db = self.x_proj(x)
    dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim = -1)

    dt = F.linear(dt, self.dt_proj.weight)
    A = -torch.exp(self.A_log.float())

    y = selective_state_update(
      ssm_state, x, dt, A, B, C, self.D, z = z, dt_bias = self.dt_proj.bias, dt_softplus = True
    )

    out = self.out_proj(y)
    return out.unsqueeze(1), conv_state, ssm_state


  def forward(self, hidden_states, inference_params = None):
    """
    hidden_states: (B, L, D)
    this appears to be the actual input x
    inference_params: ?
    returns same shape as hidden_states
    """
    batch, seqlen, dim = hidden_states.shape

    conv_state, ssm_state = None, None
    if inference_params is not None:
      conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
      if inference_params.seqlen_offset > 0:
        out, _, _ = self.step(hidden_states, conv_state, ssm_state)
        return out

    # in_proj projects to 2*model dim
    # combine batch and seq_len, transpose, linear layer
    # and revert everything
    # output is (batch, 2*expand*dim, seq_len)
    xz = rearrange(
      # flatten along batch dimension (gathering all sequences into one big sequence)
      # and transpose so it starts with the element size dimension
      # output is [d_inner, d_model] * [d_model, full_seq_len]
      # which is [d_inner, full_seq_len]
      self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
      "d (b l) -> b d l",
      l = seqlen
    )

    if self.in_proj.bias is not None:
      xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

    # (d_inner, d_state)
    # A[i][j]=-j
    # S4D-Real initialization
    A = -torch.exp(self.A_log.float())

    if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:
      out = mamba_inner_fn(
        xz,
        self.conv1d.weight,
        self.conv1d.bias,
        self.x_proj.weight,
        self.dt_proj.weight,
        self.out_proj.weight,
        self.out_proj.bias,
        A,
        None,
        None,
        self.D.float(),
        delta_bias = self.dt_proj.bias.float(),
        delta_softplus = True
      )
    else:
      x, z = xz.chunk(2, dim = 1)
      if conv_state is not None:
        conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))

      x = causal_conv1d_fn(
        x = x,
        weight = rearrange(self.weight, "d 1 w -> d w"),
        bias = self.conv1d.bias,
        activation = self.activation
      )

      x_dbl = self.x_proj(rearrange(x, "b d l -> (b l d)")) # (bl d)
      dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim = -1)

      dt = self.dt_proj.weight @ dt.t()
      dt = rearrange(dt, "d (b l) -> b d l", l = seqlen)

      B = rearrange(B, "(b l) dstate -> b dstate l", l = seqlen).contiguous()
      C = rearrange(C, "(b l) dstate -> b dstate l", l = seqlen).contiguous()

      y = selective_scan_fn(
        x,
        dt,
        A,
        B,
        C,
        self.D.float(),
        z = z,
        delta_bias = self.dt_proj.bias.float(),
        delta_softplus = True,
        return_last_state = ssm_state is not None
      )

      if ssm_state is not None:
        y, last_state = y
        ssm_state.copy_(last_state)

      y = rearrange(y, "b d l -> b l d")
      out = self.out_proj(y)
    return out

  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
    device = self.out_proj.weight.device
    conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
    conv_state = torch.zeros(
        batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
    )
    ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
    # ssm_dtype = torch.float32
    ssm_state = torch.zeros(
        batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
    )
    return conv_state, ssm_state

In [23]:
batch, length, dim = 2, 64, 16
model = Mamba(
  d_model = dim,
  d_state = 16,
  d_conv = 4,
  expand = 2
).to("cuda")

In [21]:
x = torch.randn(batch, length, dim).to("cuda")

In [24]:
y = model(x)

In [25]:
y.shape

torch.Size([2, 64, 16])

In [26]:
class Block(nn.Module):
    def __init__(
        self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
    ):
        """
        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"

        This Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA/MLP -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Add -> LN -> Mixer, returning both
        the hidden_states (output of the mixer) and the residual.
        This is purely for performance reasons, as we can fuse add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        """
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.mixer = mixer_cls(dim)
        self.norm = norm_cls(dim)
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"
            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

    def forward(
        self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
    ):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: hidden_states = Mixer(LN(residual))
        """
        if not self.fused_add_norm:
            residual = (hidden_states + residual) if residual is not None else hidden_states
            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
            hidden_states, residual = fused_add_norm_fn(
                hidden_states,
                self.norm.weight,
                self.norm.bias,
                residual=residual,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
                eps=self.norm.eps,
            )
        hidden_states = self.mixer(hidden_states, inference_params=inference_params)
        return hidden_states, residual

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)