# Requirements

In [1]:
#!python -m pip uninstall torch torchvision torchaudio -y
#!python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [2]:
!python -m pip uninstall -y torch torchvision torchaudio
!python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Found existing installation: torch 2.6.0+cu118
Uninstalling torch-2.6.0+cu118:
  Successfully uninstalled torch-2.6.0+cu118
Found existing installation: torchvision 0.21.0+cu118
Uninstalling torchvision-0.21.0+cu118:
  Successfully uninstalled torchvision-0.21.0+cu118
Found existing installation: torchaudio 2.6.0+cu118
Uninstalling torchaudio-2.6.0+cu118:
  Successfully uninstalled torchaudio-2.6.0+cu118
Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch
  Using cached https://download.pytorch.org/whl/cu118/torch-2.6.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (27 kB)
Collecting torchvision
  Using cached https://download.pytorch.org/whl/cu118/torchvision-0.21.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.6.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (6.6 kB)
Using cached https://download.pytorch.org/whl/cu118/torch-2.6.0%2Bcu118-cp311-cp311-linux_x86_64.

In [3]:
!python -m pip install mamba-ssm
#!python --version
#!pip install --upgrade pip
#!pip uninstall keras tensorflow
#!pip install -r ../requirements.txt'



# Imports

In [4]:
from __future__ import print_function
import torch
import json

from data_loader import load_cifar10, get_class_names
from training_utils import train_model, continue_training

## CUDA

In [5]:
print(f"Is CUDA available? {torch.cuda.is_available()}")
!nvcc --version

Is CUDA available? True
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0


In [6]:
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

# Initialization

In [7]:
# Set the device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
train_loader, test_loader, X_train, X_test, Y_train, Y_test = load_cifar10(batch_size=64, seed=42)
class_names = get_class_names()

# Train mamba

In [9]:
import torch.optim as optim # Import optim
import torch.nn as nn



# Test training loop
def test_training(model, train_loader, test_loader, device, num_epochs=10):
    model = model.to(device)
    # Ensure all parameters require gradients - No longer needed
    # for param in model.parameters():
    #     param.requires_grad = True
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    # ... (rest of your code) # Keep the rest of your training loop

    print("Starting test training loop...")
    print("=" * 50)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_correct = 0
        total = 0

        for i, (inputs, labels) in enumerate(train_loader):
            # Debug info
            print(f"Input shape: {inputs.shape}, dtype: {inputs.dtype}")
            print(f"Labels shape: {labels.shape}, dtype: {labels.dtype}")

            # Move to GPU with explicit type casting
            try:
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
            except RuntimeError as e:
                print(f"Error moving data to GPU: {e}")
                print(f"Memory stats before error:")
                print(torch.cuda.memory_summary())
                raise

            # Rest of the training loop...

            # Clear gradients
            optimizer.zero_grad()

            # Forward pass
            logits, probs = model(inputs)

            # Check for NaN in outputs
            if torch.isnan(logits).any() or torch.isnan(probs).any():
                print(f"\nNaN detected in model outputs at batch {i}!")
                return False

            # Compute loss
            loss = criterion(logits, labels)

            # Check for NaN in loss
            if torch.isnan(loss):
                print(f"\nNaN detected in loss at batch {i}!")
                return False

            # Backward pass
            loss.backward()

            # Check for NaN in gradients
            for name, param in model.named_parameters():
                if param.grad is not None and torch.isnan(param.grad).any():
                    print(f"\nNaN detected in gradients for {name}!")
                    return False

            # Update weights
            optimizer.step()

            # Compute batch statistics
            running_loss += loss.item()
            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

            # Print batch progress every 50 batches
            if (i + 1) % 50 == 0:
                batch_loss = running_loss / (i + 1)
                batch_acc = 100 * running_correct / total
                print(f"\rEpoch [{epoch+1}/{num_epochs}] "
                      f"Batch [{i+1}/{len(train_loader)}] "
                      f"Loss: {batch_loss:.4f} "
                      f"Acc: {batch_acc:.2f}%", end="")

        # Compute epoch statistics
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * running_correct / total

        # Evaluate on test set
        model.eval()
        test_correct = 0
        test_total = 0
        test_loss = 0.0

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                logits, probs = model(inputs)
                loss = criterion(logits, labels)

                test_loss += loss.item()
                _, predicted = torch.max(logits.data, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()

        test_loss = test_loss / len(test_loader)
        test_acc = 100 * test_correct / test_total

        print(f"\nEpoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {epoch_loss:.4f} "
              f"Train Acc: {epoch_acc:.2f}% "
              f"Test Loss: {test_loss:.4f} "
              f"Test Acc: {test_acc:.2f}%")
        print("-" * 50)

    print("\nTest training completed successfully!")
    return True

In [14]:
from ssm import FastImageMamba, ModelArgs, FastMambaBlock
import torch.optim as optim # Import optim
import torch.nn as nn
model_args = ModelArgs(
    d_model=128,       # Match the channel size in your initial conv layer
    n_layer=4,         # Number of Mamba blocks
    vocab_size=10,     # Not used for images, but required
    d_state=16,        # State space dimension
    expand=2,          # Expansion factor
    dt_rank=16,        # Rank for delta computation
    d_conv=4,          # Convolution kernel size
    seq_len=256        # Sequence length (16x16 for CIFAR-10)
)

model = FastImageMamba(model_args, num_classes=10)
model = model.to(device)

# **Important:** Make sure parameters require gradients
for name, param in model.named_parameters():
    print(f"Parameter name: {name}, requires_grad: {param.requires_grad}") # Print to check
    param.requires_grad = True

Parameter name: conv1.weight, requires_grad: True
Parameter name: conv1.bias, requires_grad: True
Parameter name: bn1.weight, requires_grad: True
Parameter name: bn1.bias, requires_grad: True
Parameter name: patch_to_seq.0.weight, requires_grad: True
Parameter name: patch_to_seq.0.bias, requires_grad: True
Parameter name: patch_to_seq.1.weight, requires_grad: True
Parameter name: patch_to_seq.1.bias, requires_grad: True
Parameter name: layers.0.A_log, requires_grad: True
Parameter name: layers.0.D, requires_grad: True
Parameter name: layers.0.in_proj.weight, requires_grad: True
Parameter name: layers.0.in_proj.bias, requires_grad: True
Parameter name: layers.0.conv1d.weight, requires_grad: True
Parameter name: layers.0.conv1d.bias, requires_grad: True
Parameter name: layers.0.x_proj.weight, requires_grad: True
Parameter name: layers.0.dt_proj.weight, requires_grad: True
Parameter name: layers.0.dt_proj.bias, requires_grad: True
Parameter name: layers.0.out_proj.weight, requires_grad: T

In [12]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name()}")

