In [None]:
!pip install causal-conv1d>=1.1.0
!pip install triton
!pip install mamba-ssm

Collecting mamba-ssm
  Downloading mamba_ssm-2.2.4.tar.gz (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.8/91.8 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: mamba-ssm
  Building wheel for mamba-ssm (pyproject.toml) ... [?25l[?25hdone
  Created wheel for mamba-ssm: filename=mamba_ssm-2.2.4-cp311-cp311-linux_x86_64.whl size=323672993 sha256=8a0be01153fa30727a9e69024fbe061eb92c7ba4416d2049c5fc3107ed91d852
  Stored in directory: /root/.cache/pip/wheels/2a/5e/64/cfcb5dfe4f854944456e031c34953dc872af1ad7c206145d4a
Successfully built mamba-ssm
Installing collected packages: mamba-ssm
Successfully installed mamba-ssm-2.2.4


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_classifier import MambaClassifier
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
import language_grammars as lg
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

from tqdm import tqdm

In [None]:
class NumberDataset(Dataset):
    def __init__(self, data, characters, max_length):
        self.data = data
        self.max_length = max_length
        self.char_to_idx = {char: idx+1 for idx, char in enumerate(characters)}
        self.char_to_idx['PAD'] = len(self.char_to_idx)
        self.VOCAB_SIZE = len(self.char_to_idx)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        string, label = self.data[idx]
        tokenized = [self.char_to_idx[char] for char in string]
        tokenized = tokenized[:self.max_length]
        tokenized += [self.char_to_idx['PAD']] * (self.max_length - len(tokenized))
        return {
            "input_ids": torch.tensor(tokenized, dtype=torch.long),
            "labels": torch.tensor(label, dtype=torch.float)
        }

In [None]:
EPOCHS = 100
D_MODEL = 30
NUM_CLASSES = 1
NUM_LAYERS = 1

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MambaClassifier(vocab_size=3,d_model=D_MODEL, num_classes=NUM_CLASSES, num_layers=NUM_LAYERS).to(device)
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

In [None]:
data = list(lg.make_train_set_for_target(lg.tomita_3, alphabet="01", max_train_samples_per_length=1000).items())[1:]

made train set of size: 3377 , of which positive examples: 1576


In [None]:
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
train_dataset = NumberDataset(train_data, "01", 30)
test_dataset = NumberDataset(test_data, "01", 30)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
def evaluate(model, test_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in test_loader:
            inputs = batch['input_ids'].to(device)
            labels = batch['labels'].to(device).unsqueeze(1)

            outputs = model(inputs)
            predictions = (torch.sigmoid(outputs) > 0.5).float()
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    return correct / total

In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        inputs = batch['input_ids']
        labels = batch['labels']
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        #outputs = torch.sigmoid(model(inputs.unsqueeze(1)))
        outputs = model(inputs)
        loss = criterion(outputs, labels.unsqueeze(1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    test_acc = evaluate(model, test_loader, device)
    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f} Accuracy: {test_acc:.4f}")

Epoch 1/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.99it/s]


Epoch 1 - Loss: 0.6797 Accuracy: 0.7352


Epoch 2/100: 100%|██████████| 2700/2700 [00:09<00:00, 280.88it/s]


Epoch 2 - Loss: 0.3775 Accuracy: 0.8580


Epoch 3/100: 100%|██████████| 2700/2700 [00:09<00:00, 286.85it/s]


Epoch 3 - Loss: 0.3035 Accuracy: 0.8772


Epoch 4/100: 100%|██████████| 2700/2700 [00:08<00:00, 300.37it/s]


Epoch 4 - Loss: 0.2849 Accuracy: 0.8861


Epoch 5/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.63it/s]


Epoch 5 - Loss: 0.2672 Accuracy: 0.8743


Epoch 6/100: 100%|██████████| 2700/2700 [00:09<00:00, 283.12it/s]


Epoch 6 - Loss: 0.2522 Accuracy: 0.9038


Epoch 7/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.42it/s]


Epoch 7 - Loss: 0.2445 Accuracy: 0.9068


Epoch 8/100: 100%|██████████| 2700/2700 [00:09<00:00, 280.87it/s]


Epoch 8 - Loss: 0.2373 Accuracy: 0.9068


Epoch 9/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.47it/s]


Epoch 9 - Loss: 0.2342 Accuracy: 0.8861


Epoch 10/100: 100%|██████████| 2700/2700 [00:08<00:00, 300.08it/s]


Epoch 10 - Loss: 0.2337 Accuracy: 0.8935


Epoch 11/100: 100%|██████████| 2700/2700 [00:09<00:00, 292.78it/s]


Epoch 11 - Loss: 0.2296 Accuracy: 0.9142


Epoch 12/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.31it/s]


