In [None]:
#| default_exp utils

In [None]:
#| export
import pickle
import torch
from pathlib import Path
import mlflow
import mlflow.pytorch
import os
import tempfile
import json
from fastai.callback.core import Callback
from fastcore.foundation import L
from typing import Any

# Utils
>

In [None]:
#| export
def store_variables(pkl_fn: str | Path, size: list, reorder: bool, resample: int | list):
    """Save variable values in a pickle file."""
    
    var_vals = [size, reorder, resample]
    
    with open(pkl_fn, 'wb') as f:
        pickle.dump(var_vals, f)

In [None]:
#| export
def load_variables(pkl_fn: (str, Path)):
    """Loads stored variable values from a pickle file.

    Args:
        pkl_fn: File path of the pickle file to be loaded.

    Returns:
        The deserialized value of the pickled data.
    """
    with open(pkl_fn, 'rb') as f:
        return pickle.load(f)

In [None]:
#| export
def print_colab_gpu_info(): 
    """Check if we have a GPU attached to the runtime."""
    
    colab_gpu_msg =(f"{'#'*80}\n"
                    "Remember to attach a GPU to your Colab Runtime:"
                    "\n1. From the **Runtime** menu select **Change Runtime Type**"
                    "\n2. Choose **GPU** from the drop-down menu"
                    "\n3. Click **'SAVE'**\n"
                    f"{'#'*80}")
    
    if torch.cuda.is_available(): print('GPU attached.')
    else: print(colab_gpu_msg)