PyTorch version: 2.6.0+cu118
CUDA available: True
CUDA device count: 1
Current CUDA device: 0
Device name: Tesla T4


In [13]:
from model import ImageMamba, ModelArgs, ProperImageMamba
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Run test training
success = test_training(model, train_loader, test_loader, device)
if success:
    print("Training is stable - no NaNs detected")
else:
    print("Training failed - NaNs detected")

Using device: cuda
Starting test training loop...
Input shape: torch.Size([64, 3, 32, 32]), dtype: torch.float32
Labels shape: torch.Size([64]), dtype: torch.int64
After initial conv: torch.Size([64, 128, 32, 32]) torch.float32 cuda:0 True
After patch_to_seq: torch.Size([64, 128, 16, 16]) torch.float32 cuda:0 True
Before mamba blocks: torch.Size([64, 256, 128]) torch.float32 cuda:0 True
Error in mamba block 0


RuntimeError: The size of tensor a (255) must match the size of tensor b (256) at non-singleton dimension 1

In [None]:
from model import ImageMamba, ModelArgs

# Load your model and data
model_args = ModelArgs(d_model=128, n_layer=8, vocab_size=0)
model = ImageMamba(args=model_args, num_classes=10)

model = model.to(device)

print(f"Model device: {next(model.parameters()).device}")

In [None]:
metrics = train_model(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    model_name='mamba',
    num_epochs=2000,
    device=device
)

# Continue training
If there was a crash. Which can be when dealing with so many epochs, one can continue from here:

In [None]:
checkpoint_dir = 'mamba_checkpoints'
with open(f'{checkpoint_dir}/training_metrics.json', 'r') as f:
    metrics = json.load(f)
print(f"Last completed epoch: {metrics['current_epoch']}")

metrics = continue_training(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    model_name='mamba',
    checkpoint_dir=checkpoint_dir,
    target_epochs=2000,
    device=device
)