Epoch 12 - Loss: 0.2281 Accuracy: 0.9112


Epoch 13/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.36it/s]


Epoch 13 - Loss: 0.2202 Accuracy: 0.9068


Epoch 14/100: 100%|██████████| 2700/2700 [00:09<00:00, 283.75it/s]


Epoch 14 - Loss: 0.2204 Accuracy: 0.8994


Epoch 15/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.85it/s]


Epoch 15 - Loss: 0.2161 Accuracy: 0.9186


Epoch 16/100: 100%|██████████| 2700/2700 [00:09<00:00, 290.21it/s]


Epoch 16 - Loss: 0.2141 Accuracy: 0.8979


Epoch 17/100: 100%|██████████| 2700/2700 [00:09<00:00, 299.19it/s]


Epoch 17 - Loss: 0.2116 Accuracy: 0.9260


Epoch 18/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.15it/s]


Epoch 18 - Loss: 0.2089 Accuracy: 0.9098


Epoch 19/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.35it/s]


Epoch 19 - Loss: 0.2080 Accuracy: 0.9290


Epoch 20/100: 100%|██████████| 2700/2700 [00:09<00:00, 282.68it/s]


Epoch 20 - Loss: 0.2064 Accuracy: 0.9290


Epoch 21/100: 100%|██████████| 2700/2700 [00:09<00:00, 286.02it/s]


Epoch 21 - Loss: 0.2039 Accuracy: 0.9186


Epoch 22/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.16it/s]


Epoch 22 - Loss: 0.2023 Accuracy: 0.9068


Epoch 23/100: 100%|██████████| 2700/2700 [00:09<00:00, 298.99it/s]


Epoch 23 - Loss: 0.1992 Accuracy: 0.9127


Epoch 24/100: 100%|██████████| 2700/2700 [00:09<00:00, 291.26it/s]


Epoch 24 - Loss: 0.1974 Accuracy: 0.9246


Epoch 25/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.91it/s]


Epoch 25 - Loss: 0.1978 Accuracy: 0.9290


Epoch 26/100: 100%|██████████| 2700/2700 [00:09<00:00, 282.02it/s]


Epoch 26 - Loss: 0.1963 Accuracy: 0.9334


Epoch 27/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.30it/s]


Epoch 27 - Loss: 0.1950 Accuracy: 0.9334


Epoch 28/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.98it/s]


Epoch 28 - Loss: 0.1939 Accuracy: 0.9260


Epoch 29/100: 100%|██████████| 2700/2700 [00:09<00:00, 293.55it/s]


Epoch 29 - Loss: 0.1941 Accuracy: 0.9231


Epoch 30/100: 100%|██████████| 2700/2700 [00:08<00:00, 301.43it/s]


Epoch 30 - Loss: 0.1910 Accuracy: 0.9320


Epoch 31/100: 100%|██████████| 2700/2700 [00:09<00:00, 286.22it/s]


Epoch 31 - Loss: 0.1914 Accuracy: 0.9334


Epoch 32/100: 100%|██████████| 2700/2700 [00:09<00:00, 282.67it/s]


Epoch 32 - Loss: 0.1885 Accuracy: 0.9260


Epoch 33/100: 100%|██████████| 2700/2700 [00:09<00:00, 279.89it/s]


Epoch 33 - Loss: 0.1915 Accuracy: 0.9290


Epoch 34/100: 100%|██████████| 2700/2700 [00:09<00:00, 282.31it/s]


Epoch 34 - Loss: 0.1883 Accuracy: 0.9305


Epoch 35/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.30it/s]


Epoch 35 - Loss: 0.1885 Accuracy: 0.9438


Epoch 36/100: 100%|██████████| 2700/2700 [00:08<00:00, 300.04it/s]


Epoch 36 - Loss: 0.1891 Accuracy: 0.9157


Epoch 37/100: 100%|██████████| 2700/2700 [00:09<00:00, 293.23it/s]


Epoch 37 - Loss: 0.1886 Accuracy: 0.9364


Epoch 38/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.09it/s]


Epoch 38 - Loss: 0.1868 Accuracy: 0.9201


Epoch 39/100: 100%|██████████| 2700/2700 [00:09<00:00, 286.01it/s]


Epoch 39 - Loss: 0.1863 Accuracy: 0.9320


Epoch 40/100: 100%|██████████| 2700/2700 [00:09<00:00, 283.72it/s]


