In [None]:
%load_ext autoreload
%autoreload 2
import sys 

sys.path.append('../..')
from omegaconf import OmegaConf
from pprint import pprint
from dacite import from_dict
from dacite import Config as DaciteConfig
import torch

from xlstm.xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# xLSTM Usage Example

In [3]:
xlstm_cfg = f""" 
mlstm_block:
  mlstm:
    conv1d_kernel_size: 4
    qkv_proj_blocksize: 4
    num_heads: 4
slstm_block:
  slstm:
    backend: {'cuda' if torch.cuda.is_available() else 'vanilla'} #! only vanilla here works
    num_heads: 4
    conv1d_kernel_size: 4
    bias_init: powerlaw_blockdependent
  feedforward:
    proj_factor: 1.3
    act_fn: gelu
context_length: 256
num_blocks: 7
embedding_dim: 128
slstm_at: [1] #[1] # for [] it also works, so if no sLSTM is in the stack
"""
cfg = OmegaConf.create(xlstm_cfg)
cfg = from_dict(data_class=xLSTMBlockStackConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))
xlstm_stack = xLSTMBlockStack(cfg)

In [4]:
pprint(cfg)

xLSTMBlockStackConfig(mlstm_block=mLSTMBlockConfig(mlstm=mLSTMLayerConfig(proj_factor=2.0,
                                                                          round_proj_up_dim_up=True,
                                                                          round_proj_up_to_multiple_of=64,
                                                                          _proj_up_dim=256,
                                                                          conv1d_kernel_size=4,
                                                                          qkv_proj_blocksize=4,
                                                                          num_heads=4,
                                                                          embedding_dim=128,
                                                                          bias=False,
                                                                          dropout=0.0,
                                                                

In [5]:
xlstm_stack

xLSTMBlockStack(
  (blocks): ModuleList(
    (0): mLSTMBlock(
      (xlstm_norm): LayerNorm()
      (xlstm): mLSTMLayer(
        (proj_up): Linear(in_features=128, out_features=512, bias=False)
        (q_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (k_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (v_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (conv1d): CausalConv1d(
          (conv): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
        )
        (conv_act_fn): SiLU()
        (mlstm_cell): mLSTMCell(
          (igate): Linear(in_features=768, out_features=4, bias=True)
          (fgate): Linear(in_features=768, out_features=4, bias=True)
          (outnorm): Mult

In [6]:
x = torch.randn(4, 256, 128).to(device=device)

In [7]:
xlstm_stack = xlstm_stack.to(device=device)

In [8]:
y = xlstm_stack(x)

In [9]:
y.shape

torch.Size([4, 256, 128])

### Readme example

In [11]:
import torch
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
import sys 
sys.path.append('..')
from xlstm import xLSTMBlockStack, xLSTMBlockStackConfig

xlstm_cfg = f""" 
mlstm_block:
  mlstm:
    conv1d_kernel_size: 4
    qkv_proj_blocksize: 4
    num_heads: 4
slstm_block:
  slstm:
    backend: {'cuda' if torch.cuda.is_available() else 'vanilla'}
    num_heads: 4
    conv1d_kernel_size: 4
    bias_init: powerlaw_blockdependent
  feedforward:
    proj_factor: 1.3
    act_fn: gelu
context_length: 256
num_blocks: 7
embedding_dim: 128
slstm_at: [1]
"""
cfg = OmegaConf.create(xlstm_cfg)
cfg = from_dict(data_class=xLSTMBlockStackConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))
xlstm_stack = xLSTMBlockStack(cfg)

x = torch.randn(4, 256, 128).to(device=device)
xlstm_stack = xlstm_stack.to(device=device)
y = xlstm_stack(x)
y.shape == x.shape

True

In [None]:
import torch
import sys

sys.path.append("..")
from xlstm import (
    xLSTMBlockStack,
    xLSTMBlockStackConfig,
    mLSTMBlockConfig,
    mLSTMLayerConfig,
    sLSTMBlockConfig,
    sLSTMLayerConfig,
    FeedForwardConfig,
)

cfg = xLSTMBlockStackConfig(
    mlstm_block=mLSTMBlockConfig(
        mlstm=mLSTMLayerConfig(
            conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
        )
    ),
    slstm_block=sLSTMBlockConfig(
        slstm=sLSTMLayerConfig(
            backend="cuda" if torch.cuda.is_available() else "vanilla",
            num_heads=4,
            conv1d_kernel_size=4,
            bias_init="powerlaw_blockdependent",
        ),
        feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
    ),
    context_length=256,
    num_blocks=7,
    embedding_dim=128,
    slstm_at=[1],

)

xlstm_stack = xLSTMBlockStack(cfg)

x = torch.randn(4, 256, 128).to(device=device)
xlstm_stack = xlstm_stack.to(device=device)
y = xlstm_stack(x)
y.shape == x.shape

True

In [12]:
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
import sys 
sys.path.append('..')
from xlstm.xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig

xlstm_cfg = f""" 
mlstm_block:
  mlstm:
    conv1d_kernel_size: 4
    qkv_proj_blocksize: 4
    num_heads: 4
slstm_block:
  slstm:
    backend: {'cuda' if torch.cuda.is_available() else 'vanilla'}
    num_heads: 4
    conv1d_kernel_size: 4
    bias_init: powerlaw_blockdependent
  feedforward:
    proj_factor: 1.3
    act_fn: gelu
context_length: 256
num_blocks: 7
embedding_dim: 128
slstm_at: [] #[1]
"""
cfg = OmegaConf.create(xlstm_cfg)
cfg = from_dict(data_class=xLSTMBlockStackConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))
xlstm_stack = xLSTMBlockStack(cfg)

x = torch.randn(4, 256, 128).to(device=device)
xlstm_stack = xlstm_stack.to(device=device)
y = xlstm_stack(x)
y.shape == x.shape

True

In [14]:
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
import sys 
sys.path.append('..')
from xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig

xlstm_cfg = f""" 
vocab_size: 50304
mlstm_block:
  mlstm:
    conv1d_kernel_size: 4
    qkv_proj_blocksize: 4
    num_heads: 4
slstm_block:
  slstm:
    backend: {'cuda' if torch.cuda.is_available() else 'vanilla'}
    num_heads: 4
    conv1d_kernel_size: 4
    bias_init: powerlaw_blockdependent
  feedforward:
    proj_factor: 1.3
    act_fn: gelu
context_length: 256
num_blocks: 7
embedding_dim: 128
slstm_at: [] #[1]
"""
cfg = OmegaConf.create(xlstm_cfg)
cfg = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))
xlstm_stack = xLSTMLMModel(cfg)

x = torch.randint(0, 50304, size=(4, 256)).to(device=device)
xlstm_stack = xlstm_stack.to(device)
y = xlstm_stack(x)
y.shape[1:] == (256, 50304)

True