In [14]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")
import torch
print(torch.backends.cudnn.version())


from xlstm.blocks.slstm.cell import sLSTMCell, sLSTMCellConfig
from xlstm.blocks.slstm.block import sLSTMBlock, sLSTMBlockConfig
from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
from xlstm.components.feedforward import FeedForwardConfig

backend = "cuda" if torch.cuda.is_available() else "vanilla"
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
import os
import shutil

# Clean the torch extensions cache
extension_cache = "/home/lucy/.cache/torch_extensions/"
if os.path.exists(extension_cache):
    shutil.rmtree(extension_cache)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
8902


In [2]:
cell = sLSTMCell(sLSTMCellConfig(hidden_size=128, num_heads=2, backend="vanilla", dtype="bfloat16", function="slstm", enable_automatic_mixed_precision=False)).to(device)

In [3]:
res = cell(torch.zeros([4, 4, 4*128], dtype=torch.bfloat16).to(device))
print(res[0].shape, res[1].shape)

torch.Size([4, 2, 4, 64]) torch.Size([4, 4, 128])


In [19]:
block = sLSTMBlock(sLSTMBlockConfig(slstm=sLSTMLayerConfig(embedding_dim=32,
    num_heads=4, backend="vanilla", function="lstm", bias_init="standard", recurrent_weight_init="standard",
    _block_idx=2, _num_blocks=4, enable_automatic_mixed_precision=True), feedforward=FeedForwardConfig())).to(device).to(dtype=torch.bfloat16)

In [20]:
inp = torch.ones((3, 4, 32), dtype=torch.bfloat16).to(device)
inp.requires_grad = True
res = block(inp)

In [21]:
res.shape

torch.Size([3, 4, 32])

In [22]:
res[1, 1].sum().backward()
# check for causality, batch interconnect
print(inp.grad[2, 1].sum(), "== 0")
print(inp.grad[0, 1].sum(), "== 0")
print(inp.grad[1, 2].sum(), "== 0")
print(inp.grad[1, 0].sum(), "!= 0")
print(inp.grad[1, 1].sum(), "!= 0")



RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasGemmStridedBatchedEx(handle, opa, opb, (int)m, (int)n, (int)k, (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`

- autocast: Automatically casts operations to bfloat16 where appropriate, falling back to float32 when necessary.
- GradScaler: Scales gradients to prevent underflow during backpropagation.

In [16]:
import torch
import shutil

# Clean the torch extensions cache
extension_cache = "/home/lucy/.cache/torch_extensions/"
if os.path.exists(extension_cache):
    shutil.rmtree(extension_cache)

# Example of rebuilding or running your project
# Ensure the model is on the correct device and using compatible data types
block = sLSTMBlock(sLSTMBlockConfig(
    slstm=sLSTMLayerConfig(
        embedding_dim=32,
        num_heads=4,
        backend="vanilla",  # Use the correct backend
        function="lstm",
        bias_init="standard",
        recurrent_weight_init="standard",
        _block_idx=2,
        _num_blocks=4,
        enable_automatic_mixed_precision=True  # Enable AMP
    ),
    feedforward=FeedForwardConfig()
)).to(device).to(dtype=torch.bfloat16)

# Ensure input tensor is on the correct device and using compatible data type
inp = torch.ones((3, 4, 32), dtype=torch.bfloat16).to(device)
inp.requires_grad = True

# Use mixed precision
scaler = torch.cuda.amp.GradScaler()

# Run the forward pass
with torch.cuda.amp.autocast():
    res = block(inp)
    print(res.shape)  # Verify the output shape

# Run the backward pass
scaler.scale(res[1, 1].sum()).backward()

# Check gradients
print(inp.grad[2, 1].sum(), "== 0")
print(inp.grad[0, 1].sum(), "== 0")
print(inp.grad[1, 2].sum(), "== 0")
print(inp.grad[1, 0].sum(), "!= 0")
print(inp.grad[1, 1].sum(), "!= 0")


torch.Size([3, 4, 32])
tensor(0., device='cuda:0', dtype=torch.bfloat16) == 0
tensor(0., device='cuda:0', dtype=torch.bfloat16) == 0
tensor(0., device='cuda:0', dtype=torch.bfloat16) == 0
tensor(nan, device='cuda:0', dtype=torch.bfloat16) != 0
tensor(nan, device='cuda:0', dtype=torch.bfloat16) != 0


In [17]:
from xlstm.xlstm_block_stack import xLSTMBlockStackConfig, xLSTMBlockStack, mLSTMBlockConfig

In [18]:
block = sLSTMBlock(sLSTMBlockConfig(
    slstm=sLSTMLayerConfig(
        embedding_dim=32,
        num_heads=4,
        backend="vanilla",
        function="lstm",
        bias_init="standard",
        recurrent_weight_init="standard",
        _block_idx=2,
        _num_blocks=4,
        enable_automatic_mixed_precision=False
    ),
    feedforward=FeedForwardConfig()
)).to(device).to(dtype=torch.bfloat16)

inp = torch.ones((3, 4, 32), dtype=torch.bfloat16).to(device)
inp.requires_grad = True
res = block(inp)

print(f"Result shape: {res.shape}")


Result shape: torch.Size([3, 4, 32])


In [19]:
bs = xLSTMBlockStack(
    xLSTMBlockStackConfig(
        slstm_block=sLSTMBlockConfig(slstm=sLSTMLayerConfig(backend="vanilla")),
        slstm_at="all",
        num_blocks=48,
        embedding_dim=1024,
    )
)

In [20]:
bs

xLSTMBlockStack(
  (blocks): ModuleList(
    (0-47): 48 x sLSTMBlock(
      (xlstm_norm): LayerNorm()
      (xlstm): sLSTMLayer(
        (conv1d): CausalConv1d(
          (conv): Conv1d(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024)
        )
        (conv_act_fn): SiLU()
        (fgate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (igate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (zgate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (ogate): LinearHeadwiseExpand(in_features=1024, num_heads=4, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (slstm_cell): sLSTMCell_vanilla(function=slstm, hidden_size=1024, num_heads=4)
        (group_norm): MultiHe

In [21]:
bs = xLSTMBlockStack(xLSTMBlockStackConfig(mlstm_block=mLSTMBlockConfig(), context_length=2048, num_blocks=48, embedding_dim=1024))

In [8]:
bs

xLSTMBlockStack(
  (blocks): ModuleList(
    (0-47): 48 x mLSTMBlock(
      (xlstm_norm): LayerNorm()
      (xlstm): mLSTMLayer(
        (proj_up): Linear(in_features=1024, out_features=4096, bias=False)
        (q_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (k_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (v_proj): LinearHeadwiseExpand(in_features=2048, num_heads=512, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (conv1d): CausalConv1d(
          (conv): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)
        )
        (conv_act_fn): SiLU()
        (mlstm_cell): mLSTMCell(
          (igate): Linear(in_features=6144, out_features=4, bias=True)
          (fgate): Linear(in_features=6144, out_features=4, bias=True)
    

In [5]:
import os
import torch
import subprocess

# Set environment variables
os.environ['PATH'] = '/usr/local/cuda-12.1/bin:' + os.environ['PATH']
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda-12.1/lib64:' + os.environ.get('LD_LIBRARY_PATH', '')

# Clean previous builds
extension_cache = "/home/lucy/.cache/torch_extensions/"
if os.path.exists(extension_cache):
    subprocess.run(['rm', '-rf', extension_cache])

# Verify PyTorch configuration
print("Torch version:", torch.__version__)
print("Is CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("cuDNN version:", torch.backends.cudnn.version())
else:
    print("CUDA not available")

# Rebuild the extension
try:
    subprocess.run([
        '/usr/local/cuda-12.1/bin/nvcc',
        '--generate-dependencies-with-compile',
        '--dependency-output', 'lstm_pointwise.cuda.o.d',
        '-ccbin', '/home/lucy/anaconda3/envs/deep/bin/x86_64-conda-linux-gnu-cc',
        '-DTORCH_EXTENSION_NAME=lstm_HS32',
        # Add other necessary flags and arguments
    ], check=True)
except subprocess.CalledProcessError as e:
    print("Compilation failed:", e)


Torch version: 2.2.0
Is CUDA available: True
CUDA version: 12.1
cuDNN version: 8902
Compilation failed: Command '['/usr/local/cuda-12.1/bin/nvcc', '--generate-dependencies-with-compile', '--dependency-output', 'lstm_pointwise.cuda.o.d', '-ccbin', '/home/lucy/anaconda3/envs/deep/bin/x86_64-conda-linux-gnu-cc', '-DTORCH_EXTENSION_NAME=lstm_HS32']' returned non-zero exit status 1.


nvcc fatal   : No input files specified; use option --help for more information


In [7]:
import os
import subprocess

# Set environment variables
os.environ['PATH'] = '/usr/local/cuda-12.1/bin:' + os.environ['PATH']
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda-12.1/lib64:' + os.environ.get('LD_LIBRARY_PATH', '')

# Clean previous builds
extension_cache = "/home/lucy/.cache/torch_extensions/"
if os.path.exists(extension_cache):
    subprocess.run(['rm', '-rf', extension_cache])

# Verify PyTorch configuration
import torch
print("Torch version:", torch.__version__)
print("Is CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("cuDNN version:", torch.backends.cudnn.version())
else:
    print("CUDA not available")

# Rebuild the extension
try:
    subprocess.run([
        '/usr/local/cuda-12.1/bin/nvcc',
        '--generate-dependencies-with-compile',
        '--dependency-output', 'lstm_pointwise.cuda.o.d',
        '-ccbin', '/home/lucy/anaconda3/envs/deep/bin/x86_64-conda-linux-gnu-cc',
        '-DTORCH_EXTENSION_NAME=lstm_HS32',
        '-I', '/home/lucy/anaconda3/envs/deep/lib/python3.11/site-packages/torch/include',
        '-I', '/home/lucy/anaconda3/envs/deep/lib/python3.11/site-packages/torch/include/torch/csrc/api/include',
        '-I', '/home/lucy/anaconda3/envs/deep/lib/python3.11/site-packages/torch/include/TH',
        '-I', '/home/lucy/anaconda3/envs/deep/lib/python3.11/site-packages/torch/include/THC',
        '-I', '/usr/local/cuda-12.1/include',
        '-I', '/home/lucy/anaconda3/envs/deep/include/python3.11',
        '-D_GLIBCXX_USE_CXX11_ABI=0',
        '-D__CUDA_NO_HALF_OPERATORS__',
        '-D__CUDA_NO_HALF_CONVERSIONS__',
        '-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
        '-D__CUDA_NO_HALF2_OPERATORS__',
        '--expt-relaxed-constexpr',
        '-gencode=arch=compute_52,code=compute_52',
        '-gencode=arch=compute_52,code=sm_52',
        '--compiler-options', '-fPIC',
        '-Xptxas="-v"',
        '-gencode', 'arch=compute_80,code=compute_80',
        '-res-usage',
        '--use_fast_math',
        '-O3',
        '-Xptxas', '-O3',
        '--extra-device-vectorization',
        '-DSLSTM_HIDDEN_SIZE=32',
        '-DSLSTM_BATCH_SIZE=8',
        '-DSLSTM_NUM_HEADS=4',
        '-DSLSTM_NUM_STATES=2',
        '-DSLSTM_DTYPE_B=float',
        '-DSLSTM_DTYPE_R=__nv_bfloat16',
        '-DSLSTM_DTYPE_W=__nv_bfloat16',
        '-DSLSTM_DTYPE_G=__nv_bfloat16',
        '-DSLSTM_DTYPE_S=__nv_bfloat16',
        '-DSLSTM_DTYPE_A=float',
        '-DSLSTM_NUM_GATES=4',
        '-DSLSTM_SIMPLE_AGG=true',
        '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL_VALID=false',
        '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL=0.0',
        '-DSLSTM_FORWARD_CLIPVAL_VALID=false',
        '-DSLSTM_FORWARD_CLIPVAL=0.0',
        '-U__CUDA_NO_HALF_OPERATORS__',
        '-U__CUDA_NO_HALF_CONVERSIONS__',
        '-U__CUDA_NO_BFLOAT16_OPERATORS__',
        '-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
        '-U__CUDA_NO_BFLOAT162_OPERATORS__',
        '-U__CUDA_NO_BFLOAT162_CONVERSIONS__',
        '-std=c++17',
        '-c', './xlstm/xlstm/blocks/slstm/src/cuda/lstm_pointwise.cu',
        '-o', 'lstm_pointwise.cuda.o'
    ], check=True)
except subprocess.CalledProcessError as e:
    print("Compilation failed:", e)


Torch version: 2.2.0
Is CUDA available: True
CUDA version: 12.1
cuDNN version: 8902
Compilation failed: Command '['/usr/local/cuda-12.1/bin/nvcc', '--generate-dependencies-with-compile', '--dependency-output', 'lstm_pointwise.cuda.o.d', '-ccbin', '/home/lucy/anaconda3/envs/deep/bin/x86_64-conda-linux-gnu-cc', '-DTORCH_EXTENSION_NAME=lstm_HS32', '-I', '/home/lucy/anaconda3/envs/deep/lib/python3.11/site-packages/torch/include', '-I', '/home/lucy/anaconda3/envs/deep/lib/python3.11/site-packages/torch/include/torch/csrc/api/include', '-I', '/home/lucy/anaconda3/envs/deep/lib/python3.11/site-packages/torch/include/TH', '-I', '/home/lucy/anaconda3/envs/deep/lib/python3.11/site-packages/torch/include/THC', '-I', '/usr/local/cuda-12.1/include', '-I', '/home/lucy/anaconda3/envs/deep/include/python3.11', '-D_GLIBCXX_USE_CXX11_ABI=0', '-D__CUDA_NO_HALF_OPERATORS__', '-D__CUDA_NO_HALF_CONVERSIONS__', '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', '-D__CUDA_NO_HALF2_OPERATORS__', '--expt-relaxed-constexpr',

cc1plus: fatal error: ./xlstm/xlstm/blocks/slstm/src/cuda/lstm_pointwise.cu: No such file or directory
compilation terminated.