Epoch 40 - Loss: 0.1873 Accuracy: 0.9275


Epoch 41/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.12it/s]


Epoch 41 - Loss: 0.1839 Accuracy: 0.9467


Epoch 42/100: 100%|██████████| 2700/2700 [00:09<00:00, 293.39it/s]


Epoch 42 - Loss: 0.1843 Accuracy: 0.9393


Epoch 43/100: 100%|██████████| 2700/2700 [00:09<00:00, 298.49it/s]


Epoch 43 - Loss: 0.1830 Accuracy: 0.9201


Epoch 44/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.37it/s]


Epoch 44 - Loss: 0.1822 Accuracy: 0.9260


Epoch 45/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.87it/s]


Epoch 45 - Loss: 0.1836 Accuracy: 0.9423


Epoch 46/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.48it/s]


Epoch 46 - Loss: 0.1815 Accuracy: 0.9320


Epoch 47/100: 100%|██████████| 2700/2700 [00:09<00:00, 283.67it/s]


Epoch 47 - Loss: 0.1811 Accuracy: 0.9393


Epoch 48/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.68it/s]


Epoch 48 - Loss: 0.1792 Accuracy: 0.9349


Epoch 49/100: 100%|██████████| 2700/2700 [00:08<00:00, 300.63it/s]


Epoch 49 - Loss: 0.1768 Accuracy: 0.9453


Epoch 50/100: 100%|██████████| 2700/2700 [00:09<00:00, 290.02it/s]


Epoch 50 - Loss: 0.1771 Accuracy: 0.9393


Epoch 51/100: 100%|██████████| 2700/2700 [00:09<00:00, 287.34it/s]


Epoch 51 - Loss: 0.1766 Accuracy: 0.9453


Epoch 52/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.19it/s]


Epoch 52 - Loss: 0.1759 Accuracy: 0.9305


Epoch 53/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.89it/s]


Epoch 53 - Loss: 0.1728 Accuracy: 0.9379


Epoch 54/100: 100%|██████████| 2700/2700 [00:09<00:00, 286.36it/s]


Epoch 54 - Loss: 0.1696 Accuracy: 0.9320


Epoch 55/100: 100%|██████████| 2700/2700 [00:09<00:00, 295.09it/s]


Epoch 55 - Loss: 0.1703 Accuracy: 0.9320


Epoch 56/100: 100%|██████████| 2700/2700 [00:09<00:00, 295.46it/s]


Epoch 56 - Loss: 0.1674 Accuracy: 0.9438


Epoch 57/100: 100%|██████████| 2700/2700 [00:09<00:00, 286.86it/s]


Epoch 57 - Loss: 0.1655 Accuracy: 0.9512


Epoch 58/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.30it/s]


Epoch 58 - Loss: 0.1626 Accuracy: 0.9512


Epoch 59/100: 100%|██████████| 2700/2700 [00:09<00:00, 283.79it/s]


Epoch 59 - Loss: 0.1609 Accuracy: 0.9349


Epoch 60/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.72it/s]


Epoch 60 - Loss: 0.1612 Accuracy: 0.9364


Epoch 61/100: 100%|██████████| 2700/2700 [00:09<00:00, 287.59it/s]


Epoch 61 - Loss: 0.1553 Accuracy: 0.9379


Epoch 62/100: 100%|██████████| 2700/2700 [00:08<00:00, 300.09it/s]


Epoch 62 - Loss: 0.1552 Accuracy: 0.9438


Epoch 63/100: 100%|██████████| 2700/2700 [00:09<00:00, 287.35it/s]


Epoch 63 - Loss: 0.1519 Accuracy: 0.9571


Epoch 64/100: 100%|██████████| 2700/2700 [00:09<00:00, 281.00it/s]


Epoch 64 - Loss: 0.1508 Accuracy: 0.9482


Epoch 65/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.20it/s]


Epoch 65 - Loss: 0.1490 Accuracy: 0.9408


Epoch 66/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.24it/s]


Epoch 66 - Loss: 0.1488 Accuracy: 0.9541


Epoch 67/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.54it/s]


Epoch 67 - Loss: 0.1456 Accuracy: 0.9556


Epoch 68/100: 100%|██████████| 2700/2700 [00:09<00:00, 293.31it/s]


Epoch 68 - Loss: 0.1479 Accuracy: 0.9393


Epoch 69/100: 100%|██████████| 2700/2700 [00:09<00:00, 296.36it/s]


Epoch 69 - Loss: 0.1437 Accuracy: 0.9556


Epoch 70/100: 100%|██████████| 2700/2700 [00:09<00:00, 286.87it/s]


