In [1]:
!pip install torchinfo
!pip install xLSTM
!pip install ninja
!pip install xlstm --force-reinstall --no-cache-dir


Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0
Collecting xLSTM
  Downloading xlstm-1.0.8-py3-none-any.whl.metadata (19 kB)
Downloading xlstm-1.0.8-py3-none-any.whl (79 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.6/79.6 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xLSTM
Successfully installed xLSTM-1.0.8
Collecting ninja
  Downloading ninja-1.11.1.2-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.3 kB)
Downloading ninja-1.11.1.2-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.9/422.9 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja
Successfully installed ninja-1.11.1.2
Collecting xlstm
  Downloading xlstm-1.0.8-py3-

In [2]:
from torch import nn
import torch
from torchinfo import summary

# LSTM

In [3]:
class SimpleModel(nn.Module):
    def __init__(self,
                 embedding_dim, # Number of features - we have 1 because music!
                 hidden_size, # Hidden size of LSTM layer
                 num_layers, # Number of LSTM layers
                 ):

        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=95,
                                      embedding_dim=embedding_dim)

        self.LSTM = nn.LSTM(input_size=embedding_dim,
                          hidden_size=hidden_size,
                          num_layers=num_layers,
                          batch_first=True)

        self.dp1 = nn.Dropout(0.5)
        self.dp2 = nn.Dropout(0.5)

        self.fc1 = nn.Linear(hidden_size,
                             hidden_size)

        self.fc2 = nn.Linear(hidden_size, 95) # hidden size x vocab size

    def forward(self, x):
        x = self.embedding(x)
        _, (h_n, _) = self.LSTM(x) # Needs N x seq_L x FEATURE_DIM
        x = h_n[-1]
        x = self.dp1(x)
        x = self.fc1(x) # See w8_forecast solution for an explanation of slicing
        x = nn.functional.relu(x)
        x = self.dp2(x)
        x = self.fc2(x)
        return x

In [4]:
model = SimpleModel(128, 256, 4)
summary(model, input_size=(1, 128), dtypes=[torch.long])

Layer (type:depth-idx)                   Output Shape              Param #
SimpleModel                              [1, 95]                   --
├─Embedding: 1-1                         [1, 128, 128]             12,160
├─LSTM: 1-2                              [1, 128, 256]             1,974,272
├─Dropout: 1-3                           [1, 256]                  --
├─Linear: 1-4                            [1, 256]                  65,792
├─Dropout: 1-5                           [1, 256]                  --
├─Linear: 1-6                            [1, 95]                   24,415
Total params: 2,076,639
Trainable params: 2,076,639
Non-trainable params: 0
Total mult-adds (M): 252.81
Input size (MB): 0.00
Forward/backward pass size (MB): 0.40
Params size (MB): 8.31
Estimated Total Size (MB): 8.70

# xLSTM

In [5]:
import torch
from torch import nn
from xlstm import xLSTMBlockStack, xLSTMBlockStackConfig, mLSTMBlockConfig, mLSTMLayerConfig, sLSTMBlockConfig, sLSTMLayerConfig, FeedForwardConfig

class SimpleModelWithxLSTM(nn.Module):
    def __init__(self,
                 vocab_size,
                 embedding_dim,
                 hidden_size,
                 context_length,
                 num_blocks,
                 slstm_at,
                 dropout_prob = 0.5):
        super().__init__()

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim
        )


        # xLSTM configuration
        self.xLSTM_cfg = xLSTMBlockStackConfig(
            mlstm_block=mLSTMBlockConfig(
                mlstm=mLSTMLayerConfig(
                    conv1d_kernel_size=4,
                    qkv_proj_blocksize=4,
                    num_heads=4
                )
            ),
            slstm_block=sLSTMBlockConfig(
                slstm=sLSTMLayerConfig(
                    backend="vanilla",
                    num_heads=4,
                    conv1d_kernel_size=4,
                    bias_init="powerlaw_blockdependent",
                ),
                feedforward=FeedForwardConfig(
                    proj_factor=1.3,
                    act_fn="gelu"
                ),
            ),
            context_length=context_length,
            num_blocks=num_blocks,
            embedding_dim=embedding_dim,
            slstm_at=slstm_at,
        )

        # Initialize xLSTM stack
        self.xLSTM = xLSTMBlockStack(self.xLSTM_cfg)

        self.dropout_1 = nn.Dropout(dropout_prob)


        # Fully connected layers
        self.fc1 = nn.Linear(embedding_dim, hidden_size)
        self.dropout_2 = nn.Dropout(dropout_prob)
        self.fc2 = nn.Linear(hidden_size, vocab_size)


    def forward(self, x):
        # Embed the input
        x = self.embedding(x)  # Shape: [batch_size, seq_length, embedding_dim]

        # Pass through the xLSTM stack
        x = self.xLSTM(x)  # Shape: [batch_size, seq_length, embedding_dim]

        x = self.dropout_1(x)
        # Take the last sequence step (e.g., for classification tasks)
        x = x[:, -1, :]  # Shape: [batch_size, embedding_dim]

        # Fully connected layers
        x = self.fc1(x)  # Shape: [batch_size, hidden_size]
        x = nn.functional.relu(x)
        x = self.dropout_2(x)
        x = self.fc2(x)  # Shape: [batch_size, vocab_size]

        return x

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


In [6]:
model = SimpleModelWithxLSTM(
    vocab_size=95,
    embedding_dim=128,
    hidden_size=256,
    context_length=128,
    num_blocks=4,
    slstm_at=[1]
)
summary(model, input_size = (1, 128), dtypes=[torch.long])

Layer (type:depth-idx)                                  Output Shape              Param #
SimpleModelWithxLSTM                                    [1, 95]                   --
├─Embedding: 1-1                                        [1, 128, 128]             12,160
├─xLSTMBlockStack: 1-2                                  [1, 128, 128]             --
│    └─ModuleList: 2-1                                  --                        --
│    │    └─mLSTMBlock: 3-1                             [1, 128, 128]             109,448
│    │    └─sLSTMBlock: 3-2                             [1, 128, 128]             108,032
│    │    └─mLSTMBlock: 3-3                             [1, 128, 128]             109,448
│    │    └─mLSTMBlock: 3-4                             [1, 128, 128]             109,448
│    └─LayerNorm: 2-2                                   [1, 128, 128]             128
├─Dropout: 1-3                                          [1, 128, 128]             --
├─Linear: 1-4                      