In [None]:
# Databricks notebook source
import mlflow
from mlflow.tracking import MlflowClient
import sys
import os

# ====================== CONFIGURATION ========================= #
try:
    EXPERIMENT_NAME = dbutils.widgets.get("experiment_name")
    print(f"‚úì Experiment Name from widget: {EXPERIMENT_NAME}")
except:
    EXPERIMENT_NAME = "/Shared/House_Price_Prediction_Config_Runs"
    print(f"‚Ñπ Using default experiment: {EXPERIMENT_NAME}")

UC_CATALOG = "workspace"
UC_SCHEMA = "ml"

# =================== MODEL CONFIG METADATA ==================== #
MODEL_CONFIG = {
    "xgboost": {
        "model_name": "house_price_xgboost_uc2",
        "artifact_path": "xgboost_model",
        "param_keys": [
            "n_estimators", "max_depth", "learning_rate",
            "subsample", "colsample_bytree"
        ],
        "metric_key": "test_rmse",
        "keywords": ["xgboost", "xgb"]
    }
}

# ================== MODEL TYPE DETECTION ====================== #
def detect_model_config(experiment_name: str):
    exp_lower = experiment_name.lower()
    for model_type, cfg in MODEL_CONFIG.items():
        for key in cfg["keywords"]:
            if key in exp_lower:
                full_uc_name = f"{UC_CATALOG}.{UC_SCHEMA}.{cfg['model_name']}"
                print(f"‚úì Detected model type: {model_type.upper()}")
                print(f"‚úì UC Model Name: {full_uc_name}")
                return (
                    full_uc_name,
                    cfg["artifact_path"],
                    cfg["param_keys"],
                    cfg["metric_key"]
                )
    raise ValueError("‚ùå No matching model config found based on experiment name!")

REGISTERED_MODEL_NAME, ARTIFACT_PATH, PARAM_KEYS, METRIC_KEY = detect_model_config(EXPERIMENT_NAME)
TOL = 1e-6  # float tolerance

# ====================== UTILITIES ====================== #
def normalize(val):
    try:
        if '.' not in str(val) and str(val).isdigit():
            return int(val)
        return float(val)
    except:
        return str(val)

# ================== FIND BEST RUN ====================== #
def get_best_run(client):
    """
    Find the run with the LOWEST test_rmse metric value (best performing model)
    """
    exp = client.get_experiment_by_name(EXPERIMENT_NAME)
    if not exp:
        print("‚ùå Experiment not found.")
        return None, {}, {}

    # ‚úÖ CRITICAL FIX: Order by metric ascending to get best performing run first
    # Remove max_results limit to fetch all runs
    runs = client.search_runs(
        [exp.experiment_id], 
        order_by=[f"metrics.{METRIC_KEY} ASC"],  # ‚úÖ Changed from start_time DESC
        max_results=1000  # Increased to ensure all runs are fetched
    )
    
    if not runs:
        print("‚ö† No runs found in experiment.")
        return None, {}, {}

    best_run = None
    best_metric = float("inf")

    # Iterate through all runs to find the one with minimum RMSE
    for r in runs:
        metric_val = r.data.metrics.get(METRIC_KEY)
        
        # Only consider runs that have the metric logged
        if metric_val is not None:
            print(f"  üìä Run: {r.info.run_name or r.info.run_id[:8]} | {METRIC_KEY}: {metric_val:.4f}")
            
            if metric_val < best_metric:
                best_metric = metric_val
                best_run = r

    if best_run:
        params = {k: normalize(v) for k, v in best_run.data.params.items() if k in PARAM_KEYS}
        metrics = best_run.data.metrics
        
        print(f"\n{'='*70}")
        print(f"üèÜ BEST RUN IDENTIFIED:")
        print(f"{'='*70}")
        print(f"   Run Name: {best_run.info.run_name or 'N/A'}")
        print(f"   Run ID: {best_run.info.run_id}")
        print(f"   {METRIC_KEY}: {best_metric:.6f}")
        print(f"   Parameters: {params}")
        print(f"{'='*70}\n")
        
        return best_run.info.run_id, params, metrics
    else:
        print("‚ö† No valid runs with metric found.")
        return None, {}, {}

