In [3]:
from snowflake.snowpark import Session, types as T
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.metrics import mean_absolute_error
from snowflake.ml.modeling.xgboost import XGBRegressor
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.impute import SimpleImputer

In [None]:
# Create session connection
session = Session.builder.config("connection_name", "mlconnection").create()

In [None]:
# Get data from feature engineering
category_1_df = session.table("category_1_feats")

In [None]:
# Create train-test split
train_df, test_df = category_1_df.random_split(
    weights=[0.8, 0.2], 
    seed=8
)

In [None]:
cat_cols = [
    c.name for c in category_1_df.schema
    if isinstance(c.datatype, T.StringType)
]

In [None]:
parameters = {
    "n_estimators": [100, 200, 300, 400, 500],
    "learning_rate": [0.1, 0.2, 0.3, 0.4, 0.5],
}

In [None]:
# Get current warehouse
current_warehouse = session.get_current_warehouse()[1:-1]

In [None]:
# Increase warehouse size for training
session.sql(
    f"ALTER WAREHOUSE {current_warehouse} SET WAREHOUSE_SIZE=LARGE;"
).collect()

In [None]:
# Build pipeline
pipeline = Pipeline(
    steps=[
        (
            "SimpleImputer",
            SimpleImputer(
                input_cols=cat_cols,
                output_cols=cat_cols,
                strategy="most_frequent",
                drop_input_cols=True,
            ),
        ),
        (
            "GridSearchCV",
            GridSearchCV(
                estimator=XGBRegressor(),
                param_grid=parameters,
                n_jobs=-1,
                scoring="neg_mean_absolute_error",
                input_cols=train_df.drop("category_1_pct", "id").columns,
                label_cols="category_1_pct",
                output_cols="pred_category_1_pct",
            ),
        ),
    ],
)

# Train using pipeline
pipeline.fit(train_df)

In [None]:
pred_result = pipeline.predict_proba(test_df)

In [None]:
# Switch back to small warehouse
session.sql(
    f"ALTER WAREHOUSE {current_warehouse} SET WAREHOUSE_SIZE=SMALL;"
).collect()

In [None]:
# Calculate Mean Absolute Error (MAE)
mae = mean_absolute_error(
    df=pred_result, 
    y_true_col_names="category_1_pct", 
    y_pred_col_names="pred_category_1_pct"
)

# Continued in MLOPS notebook...