In [None]:
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from mlflow.models.signature import infer_signature
import os
import sys
import pickle

# Configuration
DATA_ARTIFACT_PATH = "prepared_data"
MODEL_ARTIFACT_PATH = "house_price_model"
MLFLOW_EXPERIMENT_PATH = "/Shared/mlops_house_price_prediction_experiment"

def setup_mlflow():
    """MLflow setup."""
    if "DATABRICKS_RUNTIME_VERSION" in os.environ:
        mlflow.set_tracking_uri("databricks")
        mlflow.set_registry_uri("databricks-uc") 
    try:
        mlflow.set_experiment(MLFLOW_EXPERIMENT_PATH)
    except:
        mlflow.set_experiment("mlops_house_price_prediction_experiment")

def run_model_training(data_prep_run_id):
    """Model train, evaluate aur artifact log karta hai."""
    setup_mlflow()
    
    # Data Prep Run ID ko tag ke roop mein log karte hain
    with mlflow.start_run(run_name="02_Model_Training", tags={"data_prep_run_id": data_prep_run_id}) as run:
        print(f"🚀 MLflow Training run started: {run.info.run_id}")
        print(f"🔗 Using data from Run ID: {data_prep_run_id}")

        # 1. Data Load from Artifacts
        client = mlflow.tracking.MlflowClient()
        temp_dir = "./temp_data"
        os.makedirs(temp_dir, exist_ok=True)
        
        # Data artifacts ko download karte hain
        client.download_artifacts(
            run_id=data_prep_run_id, 
            path=DATA_ARTIFACT_PATH, 
            dst_path=temp_dir
        )
        
        # Downloaded files ko load karte hain
        with open(f"{temp_dir}/{DATA_ARTIFACT_PATH}/X_train.pkl", 'rb') as f: X_train = pickle.load(f)
        with open(f"{temp_dir}/{DATA_ARTIFACT_PATH}/X_test.pkl", 'rb') as f: X_test = pickle.load(f)
        with open(f"{temp_dir}/{DATA_ARTIFACT_PATH}/y_train.pkl", 'rb') as f: y_train = pickle.load(f)
        with open(f"{temp_dir}/{DATA_ARTIFACT_PATH}/y_test.pkl", 'rb') as f: y_test = pickle.load(f)

        print(f"✅ Data loaded successfully. Training on {len(X_train)} samples.")
        
        # 2. Train Model
        n_estimators = 150 # Parameter change kiya for demo
        random_state = 42
        
        model = RandomForestRegressor(n_estimators=n_estimators, random_state=random_state)
        model.fit(X_train, y_train)

        # 3. Evaluation
        predictions = model.predict(X_test)
        mse = mean_squared_error(y_test, predictions)
        r2 = r2_score(y_test, predictions)

        # 4. Log Params, Metrics, and Model Artifact
        mlflow.log_param("n_estimators", n_estimators)
        mlflow.log_param("random_state", random_state)
        mlflow.log_metric("mse", mse)
        mlflow.log_metric("r2_score", r2)
        print(f"📊 Metrics - MSE: {mse:.2f}, R²: {r2:.4f}")
        
        signature = infer_signature(X_train, predictions)
        
        # Model ko sirf Run Artifact ke roop mein log karte hain (Register nahi)
        model_info = mlflow.sklearn.log_model(
            sk_model=model,
            artifact_path=MODEL_ARTIFACT_PATH,
            signature=signature,
            # registered_model_name yahan nahi denge
        )
        print(f"✅ Model logged successfully as a run artifact: {model_info.model_uri}")
        
        # Cleanup
        import shutil
        shutil.rmtree(temp_dir)

        print(f"\n--- Model Training Complete ---")
        print(f"Training Run ID: {run.info.run_id}")
        return run.info.run_id, r2

if __name__ == "__main__":
    # Expecting the Run ID of the previous (Data Prep) step as argument
    if len(sys.argv) < 2:
        print("❌ Error: Data Preparation Run ID required as argument.")
        sys.exit(1)
        
    data_run_id = sys.argv[1]
    
    # Example: Agar aap locally test kar rahe hain, toh ek sample ID daal sakte hain
    # data_run_id = "your_previous_data_prep_run_id" 
    
    training_run_id, r2_score = run_model_training(data_run_id)
    
    # Yeh ID agle step ke liye zaroori hai
    print(f"💡 The Training Run ID for next step is: {training_run_id}")
