In [None]:
%load_ext autoreload
%autoreload 2
import sys 

sys.path.append('../..')
from omegaconf import OmegaConf
from pprint import pprint
from dacite import from_dict
from dacite import Config as DaciteConfig
import torch

from xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig

In [2]:
# create new model
xlstm_cfg = """ 
vocab_size: 50304
context_length: 2048      
num_blocks: 24 #!
embedding_dim: 768 #!
tie_weights: false
weight_decay_on_embedding: false
mlstm_block:
  mlstm:
    conv1d_kernel_size: 4
    qkv_proj_blocksize: 4
    num_heads: 4
"""
cfg = OmegaConf.create(xlstm_cfg)
cfg = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))
model_new = xLSTMLMModel(cfg)

In [3]:
pprint(cfg)

xLSTMLMModelConfig(mlstm_block=mLSTMBlockConfig(mlstm=mLSTMLayerConfig(proj_factor=2.0,
                                                                       round_proj_up_dim_up=True,
                                                                       round_proj_up_to_multiple_of=64,
                                                                       _proj_up_dim=1536,
                                                                       conv1d_kernel_size=4,
                                                                       qkv_proj_blocksize=4,
                                                                       num_heads=4,
                                                                       embedding_dim=768,
                                                                       bias=False,
                                                                       dropout=0.0,
                                                                       context_length=2048,
 

In [4]:
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [5]:
x_in = torch.randint(0, 50304, (1, 32)).to(device=DEVICE)

In [6]:
model_new = model_new.to(device=DEVICE)

In [7]:
model_new

xLSTMLMModel(
  (xlstm_block_stack): xLSTMBlockStack(
    (blocks): ModuleList(
      (0-23): 24 x mLSTMBlock(
        (xlstm_norm): LayerNorm()
        (xlstm): mLSTMLayer(
          (proj_up): Linear(in_features=768, out_features=3072, bias=False)
          (q_proj): LinearHeadwiseExpand(in_features=1536, num_heads=384, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (k_proj): LinearHeadwiseExpand(in_features=1536, num_heads=384, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (v_proj): LinearHeadwiseExpand(in_features=1536, num_heads=384, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (conv1d): CausalConv1d(
            (conv): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          )
          (conv_act_fn): SiLU()
          (mlstm_cell): mLSTMCell(
            (igate): Linear(in_features=4608, out_features=4, bias=True)
           

In [8]:
y_new = model_new(x_in)

In [9]:
y_new.shape

torch.Size([1, 32, 50304])

In [10]:
y_new_step = []
state = None
for x in x_in.split(1, dim=1):
    y, state = model_new.step(x, state)
    y_new_step.append(y)
y_new_step = torch.cat(y_new_step, dim=1)
print(x.shape)

torch.Size([1, 1])


In [11]:
y_new_step.shape

torch.Size([1, 32, 50304])

In [12]:
y_new - y_new_step

tensor([[[-3.5763e-07,  1.1921e-07, -5.9605e-08,  ..., -2.9802e-08,
          -2.9802e-07, -4.9174e-07],
         [ 0.0000e+00,  1.7881e-07, -1.7136e-07,  ...,  5.9605e-08,
           1.7881e-07,  1.7881e-07],
         [-4.4703e-08, -2.3842e-07, -2.6822e-07,  ...,  2.2352e-07,
          -1.1921e-07, -9.9652e-08],
         ...,
         [-1.0431e-07, -2.3842e-07, -1.1921e-07,  ...,  1.1921e-07,
          -8.9407e-08, -3.5763e-07],
         [ 0.0000e+00,  1.7881e-07, -1.1921e-07,  ..., -2.3842e-07,
           1.1921e-07,  2.9802e-07],
         [-1.1921e-07, -4.4703e-08, -1.7881e-07,  ...,  0.0000e+00,
           2.9802e-08, -1.2107e-07]]], device='cuda:0', grad_fn=<SubBackward0>)

In [13]:
torch.allclose(y_new, y_new_step, atol=1e-5)

True

## verify config fields are passed correctly

In [14]:
pprint(cfg)

xLSTMLMModelConfig(mlstm_block=mLSTMBlockConfig(mlstm=mLSTMLayerConfig(proj_factor=2.0,
                                                                       round_proj_up_dim_up=True,
                                                                       round_proj_up_to_multiple_of=64,
                                                                       _proj_up_dim=1536,
                                                                       conv1d_kernel_size=4,
                                                                       qkv_proj_blocksize=4,
                                                                       num_heads=4,
                                                                       embedding_dim=768,
                                                                       bias=False,
                                                                       dropout=0.0,
                                                                       context_length=2048,
 

In [15]:
pprint(model_new.config)

xLSTMLMModelConfig(mlstm_block=mLSTMBlockConfig(mlstm=mLSTMLayerConfig(proj_factor=2.0,
                                                                       round_proj_up_dim_up=True,
                                                                       round_proj_up_to_multiple_of=64,
                                                                       _proj_up_dim=1536,
                                                                       conv1d_kernel_size=4,
                                                                       qkv_proj_blocksize=4,
                                                                       num_heads=4,
                                                                       embedding_dim=768,
                                                                       bias=False,
                                                                       dropout=0.0,
                                                                       context_length=2048,
 

In [16]:
# Here the _num_blocks field must match the number of blocks in the model
pprint(model_new.xlstm_block_stack.blocks[0].config)

xLSTMBlockConfig(mlstm=mLSTMLayerConfig(proj_factor=2.0,
                                        round_proj_up_dim_up=True,
                                        round_proj_up_to_multiple_of=64,
                                        _proj_up_dim=1536,
                                        conv1d_kernel_size=4,
                                        qkv_proj_blocksize=4,
                                        num_heads=4,
                                        embedding_dim=768,
                                        bias=False,
                                        dropout=0.0,
                                        context_length=2048,
                                        _num_blocks=24,
                                        _inner_embedding_dim=1536),
                 slstm=None,
                 feedforward=None,
                 _num_blocks=24,
                 _block_idx=0)
