# imports

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import ConvNextV2ForImageClassification, ConvNextV2Config
from torch.utils.data import DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


# cifar10 test

In [3]:
RUNS_DIR = "/mnt/ssd-1/adam/basin-volume/runs"

In [4]:
def evaluate_checkpoints(model_path):
    """
    Evaluate model checkpoints and plot training progress.
    
    Args:
        model_path (str): Path after /mnt/ssd-1/adam/basin-volume/runs/
                         e.g., "cifar10/convnext"
    """
    RUNS_DIR = "/mnt/ssd-1/adam/basin-volume/runs"
    
    # Load and process datasets
    ds = load_dataset("cifar10")
    train_ds = ds["train"].select(range(5000))  # ~10% of 50k training samples
    val_ds = ds["test"].select(range(1000))     # 10% of 10k test samples
    
    transform = T.Compose([T.ToTensor()])
    
    def preprocess(examples):
        return {
            "pixel_values": [transform(image.convert("RGB")) for image in examples["img"]],
            "label": examples["label"]
        }
    
    train_ds = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)
    val_ds = val_ds.map(preprocess, batched=True, remove_columns=val_ds.column_names)
    train_ds.set_format(type="torch")
    val_ds.set_format(type="torch")
    
    train_loader = DataLoader(train_ds, batch_size=128, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=128, num_workers=4)
    
    # Get checkpoints
    checkpoints = sorted(
        Path(RUNS_DIR, model_path).glob("checkpoint-*"),
        key=lambda x: int(x.name.split("-")[1])
    )
    
    train_accuracies = []
    val_accuracies = []
    train_losses = []
    val_losses = []
    steps = []
    
    for ckpt in tqdm(checkpoints):
        config = ConvNextV2Config.from_pretrained(ckpt)
        model = ConvNextV2ForImageClassification.from_pretrained(
            ckpt,
            config=config,
            torch_dtype=torch.float16
        ).cuda()
        model.eval()
        
        # Evaluate on both sets
        for loader, accuracies, losses in [(train_loader, train_accuracies, train_losses),
                                         (val_loader, val_accuracies, val_losses)]:
            correct = 0
            total = 0
            running_loss = 0.0
            
            with torch.no_grad():
                for batch in loader:
                    pixel_values = batch["pixel_values"].cuda().to(torch.float16)
                    labels = batch["label"].cuda()
                    
                    outputs = model(pixel_values, labels=labels)
                    predictions = outputs.logits.argmax(-1)
                    
                    correct += (predictions == labels).sum().item()
                    total += labels.size(0)
                    running_loss += outputs.loss.item() * labels.size(0)
            
            accuracy = correct / total
            avg_loss = running_loss / total
            if loader == train_loader:  # Only add steps once
                step = int(ckpt.name.split("-")[1])
                steps.append(step)
                print(f"Step {step}: Train acc: {accuracy:.3f}, loss: {avg_loss:.3f}", end="")
            else:
                print(f", Val acc: {accuracy:.3f}, loss: {avg_loss:.3f}")
            accuracies.append(accuracy)
            losses.append(avg_loss)
    
    # Create plots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True)
    
    # Plot accuracies
    ax1.semilogx(steps, train_accuracies, "-o", label="Train")
    ax1.semilogx(steps, val_accuracies, "-o", label="Validation")
    ax1.grid(True)
    ax1.set_ylabel("Accuracy")
    ax1.set_title(f"Training Progress for {model_path}")
    ax1.legend()
    
    # Plot losses
    ax2.semilogx(steps, train_losses, "-o", label="Train")
    ax2.semilogx(steps, val_losses, "-o", label="Validation")
    ax2.grid(True)
    ax2.set_xlabel("Training Steps")
    ax2.set_ylabel("Cross Entropy Loss")
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

    return train_accuracies, val_accuracies, train_losses, val_losses

In [6]:
metrics = evaluate_checkpoints("b8")

  0%|          | 0/9 [00:05<?, ?it/s]


KeyboardInterrupt: 

In [6]:
metrics6 = evaluate_checkpoints("b16")

  0%|          | 0/17 [00:00<?, ?it/s]


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
