# Classification Tutorial

This notebook demonstrates how to:
1. Load and visualize data
2. Train a classification model
3. Evaluate the model
4. Generate XAI explanations

In [None]:
import sys
sys.path.append('..')

import torch
import matplotlib.pyplot as plt

from src.models import build_model
from src.datasets import DummyClassificationDataset, get_classification_transforms
from src.utils import set_seed, get_device

## 1. Setup

In [None]:
# Set seed for reproducibility
set_seed(42)

# Get device
device = get_device()
print(f"Using device: {device}")

## 2. Create Dataset

In [None]:
# Create dummy dataset for demonstration
transform = get_classification_transforms('train', image_size=32)

dataset = DummyClassificationDataset(
    num_samples=100,
    num_classes=10,
    image_size=(32, 32),
    transform=transform
)

print(f"Dataset size: {len(dataset)}")
print(f"Number of classes: {len(dataset.classes)}")

In [None]:
# Visualize samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
    image, label = dataset[i]
    # Denormalize for visualization
    img = image.permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.set_title(f"Class: {label}")
    ax.axis('off')
plt.tight_layout()
plt.show()

## 3. Build Model

In [None]:
# Build simple CNN model
model = build_model(
    task="classification",
    model_name="simple_cnn",
    num_classes=10
)

model = model.to(device)

print(f"Model: {model.__class__.__name__}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Training

For full training, use the training script:
```bash
python ../scripts/train_classification.py
```

In [None]:
# Quick test forward pass
model.eval()
test_image, test_label = dataset[0]
test_image = test_image.unsqueeze(0).to(device)

with torch.no_grad():
    output = model(test_image)
    pred = output.argmax(dim=1).item()

print(f"Prediction: {pred}")
print(f"Ground truth: {test_label}")

## 5. XAI - Generate Explanations

Generate attribution maps using Integrated Gradients

In [None]:
from src.xai import AttributionEngine, visualize_attribution

# Initialize attribution engine
attribution_engine = AttributionEngine(model, device, task="classification")

# Get attribution for test image
test_image, _ = dataset[0]
attribution = attribution_engine.get_attribution(
    test_image,
    method="integrated_gradients",
    target=pred
)

# Visualize
visualize_attribution(
    test_image,
    attribution,
    title="Integrated Gradients"
)

## Next Steps

- Explore other notebooks for segmentation, concept analysis, and manifold exploration
- Use the training scripts for full model training
- Generate comprehensive XAI reports with `run_xai.py`