# Requirements

In [1]:
#!python --version
#!pip install --upgrade pip
#!pip uninstall keras tensorflow
#!pip install -r ../requirements.txt

# Imports

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
import numpy as np
import json
import os
from tqdm import tqdm
from model import ModelArgs, ImageMamba

from sklearn.metrics import accuracy_score
from model import ImageMamba, ModelArgs
from __future__ import print_function
from data_loader import load_cifar10, get_class_names # For consistent data loading

# Initialization

In [3]:
train_loader, test_loader, X_train, X_test, Y_train, Y_test = load_cifar10(batch_size=64, seed=42)
class_names = get_class_names()

Files already downloaded and verified
Files already downloaded and verified


## CUDA

In [4]:
print(f"Is CUDA available? {torch.cuda.is_available()}")

Is CUDA available? True


In [5]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Sep_12_02:55:00_Pacific_Daylight_Time_2024
Cuda compilation tools, release 12.6, V12.6.77
Build cuda_12.6.r12.6/compiler.34841621_0


In [6]:
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

# Checkpoints & measures

In [7]:
from check_and_measure import evaluate_model, save_checkpoint, load_last_checkpoint

# Train mamba

In [8]:
def train_mamba(model, train_loader, test_loader, num_epochs=2000, device='cuda',
               checkpoint_dir='mamba_checkpoints', checkpoint_freq=100):
   """Train model with comprehensive metrics tracking."""
   
   # Check if directory exists and contains files
   if os.path.exists(checkpoint_dir) and os.listdir(checkpoint_dir):
       raise RuntimeError(
           f"Directory {checkpoint_dir} already contains files. Using this directory would overwrite "
           "existing training data. To prevent data loss, please use an empty directory "
           "or use continue_training() to resume from the last checkpoint."
       )
       
   os.makedirs(checkpoint_dir, exist_ok=True)
   criterion = nn.CrossEntropyLoss()
   optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)
   
   scheduler = OneCycleLR(
       optimizer,
       max_lr=1e-3,
       epochs=num_epochs,
       steps_per_epoch=len(train_loader),
       pct_start=0.3,
       anneal_strategy='cos'
   )
   
   metrics = {
       'train_losses': [], 'test_losses': [],
       'train_accuracies': [], 'test_accuracies': [],
       'train_confidences': [], 'test_confidences': [],
       'epoch_train_confidences': [], 'epoch_test_confidences': []
   }
   
   for epoch in range(num_epochs):
       model.train()
       running_loss = 0.0
       running_correct = 0
       total_samples = 0
       train_confidences = []
       
       for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
           inputs, labels = inputs.to(device), labels.to(device)
           optimizer.zero_grad()
           
           logits, probabilities = model(inputs)
           loss = criterion(logits, labels)
           loss.backward()
           optimizer.step()
           scheduler.step()
           
           _, predicted = torch.max(logits, 1)
           confidence, _ = torch.max(probabilities, 1)
           
           running_loss += loss.item()
           running_correct += (predicted == labels).sum().item()
           total_samples += labels.size(0)
           train_confidences.extend(confidence.detach().cpu().numpy())
       
       train_loss = running_loss / len(train_loader)
       train_accuracy = 100 * running_correct / total_samples
       train_avg_confidence = np.mean(train_confidences)
       
       test_loss, test_accuracy, test_avg_confidence, test_confidences = evaluate_model(
           model, test_loader, criterion, device)
       
       metrics['train_losses'].append(train_loss)
       metrics['test_losses'].append(test_loss)
       metrics['train_accuracies'].append(train_accuracy)
       metrics['test_accuracies'].append(test_accuracy)
       metrics['train_confidences'].append(train_avg_confidence)
       metrics['test_confidences'].append(test_avg_confidence)
       metrics['epoch_train_confidences'].append(train_confidences)
       metrics['epoch_test_confidences'].append(test_confidences)
       
       print(f'Epoch [{epoch+1}/{num_epochs}]')
       print(f'Train Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%, Confidence: {train_avg_confidence:.4f}')
       print(f'Test Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.2f}%, Confidence: {test_avg_confidence:.4f}')
       
       if (epoch + 1) % checkpoint_freq == 0:
           checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pt')
           save_checkpoint(model, optimizer, scheduler, epoch, metrics, checkpoint_path)
           
           metrics_path = os.path.join(checkpoint_dir, 'training_metrics.json')
           json_metrics = {
               'train_losses': [float(x) for x in metrics['train_losses']],
               'test_losses': [float(x) for x in metrics['test_losses']],
               'train_accuracies': [float(x) for x in metrics['train_accuracies']],
               'test_accuracies': [float(x) for x in metrics['test_accuracies']],
               'train_confidences': [float(x) for x in metrics['train_confidences']],
               'test_confidences': [float(x) for x in metrics['test_confidences']],
               'current_epoch': epoch + 1
           }
           with open(metrics_path, 'w') as f:
               json.dump(json_metrics, f, indent=4)
   
   return metrics

