In [None]:
import mlflow
from mlflow.tracking import MlflowClient
import os
import sys

# Configuration
REGISTERED_MODEL_NAME = "workspace.ml.house_price_model" # Unity Catalog format
MODEL_ARTIFACT_PATH = "house_price_model"
R2_THRESHOLD = 0.7 # Model ko register karne ke liye minimum R2 score

def setup_mlflow():
    """MLflow setup."""
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_tracking_uri("databricks")
        mlflow.set_registry_uri("databricks-uc") 

def get_run_params(client, run_id):
    """Diye gaye run ID se parameters nikalta hai."""
    run_data = client.get_run(run_id)
    # Sirf un parameters ko filter karte hain jo hum train_model mein set karte hain
    params = {k: v for k, v in run_data.data.params.items() 
              if k in ["n_estimators", "random_state"]}
    return params, run_data.data.metrics.get("r2_score", 0.0)

def run_model_registration(training_run_id):
    """Model ko compare, aur agar zaruri ho toh register karta hai."""
    setup_mlflow()
    client = MlflowClient()
    
    try:
        # Step 1: Current Run Data Extract
        current_params, current_r2 = get_run_params(client, training_run_id)
        current_r2 = float(current_r2)
        print(f"Current Run ({training_run_id}) Metrics: R² = {current_r2:.4f}, Params: {current_params}")
        
        # Step 2: Performance Check
        if current_r2 < R2_THRESHOLD:
            print(f"❌ R² score ({current_r2:.4f}) is below the threshold ({R2_THRESHOLD}). Skipping registration.")
            return

        # Step 3: Latest Registered Model Check
        latest_versions = client.get_latest_versions(REGISTERED_MODEL_NAME, stages=['None', 'Staging', 'Production'])
        should_register = True
        
        if latest_versions:
            latest_version = latest_versions[0]
            latest_params, latest_r2 = get_run_params(client, latest_version.run_id)
            latest_r2 = float(latest_r2)
            
            print(f"Latest Registered Version {latest_version.version}: R² = {latest_r2:.4f}, Params: {latest_params}")
            
            # Parameter comparison logic (User's original logic)
            if latest_params == current_params:
                print("⏭️ Parameters are the same. Skipping model registration.")
                should_register = False
            
            # Optional: Performance comparison logic (if new R2 is worse, skip)
            elif current_r2 <= latest_r2:
                print(f"⏭️ New R² ({current_r2:.4f}) is not better than the latest registered R² ({latest_r2:.4f}). Skipping registration.")
                should_register = False

        else:
            print(f"📝 First time registration for model: {REGISTERED_MODEL_NAME}.")
        
        # Step 4: Conditional Registration
        if should_register:
            # Source URI: MLflow ko pata hai ki artifact kahan hai
            source_uri = f"runs:/{training_run_id}/{MODEL_ARTIFACT_PATH}"
            
            new_version = client.create_model_version(
                name=REGISTERED_MODEL_NAME, 
                source=source_uri, 
                run_id=training_run_id
            )
            
            print(f"✅ Model registered successfully as Version {new_version.version} in {REGISTERED_MODEL_NAME}.")
            
            # Model ko turant Staging mein promote karte hain
            client.transition_model_version_stage(
                name=REGISTERED_MODEL_NAME,
                version=new_version.version,
                stage="Staging",
                archive_existing_versions=True
            )
            print(f"➡️ Model Version {new_version.version} moved to 'Staging'.")
            
        else:
            print("--- Registration/Promotion skipped. ---")
            
    except Exception as e:
        print(f"❌ Error during model registration/promotion: {e}")
        print("⚠️ Ensure the model name exists in Unity Catalog and permissions are correct.")


if __name__ == "__main__":
    # Expecting the Run ID of the previous (Training) step as argument
    if len(sys.argv) < 2:
        print("❌ Error: Model Training Run ID required as argument.")
        sys.exit(1)
        
    training_run_id = sys.argv[1]
    run_model_registration(training_run_id)
