### Set up Snowpark Session

See [Configure Connections](https://docs.snowflake.com/developer-guide/snowflake-cli/connecting/configure-connections#define-connections)
for information on how to define default Snowflake connection(s) in a config.toml
file.

In [1]:
from snowflake.snowpark import Session, Row

# Requires valid ~/.snowflake/config.toml file
session = Session.builder.getOrCreate()
print(session)

<snowflake.snowpark.session.Session: account="DEMO_ACCOUNT", role="DEMO_RL", database="DEMO_DEV", schema="DEMO_DEV", warehouse="DEMO_WH">


#### Set up Snowflake resources

In [2]:
schema_name = "HEADLESS_DEMO"
session.sql(f"CREATE SCHEMA IF NOT EXISTS {schema_name}").collect()
session.use_schema(schema_name)

In [3]:
# Create compute pool
def create_compute_pool(name: str, instance_family: str, min_nodes: int = 1, max_nodes: int = 10) -> list[Row]:
    query = f"""
        CREATE COMPUTE POOL IF NOT EXISTS {name}
            MIN_NODES = {min_nodes}
            MAX_NODES = {max_nodes}
            INSTANCE_FAMILY = {instance_family}
    """
    return session.sql(query).collect()

compute_pool = "DEMO_POOL_CPU"
create_compute_pool(compute_pool, "CPU_X64_S")

[Row(status='DEMO_POOL_CPU already exists, statement succeeded.')]

In [4]:
# Generate synthetic data
def generate_data(table_name: str, num_rows: int, replace: bool = False) -> list[Row]:
    query = f"""
        CREATE{" OR REPLACE" if replace else ""} TABLE{"" if replace else " IF NOT EXISTS"} {table_name} AS
        SELECT 
            ROW_NUMBER() OVER (ORDER BY RANDOM()) as application_id,
            ROUND(NORMAL(40, 10, RANDOM())) as age,
            ROUND(NORMAL(65000, 20000, RANDOM())) as income,
            ROUND(NORMAL(680, 50, RANDOM())) as credit_score,
            ROUND(NORMAL(5, 2, RANDOM())) as employment_length,
            ROUND(NORMAL(25000, 8000, RANDOM())) as loan_amount,
            ROUND(NORMAL(35, 10, RANDOM()), 2) as debt_to_income,
            ROUND(NORMAL(5, 2, RANDOM())) as number_of_credit_lines,
            GREATEST(0, ROUND(NORMAL(1, 1, RANDOM()))) as previous_defaults,
            ARRAY_CONSTRUCT(
                'home_improvement', 'debt_consolidation', 'business', 'education',
                'major_purchase', 'medical', 'vehicle', 'other'
            )[UNIFORM(1, 8, RANDOM())] as loan_purpose,
            RANDOM() < 0.15 as is_default,
            TIMEADD("MINUTE", UNIFORM(-525600, 0, RANDOM()), CURRENT_TIMESTAMP()) as created_at
        FROM TABLE(GENERATOR(rowcount => {num_rows}))
        ORDER BY created_at;
    """
    return session.sql(query).collect()

table_name = "loan_applications"
generate_data(table_name, 1e5)

[Row(status='LOAN_APPLICATIONS already exists, statement succeeded.')]

### Prepare Model Script

In [5]:
import json
import os
import pickle
from time import perf_counter
from typing import Literal, Optional

import pandas as pd
import xgboost as xgb
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler


from snowflake.ml.data.data_connector import DataConnector
from snowflake.ml.registry import Registry as ModelRegistry
from snowflake.snowpark import Session


def create_data_connector(session, table_name: str) -> DataConnector:
    """Load data from Snowflake table"""
    # Example query - modify according to your schema
    query = f"""
    SELECT
        age,
        income,
        credit_score,
        employment_length,
        loan_amount,
        debt_to_income,
        number_of_credit_lines,
        previous_defaults,
        loan_purpose,
        is_default
    FROM {table_name}
    """
    sp_df = session.sql(query)
    return DataConnector.from_dataframe(sp_df)


def build_pipeline(**model_params) -> Pipeline:
    """Create pipeline with preprocessors and model"""
    # Define column types
    categorical_cols = ["LOAN_PURPOSE"]
    numerical_cols = [
        "AGE",
        "INCOME",
        "CREDIT_SCORE",
        "EMPLOYMENT_LENGTH",
        "LOAN_AMOUNT",
        "DEBT_TO_INCOME",
        "NUMBER_OF_CREDIT_LINES",
        "PREVIOUS_DEFAULTS",
    ]

    # Numerical preprocessing pipeline
    numeric_transformer = Pipeline(
        steps=[
            ("imputer", SimpleImputer(strategy="median")),
            ("scaler", StandardScaler()),
        ]
    )

    # Categorical preprocessing pipeline
    categorical_transformer = Pipeline(
        steps=[
            ("imputer", SimpleImputer(strategy="constant", fill_value="missing")),
            ("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=False)),
        ]
    )

    # Combine transformers
    preprocessor = ColumnTransformer(
        transformers=[
            ("num", numeric_transformer, numerical_cols),
            ("cat", categorical_transformer, categorical_cols),
        ]
    )

    # Define model parameters
    default_params = {
        "objective": "binary:logistic",
        "eval_metric": "auc",
        "max_depth": 6,
        "learning_rate": 0.1,
        "n_estimators": 100,
        "subsample": 0.8,
        "colsample_bytree": 0.8,
        "random_state": 42,
    }
    model = xgb.XGBClassifier(**(model_params or default_params))

    return Pipeline([("preprocessor", preprocessor), ("classifier", model)])


def evaluate_model(model: Pipeline, X_test: pd.DataFrame, y_test: pd.DataFrame):
    """Evaluate model performance"""
    # Make predictions
    y_pred = model.predict(X_test)
    y_pred_proba = model.predict_proba(X_test)[:, 1]

    # Calculate metrics
    metrics = {
        "accuracy": accuracy_score(y_test, y_pred),
        "roc_auc": roc_auc_score(y_test, y_pred_proba),
        "classification_report": classification_report(y_test, y_pred),
    }

    return metrics


def save_to_registry(
    session: Session,
    model: Pipeline,
    model_name: str,
    metrics: dict,
    sample_input_data: pd.DataFrame,
):
    """Save model and artifacts to Snowflake Model Registry"""
    # Initialize model registry
    registry = ModelRegistry(session)

    # Save to registry
    registry.log_model(
        model=model,
        model_name=model_name,
        metrics=metrics,
        sample_input_data=sample_input_data[:5],
        conda_dependencies=["xgboost"],
    )


def train(session: Session, source_data: str, save_mode: Literal["local", "registry"] = "local", output_dir: Optional[str] = None, **kwargs):
    # Load data
    dc = create_data_connector(session, table_name=source_data)
    print("Loading data...", end="", flush=True)
    start = perf_counter()
    df = dc.to_pandas()
    elapsed = perf_counter() - start
    print(f" done! Loaded {len(df)} rows, elapsed={elapsed:.3f}s")

    # Split data
    X = df.drop("IS_DEFAULT", axis=1)
    y = df["IS_DEFAULT"]
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    # Train model
    model = build_pipeline()
    print("Training model...", end="")
    start = perf_counter()
    model.fit(X_train, y_train)
    elapsed = perf_counter() - start
    print(f" done! Elapsed={elapsed:.3f}s")

    # Evaluate model
    print("Evaluating model...", end="")
    start = perf_counter()
    metrics = evaluate_model(
        model,
        X_test,
        y_test,
    )
    elapsed = perf_counter() - start
    print(f" done! Elapsed={elapsed:.3f}s")

    # Print evaluation results
    print("\nModel Performance Metrics:")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"ROC AUC: {metrics['roc_auc']:.4f}")
    # Uncomment below for full classification report
    # print("\nClassification Report:")
    # print(metrics["classification_report"])

    start = perf_counter()
    if save_mode == "local":
        # Save model locally
        print("Saving model to disk...", end="")
        output_dir = output_dir or '.'
        model_subdir = os.environ.get("SNOWFLAKE_SERVICE_NAME", "output")
        model_dir = os.path.join(output_dir, model_subdir) if not output_dir.endswith(model_subdir) else output_dir
        os.makedirs(model_dir, exist_ok=True)
        with open(os.path.join(model_dir, "model.pkl"), "wb") as f:
            pickle.dump(model, f)
        with open(os.path.join(model_dir, "metrics.json"), "w") as f:
            json.dump(metrics, f, indent=2)
    elif save_mode == "registry":
        # Save model to registry
        print("Logging model to Model Registry...", end="")
        save_to_registry(
            session,
            model=model,
            model_name="loan_default_predictor",
            metrics=metrics,
            sample_input_data=X_train,
        )
    elapsed = perf_counter() - start
    print(f" done! Elapsed={elapsed:.3f}s")

### Run training locally

In [6]:
train(session, table_name)

DataConnector.from_dataframe() is in private preview since 1.6.0. Do not use it in production. 
DataConnector.from_sql() is in private preview since 1.7.3. Do not use it in production. 


Loading data... done! Loaded 10000000 rows, elapsed=21.933s
Training model... done! Elapsed=16.225s
Evaluating model... done! Elapsed=3.842s

Model Performance Metrics:
Accuracy: 0.5003
ROC AUC: 0.5000
Saving model to disk... done! Elapsed=0.004s


### Train with remote SPCS instance


In [7]:
from snowflake.ml.jobs import remote

@remote(compute_pool, stage_name="payload_stage")
def train_remote(source_data: str, save_mode: str = "local", output_dir: str = None):
    # Retrieve session from SPCS service context
    session = Session.builder.getOrCreate()

    # Run training script
    train(session, source_data, save_mode, output_dir)

train_job = train_remote(table_name)

remote() is in private preview since 1.7.4. Do not use it in production. 


In [8]:
print(train_job.id)
print(train_job.status)

MLJOB_D686BEB4_91ED_4DE7_BFD6_3FB38DCAF972
PENDING


In [9]:
train_job.wait()
train_job.show_logs()

MLJob.wait() is in private preview since 1.7.4. Do not use it in production. 
MLJob.show_logs() is in private preview since 1.7.4. Do not use it in production. 
MLJob.get_logs() is in private preview since 1.7.4. Do not use it in production. 



'micromamba' is running as a subprocess and can't modify the parent shell.
Thus you must initialize your shell before using activate and deactivate.

To initialize the current bash shell, run:
    $ eval "$(micromamba shell hook --shell bash)"
and then activate or deactivate with:
    $ micromamba activate
To automatically initialize all future (bash) shells, run:
    $ micromamba shell init --shell bash --root-prefix=~/micromamba
If your shell was already initialized, reinitialize your shell with:
    $ micromamba shell reinit --shell bash
Otherwise, this may be an issue. In the meantime you can run commands. See:
    $ micromamba run --help

Supported shells are {bash, zsh, csh, xonsh, cmd.exe, powershell, fish}.
Creating log directories...
 * Starting periodic command scheduler cron
   ...done.
2025-04-24 23:53:18,741	INFO usage_lib.py:441 -- Usage stats collection is disabled.
2025-04-24 23:53:18,742	INFO scripts.py:767 -- [37mLocal node IP[39m: [1m10.244.64.74[22m
2025-04-24 

### Run concurrent training jobs on SPCS

Suppose we want to train multiple models on different datasets

In [10]:
datasets = []
print("Generating datasets")
for i in range(10):
    dataset = f"loan_applications_{i}"
    generate_data(dataset, 1e6)
    datasets.append(dataset)
print(f"Generated datasets: {datasets}")
    
print("Starting training jobs")
train_jobs = []
for ds in datasets:
    train_jobs.append(train_remote(ds))
print(f"Started {len(train_jobs)} training jobs")

Generating datasets
Generated datasets: ['loan_applications_0', 'loan_applications_1', 'loan_applications_2', 'loan_applications_3', 'loan_applications_4', 'loan_applications_5', 'loan_applications_6', 'loan_applications_7', 'loan_applications_8', 'loan_applications_9']
Starting training jobs
Started 10 training jobs


In [11]:
from snowflake.ml.jobs import list_jobs

list_jobs().show()



------------------------------------------------------------------------------------------------------------------------
|"id"                                        |"owner"   |"status"  |"created_on"                      |"compute_pool"  |
------------------------------------------------------------------------------------------------------------------------
|MLJOB_2B691E83_2D52_49D1_ACC9_DB0C951017FB  |ENGINEER  |RUNNING   |2025-04-24 16:55:48.851000-07:00  |DEMO_POOL_CPU   |
|MLJOB_B6B11A30_EAC9_4BBB_A41F_B47827C8013A  |ENGINEER  |RUNNING   |2025-04-24 16:55:43.606000-07:00  |DEMO_POOL_CPU   |
|MLJOB_75B596F9_5674_4636_B9F9_67E036E67DEA  |ENGINEER  |RUNNING   |2025-04-24 16:55:32.928000-07:00  |DEMO_POOL_CPU   |
|MLJOB_56239D4C_F6FA_4691_8ABC_FCA5BA8E58B6  |ENGINEER  |DONE      |2025-04-24 16:55:27.958000-07:00  |DEMO_POOL_CPU   |
|MLJOB_6A0DB4D4_29F8_4CFB_BD46_9FB0966CBC5F  |ENGINEER  |DONE      |2025-04-24 16:55:17.525000-07:00  |DEMO_POOL_CPU   |
|MLJOB_BF15DA6D_8192_4831_A92F_A

In [12]:
# session.sql(f"DROP SCHEMA {schema_name}").collect()