In [1]:
import sys

import torch
import numpy as np

In [2]:
from xlstm.xlstm_large.model import xLSTMLargeConfig, xLSTMLarge
from mlstm_kernels.torch import get_available_mlstm_step_kernels, get_available_mlstm_kernels, get_available_mlstm_sequence_kernels

In [3]:
get_available_mlstm_kernels(), get_available_mlstm_step_kernels(), get_available_mlstm_sequence_kernels()

(['chunkwise--native_autograd',
  'chunkwise--native_custbw',
  'chunkwise--triton_limit_chunk',
  'chunkwise--triton_xl_chunk',
  'chunkwise--triton_xl_chunk_siging',
  'parallel--native_autograd',
  'parallel--native_custbw',
  'parallel--native_stablef_autograd',
  'parallel--native_stablef_custbw',
  'parallel--triton_limit_headdim',
  'parallel--native_siging_autograd',
  'parallel--native_siging_custbw'],
 ['native', 'triton'],
 ['native_sequence__native', 'native_sequence__triton'])

In [4]:
xlstm_config = xLSTMLargeConfig(
    embedding_dim=512,
    num_heads=4,
    num_blocks=6,
    vocab_size=2048,
    return_last_states=True,
    mode="inference",
    chunkwise_kernel="chunkwise--triton_xl_chunk", # xl_chunk == TFLA kernels
    sequence_kernel="native_sequence__triton",
    step_kernel="triton",
)

In [5]:
xlstm = xLSTMLarge(xlstm_config)

In [6]:
xlstm

xLSTMLarge(
  (embedding): Embedding(2048, 512)
  (backbone): xLSTMLargeBlockStack(
    (blocks): ModuleList(
      (0-5): 6 x mLSTMBlock(
        (norm_mlstm): RMSNorm()
        (mlstm_layer): mLSTMLayer(
          (q): Linear(in_features=512, out_features=256, bias=False)
          (k): Linear(in_features=512, out_features=256, bias=False)
          (v): Linear(in_features=512, out_features=512, bias=False)
          (ogate_preact): Linear(in_features=512, out_features=512, bias=False)
          (igate_preact): Linear(in_features=512, out_features=4, bias=True)
          (fgate_preact): Linear(in_features=512, out_features=4, bias=True)
          (ogate_act_fn): Sigmoid()
          (mlstm_backend): mLSTMBackend(mLSTMBackendConfig(chunkwise_kernel='chunkwise--triton_xl_chunk', sequence_kernel='native_sequence__triton', step_kernel='triton', mode='inference', chunk_size=64, return_last_states=True, autocast_kernel_dtype='bfloat16', eps=1e-06, inference_state_dtype='float32', normalize_

In [7]:
xlstm = xlstm.to("cuda")

In [8]:
input = torch.randint(0, 2048, (3, 256)).to("cuda")
input.shape

torch.Size([3, 256])

In [9]:
out = xlstm(input)

In [10]:
if len(out) == 2:
    out, state = out

In [11]:
out.shape[1:] == (256, 2048)

True

In [12]:
state.keys()

dict_keys([0, 1, 2, 3, 4, 5])

In [13]:
len(state), len(state[0])

(6, 3)

In [14]:
input[:, 0:1].shape, input.shape

(torch.Size([3, 1]), torch.Size([3, 256]))

In [15]:
step_out, step_state = xlstm(input[:, 0:1], state)

In [16]:
step_out.shape

torch.Size([3, 1, 2048])

In [17]:
out_chunkwise, last_state_chunkwise = xlstm(input)

In [18]:
out_steps = []
state = None
for i in range(input.shape[1]):
    out_step, state = xlstm(input[:, i:i + 1], state)
    out_steps.append(out_step)

In [19]:
out_steps = torch.cat(out_steps, dim=1)

In [20]:
out_steps.shape, out_chunkwise.shape

(torch.Size([3, 256, 2048]), torch.Size([3, 256, 2048]))

In [21]:
(out_chunkwise - out_steps).abs().max()

tensor(0.0056, device='cuda:0', grad_fn=<MaxBackward1>)

In [22]:
torch.allclose(out_chunkwise, out_steps, atol=7e-2, rtol=1e-3)

True

In [23]:
list(xlstm.state_dict().keys())

['embedding.weight',
 'backbone.blocks.0.norm_mlstm.weight',
 'backbone.blocks.0.mlstm_layer.q.weight',
 'backbone.blocks.0.mlstm_layer.k.weight',
 'backbone.blocks.0.mlstm_layer.v.weight',
 'backbone.blocks.0.mlstm_layer.ogate_preact.weight',
 'backbone.blocks.0.mlstm_layer.igate_preact.weight',
 'backbone.blocks.0.mlstm_layer.igate_preact.bias',
 'backbone.blocks.0.mlstm_layer.fgate_preact.weight',
 'backbone.blocks.0.mlstm_layer.fgate_preact.bias',
 'backbone.blocks.0.mlstm_layer.multihead_norm.weight',
 'backbone.blocks.0.mlstm_layer.out_proj.weight',
 'backbone.blocks.0.norm_ffn.weight',
 'backbone.blocks.0.ffn.proj_up_gate.weight',
 'backbone.blocks.0.ffn.proj_up.weight',
 'backbone.blocks.0.ffn.proj_down.weight',
 'backbone.blocks.1.norm_mlstm.weight',
 'backbone.blocks.1.mlstm_layer.q.weight',
 'backbone.blocks.1.mlstm_layer.k.weight',
 'backbone.blocks.1.mlstm_layer.v.weight',
 'backbone.blocks.1.mlstm_layer.ogate_preact.weight',
 'backbone.blocks.1.mlstm_layer.igate_preact.we