In [None]:
#| export
class ModelTrackingCallback(Callback):
    """
    A FastAI callback for comprehensive MLflow experiment tracking.
    
    This callback automatically logs hyperparameters, metrics, model artifacts,
    and configuration to MLflow during training.
    """
    
    def __init__(
        self, 
        model_name: str, 
        loss_function: str, 
        item_tfms: list[Any],
        size: list[int], 
        resample: list[float], 
        reorder: bool
    ):
        """
        Initialize the MLflow tracking callback.
        
        Args:
            model_name: Name of the model architecture for registration
            loss_function: Name of the loss function being used
            size: Model input dimensions
            resample: Resampling dimensions
            reorder: Whether reordering augmentation is applied
        """
        self.model_name = model_name
        self.loss_function = loss_function
        self.item_tfms = item_tfms
        self.size = size
        self.resample = resample
        self.reorder = reorder
        
        self.config = self._build_config()
        
    def extract_all_params(self, tfm):
        """
        Extract all parameters from a transform object for detailed logging.
        
        Args:
            tfm: Transform object to extract parameters from
            
        Returns:
            dict: Dictionary with 'name' and 'params' keys containing transform details
        """
        class_name = tfm.__class__.__name__
        params = {}
        
        for key, value in tfm.__dict__.items():
            if not key.startswith('_') and key != '__signature__':
                if hasattr(value, '__dict__') and hasattr(value, 'target_shape'):
                    params['target_shape'] = value.target_shape
                elif hasattr(value, '__dict__') and not key.startswith('_'):
                    nested_params = {k: v for k, v in value.__dict__.items() 
                                   if not k.startswith('_') and isinstance(v, (int, float, str, bool, tuple, list))}
                    params.update(nested_params)
                elif isinstance(value, (int, float, str, bool, tuple, list)):
                    params[key] = value
        
        return {
            'name': class_name,
            'params': params
        }
        
    def _build_config(self) -> dict[str, Any]:
        """Build configuration dictionary from initialization parameters."""
        # Extract detailed transform information
        transform_details = [self.extract_all_params(tfm) for tfm in self.item_tfms]
        
        return {
            "model_name": self.model_name,
            "loss_function": self.loss_function,
            "transform_details": transform_details,
            "size": self.size,
            "resample": self.resample,
            "reorder": self.reorder,
        }
    
    def _extract_training_params(self) -> dict[str, Any]:
        """Extract training hyperparameters from the learner."""
        params = {}
        
        params["epochs"] = self.learn.n_epoch
        params["learning_rate"] = float(self.learn.lr)
        params["optimizer"] = self.learn.opt_func.__name__
        params["batch_size"] = self.learn.dls.bs
        
        params["loss_function"] = self.config["loss_function"]
        params["size"] = self.config["size"]
        params["resample"] = self.config["resample"]
        params["reorder"] = self.config["reorder"]
        
        params["transformations"] = json.dumps(
            self.config["transform_details"], 
            indent=2, 
            separators=(',', ': ')
        )
        
        return params
    
    def _extract_epoch_metrics(self) -> dict[str, float]:
        """Extract metrics from the current epoch."""
        recorder = self.learn.recorder
        
        # Get custom metric names and values (skip 'epoch' and 'time')
        metric_names = recorder.metric_names[2:]
        raw_metric_values = recorder.log[2:]
        
        metrics = {}
        
        # Process each metric, handling both scalars and tensors
        for name, val in zip(metric_names, raw_metric_values):
            if val is None:
                continue  # Skip None values during inference
            if isinstance(val, torch.Tensor):
                if val.numel() == 1:
                    # Single value tensor (like binary dice score)
                    metrics[name] = float(val)
                else:
                    # Multi-element tensor (like multiclass dice scores)
                    val_list = val.tolist() if hasattr(val, 'tolist') else list(val)
                    # Log individual class scores
                    for i, class_score in enumerate(val_list):
                        metrics[f"{name}_class_{i+1}"] = float(class_score)
                    # Log mean across classes
                    metrics[f"{name}_mean"] = float(torch.mean(val))
            else:
                metrics[name] = float(val)
        
        # Handle loss values
        if len(recorder.log) >= 2:
            if recorder.log[1] is not None:
                metrics['train_loss'] = float(recorder.log[1])
            if len(recorder.log) >= 3 and recorder.log[2] is not None:
                metrics['valid_loss'] = float(recorder.log[2])
                
        return metrics
    
    def _save_model_artifacts(self, temp_dir: Path) -> None:
        """Save model weights, learner, and configuration as artifacts."""
        weights_path = temp_dir / "weights"
        self.learn.save(str(weights_path))
        
        weights_file = f"{weights_path}.pth"
        if os.path.exists(weights_file):
            mlflow.log_artifact(weights_file, "model")
        
        # Remove MLflow callbacks before exporting learner for inference
        # This prevents the callback from being triggered during inference
        original_cbs = self.learn.cbs.copy()  # Save original callbacks
        
        # Remove ModelTrackingCallback instances from learner using proper collection type
        filtered_cbs = L([cb for cb in self.learn.cbs if not isinstance(cb, ModelTrackingCallback)])
        self.learn.cbs = filtered_cbs
        
        # Export clean learner without MLflow callbacks
        learner_path = temp_dir / "learner.pkl"
        self.learn.export(str(learner_path))
        mlflow.log_artifact(str(learner_path), "model")
        
        # Restore original callbacks for current session
        self.learn.cbs = original_cbs
        
        config_path = temp_dir / "inference_settings.pkl"
        store_variables(config_path, self.size, self.reorder, self.resample)
        mlflow.log_artifact(str(config_path), "config")
    
    def _register_pytorch_model(self) -> None:
        """Register the PyTorch model with MLflow."""
        mlflow.pytorch.log_model(
            pytorch_model=self.learn.model,
            registered_model_name=self.model_name
        )
    
    def before_fit(self) -> None:
        """Log hyperparameters before training starts."""
        params = self._extract_training_params()
        mlflow.log_params(params)
    
    def after_epoch(self) -> None:
        """Log metrics after each epoch."""
        metrics = self._extract_epoch_metrics()
        if metrics:
            mlflow.log_metrics(metrics, step=self.learn.epoch)
    
    def after_fit(self) -> None:
        """Log model artifacts after training completion."""
        print("\nTraining finished. Logging model artifacts to MLflow...")
        
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_path = Path(temp_dir)
            
            self._save_model_artifacts(temp_path)
            
            self._register_pytorch_model()
            
        print(f"MLflow run completed. Run ID: {mlflow.active_run().info.run_id}")

In [None]:
#| export

import subprocess
import threading
import time
import socket
import os
from IPython.display import display, HTML, clear_output
from IPython.core.magic import register_line_magic
from IPython import get_ipython
import requests
import shutil

