# MLflow Simple Training Tutorial

This notebook demonstrates how to train a semantic segmentation model using Datamint's MLflow integration. You'll learn how to:

- Set up your environment and configure MLflow
- Load and visualize data from Datamint
- Define data transformations for training
- Train a model with automatic MLflow logging
- Test the model and register it in the model registry
- Make predictions with the trained model

## Prerequisites

Before running this notebook, make sure you have:
1. Datamint Python API installed (`pip install git+https://github.com/Sonance/datamint-python-api.git`)
2. Your API key configured (run `datamint-config` in terminal)
3. Access to a project with segmentation data
4. Basic understanding of PyTorch and/or PyTorch Lightning

## What is MLflow?

MLflow is an open-source platform for managing machine learning workflows. It helps you:
- Track experiments (metrics, parameters, code versions)
- Package and reproduce models
- Deploy models to production
- Manage model versions in a central registry

Datamint provides seamless MLflow integration, automatically configuring tracking and artifact storage.

In [None]:
# STEP 1: Environment Setup
# ========================

# Import datamint.mlflow to automatically configure MLflow environment
# This sets up MLflow tracking URI and authentication based on your Datamint configuration
import datamint.mlflow
from datamint import APIHandler
import logging
import rich.logging

logging.getLogger().addHandler(rich.logging.RichHandler())
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)


# Initialize Datamint API handler to verify connection
# This will use your configured API key (set via `datamint-config` command)
api = APIHandler()
LOGGER.info("✅ Datamint API connection established successfully!")

In [None]:
# STEP 2: Project Configuration
# =============================

import lightning as L
from lightning.pytorch.loggers import MLFlowLogger
from datamint.mlflow import set_project

# IMPORTANT: Replace 'BoneSeg' with your actual project name
PROJECT_NAME = 'BoneSeg' # you can retrieve project names using `api.get_projects()`

# Set the active project for MLflow tracking
# This ensures all experiments are logged under the correct project
project_info = set_project(PROJECT_NAME)
LOGGER.info(f"✅ Active project set to: {project_info['name']}")
LOGGER.info(f"Description: {project_info.get('description', 'No description')}")

## Data Transformations

Data augmentation is crucial for training robust models. We'll use Albumentations library to define transformations that:

- **Resize and crop**: Standardize input size while maintaining aspect ratio
- **Symmetry**: Apply square symmetry for anatomical consistency  
- **Color jitter**: Vary brightness/contrast to handle different imaging conditions
- **Gaussian noise**: Add robustness to image artifacts

> 💡 **Tip**: Start with simple transformations and gradually add more complex ones. Monitor validation metrics to ensure augmentations help rather than hurt performance.

In [None]:
# STEP 3: Define Data Transformations
# ===================================

from datamintapi.utils.visualization import show, draw_masks
import albumentations as A
from datamint import Dataset

# Define the target image size for training
# Smaller sizes train faster but may lose detail; larger sizes are more accurate but slower
IMAGE_SIZE = (512, 512)

# Create augmentation pipeline using Albumentations
# Each transformation has a probability (p) of being applied
transf = A.Compose([
    # Randomly crop and resize to target size (scale: 33%-100% of original)
    A.RandomResizedCrop(size=IMAGE_SIZE, scale=(0.33, 1.0), ratio=(0.9, 1.1), p=1.0),
    
    # Apply square symmetry (useful for anatomical structures)
    A.SquareSymmetry(p=0.5),
    
    # Vary image appearance (brightness, contrast) - no hue/saturation for medical images
    A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.0, hue=0.0, p=0.5),
    
    # Add small amount of noise for robustness
    A.GaussNoise(std_range=(0.01, 0.1), per_channel=False, p=0.2),
])