In [9]:
# Define model parameters for overfitting
d_model = 128
n_layer = 8
num_classes = 10

# Initialize model
model_args = ModelArgs(d_model=d_model, n_layer=n_layer, vocab_size=0)
model = ImageMamba(model_args, num_classes=num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"Model device: {next(model.parameters()).device}")

Model device: cuda:0


In [10]:
# Train model
metrics = train_mamba(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    num_epochs=2000,
    device=device,
    checkpoint_freq=100
)

RuntimeError: Directory mamba_checkpoints already contains files. Using this directory would overwrite existing training data. To prevent data loss, please use an empty directory or use continue_training() to resume from the last checkpoint.

# Crash
If there was a crash. Which can be when dealing with so many epochs, one can continue from here:

In [11]:
def continue_training(model, train_loader, test_loader, checkpoint_dir, target_epochs=2000, device='cuda'):
    """Continue training from last checkpoint."""
    checkpoint, last_epoch = load_last_checkpoint(checkpoint_dir)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    scheduler = OneCycleLR(
        optimizer,
        max_lr=1e-3,
        epochs=target_epochs - last_epoch,
        steps_per_epoch=len(train_loader),
        pct_start=0.3,
        anneal_strategy='cos'
    )
    
    if checkpoint['scheduler_state_dict']:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Load existing metrics
    with open(os.path.join(checkpoint_dir, 'training_metrics.json'), 'r') as f:
        metrics = json.load(f)
    
    complete_metrics = {
        'train_losses': metrics['train_losses'],
        'test_losses': metrics['test_losses'],
        'train_accuracies': metrics['train_accuracies'],
        'test_accuracies': metrics['test_accuracies'],
        'train_confidences': metrics['train_confidences'],
        'test_confidences': metrics['test_confidences'],
        'epoch_train_confidences': checkpoint['metrics']['epoch_train_confidences'],
        'epoch_test_confidences': checkpoint['metrics']['epoch_test_confidences']
    }
    
    print(f"Continuing training from epoch {last_epoch} to {target_epochs}")
    
    for epoch in range(last_epoch, target_epochs):
        model.train()
        running_loss = 0.0
        running_correct = 0
        total_samples = 0
        train_confidences = []
        
        for inputs, labels in tqdm(train_loader, desc=f'Epoch [{epoch+1}/{target_epochs}]'):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            logits, probabilities = model(inputs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            _, predicted = torch.max(logits, 1)
            confidence, _ = torch.max(probabilities, 1)
            
            running_loss += loss.item()
            running_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            train_confidences.extend(confidence.detach().cpu().numpy())
        
        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * running_correct / total_samples
        train_avg_confidence = np.mean(train_confidences)
        
        test_loss, test_accuracy, test_avg_confidence, test_confidences = evaluate_model(
            model, test_loader, criterion, device)
        
        complete_metrics['train_losses'].append(train_loss)
        complete_metrics['test_losses'].append(test_loss)
        complete_metrics['train_accuracies'].append(train_accuracy)
        complete_metrics['test_accuracies'].append(test_accuracy)
        complete_metrics['train_confidences'].append(train_avg_confidence)
        complete_metrics['test_confidences'].append(test_avg_confidence)
        complete_metrics['epoch_train_confidences'].append(train_confidences)
        complete_metrics['epoch_test_confidences'].append(test_confidences)
        
        print(f'Train Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%, Confidence: {train_avg_confidence:.4f}')
        print(f'Test Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.2f}%, Confidence: {test_avg_confidence:.4f}')
        
        if (epoch + 1) % 100 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pt')
            save_checkpoint(model, optimizer, scheduler, epoch, complete_metrics, checkpoint_path)
            
            json_metrics = {
                'train_losses': [float(x) for x in complete_metrics['train_losses']],
                'test_losses': [float(x) for x in complete_metrics['test_losses']],
                'train_accuracies': [float(x) for x in complete_metrics['train_accuracies']],
                'test_accuracies': [float(x) for x in complete_metrics['test_accuracies']],
                'train_confidences': [float(x) for x in complete_metrics['train_confidences']],
                'test_confidences': [float(x) for x in complete_metrics['test_confidences']],
                'current_epoch': epoch + 1
            }
            with open(os.path.join(checkpoint_dir, 'training_metrics.json'), 'w') as f:
                json.dump(json_metrics, f, indent=4)
    
    return complete_metrics

In [None]:
with open('mamba_checkpoints/training_metrics.json', 'r') as f:
   metrics = json.load(f)
print(f"Last completed epoch: {metrics['current_epoch']}")

metrics = continue_training(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    checkpoint_dir='mamba_checkpoints',
    target_epochs=2000,
    device=device
)

Last completed epoch: 900


  checkpoint = torch.load(checkpoint_path)


Continuing training from epoch 900 to 2000


Epoch [901/2000]:  69%|██████▉   | 543/782 [00:53<00:20, 11.87it/s]