In [None]:
import mlflow
import os
from mlflow.models.signature import infer_signature
import json

def track_mlflow_experiment(
    model,
    X_train,
    y_train,
    X_test,
    y_test,
    params: dict,
    metrics: dict,
    experiment_name: str,
    model_name: str,
    run_name: str = None,
    tags: dict = None,
    artifacts: dict = None,
    log_model: bool = True
):
    """
    Track model training in SageMaker + MLflow hosted environment.

    Args:
        model: Trained model object (e.g. scikit-learn, XGBoost).
        X_train, y_train, X_test, y_test: Data used.
        params (dict): Model hyperparameters.
        metrics (dict): Evaluation metrics (accuracy, F1, etc.).
        experiment_name (str): MLflow experiment name.
        model_name (str): For model registry.
        run_name (str): Optional run name.
        tags (dict): Optional tags (e.g. {"env": "dev"}).
        artifacts (dict): Optional artifacts to log: {name: path}.
        log_model (bool): Whether to log the model.
    """
    
    # Set experiment
    mlflow.set_experiment(experiment_name)

    with mlflow.start_run(run_name=run_name):

        # Log Params
        for key, val in params.items():
            mlflow.log_param(key, val)

        # Log Metrics
        for key, val in metrics.items():
            mlflow.log_metric(key, val)

        # Log Tags
        if tags:
            for key, val in tags.items():
                mlflow.set_tag(key, val)

        # Infer signature from test data
        signature = infer_signature(X_test, model.predict(X_test))

        # Log Model
        if log_model:
            mlflow.sklearn.log_model(
                model,
                artifact_path="model",
                signature=signature,
                input_example=X_test[:5],
                registered_model_name=model_name
            )

        # Log Artifacts
        if artifacts:
            for key, path in artifacts.items():
                if os.path.exists(path):
                    mlflow.log_artifact(path, artifact_path=key)

        # Log environment (requirements.txt)
        if os.path.exists("requirements.txt"):
            mlflow.log_artifact("requirements.txt")

        # Log notebook or script
        for fname in ["train.py", "notebook.ipynb", "main.py"]:
            if os.path.exists(fname):
                mlflow.log_artifact(fname)

        print("✅ Logged to MLflow successfully.")


In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

model = RandomForestClassifier(n_estimators=100, max_depth=5)
model.fit(X_train, y_train)

y_pred = model.predict(X_test)

params = model.get_params()
metrics = {"accuracy": accuracy_score(y_test, y_pred)}

track_mlflow_experiment(
    model=model,
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    params=params,
    metrics=metrics,
    experiment_name="sagemaker-experiment",
    model_name="rf-prod-model",
    run_name="baseline_rf_sagemaker",
    tags={"user": "prasad", "project": "churn"},
    artifacts={"plots": "output/roc_curve.png"}
)
