# Dragon ML Toolbox - Classification Tutorial

This notebook demonstrates the complete workflow for training, evaluating, and explaining a PyTorch classification model using the `dragon-ml-toolbox`.

## 1. Imports

First, let's import all the necessary components from PyTorch, sklearn, and your toolbox.

In [None]:
import torch
from torch import nn
from torch.utils.data import TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import numpy as np
from pathlib import Path

# Import from your dragon_ml_toolbox package
from ml_tools.ML_trainer import MyTrainer
from ml_tools.ML_callbacks import EarlyStopping, ModelCheckpoint
from ml_tools.keys import LogKeys

## 2. Setup Device

We'll automatically select the best available hardware accelerator (CUDA or MPS) or default to the CPU.

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f'Using device: {device}')

## 3. Prepare the Data

We will generate mock data for a binary classification task and wrap it in PyTorch `TensorDataset` objects.

In [None]:
# Create a synthetic dataset
X, y = make_classification(
    n_samples=500, 
    n_features=15, 
    n_informative=8, 
    n_redundant=2, 
    n_classes=2,
    random_state=42
)

# Create feature names for later use in SHAP plots
feature_names = [f'feature_{i+1}' for i in range(X.shape[1])]

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Convert data to PyTorch Tensors
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.LongTensor(y_train) # CrossEntropyLoss expects LongTensor for labels
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.LongTensor(y_test)

# Create TensorDatasets
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

## 4. Define Model, Criterion, and Optimizer

Next, we define our neural network architecture, choose a loss function, and select an optimizer.

In [None]:
class SimpleClassifier(nn.Module):
    def __init__(self, input_features, num_classes):
        super().__init__()
        self.layer_1 = nn.Linear(input_features, 64)
        self.layer_2 = nn.Linear(64, 32)
        self.layer_3 = nn.Linear(32, num_classes)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.layer_1(x))
        x = self.relu(self.layer_2(x))
        x = self.layer_3(x) # No softmax needed here, CrossEntropyLoss handles it
        return x

# Instantiate the components
model = SimpleClassifier(input_features=X_train.shape[1], num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## 5. Configure Callbacks

We'll set up `ModelCheckpoint` to save the best performing model based on validation loss and `EarlyStopping` to prevent overfitting.

In [None]:
CHECKPOINT_DIR = 'checkpoints'
MONITOR_METRIC = LogKeys.VAL_LOSS

# This callback saves the best model state to a directory
model_checkpoint = ModelCheckpoint(
    save_dir=CHECKPOINT_DIR,
    monitor=MONITOR_METRIC,
    save_best_only=True, 
    mode='min',
    verbose=1
)

# This callback stops training if the validation loss doesn't improve
early_stopping = EarlyStopping(
    monitor=MONITOR_METRIC,
    patience=15, # Wait 15 epochs for improvement
    min_delta=0.001, # A change smaller than this is not considered an improvement
    mode='min',
    verbose=1
)

## 6. Initialize the Trainer

Now we can instantiate `MyTrainer`, bringing all the pieces together.

In [None]:
trainer = MyTrainer(
    model=model,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    kind='classification',
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    callbacks=[model_checkpoint, early_stopping] # Add our custom callbacks
)

## 7. Train the Model

Call the `.fit()` method to start the training process. The trainer will automatically handle the training loop, validation, progress bars, and callbacks.

In [None]:
history = trainer.fit(epochs=150, batch_size=32, shuffle=True)

## 8. Evaluate the Model

After training, we first load the weights of the best model saved by `ModelCheckpoint`. Then, we call `.evaluate()` to generate and save a full performance report.

In [None]:
# Find the best model saved by the callback
checkpoint_path = Path(CHECKPOINT_DIR)
best_model_path = model_checkpoint.last_best_filepath

if best_model_path and best_model_path.exists():
    print(f'Loading best model weights from: {best_model_path}')
    trainer.model.load_state_dict(torch.load(best_model_path))
else:
    print('Warning: No best model found. Evaluating with the last model state.')

# Define a directory to save all evaluation artifacts
EVAL_DIR = Path('tutorial_results') / 'evaluation_report'

# Evaluate the model (will use the internal test_dataset)
trainer.evaluate(save_dir=EVAL_DIR)

## 9. Explain the Model

Finally, we can use the `.explain()` method to generate SHAP plots for model interpretability. This helps us understand which features are most important for the model's predictions.

In [None]:
# Define a directory to save all explanation artifacts
EXPLAIN_DIR = Path('tutorial_results') / 'explanation_report'

# Generate and save SHAP summary plots
trainer.explain(
    explain_dataset=test_dataset, # The data to explain (defaults to test_dataset if None)
    n_samples=50, # Use 50 samples for the explanation
    feature_names=feature_names, # Provide names for the plot
    save_dir=EXPLAIN_DIR
)