# Load dataset for visualization
LOGGER.info("Loading dataset...")
D = Dataset(
    project_name=PROJECT_NAME,
    return_as_semantic_segmentation=True,    # Convert to pixel-level masks
    semantic_seg_merge_strategy="union",     # Combine overlapping annotations
    return_frame_by_frame=True,              # Individual frames (not videos)
    include_unannotated=False,               # Only annotated data
    auto_update=False,                       # Don't check for updates (faster)
    alb_transform=transf,                    # Apply our transformations
)

LOGGER.info(f"✅ Dataset loaded: {len(D)} samples")
LOGGER.info(f"Available segmentation labels: {D.segmentation_labels_set}")

# Visualize a sample with transformations applied
item = D[0]
LOGGER.info(f"Image shape: {item['image'].shape}")
LOGGER.info(f"Segmentation shape: {item['segmentations'].shape}")

# Display the image with overlay masks (excluding background at index 0)
show(draw_masks(item['image'], item['segmentations'][1:], alpha=0.5))

## Model Setup

We'll now set up the training components:

1. **DataModule**: Handles data loading and train/validation splits
2. **Model**: Custom segmentation model (DeepLabV3 with ResNet50 backbone)
3. **MLflow Integration**: Automatic experiment tracking and model versioning

### Key Components Explained:

- **DatamintDataModule**: Lightning-compatible data loader for Datamint datasets
- **MyModel**: Custom model with multiple loss functions (CrossEntropy + GIOU + Focal)
- **MLFlowModelCheckpoint**: Saves best models and logs them to MLflow automatically

In [None]:
# STEP 4: Import Training Components
# =================================

from datamint.lightning import DatamintDataModule  # Lightning integration for Datamint
from my_custom_model import MyModel                # Your custom segmentation model
from datamint.mlflow.lightning.callbacks import MLFlowModelCheckpoint  # MLflow integration

# Import torch for performance optimization
import torch

LOGGER.info("✅ All training components imported successfully!")

In [None]:
# STEP 5: Configure Training Setup
# ===============================

# Define metadata that will be saved with the model
# This helps with model deployment and inference later
model_metadata = {
    "task_type": "semantic_segmentation",
    "labels": ["background"] + D.segmentation_labels_set,  # Include background as first label
    "need_gpu": False,                    # Whether GPU is required for inference
    "automatic_preprocessing": True       # Whether preprocessing is handled automatically
}

LOGGER.info(f"Model will predict {len(model_metadata['labels'])} classes:")
for i, label in enumerate(model_metadata['labels']):
    LOGGER.info(f"  {i}: {label}")

# Configure model checkpointing with MLflow integration
checkcb = MLFlowModelCheckpoint(
    monitor="val/loss",                   # Metric to monitor for best model
    mode="min",                          # Save model when monitored metric decreases
    save_top_k=1,                        # Keep only the best model
    filename="best",                     # Checkpoint filename
    save_weights_only=True,              # Save only model weights (not optimizer state)
    register_model_name=PROJECT_NAME,    # Name for model registry
    register_model_on='test',            # Register model after testing
    code_paths=['my_custom_model.py'],   # Include source code with model
    log_model_at_end_only=True,         # Log to MLflow only at the end (faster)
    additional_metadata=model_metadata,  # Include our metadata
)

# Create MLflow logger for experiment tracking
mlflow_logger = MLFlowLogger(experiment_name=PROJECT_NAME)
LOGGER.info(f"✅ MLflow experiment: {PROJECT_NAME}")

# Configure Lightning Trainer
trainer = L.Trainer(
    max_epochs=10,                       # Number of training epochs
    logger=mlflow_logger,                # MLflow integration
    precision='16-mixed',                # Use mixed precision for faster training
    enable_model_summary=True,           # Show model architecture summary
    enable_progress_bar=True,            # Show training progress
    callbacks=[checkcb],                 # Include our checkpoint callback
    num_sanity_val_steps=0,             # Skip validation sanity check
)

