# Imports

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
import numpy as np
import json
import os
from tqdm import tqdm
from data_loader import load_cifar10, get_class_names
from check_and_measure import evaluate_model, save_checkpoint, load_last_checkpoint# For consistent data loading

# CUDA

In [None]:
!nvcc --version

In [3]:
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

# Initialization

In [4]:
# Set the device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_loader, test_loader, X_train, X_test, Y_train, Y_test = load_cifar10(batch_size=64, seed=42)
class_names = get_class_names()

In [6]:
from model import SmallerComparableCNN

# Train

In [14]:
def train_cnn(model, train_loader, test_loader, num_epochs=2000, device='cuda',
              checkpoint_dir='cnn_checkpoints', checkpoint_freq=100):
    """Train CNN model with comprehensive metrics tracking."""
    
    # Check if directory exists and contains files
    if os.path.exists(checkpoint_dir) and os.listdir(checkpoint_dir):
        raise RuntimeError(
            f"Directory {checkpoint_dir} already contains files. Using this directory would overwrite "
            "existing training data. To prevent data loss, please use an empty directory "
            "or use continue_training() to resume from the last checkpoint."
        )
        
    os.makedirs(checkpoint_dir, exist_ok=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)  # Same as MAMBA
    
    scheduler = OneCycleLR(
        optimizer,
        max_lr=1e-3,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.3,
        anneal_strategy='cos'
    )
    
    metrics = {
        'train_losses': [], 'test_losses': [],
        'train_accuracies': [], 'test_accuracies': [],
        'train_confidences': [], 'test_confidences': [],
        'epoch_train_confidences': [], 'epoch_test_confidences': []
    }
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_correct = 0
        total_samples = 0
        train_confidences = []

        tqdm_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', disable=(epoch % 100 != 0))
        for inputs, labels in tqdm_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            logits, probabilities = model(inputs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            _, predicted = torch.max(logits, 1)
            confidence, _ = torch.max(probabilities, 1)
            
            running_loss += loss.item()
            running_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            train_confidences.extend(confidence.detach().cpu().numpy())
        
        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * running_correct / total_samples
        train_avg_confidence = np.mean(train_confidences)
        
        test_loss, test_accuracy, test_avg_confidence, test_confidences = evaluate_model(
            model, test_loader, criterion, device)
        
        metrics['train_losses'].append(train_loss)
        metrics['test_losses'].append(test_loss)
        metrics['train_accuracies'].append(train_accuracy)
        metrics['test_accuracies'].append(test_accuracy)
        metrics['train_confidences'].append(train_avg_confidence)
        metrics['test_confidences'].append(test_avg_confidence)
        metrics['epoch_train_confidences'].append(train_confidences)
        metrics['epoch_test_confidences'].append(test_confidences)

        # print once every 400 epochs
        if (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Train Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%, Confidence: {train_avg_confidence:.4f}')
            print(f'Test Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.2f}%, Confidence: {test_avg_confidence:.4f}')
        
        if (epoch + 1) % checkpoint_freq == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pt')
            save_checkpoint(model, optimizer, scheduler, epoch, metrics, checkpoint_path)
            
            metrics_path = os.path.join(checkpoint_dir, 'training_metrics.json')
            json_metrics = {
                'train_losses': [float(x) for x in metrics['train_losses']],
                'test_losses': [float(x) for x in metrics['test_losses']],
                'train_accuracies': [float(x) for x in metrics['train_accuracies']],
                'test_accuracies': [float(x) for x in metrics['test_accuracies']],
                'train_confidences': [float(x) for x in metrics['train_confidences']],
                'test_confidences': [float(x) for x in metrics['test_confidences']],
                'current_epoch': epoch + 1
            }
            with open(metrics_path, 'w') as f:
                json.dump(json_metrics, f, indent=4)
    
    return metrics

In [15]:
# Initialize model and set device
model = SmallerComparableCNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"Model device: {next(model.parameters()).device}")

# Train model
metrics = train_cnn(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    num_epochs=1200,
    device=device,
    checkpoint_freq=100
)

Model device: cuda:0


Epoch 1/1200: 100%|██████████| 782/782 [00:25<00:00, 30.52it/s]


Epoch [100/1200]
Train Loss: 1.1683, Accuracy: 59.82%, Confidence: 0.5118
Test Loss: 1.2565, Accuracy: 55.50%, Confidence: 0.5101


Epoch 101/1200: 100%|██████████| 782/782 [00:03<00:00, 204.09it/s]


Epoch [200/1200]
Train Loss: 1.1099, Accuracy: 62.29%, Confidence: 0.5300
Test Loss: 1.3378, Accuracy: 52.64%, Confidence: 0.5437


Epoch 201/1200: 100%|██████████| 782/782 [00:04<00:00, 181.06it/s]


Epoch [300/1200]
Train Loss: 1.1047, Accuracy: 62.46%, Confidence: 0.5301
Test Loss: 1.3079, Accuracy: 52.26%, Confidence: 0.5635


Epoch 301/1200: 100%|██████████| 782/782 [00:05<00:00, 131.15it/s]


Epoch [400/1200]
Train Loss: 1.1046, Accuracy: 62.50%, Confidence: 0.5308
Test Loss: 1.2681, Accuracy: 55.30%, Confidence: 0.5291


Epoch 401/1200: 100%|██████████| 782/782 [00:09<00:00, 84.10it/s] 


Epoch [500/1200]
Train Loss: 1.0972, Accuracy: 62.59%, Confidence: 0.5337
Test Loss: 1.2974, Accuracy: 54.68%, Confidence: 0.5523


Epoch 501/1200: 100%|██████████| 782/782 [00:09<00:00, 86.11it/s] 


Epoch [600/1200]
Train Loss: 1.0886, Accuracy: 63.09%, Confidence: 0.5364
Test Loss: 1.1685, Accuracy: 59.14%, Confidence: 0.5376


Epoch 601/1200: 100%|██████████| 782/782 [00:09<00:00, 80.53it/s] 


Epoch [700/1200]
Train Loss: 1.0735, Accuracy: 63.70%, Confidence: 0.5421
Test Loss: 1.2357, Accuracy: 56.04%, Confidence: 0.5318


Epoch 701/1200: 100%|██████████| 782/782 [00:08<00:00, 88.39it/s] 


Epoch [800/1200]
Train Loss: 1.0502, Accuracy: 64.51%, Confidence: 0.5469
Test Loss: 1.1048, Accuracy: 62.11%, Confidence: 0.5550


Epoch 801/1200: 100%|██████████| 782/782 [00:08<00:00, 94.38it/s] 


Epoch [900/1200]
Train Loss: 1.0248, Accuracy: 65.63%, Confidence: 0.5549
Test Loss: 1.0619, Accuracy: 63.32%, Confidence: 0.5532


RuntimeError: [enforce fail at inline_container.cc:603] . unexpected pos 64 vs 0

# Continue training

In [9]:
def continue_cnn_training(model, train_loader, test_loader, checkpoint_dir, target_epochs=2000, device='cuda'):
    """Continue training from last checkpoint."""
    checkpoint, last_epoch = load_last_checkpoint(checkpoint_dir)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    scheduler = OneCycleLR(
        optimizer,
        max_lr=1e-3,
        epochs=target_epochs - last_epoch,
        steps_per_epoch=len(train_loader),
        pct_start=0.3,
        anneal_strategy='cos'
    )
    
    if checkpoint['scheduler_state_dict']:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Load existing metrics
    with open(os.path.join(checkpoint_dir, 'training_metrics.json'), 'r') as f:
        metrics = json.load(f)
    
    complete_metrics = {
        'train_losses': metrics['train_losses'],
        'test_losses': metrics['test_losses'],
        'train_accuracies': metrics['train_accuracies'],
        'test_accuracies': metrics['test_accuracies'],
        'train_confidences': metrics['train_confidences'],
        'test_confidences': metrics['test_confidences'],
        'epoch_train_confidences': checkpoint['metrics']['epoch_train_confidences'],
        'epoch_test_confidences': checkpoint['metrics']['epoch_test_confidences']
    }
    
    print(f"Continuing training from epoch {last_epoch} to {target_epochs}")
    
    for epoch in range(last_epoch, target_epochs):
        model.train()
        running_loss = 0.0
        running_correct = 0
        total_samples = 0
        train_confidences = []
        
        for inputs, labels in tqdm(train_loader, desc=f'Epoch [{epoch+1}/{target_epochs}]'):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            logits, probabilities = model(inputs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            _, predicted = torch.max(logits, 1)
            confidence, _ = torch.max(probabilities, 1)
            
            running_loss += loss.item()
            running_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            train_confidences.extend(confidence.detach().cpu().numpy())
        
        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * running_correct / total_samples
        train_avg_confidence = np.mean(train_confidences)
        
        test_loss, test_accuracy, test_avg_confidence, test_confidences = evaluate_model(
            model, test_loader, criterion, device)
        
        complete_metrics['train_losses'].append(train_loss)
        complete_metrics['test_losses'].append(test_loss)
        complete_metrics['train_accuracies'].append(train_accuracy)
        complete_metrics['test_accuracies'].append(test_accuracy)
        complete_metrics['train_confidences'].append(train_avg_confidence)
        complete_metrics['test_confidences'].append(test_avg_confidence)
        complete_metrics['epoch_train_confidences'].append(train_confidences)
        complete_metrics['epoch_test_confidences'].append(test_confidences)
        
        print(f'Train Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%, Confidence: {train_avg_confidence:.4f}')
        print(f'Test Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.2f}%, Confidence: {test_avg_confidence:.4f}')
        
        if (epoch + 1) % 100 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'cnn_model_epoch_{epoch+1}.pt')
            save_checkpoint(model, optimizer, scheduler, epoch, complete_metrics, checkpoint_path)
            
            json_metrics = {
                'train_losses': [float(x) for x in complete_metrics['train_losses']],
                'test_losses': [float(x) for x in complete_metrics['test_losses']],
                'train_accuracies': [float(x) for x in complete_metrics['train_accuracies']],
                'test_accuracies': [float(x) for x in complete_metrics['test_accuracies']],
                'train_confidences': [float(x) for x in complete_metrics['train_confidences']],
                'test_confidences': [float(x) for x in complete_metrics['test_confidences']],
                'current_epoch': epoch + 1
            }
            with open(os.path.join(checkpoint_dir, 'training_metrics.json'), 'w') as f:
                json.dump(json_metrics, f, indent=4)
    
    return complete_metrics

In [None]:
# Create model instance
model = SmallerComparableCNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Find the last checkpoint
checkpoint_dir = 'cnn_checkpoints'  
with open(f'{checkpoint_dir}/training_metrics.json', 'r') as f:
    metrics = json.load(f)
print(f"Last completed epoch: {metrics['current_epoch']}")

# Continue training
metrics = continue_cnn_training(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    checkpoint_dir=checkpoint_dir,
    target_epochs=1200,  
    device=device
)