In [0]:
import mlflow
import mlflow.sklearn
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

In [0]:
# Read data
df = spark.read.format("delta").load("/Volumes/workspace/ecommerce/silver/events_delta")
# Create features and target variable
pdf = df.select("price", "user_id").limit(10000).toPandas()
X = pdf[["user_id"]]
y = pdf["price"]
# Split data into training and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [0]:
# Set experiment
mlflow.set_experiment("/Shared/simple_regression_experiment")
# Start run
with mlflow.start_run():
    model = LinearRegression()
    model.fit(X_train, y_train)
    # Predict model
    preds = model.predict(X_test)
    # Compute RMSE manually since 'squared' argument is not supported
    mse = mean_squared_error(y_test, preds)
    rmse = mse ** 0.5
    # Log parameters
    mlflow.log_param("model", "LinearRegression")
    mlflow.log_param("feature", "user_id")
    # Log metric
    mlflow.log_metric("rmse", rmse)
    # Log model
    mlflow.sklearn.log_model(model, "regression_model")
    # Display root mean square error value
    print("RMSE:", rmse)



RMSE: 360.9539860195898


In [0]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [0]:
with mlflow.start_run():
    model = LinearRegression()
    model.fit(X_train, y_train)
    # Predict model
    preds = model.predict(X_test)
    # rmse = mean_squared_error(y_test, preds, squared=False)
    # Compute RMSE manually since 'squared' argument is not supported
    mse = mean_squared_error(y_test, preds)
    rmse = mse ** 0.5
    # Log parameters
    mlflow.log_param("model", "LinearRegression")
    mlflow.log_param("feature", "user_id")
    # Log new experiemental run
    mlflow.log_param("run_variant", "second_run")
    # Log metric
    mlflow.log_metric("rmse", rmse)
    mlflow.sklearn.log_model(model, "regression_model")
    # Display root mean square error value
    print("RMSE:", rmse)



RMSE: 352.89316096883294
