In [None]:
# Databricks notebook source
import mlflow
from mlflow.tracking import MlflowClient
import sys
import os

# ==================== CONFIGURATION ====================

# ⚙️ Automatically detect experiment name
try:
    EXPERIMENT_NAME = dbutils.widgets.get("experiment_name")
    print(f"✓ Experiment name from widget: {EXPERIMENT_NAME}")
except:
    EXPERIMENT_NAME = "/Shared/House_Price_Prediction_Delta_XGBoost"
    print(f"ℹ️ Using default experiment name: {EXPERIMENT_NAME}")

UC_CATALOG_NAME = "workspace"
UC_SCHEMA_NAME = "ml"

# ==================== MODEL CONFIGS ====================
MODEL_CONFIGS = {
    "xgboost": {
        "model_name": "house_price_xgboost_uc",
        "artifact_path": "xgboost_model",
        "param_keys": ['best_n_estimators', 'best_max_depth', 'best_learning_rate', 'best_subsample', 'best_colsample_bytree'],
        "metric_key": "test_rmse",
        "keywords": ["xgboost", "xgb"]
    },
    "randomforest": {
        "model_name": "house_price_rf_uc",
        "artifact_path": "sklearn_rf_model",
        "param_keys": ['best_n_estimators', 'best_max_depth', 'best_min_samples_split', 'best_min_samples_leaf'],
        "metric_key": "test_rmse",
        "keywords": ["rf", "randomforest", "random_forest"]
    },
    "lightgbm": {
        "model_name": "house_price_lightgbm_uc",
        "artifact_path": "lightgbm_model",
        "param_keys": ['best_n_estimators', 'best_max_depth', 'best_learning_rate', 'best_num_leaves'],
        "metric_key": "test_rmse",
        "keywords": ["lightgbm", "lgbm", "lgb"]
    },
    "catboost": {
        "model_name": "house_price_catboost_uc",
        "artifact_path": "catboost_model",
        "param_keys": ['best_iterations', 'best_depth', 'best_learning_rate', 'best_l2_leaf_reg'],
        "metric_key": "test_rmse",
        "keywords": ["catboost", "cat"]
    },
    "gradientboosting": {
        "model_name": "house_price_gb_uc",
        "artifact_path": "sklearn_gb_model",
        "param_keys": ['best_n_estimators', 'best_max_depth', 'best_learning_rate', 'best_subsample'],
        "metric_key": "test_rmse",
        "keywords": ["gradientboosting", "gb", "gradient_boosting"]
    },
    "linear": {
        "model_name": "house_price_linear_uc",
        "artifact_path": "sklearn_linear_model",
        "param_keys": ['best_alpha', 'best_fit_intercept'],
        "metric_key": "test_rmse",
        "keywords": ["linear", "ridge", "lasso", "elasticnet"]
    },
    "decisiontree": {
        "model_name": "house_price_dt_uc",
        "artifact_path": "sklearn_dt_model",
        "param_keys": ['best_max_depth', 'best_min_samples_split', 'best_min_samples_leaf'],
        "metric_key": "test_rmse",
        "keywords": ["decisiontree", "dt", "decision_tree"]
    },
    "logistic": {
        "model_name": "house_price_logreg_uc",
        "artifact_path": "sklearn_logistic_model",
        "param_keys": ['best_C', 'best_solver', 'best_max_iter'],
        "metric_key": "test_accuracy",
        "keywords": ["logistic", "logreg", "logisticregression"]
    },
    "svm": {
        "model_name": "house_price_svm_uc",
        "artifact_path": "sklearn_svm_model",
        "param_keys": ['best_C', 'best_kernel', 'best_gamma'],
        "metric_key": "test_accuracy",
        "keywords": ["svm", "supportvectormachine"]
    }
}

# ==================== DETECTION LOGIC ====================
def generate_model_name_from_experiment(experiment_name: str) -> tuple:
    experiment_lower = experiment_name.lower()
    for model_type, config in MODEL_CONFIGS.items():
        for keyword in config["keywords"]:
            if keyword in experiment_lower:
                registered_model_name = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{config['model_name']}"
                print(f"✓ Detected Model Type: {model_type.upper()}")
                print(f"✓ Registered Model Name: {registered_model_name}")
                return (
                    registered_model_name,
                    config["artifact_path"],
                    config["param_keys"],
                    config["metric_key"]
                )
    
    # 🧩 No match → auto-generate generic UC model name
    print(f"⚠️ Warning: No predefined model type found for '{experiment_name}'")
    last_part = experiment_name.strip('/').split('/')[-1].lower()
    clean_name = last_part.replace(' ', '').replace('-', '').replace('_', '')
    registered_model_name = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.house_price_{clean_name}_uc"
    print(f"✓ Generated Generic Model Name: {registered_model_name}")
    return registered_model_name, "model", [], "test_rmse"