class MLflowUIManager:
    def __init__(self):
        self.process = None
        self.thread = None
        self.port = 5001
        self.host = '0.0.0.0'
        self.backend_store_uri = './mlruns'
        
    def is_port_available(self, port):
        """Check if a port is available."""
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            try:
                s.bind(('localhost', port))
                return True
            except OSError:
                return False
    
    def is_mlflow_running(self):
        """Check if MLflow UI is actually responding."""
        try:
            response = requests.get(f'http://localhost:{self.port}', timeout=2)
            return response.status_code == 200
        except:
            return False
    
    def find_available_port(self, start_port=5001):
        """Find an available port starting from start_port."""
        for port in range(start_port, start_port + 10):
            if self.is_port_available(port):
                return port
        return None
    
    def check_mlflow_installed(self):
        """Check if MLflow is installed."""
        return shutil.which('mlflow') is not None
    
    def start_ui(self, auto_open=True, quiet=False):
        """Start MLflow UI with better error handling and user feedback."""
        
        # Check if MLflow is installed
        if not self.check_mlflow_installed():
            if not quiet:
                display(HTML('<div style="color: #d32f2f; font-weight: bold; font-size: 14px;">❌ MLflow not installed. Run: pip install mlflow</div>'))
            return False
        
        # Find available port
        available_port = self.find_available_port(self.port)
        if available_port is None:
            if not quiet:
                display(HTML('<div style="color: #d32f2f; font-weight: bold; font-size: 14px;">❌ No available ports found (5001-5010)</div>'))
            return False
        
        self.port = available_port
        
        # Start MLflow UI in a separate thread
        def run_mlflow():
            try:
                self.process = subprocess.Popen([
                    'mlflow', 'ui', 
                    '--host', self.host,
                    '--port', str(self.port),
                    '--backend-store-uri', self.backend_store_uri
                ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
                self.process.wait()
            except Exception as e:
                if not quiet:
                    display(HTML(f'<div style="color: #d32f2f; font-weight: bold; font-size: 14px;">❌ Error: {str(e)}</div>'))
        
        self.thread = threading.Thread(target=run_mlflow, daemon=True)
        self.thread.start()
        
        # Wait and check if server started successfully
        max_wait = 10
        for i in range(max_wait):
            time.sleep(1)
            if self.is_mlflow_running():
                if quiet:
                    # Bright, visible link for quiet mode
                    display(HTML(f'''
                        <a href="http://localhost:{self.port}" target="_blank" 
                           style="color: #1976d2; font-weight: bold; font-size: 16px; text-decoration: underline;">
                           🔗 MLflow UI (Port {self.port})
                        </a>
                    '''))
                else:
                    # Success message with high contrast colors
                    display(HTML(f'''
                        <div style="background-color: #c8e6c9; border: 2px solid #388e3c; padding: 15px; border-radius: 8px; margin: 10px 0;">
                            <div style="color: #1b5e20; font-weight: bold; font-size: 16px; margin-bottom: 10px;">
                                ✅ MLflow UI is running successfully!
                            </div>
                            <a href="http://localhost:{self.port}" target="_blank" 
                               style="background-color: #1976d2; color: white; padding: 12px 24px; text-decoration: none; border-radius: 6px; font-weight: bold; font-size: 14px; display: inline-block; margin: 5px 0;">
                                🔗 Open MLflow UI
                            </a>
                            <div style="margin-top: 10px;">
                                <div style="color: #424242; font-size: 13px;">URL: http://localhost:{self.port}</div>
                            </div>
                        </div>
                    '''))
                return True
        
        # If we get here, server didn't start properly
        if not quiet:
            display(HTML('<div style="color: #d32f2f; font-weight: bold; font-size: 14px;">❌ Failed to start MLflow UI</div>'))
        return False
    
    def stop(self):
        """Stop the MLflow UI server."""
        if self.process:
            self.process.terminate()
            self.process = None
            display(HTML('''
                <div style="background-color: #ffecb3; border: 2px solid #f57c00; padding: 10px; border-radius: 6px;">
                    <span style="color: #e65100; font-weight: bold; font-size: 14px;">🛑 MLflow UI stopped</span>
                </div>
            '''))
        else:
            display(HTML('''
                <div style="background-color: #f0f0f0; border: 2px solid #757575; padding: 10px; border-radius: 6px;">
                    <span style="color: #424242; font-weight: bold; font-size: 14px;">ℹ️ MLflow UI is not currently running</span>
                </div>
            '''))
    
    def status(self):
        """Check MLflow UI status."""
        if self.is_mlflow_running():
            display(HTML(f'''
                <div style="background-color: #c8e6c9; border: 2px solid #388e3c; padding: 10px; border-radius: 6px;">
                    <div style="color: #1b5e20; font-weight: bold; font-size: 14px;">✅ MLflow UI is running</div>
                    <a href="http://localhost:{self.port}" target="_blank" 
                       style="color: #1976d2; font-weight: bold; text-decoration: underline;">
                       http://localhost:{self.port}
                    </a>
                </div>
            '''))
        else:
            display(HTML('''
                <div style="background-color: #ffcdd2; border: 2px solid #d32f2f; padding: 10px; border-radius: 6px;">
                    <div style="color: #b71c1c; font-weight: bold; font-size: 14px;">❌ MLflow UI is not running</div>
                    <div style="color: #424242; font-size: 13px; margin-top: 5px;">
                        Run <code style="background-color: #f5f5f5; padding: 2px 4px; border-radius: 3px;">mlflow_ui.start_ui()</code> to start it.
                    </div>
                </div>
            '''))