In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Step 1: Data Preparation
print("Loading SST-2 dataset...")
dataset = load_dataset("sst2")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def tokenize_function(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

train_dataset = tokenized_datasets["train"].with_format("torch")
test_dataset = tokenized_datasets["validation"].with_format("torch")  # Using validation set as test set

BATCH_SIZE = 64
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

Loading SST-2 dataset...


Downloading readme: 100%|██████████| 5.27k/5.27k [00:00<00:00, 9.55kB/s]
Downloading data: 100%|██████████| 3.11M/3.11M [00:01<00:00, 1.74MB/s]
Downloading data: 100%|██████████| 72.8k/72.8k [00:00<00:00, 119kB/s]
Downloading data: 100%|██████████| 148k/148k [00:00<00:00, 236kB/s]
Generating train split: 100%|██████████| 67349/67349 [00:00<00:00, 2820252.99 examples/s]
Generating validation split: 100%|██████████| 872/872 [00:00<00:00, 436552.05 examples/s]
Generating test split: 100%|██████████| 1821/1821 [00:00<00:00, 825889.66 examples/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Map: 100%|██████████| 67349/67349 [00:02<00:00, 27039.60 examples/s]
Map: 100%|██████████| 872/872 [00:00<00:00, 19572.28 examples/s]
Map: 100%|██████████| 1821/1821 [00:00<00:00, 26460.6

In [3]:
# Step 2: Define Models
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads),
            num_layers=num_layers
        )
        self.fc = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(1, 0, 2)  # Transformer expects (seq_len, batch, features)
        x = self.transformer(x)
        x = x.mean(dim=0)  # Global average pooling
        x = self.fc(x)
        return x

class SimpleSSM(nn.Module):
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.A = nn.Parameter(torch.randn(d_state, d_state))
        self.B = nn.Parameter(torch.randn(d_state, d_model))
        self.C = nn.Parameter(torch.randn(d_model, d_state))
        self.D = nn.Parameter(torch.randn(d_model))
        
    def forward(self, u):
        # u: (batch, seq_len, d_model)
        seq_len = u.size(1)
        x = torch.zeros(u.size(0), self.d_state, device=u.device)
        outputs = []
        
        for t in range(seq_len):
            x = torch.tanh(self.A @ x.unsqueeze(-1) + self.B @ u[:, t, :].unsqueeze(-1)).squeeze(-1)
            y = (self.C @ x.unsqueeze(-1)).squeeze(-1) + self.D * u[:, t, :]
            outputs.append(y)
        
        return torch.stack(outputs, dim=1)

class SimpleMamba(nn.Module):
    def __init__(self, d_model, d_state, d_conv, expand):
        super().__init__()
        self.d_inner = expand * d_model
        self.proj_in = nn.Linear(d_model, self.d_inner)
        self.proj_out = nn.Linear(self.d_inner, d_model)
        self.ssm = SimpleSSM(self.d_inner, d_state)
        self.conv = nn.Conv1d(self.d_inner, self.d_inner, d_conv, padding=d_conv-1, groups=self.d_inner)
        
    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = self.proj_in(x)
        x_ssm = self.ssm(x)
        x_conv = self.conv(x.transpose(1, 2))[:, :, :x.size(1)].transpose(1, 2)
        x = F.silu(x_ssm) * x_conv
        return self.proj_out(x)

class MambaModel(nn.Module):
    def __init__(self, vocab_size, d_model, d_state, d_conv, expand, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.mamba = SimpleMamba(d_model, d_state, d_conv, expand)
        self.fc = nn.Linear(d_model, num_classes)
        
    def forward(self, x):
        x = self.embedding(x)
        x = self.mamba(x)
        x = x.mean(dim=1)  # Global average pooling
        x = self.fc(x)
        return x

In [4]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for texts, labels in dataloader:
        texts, labels = texts.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(texts)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for texts, labels in dataloader:
            texts, labels = texts.to(device), labels.to(device)
            outputs = model(texts)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return total_loss / len(dataloader), accuracy

In [5]:
# Step 3: Training and Evaluation Functions
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['label'].to(device)
        optimizer.zero_grad()
        outputs = model(input_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return total_loss / len(dataloader), accuracy

In [6]:
# Step 4: Main Comparison
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
VOCAB_SIZE = tokenizer.vocab_size
EMBED_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 2
NUM_CLASSES = 2  # Binary classification
NUM_EPOCHS = 5
D_STATE = 16
D_CONV = 4
EXPAND = 2

# Initialize models
transformer_model = TransformerModel(VOCAB_SIZE, EMBED_DIM, NUM_HEADS, NUM_LAYERS, NUM_CLASSES).to(device)
mamba_model = MambaModel(VOCAB_SIZE, EMBED_DIM, D_STATE, D_CONV, EXPAND, NUM_CLASSES).to(device)

# Training loop
criterion = nn.CrossEntropyLoss()
transformer_optimizer = optim.Adam(transformer_model.parameters())
mamba_optimizer = optim.Adam(mamba_model.parameters())

results = {
    "transformer": {"train_time": 0, "inference_time": 0, "accuracy": 0},
    "mamba": {"train_time": 0, "inference_time": 0, "accuracy": 0}
}

for model_name, model, optimizer in [("transformer", transformer_model, transformer_optimizer),
                                     ("mamba", mamba_model, mamba_optimizer)]:
    print(f"\nTraining {model_name.capitalize()} model:")
    train_start = time.time()
    for epoch in range(NUM_EPOCHS):
        loss = train(model, train_dataloader, criterion, optimizer, device)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {loss:.4f}")
    train_time = time.time() - train_start
    results[model_name]["train_time"] = train_time
    
    inference_start = time.time()
    test_loss, accuracy = evaluate(model, test_dataloader, criterion, device)
    inference_time = time.time() - inference_start
    results[model_name]["inference_time"] = inference_time
    results[model_name]["accuracy"] = accuracy
    
    print(f"{model_name.capitalize()} Results:")
    print(f"  Training Time: {train_time:.2f} seconds")
    print(f"  Inference Time: {inference_time:.2f} seconds")
    print(f"  Test Accuracy: {accuracy:.4f}")

Using device: cuda





Training Transformer model:


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 1/5, Loss: 0.5545
Epoch 2/5, Loss: 0.3434
Epoch 3/5, Loss: 0.2603
Epoch 4/5, Loss: 0.2138
Epoch 5/5, Loss: 0.1862
Transformer Results:
  Training Time: 41.68 seconds
  Inference Time: 0.08 seconds
  Test Accuracy: 0.7683

Training Mamba model:
Epoch 1/5, Loss: 0.5823
Epoch 2/5, Loss: 0.3648
Epoch 3/5, Loss: 0.2821
Epoch 4/5, Loss: 0.2380
Epoch 5/5, Loss: 0.2078
Mamba Results:
  Training Time: 556.68 seconds
  Inference Time: 0.35 seconds
  Test Accuracy: 0.7901


In [7]:
# Final Comparison
print("\nFinal Comparison:")
print(f"Transformer vs Mamba:")
print(f"  Training Time: {results['transformer']['train_time']:.2f}s vs {results['mamba']['train_time']:.2f}s")
print(f"  Inference Time: {results['transformer']['inference_time']:.2f}s vs {results['mamba']['inference_time']:.2f}s")
print(f"  Accuracy: {results['transformer']['accuracy']:.4f} vs {results['mamba']['accuracy']:.4f}")


Final Comparison:
Transformer vs Mamba:
  Training Time: 41.68s vs 556.68s
  Inference Time: 0.08s vs 0.35s
  Accuracy: 0.7683 vs 0.7901
