In [2]:
import ast

import pandas as pd

from snowflake.ml.registry import Registry
from snowflake.ml.modeling.model_selection import GridSearchCV

In [None]:
def iterate_model_name(df: pd.DataFrame, model_name: str) -> str:

   
    """
    Determine the next model version name based on the provided model name 
    and current versions in the DataFrame.

    Generates a new version string based on the latest version present. 
    Returns "V_1" if the DataFrame is empty or the model name is not found. 
    Otherwise, increments the most recent version.

    Args:
        df (pd.DataFrame): DataFrame containing model information.
        model_name (str): Name of the model to check and increment.

    Returns:
        str: The new model version string.
    """
    
    if df.empty:
        # Return default version if DataFrame is empty
        return "V_1"
    
    if df[df["name"] == model_name].empty:
        # Return default version if the model name is not found
        return "V_1"
    
    # Extract the list of versions from the DataFrame and increment the latest version
    versions_str = df[df["name"] == model_name]["versions"].iloc[0]
    versions_list = sorted(ast.literal_eval(versions_str))
    last_version = versions_list[-1]
    
    # Extract prefix and number from the last version
    prefix, number = last_version.rsplit("_", 1)
    new_number = int(number) + 1
    
    # Construct new version name
    new_version = f"{prefix}_{new_number}"
    
    # Update the list with the new version and return it
    versions_list[-1] = new_version
    
    return new_version

In [None]:
# Get optimal model
optimal_model = GridSearchCV.to_sklearn().best_estimator_

# Register Model

In [None]:
# Get sample input data to pass into the registry logging function
X = train_df.drop("category_1_pct", "id").limit(100)

# Get registry to log the model
reg = Registry(session=session)

# Get current registered models
reg_df = reg.show_models()

# Define model name
model_name = "category_1_model"

# Get model version based on models in registry
model_version = iterate_model_name(reg_df, model_name)

# Log Model, Metrics and Hyperparameters

In [None]:
# Log Model
category_1_model = reg.log_model(
    model_name=model_name,
    version_name=model_version,
    model=optimal_model,
    sample_input_data=X,
)

In [None]:
# Log evaluation metric
category_1_model.set_metric(
    metric_name="mae",
    value=mae
)

In [None]:
hyperparameters = {
    k: v for k, v in optimal_model.get_params().items() 
    if v and k != "missing"
}
category_1_model.set_metric(
    metric_name="hyperparameters", 
    value=hyperparameters
)