Epoch 70 - Loss: 0.1440 Accuracy: 0.9408


Epoch 71/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.51it/s]


Epoch 71 - Loss: 0.1431 Accuracy: 0.9467


Epoch 72/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.21it/s]


Epoch 72 - Loss: 0.1448 Accuracy: 0.9571


Epoch 73/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.39it/s]


Epoch 73 - Loss: 0.1404 Accuracy: 0.9334


Epoch 74/100: 100%|██████████| 2700/2700 [00:09<00:00, 287.77it/s]


Epoch 74 - Loss: 0.1429 Accuracy: 0.9349


Epoch 75/100: 100%|██████████| 2700/2700 [00:09<00:00, 299.75it/s]


Epoch 75 - Loss: 0.1420 Accuracy: 0.9556


Epoch 76/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.55it/s]


Epoch 76 - Loss: 0.1408 Accuracy: 0.9497


Epoch 77/100: 100%|██████████| 2700/2700 [00:09<00:00, 283.91it/s]


Epoch 77 - Loss: 0.1376 Accuracy: 0.9497


Epoch 78/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.56it/s]


Epoch 78 - Loss: 0.1385 Accuracy: 0.9512


Epoch 79/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.59it/s]


Epoch 79 - Loss: 0.1347 Accuracy: 0.9601


Epoch 80/100: 100%|██████████| 2700/2700 [00:09<00:00, 283.71it/s]


Epoch 80 - Loss: 0.1359 Accuracy: 0.9660


Epoch 81/100: 100%|██████████| 2700/2700 [00:08<00:00, 301.00it/s]


Epoch 81 - Loss: 0.1344 Accuracy: 0.9615


Epoch 82/100: 100%|██████████| 2700/2700 [00:09<00:00, 291.42it/s]


Epoch 82 - Loss: 0.1360 Accuracy: 0.9571


Epoch 83/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.22it/s]


Epoch 83 - Loss: 0.1357 Accuracy: 0.9615


Epoch 84/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.93it/s]


Epoch 84 - Loss: 0.1348 Accuracy: 0.9615


Epoch 85/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.68it/s]


Epoch 85 - Loss: 0.1337 Accuracy: 0.9482


Epoch 86/100: 100%|██████████| 2700/2700 [00:09<00:00, 283.02it/s]


Epoch 86 - Loss: 0.1291 Accuracy: 0.9586


Epoch 87/100: 100%|██████████| 2700/2700 [00:09<00:00, 292.98it/s]


Epoch 87 - Loss: 0.1320 Accuracy: 0.9689


Epoch 88/100: 100%|██████████| 2700/2700 [00:09<00:00, 299.30it/s]


Epoch 88 - Loss: 0.1281 Accuracy: 0.9630


Epoch 89/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.78it/s]


Epoch 89 - Loss: 0.1269 Accuracy: 0.9675


Epoch 90/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.13it/s]


Epoch 90 - Loss: 0.1257 Accuracy: 0.9275


Epoch 91/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.17it/s]


Epoch 91 - Loss: 0.1274 Accuracy: 0.9660


Epoch 92/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.07it/s]


Epoch 92 - Loss: 0.1266 Accuracy: 0.9467


Epoch 93/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.63it/s]


Epoch 93 - Loss: 0.1277 Accuracy: 0.9704


Epoch 94/100: 100%|██████████| 2700/2700 [00:09<00:00, 295.37it/s]


Epoch 94 - Loss: 0.1254 Accuracy: 0.9689


Epoch 95/100: 100%|██████████| 2700/2700 [00:09<00:00, 286.98it/s]


Epoch 95 - Loss: 0.1246 Accuracy: 0.9675


Epoch 96/100: 100%|██████████| 2700/2700 [00:09<00:00, 285.69it/s]


Epoch 96 - Loss: 0.1214 Accuracy: 0.9305


Epoch 97/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.24it/s]


Epoch 97 - Loss: 0.1232 Accuracy: 0.9586


Epoch 98/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.23it/s]


Epoch 98 - Loss: 0.1236 Accuracy: 0.9512


Epoch 99/100: 100%|██████████| 2700/2700 [00:09<00:00, 284.02it/s]


Epoch 99 - Loss: 0.1208 Accuracy: 0.9556


Epoch 100/100: 100%|██████████| 2700/2700 [00:09<00:00, 291.77it/s]


Epoch 100 - Loss: 0.1166 Accuracy: 0.9660


In [None]:
train_acc = evaluate(model, train_loader, device)
test_acc = evaluate(model, test_loader, device)