In [1]:
!pip install -q mamba-ssm
!pip install -q causal-conv1d

In [2]:
import torch
import time
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

In [3]:
# config
torch.manual_seed(11)
device = "cuda"

In [4]:
with open("../input/for-use/shakespeare.txt", "r") as f:
    data = f.read()
print(data[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [5]:
chars = sorted(list(set(data)))
vocab_size = len(chars)

input_size = 1024
batch_size = 4

c2i = {c: i for i, c in enumerate(chars)}
i2c = {i: c for i, c in enumerate(chars)}

encode = lambda s: [c2i[c] for c in s]
decode = lambda l: "".join([i2c[i] for i in l])

data_ = torch.tensor(encode(data), dtype=torch.long)
n = int(len(data_) * 0.9)

train_data = data_[:n]
val_data = data_[n:]

In [6]:
def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - input_size, (batch_size,))
    x = torch.stack([data[i : i + input_size] for i in ix])
    y = torch.stack([data[i + 1 : i + input_size + 1] for i in ix])
    x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    return x, y

model = MambaLMHeadModel(d_model=768,n_layer=12,vocab_size=vocab_size,device=device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
loss_fct = torch.nn.CrossEntropyLoss()

In [7]:
# training
t_start = time.time()
max_iters = 1000
print_interval = 100
for iter in range(max_iters):
    xb, yb = get_batch("train")
    logits = model(xb).logits
    B, T, C = logits.shape
    logits = logits.view(B * T, C)
    yb = yb.view(B * T)
    loss = loss_fct(logits, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if iter % print_interval == 0:
        print('Iteration {} loss: {}'.format(iter + 1, loss.item()))

print("\n")
print(f"Number of parameters: {n_params}")
print(f"Training time:        {(time.time() - t_start)/60:.2f} min")

TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor

Invoked with: tensor([[[-0.4288, -0.7415,  0.2338,  ...,  0.4696, -0.7415,  0.1994],
         [-0.2644,  0.9622, -0.7267,  ..., -0.2105,  0.9622, -0.5146],
         [-1.2545, -0.0087,  0.3420,  ...,  0.3833, -0.0087,  0.5425],
         ...,
         [-0.0883, -0.8732,  0.4167,  ...,  0.7703, -0.8732,  0.5065],
         [-0.3103,  0.1674, -0.9188,  ..., -0.4987,  0.1674, -0.6746],
         [ 0.0817,  0.1187,  0.4305,  ...,  0.2771,  0.1187, -0.6790]],

        [[ 0.4696,  0.3384,  0.5967,  ...,  0.2312, -0.6446, -0.5635],
         [-0.2105, -0.9218,  0.6288,  ..., -0.8121,  0.4196, -0.1247],
         [ 0.3833, -0.1912, -0.0106,  ...,  0.5447,  0.8804, -0.5389],
         ...,
         [ 0.7703, -0.0258, -0.0229,  ...,  0.5308, -0.0322, -0.6238],
         [-0.4987,  1.0386, -0.5600,  ..., -0.3423, -0.3097,  0.6895],
         [ 0.2771, -0.1353,  0.4125,  ...,  0.3521, -0.4873, -0.0586]],

        [[ 0.3327, -0.5635,  0.4696,  ..., -0.9914, -0.7415,  0.3327],
         [-0.4432, -0.1247, -0.2105,  ..., -0.6084,  0.9622, -0.4432],
         [ 0.0035, -0.5389,  0.3833,  ..., -0.5454, -0.0087,  0.0035],
         ...,
         [-0.3613, -0.6238,  0.7703,  ...,  0.2984, -0.8732, -0.3613],
         [ 0.2868,  0.6895, -0.4987,  ..., -0.9830,  0.1674,  0.2868],
         [-0.2470, -0.0586,  0.2771,  ..., -0.0096,  0.1187, -0.2470]],

        [[ 0.4696,  0.3327, -0.6446,  ..., -0.7415, -0.4288, -0.9914],
         [-0.2105, -0.4432,  0.4196,  ...,  0.9622, -0.2644, -0.6084],
         [ 0.3833,  0.0035,  0.8804,  ..., -0.0087, -1.2545, -0.5454],
         ...,
         [ 0.7703, -0.3613, -0.0322,  ..., -0.8732, -0.0883,  0.2984],
         [-0.4987,  0.2868, -0.3097,  ...,  0.1674, -0.3103, -0.9830],
         [ 0.2771, -0.2470, -0.4873,  ...,  0.1187,  0.0817, -0.0096]]],
       device='cuda:0', requires_grad=True), tensor([[ 0.1220,  0.0901,  0.0736,  0.1439],
        [-0.0095, -0.1856,  0.4502,  0.1997],
        [-0.2036, -0.0495,  0.3747, -0.0074],
        ...,
        [ 0.2265,  0.2769,  0.0143,  0.1760],
        [ 0.0952,  0.1286,  0.1418, -0.4398],
        [ 0.1948, -0.0843, -0.3106, -0.2859]], device='cuda:0',
       requires_grad=True), Parameter containing:
tensor([ 0.4892,  0.4341,  0.1342,  ...,  0.0411, -0.0503, -0.2879],
       device='cuda:0', requires_grad=True), True

In [8]:
prompt_tokens = torch.tensor(encode("Shall I compare thee to a summer's "),dtype=torch.long, device=device).unsqueeze(1).T
res_tokens = model.generate(prompt_tokens, max_length=200,top_k=10,top_p=1.0,temperature=1.1,cg=True)
list_chars = res_tokens.tolist()[0]
print(decode(list_chars))

RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /usr/local/src/pytorch/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xae (0x7c12e295300e in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7c12e29191cf in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3f2 (0x7c12e29de722 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: void causal_conv1d_update_launch<64, 4, float, float>(ConvParamsBase&, CUstream_st*) + 0x88 (0x7c124eece408 in /opt/conda/lib/python3.10/site-packages/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so)
frame #4: void causal_conv1d_update_cuda<float, float>(ConvParamsBase&, CUstream_st*) + 0x1c5 (0x7c124eece6e5 in /opt/conda/lib/python3.10/site-packages/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so)
frame #5: causal_conv1d_update(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, bool) + 0x64f (0x7c124ee905af in /opt/conda/lib/python3.10/site-packages/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x2ce0d (0x7c124ee96e0d in /opt/conda/lib/python3.10/site-packages/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x3863a (0x7c124eea263a in /opt/conda/lib/python3.10/site-packages/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x144516 (0x5d20bc8ca516 in /opt/conda/bin/python3.10)
frame #9: _PyObject_MakeTpCall + 0x26b (0x5d20bc8c3a6b in /opt/conda/bin/python3.10)
frame #10: _PyEval_EvalFrameDefault + 0x54a6 (0x5d20bc8bf9d6 in /opt/conda/bin/python3.10)
frame #11: _PyFunction_Vectorcall + 0x6c (0x5d20bc8ca99c in /opt/conda/bin/python3.10)
frame #12: _PyEval_EvalFrameDefault + 0x320 (0x5d20bc8ba850 in /opt/conda/bin/python3.10)
frame #13: <unknown function> + 0x1504f2 (0x5d20bc8d64f2 in /opt/conda/bin/python3.10)
frame #14: _PyEval_EvalFrameDefault + 0x4c12 (0x5d20bc8bf142 in /opt/conda/bin/python3.10)
frame #15: <unknown function> + 0x1504f2 (0x5d20bc8d64f2 in /opt/conda/bin/python3.10)
frame #16: PyObject_Call + 0xbc (0x5d20bc8d6e8c in /opt/conda/bin/python3.10)
frame #17: _PyEval_EvalFrameDefault + 0x2d80 (0x5d20bc8bd2b0 in /opt/conda/bin/python3.10)
frame #18: _PyFunction_Vectorcall + 0x6c (0x5d20bc8ca99c in /opt/conda/bin/python3.10)
frame #19: _PyObject_FastCallDictTstate + 0x187 (0x5d20bc8c3007 in /opt/conda/bin/python3.10)
frame #20: _PyObject_Call_Prepend + 0x69 (0x5d20bc8d4ba9 in /opt/conda/bin/python3.10)
frame #21: <unknown function> + 0x2114c9 (0x5d20bc9974c9 in /opt/conda/bin/python3.10)
frame #22: _PyObject_MakeTpCall + 0x26b (0x5d20bc8c3a6b in /opt/conda/bin/python3.10)
frame #23: _PyEval_EvalFrameDefault + 0x5709 (0x5d20bc8bfc39 in /opt/conda/bin/python3.10)
frame #24: <unknown function> + 0x1504f2 (0x5d20bc8d64f2 in /opt/conda/bin/python3.10)
frame #25: PyObject_Call + 0xbc (0x5d20bc8d6e8c in /opt/conda/bin/python3.10)
frame #26: _PyEval_EvalFrameDefault + 0x2d80 (0x5d20bc8bd2b0 in /opt/conda/bin/python3.10)
frame #27: _PyFunction_Vectorcall + 0x6c (0x5d20bc8ca99c in /opt/conda/bin/python3.10)
frame #28: _PyObject_FastCallDictTstate + 0x187 (0x5d20bc8c3007 in /opt/conda/bin/python3.10)
frame #29: _PyObject_Call_Prepend + 0x69 (0x5d20bc8d4ba9 in /opt/conda/bin/python3.10)
frame #30: <unknown function> + 0x2114c9 (0x5d20bc9974c9 in /opt/conda/bin/python3.10)
frame #31: _PyObject_MakeTpCall + 0x26b (0x5d20bc8c3a6b in /opt/conda/bin/python3.10)
frame #32: _PyEval_EvalFrameDefault + 0x5709 (0x5d20bc8bfc39 in /opt/conda/bin/python3.10)
frame #33: <unknown function> + 0x1504f2 (0x5d20bc8d64f2 in /opt/conda/bin/python3.10)
frame #34: PyObject_Call + 0xbc (0x5d20bc8d6e8c in /opt/conda/bin/python3.10)
frame #35: _PyEval_EvalFrameDefault + 0x2d80 (0x5d20bc8bd2b0 in /opt/conda/bin/python3.10)
frame #36: _PyFunction_Vectorcall + 0x6c (0x5d20bc8ca99c in /opt/conda/bin/python3.10)
frame #37: _PyObject_FastCallDictTstate + 0x187 (0x5d20bc8c3007 in /opt/conda/bin/python3.10)
frame #38: _PyObject_Call_Prepend + 0x69 (0x5d20bc8d4ba9 in /opt/conda/bin/python3.10)
frame #39: <unknown function> + 0x2114c9 (0x5d20bc9974c9 in /opt/conda/bin/python3.10)
frame #40: _PyObject_MakeTpCall + 0x26b (0x5d20bc8c3a6b in /opt/conda/bin/python3.10)
frame #41: _PyEval_EvalFrameDefault + 0x5709 (0x5d20bc8bfc39 in /opt/conda/bin/python3.10)
frame #42: <unknown function> + 0x1504f2 (0x5d20bc8d64f2 in /opt/conda/bin/python3.10)
frame #43: PyObject_Call + 0xbc (0x5d20bc8d6e8c in /opt/conda/bin/python3.10)
frame #44: _PyEval_EvalFrameDefault + 0x2d80 (0x5d20bc8bd2b0 in /opt/conda/bin/python3.10)
frame #45: _PyFunction_Vectorcall + 0x6c (0x5d20bc8ca99c in /opt/conda/bin/python3.10)
frame #46: _PyObject_FastCallDictTstate + 0x187 (0x5d20bc8c3007 in /opt/conda/bin/python3.10)
frame #47: _PyObject_Call_Prepend + 0x69 (0x5d20bc8d4ba9 in /opt/conda/bin/python3.10)
frame #48: <unknown function> + 0x2114c9 (0x5d20bc9974c9 in /opt/conda/bin/python3.10)
frame #49: _PyObject_MakeTpCall + 0x26b (0x5d20bc8c3a6b in /opt/conda/bin/python3.10)
frame #50: _PyEval_EvalFrameDefault + 0x5709 (0x5d20bc8bfc39 in /opt/conda/bin/python3.10)
frame #51: _PyFunction_Vectorcall + 0x6c (0x5d20bc8ca99c in /opt/conda/bin/python3.10)
frame #52: _PyEval_EvalFrameDefault + 0x13ca (0x5d20bc8bb8fa in /opt/conda/bin/python3.10)
frame #53: _PyFunction_Vectorcall + 0x6c (0x5d20bc8ca99c in /opt/conda/bin/python3.10)
frame #54: PyObject_Call + 0xbc (0x5d20bc8d6e8c in /opt/conda/bin/python3.10)
frame #55: _PyEval_EvalFrameDefault + 0x2d80 (0x5d20bc8bd2b0 in /opt/conda/bin/python3.10)
frame #56: _PyFunction_Vectorcall + 0x6c (0x5d20bc8ca99c in /opt/conda/bin/python3.10)
frame #57: _PyEval_EvalFrameDefault + 0x13ca (0x5d20bc8bb8fa in /opt/conda/bin/python3.10)
frame #58: _PyFunction_Vectorcall + 0x6c (0x5d20bc8ca99c in /opt/conda/bin/python3.10)
frame #59: PyObject_Call + 0xbc (0x5d20bc8d6e8c in /opt/conda/bin/python3.10)
frame #60: _PyEval_EvalFrameDefault + 0x2d80 (0x5d20bc8bd2b0 in /opt/conda/bin/python3.10)
frame #61: _PyFunction_Vectorcall + 0x6c (0x5d20bc8ca99c in /opt/conda/bin/python3.10)
frame #62: PyObject_Call + 0xbc (0x5d20bc8d6e8c in /opt/conda/bin/python3.10)
frame #63: _PyEval_EvalFrameDefault + 0x2d80 (0x5d20bc8bd2b0 in /opt/conda/bin/python3.10)