# Initialize the model
# Note: num_classes should match the number of segmentation classes (excluding background)
num_classes = len(D.segmentation_labels_set)
model = MyModel(num_classes=num_classes, learning_rate=3e-4)
LOGGER.info(f"✅ Model initialized for {num_classes} classes")

# Create data module with train/validation split
dm = DatamintDataModule(
    PROJECT_NAME,
    batch_size=8,                        # Adjust based on your GPU memory
    alb_transform=transf,                # Apply our transformations
    num_workers=8,                       # Parallel data loading workers
    # enable_video_cache=True,           # Uncomment to cache video frames
    include_segmentation_names=['fibula', 'tibia', 'patella', 'femur']  # Specify which labels to include
)

LOGGER.info("✅ Training setup complete!")
LOGGER.info(f"Batch size: {dm.batch_size}")
LOGGER.info(f"Data workers: {dm.num_workers}")

## Training

Now we'll start the actual training process. This will:

1. Automatically split your data into training and validation sets
2. Train the model for the specified number of epochs
3. Track metrics (loss, IoU, etc.) in MLflow
4. Save the best model checkpoint based on validation loss

### What to expect:
- Training progress bar with loss values
- Automatic logging of metrics to MLflow
- Model checkpointing when validation improves

> ⚠️ **Note**: Training time depends on your data size, model complexity, and hardware. Start with fewer epochs for testing.

In [None]:
# STEP 6: Start Training
# =====================

# Optimize matrix multiplication performance (PyTorch 2.0+)
torch.set_float32_matmul_precision('high')

LOGGER.info("🚀 Starting training...")
LOGGER.info("This will automatically:")
LOGGER.info("  - Split data into train/validation sets")
LOGGER.info("  - Track metrics in MLflow")
LOGGER.info("  - Save the best model checkpoint")
LOGGER.info("  - Log model artifacts")

# Start training!
trainer.fit(model, datamodule=dm)

LOGGER.info("✅ Training completed!")
LOGGER.info(f"Best model saved at: {checkcb.best_model_path}")
LOGGER.info(f"MLflow run ID: {mlflow_logger.run_id}")

## Manual Model Logging (Optional)

If you interrupted training or want to log the current model state manually, you can use the cell below. This is useful for:

- Recovering from interrupted training sessions
- Logging intermediate model states
- Testing the logging functionality

> 💡 **Tip**: This is only needed if automatic logging failed or was interrupted.

In [None]:
# STEP 7: Manual Model Logging (if needed)
# =======================================

# Uncomment the following line if you cancelled training but want to log the model anyway:
# checkcb.log_model_to_mlflow(model, mlflow_logger.run_id)

LOGGER.info("Manual logging skipped - model should have been logged automatically during training")

## Update Model Metadata (Optional)

You can add or update metadata after training is complete. This is useful for:

- Adding deployment-specific information
- Updating model descriptions
- Including performance benchmarks
- Specifying hardware requirements

The metadata is stored as a JSON file alongside your model in MLflow.

In [None]:
# STEP 8: Update Model Metadata (Optional)
# ========================================

# Define updated or additional metadata
model_metadata = {
    "task_type": "semantic_segmentation",
    "labels": ["background"] + D.segmentation_labels_set,
    "need_gpu": False,
    "automatic_preprocessing": True,
}

LOGGER.info("Updating model metadata...")
LOGGER.info("New metadata:")
for key, value in model_metadata.items():
    LOGGER.info(f"  {key}: {value}")

# Log the metadata (this will overwrite existing metadata)
checkcb.log_additional_metadata(
    trainer,              # Pass trainer (or mlflow_logger directly)
    model_metadata       # Updated metadata dictionary
)

LOGGER.info("✅ Metadata updated successfully!")

## Model Testing

Now we'll evaluate the trained model on the test set. This will:

1. Load the best model checkpoint (not the final training state)
2. Run inference on the test/validation data
3. Calculate final performance metrics
4. Automatically register the model in MLflow Model Registry