# ================ DUPLICATE VERSION CHECK ===================== #
def check_duplicate(client, new_params, new_metrics):
    """
    Check if a model version with same parameters and metrics already exists
    """
    try:
        mv_list = client.search_model_versions(f"name = '{REGISTERED_MODEL_NAME}'")
    except Exception as e:
        print(f"‚Ñπ No existing model versions found (this may be first registration): {e}")
        return None
    
    if not mv_list:
        return None

    new_metric_val = new_metrics.get(METRIC_KEY, None)
    
    for mv in mv_list:
        try:
            run = client.get_run(mv.run_id)
        except Exception as e:
            print(f"‚ö† Could not fetch run {mv.run_id}: {e}")
            continue

        old_params = {k: normalize(v) for k, v in run.data.params.items() if k in new_params}
        old_metric_val = run.data.metrics.get(METRIC_KEY, None)

        same_params = all(old_params.get(k) == new_params.get(k) for k in new_params)
        same_metric = (
            old_metric_val is not None and new_metric_val is not None
            and abs(old_metric_val - new_metric_val) <= TOL
        )
        
        if same_params and same_metric:
            print(f"\n‚è≠Ô∏è DUPLICATE DETECTED!")
            print(f"   Existing Version: {mv.version}")
            print(f"   Run ID: {mv.run_id}")
            print(f"   This model is already registered with same params & performance.")
            return mv
            
    return None

# ================== REGISTER MODEL LOGIC ======================= #
def register_model(client, run_id, params, metrics):
    """
    Register the best model to Unity Catalog
    """
    duplicate_version = check_duplicate(client, params, metrics)
    if duplicate_version:
        print(f"‚úÖ Using existing registered version: {duplicate_version.version}")
        return duplicate_version

    model_uri = f"runs:/{run_id}/{ARTIFACT_PATH}"
    print(f"\n‚è≥ Registering new model version...")
    print(f"   Model URI: {model_uri}")
    print(f"   Target: {REGISTERED_MODEL_NAME}")

    try:
        new_version = mlflow.register_model(model_uri, REGISTERED_MODEL_NAME)
        
        print(f"\n{'='*70}")
        print("‚úÖ MODEL REGISTERED SUCCESSFULLY!")
        print(f"{'='*70}")
        print(f"   Model Name: {REGISTERED_MODEL_NAME}")
        print(f"   Version: {new_version.version}")
        print(f"   Source Run ID: {run_id}")
        print(f"   {METRIC_KEY}: {metrics.get(METRIC_KEY, 'N/A')}")
        print(f"{'='*70}\n")

        # üè∑Ô∏è Add helpful tags for tracking
        client.set_model_version_tag(
            REGISTERED_MODEL_NAME, 
            new_version.version, 
            "source_run_id", 
            run_id
        )
        client.set_model_version_tag(
            REGISTERED_MODEL_NAME, 
            new_version.version, 
            "experiment_name", 
            EXPERIMENT_NAME
        )
        client.set_model_version_tag(
            REGISTERED_MODEL_NAME, 
            new_version.version, 
            "metric_rmse", 
            str(metrics.get(METRIC_KEY, ""))
        )
        
        # Add parameters as tags for easy reference
        for param_key, param_val in params.items():
            client.set_model_version_tag(
                REGISTERED_MODEL_NAME,
                new_version.version,
                f"param_{param_key}",
                str(param_val)
            )

        return new_version
        
    except Exception as e:
        print(f"‚ùå Registration Failed: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

# ============================ MAIN ============================ #
if __name__ == "__main__":
    print("\n" + "=" * 70)
    print("üöÄ MODEL REGISTRATION - BEST RUN SELECTION (MLflow + UC)")
    print("=" * 70 + "\n")

    client = MlflowClient()
    
    print(f"üìã Configuration:")
    print(f"   Experiment: {EXPERIMENT_NAME}")
    print(f"   Target Model: {REGISTERED_MODEL_NAME}")
    print(f"   Metric to optimize: {METRIC_KEY} (lower is better)")
    print(f"   Artifact Path: {ARTIFACT_PATH}\n")
    
    print("üîç Searching for best run...")
    run_id, params, metrics = get_best_run(client)

    if run_id:
        register_model(client, run_id, params, metrics)
        print("\n‚ú® Registration process completed successfully!")
    else:
        print("‚ùå No valid best run found. Exiting.")
        sys.exit(1)