REGISTERED_MODEL_NAME, MODEL_ARTIFACT_PATH, PARAM_KEYS, METRIC_KEY = generate_model_name_from_experiment(EXPERIMENT_NAME)

METRIC_TOLERANCE = 1e-6

# ==================== UTILITY FUNCTIONS ====================
def normalize_param_value(value):
    if value is None:
        return None
    try:
        if '.' not in str(value) and str(value).lstrip('-').isdigit():
            return int(value)
        return float(value)
    except:
        return str(value)

def get_latest_run_info(client):
    experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
    if not experiment:
        print(f"❌ Experiment '{EXPERIMENT_NAME}' not found.")
        return None, {}, {}
    runs = client.search_runs([experiment.experiment_id], order_by=["start_time DESC"], max_results=1)
    if not runs:
        print(f"⚠️ No finished runs found for experiment '{EXPERIMENT_NAME}'.")
        return None, {}, {}
    latest_run = runs[0]
    run_id = latest_run.info.run_id
    run_params = {k: normalize_param_value(v) for k, v in latest_run.data.params.items() if not PARAM_KEYS or k in PARAM_KEYS}
    run_metrics = latest_run.data.metrics
    print(f"✓ Latest Run ID: {run_id}")
    print(f"  Params: {run_params}")
    print(f"  Metrics: {run_metrics}")
    return run_id, run_params, run_metrics

# ==================== DUPLICATE CHECK ====================
def check_existing_version(client, current_params, current_metrics):
    current_metric_value = current_metrics.get(METRIC_KEY, -1)
    try:
        filter_string = f"name = '{REGISTERED_MODEL_NAME}'"
        versions = client.search_model_versions(filter_string=filter_string)
        if not versions:
            return None

        for version in versions:
            version_run_id = version.run_id
            if not version_run_id:
                continue
            run = client.get_run(version_run_id)
            version_params = {k: normalize_param_value(v) for k, v in run.data.params.items() if k in current_params}
            version_metrics = run.data.metrics
            version_metric_value = version_metrics.get(METRIC_KEY, -1)

            # Compare parameters and metrics
            params_match = all(version_params.get(k) == current_params.get(k) for k in current_params)
            metrics_match = abs(version_metric_value - current_metric_value) <= METRIC_TOLERANCE

            if params_match and metrics_match:
                print(f"\n⏭️ DUPLICATE DETECTED: Version {version.version} already exists with same params & metric.")
                return version
        return None
    except Exception as e:
        print(f"⚠ Warning checking duplicates: {e}")
        return None

# ==================== MODEL REGISTRATION ====================
def register_model_for_serving(client, run_id, run_params, run_metrics, model_name, artifact_path):
    # 1️⃣ Check duplicate first
    existing_version = check_existing_version(client, run_params, run_metrics)
    if existing_version:
        print(f"🎯 Using existing model version: {existing_version.name} v{existing_version.version}")
        return existing_version

    # 2️⃣ Register new version
    model_uri = f"runs:/{run_id}/{artifact_path}"
    print(f"⏳ Registering model from: {model_uri}")
    try:
        model_version = mlflow.register_model(model_uri=model_uri, name=model_name)
        print(f"\n✅ Model Registered in UC")
        print(f"   Name: {model_version.name}")
        print(f"   Version: {model_version.version}")
        print(f"   Params: {run_params}")
        print(f"   Metrics: {run_metrics}")
        return model_version
    except Exception as e:
        print(f"❌ Registration failed: {e}")
        sys.exit(1)

# ==================== MAIN ====================
if __name__ == "__main__":
    print("\n" + "=" * 60)
    print("🚀 FLEXIBLE MLFLOW MODEL REGISTRATION (Duplicate Safe)")
    print("=" * 60 + "\n")

    client = MlflowClient()
    run_id, params, metrics = get_latest_run_info(client)
    
    if run_id:
        register_model_for_serving(client, run_id, params, metrics, REGISTERED_MODEL_NAME, MODEL_ARTIFACT_PATH)