> 📊 **Important**: We test on the best checkpoint (lowest validation loss) rather than the final model state to get the most reliable performance estimates.

In [None]:
# STEP 9: Test the Model
# =====================

LOGGER.info("🧪 Starting model testing...")
LOGGER.info("This will:")
LOGGER.info("  - Load the best model checkpoint")
LOGGER.info("  - Evaluate on test data") 
LOGGER.info("  - Log final metrics to MLflow")
LOGGER.info("  - Register model in MLflow Model Registry")

# Test using the best model checkpoint
test_results = trainer.test(
    model,
    ckpt_path=checkcb.best_model_path,    # Use best model, not last
    datamodule=dm
)

LOGGER.info("✅ Testing completed!")
LOGGER.info("Final test metrics:")
for metric_name, value in test_results[0].items():
    LOGGER.info(f"  {metric_name}: {value:.4f}")

# Model registration happens automatically due to register_model_on='test'
LOGGER.info(f"🏆 Model registered in MLflow Model Registry as '{PROJECT_NAME}'")

# Making Predictions

Now let's use our trained model to make predictions on new data. This section demonstrates:

1. Loading the trained model
2. Setting up a prediction pipeline
3. Running inference on test data
4. Visualizing the results

This is similar to how you would deploy the model in production.

## Alternative Model Loading

You can load models in several ways:
- From checkpoint file: `trainer.predict(model, ckpt_path="path/to/checkpoint")`
- From MLflow registry: `mlflow.pytorch.load_model("models:/ModelName/Version")`
- From MLflow run: `mlflow.pytorch.load_model("runs:/run_id/model/artifact_path")`

In [None]:
# STEP 10: Make Predictions
# ========================

LOGGER.info("🔮 Setting up prediction pipeline...")

# Create a new trainer for prediction (no training setup needed)
pred_trainer = L.Trainer(
    enable_model_summary=True,
    enable_progress_bar=True,
)

# Set up data module for prediction (same as before)
pred_dm = DatamintDataModule(
    PROJECT_NAME,
    batch_size=8,
    alb_transform=transf,
    include_segmentation_names=['fibula', 'tibia', 'patella', 'femur']
)

LOGGER.info("Running predictions...")
LOGGER.info("Note: Using the model already in memory")
LOGGER.info("Alternative: Load from MLflow registry with:")
LOGGER.info(f"  model = mlflow.pytorch.load_model('models:/{PROJECT_NAME}/latest')")

# Option 1: Use model already in memory
preds = pred_trainer.predict(
    model,
    # ckpt_path=checkcb.best_model_path,  # Uncomment to load from checkpoint
    datamodule=pred_dm
)

# Option 2: Load from MLflow Model Registry (commented out)
# registered_model = mlflow.pytorch.load_model(f'models:/{PROJECT_NAME}/latest')
# preds = pred_trainer.predict(registered_model, datamodule=pred_dm)

LOGGER.info(f"✅ Predictions completed!")
LOGGER.info(f"Generated {len(preds)} batches of predictions")
LOGGER.info(f"First batch shape: {preds[0].shape}")

## Visualizing Results

Let's visualize the model's predictions to see how well it's performing. We'll:

1. Convert model outputs to binary masks
2. Load the corresponding input images
3. Overlay predicted masks on the original images
4. Display the results

> 🎨 **Visualization Notes**: 
> - We exclude the background class (index 0) from visualization
> - Different colors represent different anatomical structures
> - Transparency (alpha) allows you to see both image and predictions

In [None]:
# STEP 11: Visualize Predictions
# ==============================
predicted_mask = preds[0] > 0
# plot mask
for batch in dm.predict_dataloader():
    imgs = batch['image']
    break

imgs_with_mask = []
for im, pr in zip(imgs, predicted_mask):
    imgs_with_mask.append(draw_masks(im, pr[1:]))  # pr[0] is the background
show(imgs_with_mask, figsize=(16, 7))