### Set up Snowpark Session

See [Configure Connections](https://docs.snowflake.com/developer-guide/snowflake-cli/connecting/configure-connections#define-connections)
for information on how to define default Snowflake connection(s) in a config.toml
file.

In [1]:
from snowflake.snowpark import Session, Row

# Requires valid ~/.snowflake/config.toml file
session = Session.builder.getOrCreate()
print(session)

<snowflake.snowpark.session.Session: account="NOTEBOOK_MLTEST", role="SYSADMIN", database="HEADLESS_STARTER_DB", schema="HEADLESS_STARTER_SCHEMA", warehouse="ST_WH">


#### Set up Snowflake resources

In [3]:
# OPTIONAL: Uncomment below to select a database and schema to use
# session.use_database("temp")
# session.use_schema("public")

In [4]:
# Create compute pool if not exists
def create_compute_pool(name: str, instance_family: str, min_nodes: int = 1, max_nodes: int = 10):
    query = f"""
        CREATE COMPUTE POOL IF NOT EXISTS {name}
            MIN_NODES = {min_nodes}
            MAX_NODES = {max_nodes}
            INSTANCE_FAMILY = {instance_family}
    """
    return session.sql(query).collect()

compute_pool = "DEMO_POOL_CPU"
create_compute_pool(compute_pool, "CPU_X64_S", 1, 5)

[Row(status='DEMO_POOL_CPU already exists, statement succeeded.')]

In [5]:
# Enable multi node ML jobs
# Note: If the parameter is invisible to you, contact the Snowflake account admin to enable the parameter for your account.
session.sql("alter session set ENABLE_BATCH_JOB_SERVICES = true").collect()

[Row(status='Statement executed successfully.')]

### Approach 1: Train with function

In [6]:
# Generate a arbitary dataset
def generate_dataset_sql(db, schema, table_name, num_rows, num_cols) -> str:
    sql_script = f"CREATE TABLE IF NOT EXISTS {db}.{schema}.{table_name} AS \n"
    sql_script += f"SELECT \n"
    for i in range(1, num_cols):
        sql_script += f"uniform(0::FLOAT, 10::FLOAT, random()) AS FEATURE_{i}, \n"
    sql_script += f"FEATURE_1 + FEATURE_1 AS TARGET_1, \n"
    sql_script += f"FROM TABLE(generator(rowcount=>({num_rows})));"
    return sql_script
num_rows = 1000 * 1000
num_cols = 100
table_name = "MULTINODE_CPU_TRAIN_DS"
session.sql(generate_dataset_sql(session.get_current_database(), session.get_current_schema(), 
                                table_name, num_rows, num_cols)).collect()
feature_list = [f'FEATURE_{num}' for num in range(1, num_cols)]

In [7]:
from snowflake.ml.jobs import remote

@remote(compute_pool, stage_name="payload_stage", num_instances=3)
def xgb(table_name, input_cols, label_col):
    from snowflake.snowpark import Session
    from snowflake.ml.modeling.distributors.xgboost import XGBEstimator, XGBScalingConfig
    from snowflake.ml.data.data_connector import DataConnector

    session = Session.builder.getOrCreate()
    cpu_train_df = session.table(table_name)
    
    params = {
        "tree_method": "hist",
        "objective": "reg:pseudohubererror",
        "eta": 1e-4,
        "subsample": 0.5,
        "max_depth": 50,
        "max_leaves": 1000,
        "max_bin":63,
    }
    scaling_config = XGBScalingConfig(use_gpu=False)
    estimator = XGBEstimator(
        n_estimators=100,
        params=params,
        scaling_config=scaling_config,
    )
    data_connector = DataConnector.from_dataframe(cpu_train_df)
    xgb_model = estimator.fit(
        data_connector, input_cols=input_cols, label_col=label_col
    )
    return xgb_model

# Function invocation returns a job handle (snowflake.ml.jobs.MLJob)
job = xgb(table_name, feature_list, "TARGET_1")

'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead.


In [8]:
print(job.id)
print(job.status)

HEADLESS_STARTER_DB.HEADLESS_DEMO.MLJOB_3152B8EE_A391_4340_9152_D54A58365C1B
PENDING


In [9]:
job.wait()
job.show_logs()

2025-06-24 21:08:17,038	INFO job_manager.py:528 -- Runtime env is setting up.
  import pkg_resources

2025-06-24 21:08:20,186	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.

83f38934a778457abfaff63217cfc17e: Received raw arguments: XGBTrainArgs(model_type=<BoostingModelTypes.XGBOOST: 'xgboost'>, dataset=<snowflake.ml.data.data_connector.DataConnector object at 0x7fd3b8f30eb0>, input_cols=['FEATURE_1', 'FEATURE_2', 'FEATURE_3', 'FEATURE_4', 'FEATURE_5', 'FEATURE_6', 'FEATURE_7', 'FEATURE_8', 'FEATURE_9', 'FEATURE_10', 'FEATURE_11', 'FEATURE_12', 'FEATURE_13', 'FEATURE_14', 'FEATURE_15', 'FEATURE_16', 'FEATURE_17', 'FEATURE_18', 'FEATURE_19', 'FEATURE_20', 'FEATURE_21', 'FEATURE_22', 'FEATURE_23', 'FEATURE_24', 'FEATURE_25', 'FEATURE_26', 'FEATURE_27', 'FEATURE_28', 'FEATURE_29', 'FEATURE_30', 'FEATURE_31', 'FEATURE_32', 'FEATURE_33', 'FEATURE_34', 'FEATURE_35', 'FEATURE_36', 'FEATURE_37', 

In [10]:
import xgboost

# Retrieve trained model from job execution and use it for prediction
xgb_model = job.result()

# Predict on a sample of the dataset
# Note: This is just a demonstration, in practice you would want to predict on a different dataset
dataset = session.table(table_name).drop("TARGET_1").limit(10).to_pandas()
xgb_model.predict(xgboost.DMatrix(dataset))

configuration generated by an older version of XGBoost, please export the model by calling
`Booster.save_model` from that version first, then load it back in current version. See:

    https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html

for more details about differences between saving model and serializing.



array([11.8159   , 10.611345 ,  9.717881 , 18.790493 ,  7.9805217,
       16.480486 , 15.571457 , 14.789684 , 12.37405  , 12.086709 ],
      dtype=float32)