# Requirements

In [1]:
#!python --version
#!pip install --upgrade pip
#!pip uninstall keras tensorflow
#!pip install -r ../requirements.txt

# Imports

In [2]:
from __future__ import print_function
import torch
import json

from data_loader import load_cifar10, get_class_names
from training_utils import train_model, continue_training

# Initialization

In [None]:
train_loader, test_loader, X_train, X_test, Y_train, Y_test = load_cifar10(batch_size=64, seed=42)
class_names = get_class_names()

## CUDA

In [None]:
print(f"Is CUDA available? {torch.cuda.is_available()}")

In [None]:
!nvcc --version

In [6]:
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

# Train mamba

In [None]:
from model import ImageMamba, ModelArgs
model_args = ModelArgs(
    d_model=128,  # For overfitting
    n_layer=8,    # For overfitting
    vocab_size=0
)
model = ImageMamba(
    args=model_args, 
    num_classes=10,
    image_size=32,  # CIFAR-10 image size
    patch_size=4    # Creates 8x8 patches
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"Model device: {next(model.parameters()).device}")

In [None]:
metrics = train_model(
    model=model,  
    train_loader=train_loader,
    test_loader=test_loader,
    model_name='mamba',
    num_epochs=2000,
    device=device
)

# Crash
If there was a crash. Which can be when dealing with so many epochs, one can continue from here:

In [None]:
with open('mamba_checkpoints/training_metrics.json', 'r') as f:
   metrics = json.load(f)
print(f"Last completed epoch: {metrics['current_epoch']}")

metrics = continue_training(
    model=model, 
    train_loader=train_loader,
    test_loader=test_loader,
    model_name='mamba',  
    checkpoint_dir='mamba_checkpoints',
    target_epochs=2000,
    device=device
)