In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../..")
import torch

from xlstm.blocks.slstm.cell import sLSTMCell, sLSTMCellConfig
from xlstm.blocks.slstm.block import sLSTMBlock, sLSTMBlockConfig
from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
from xlstm.components.feedforward import FeedForwardConfig

backend = "vanilla" #"cuda" if torch.cuda.is_available() else "vanilla"
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
cell = sLSTMCell(sLSTMCellConfig(hidden_size=128, num_heads=2, backend=backend, dtype="bfloat16", function="slstm", enable_automatic_mixed_precision=False)).to(device)

In [3]:
res = cell(torch.zeros([4, 4, 4*128], dtype=torch.bfloat16).to(device))
print(res[0].shape, res[1].shape)

torch.Size([4, 2, 4, 64]) torch.Size([4, 4, 128])


In [4]:
block = sLSTMBlock(sLSTMBlockConfig(slstm=sLSTMLayerConfig(embedding_dim=32,
    num_heads=4, backend=backend, function="lstm", bias_init="standard", recurrent_weight_init="standard",
    _block_idx=2, _num_blocks=4, enable_automatic_mixed_precision=False), feedforward=FeedForwardConfig())).to(device).to(dtype=torch.bfloat16)

In [5]:
inp = torch.ones((3, 4, 32), dtype=torch.bfloat16).to(device)
inp.requires_grad = True
res = block(inp)

In [6]:
res.shape

torch.Size([3, 4, 32])

In [7]:
res[1, 1].sum().backward()
# check for causality, batch interconnect
print(inp.grad[2, 1].sum(), "== 0")
print(inp.grad[0, 1].sum(), "== 0")
print(inp.grad[1, 2].sum(), "== 0")
print(inp.grad[1, 0].sum(), "!= 0")
print(inp.grad[1, 1].sum(), "!= 0")



tensor(0., device='cuda:0', dtype=torch.bfloat16) == 0
tensor(0., device='cuda:0', dtype=torch.bfloat16) == 0
tensor(0., device='cuda:0', dtype=torch.bfloat16) == 0
tensor(1.1250, device='cuda:0', dtype=torch.bfloat16) != 0
tensor(33.7500, device='cuda:0', dtype=torch.bfloat16) != 0


In [8]:
from xlstm.xlstm_block_stack import xLSTMBlockStackConfig, xLSTMBlockStack, mLSTMBlockConfig

In [9]:
bs = xLSTMBlockStack(
    xLSTMBlockStackConfig(
        slstm_block=sLSTMBlockConfig(slstm=sLSTMLayerConfig(backend=backend)),
        slstm_at="all",
        num_blocks=48,
        embedding_dim=1024,
    )
)

In [10]:
bs

xLSTMBlockStack(
  (blocks): ModuleList(
    (0-47): 48 x sLSTMBlock(
      (xlstm_norm): LayerNorm()
      (xlstm): sLSTMLayer(
        (conv1d): CausalConv1d(
          (conv): Conv1d(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024)
        )
        (conv_act_fn): SiLU()
        (fgate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (igate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (zgate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (ogate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (slstm_cell): sLSTMCell_vanilla(function=slstm, hidden_size=1024, num_heads=4)
        (group_norm): MultiHe

In [11]:
bs = xLSTMBlockStack(xLSTMBlockStackConfig(mlstm_block=mLSTMBlockConfig(), context_length=2048, num_blocks=48, embedding_dim=1024))

In [12]:
bs

xLSTMBlockStack(
  (blocks): ModuleList(
    (0-47): 48 x mLSTMBlock(
      (xlstm_norm): LayerNorm()
      (xlstm): mLSTMLayer(
        (proj_up): Linear(in_features=1024, out_features=4096, bias=False)
        (q_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (k_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (v_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (conv1d): CausalConv1d(
          (conv): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)
        )
        (conv_act_fn): SiLU()
        (mlstm_cell): mLSTMCell(
          (igate): Linear(in_features=6144, out_features=4, bias=True)
          (fgate): Linear(in_features=6144, out_features=4, bias=True)
    