### Set up Snowpark Session

See [Configure Connections](https://docs.snowflake.com/developer-guide/snowflake-cli/connecting/configure-connections#define-connections)
for information on how to define default Snowflake connection(s) in a config.toml
file.

In [None]:
# from snowflake.snowpark import Session, Row

# # Requires valid ~/.snowflake/config.toml file
# session = Session.builder.getOrCreate()
# print(session)

import pandas as pd

from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session

session = Session.builder.configs(SnowflakeLoginOptions("preprod8")).create()

#### Set up Snowflake resources

In [None]:
# Use any database and schema you want
session.sql("use database temp").collect()
session.sql("use schema public").collect()

In [None]:
# Create compute pool if not exists
def create_compute_pool(name: str, instance_family: str, min_nodes: int = 1, max_nodes: int = 10):
    query = f"""
        CREATE COMPUTE POOL IF NOT EXISTS {name}
            MIN_NODES = {min_nodes}
            MAX_NODES = {max_nodes}
            INSTANCE_FAMILY = {instance_family}
    """
    return session.sql(query).collect()

compute_pool = "E2E_CPU_POOL"
create_compute_pool(compute_pool, "CPU_X64_S", 1, 5)

### Approach 1: Train with function

In [None]:
# Generate a arbitary dataset
def generate_dataset_sql(db, schema, table_name, num_rows, num_cols) -> str:
    sql_script = f"CREATE TABLE IF NOT EXISTS {db}.{schema}.{table_name} AS \n"
    sql_script += f"SELECT \n"
    for i in range(1, num_cols):
        sql_script += f"uniform(0::FLOAT, 10::FLOAT, random()) AS FEATURE_{i}, \n"
    sql_script += f"FEATURE_1 + FEATURE_1 AS TARGET_1, \n"
    sql_script += f"FROM TABLE(generator(rowcount=>({num_rows})));"
    return sql_script
num_rows = 1000 * 1000
num_cols = 100
table_name = "MULTINODE_CPU_TRAIN_DS"
session.sql(generate_dataset_sql(session.get_current_database(), session.get_current_schema(), 
                                table_name, num_rows, num_cols)).collect()
feature_list = [f'FEATURE_{num}' for num in range(1, num_cols)]

In [None]:
from snowflake.ml.jobs import remote

@remote(compute_pool, stage_name="payload_stage", num_instances=3)
def xgb(table_name, input_cols, label_col):
    from snowflake.snowpark.context import get_active_session
    from snowflake.ml.modeling.distributors.xgboost import XGBEstimator, XGBScalingConfig
    from snowflake.ml.data.data_connector import DataConnector
    from implementations.ray_data_ingester import RayDataIngester

    def xgb_train(num_workers, num_cpu_per_worker):
        session = get_active_session()
        cpu_train_df = session.table(table_name)
        
        params = {
            "tree_method": "hist",
            "objective": "reg:pseudohubererror",
            "eta": 1e-4,
            "subsample": 0.5,
            "max_depth": 50,
            "max_leaves": 1000,
            "max_bin":63,
        }
        scaling_config = XGBScalingConfig(
            num_workers=num_workers, num_cpu_per_worker=num_cpu_per_worker, 
            use_gpu=False
        )
        estimator = XGBEstimator(
            n_estimators=100,
            params=params,
            scaling_config=scaling_config,
        )
        data_connector = DataConnector.from_dataframe(
            cpu_train_df, ingestor_class=RayDataIngester
        )
        # return data_connector
        xgb_model = estimator.fit(
            data_connector, input_cols=input_cols, label_col=label_col
        )
        return xgb_model
    assert xgb_train(-1, -1) is not None
    assert xgb_train(3, 4) is not None
    assert xgb_train(6, 2) is not None
    return 1

# Function invocation returns a job handle (snowflake.ml.jobs.MLJob)
job = xgb(table_name, feature_list, "TARGET_1")

In [None]:
print(job.id)
print(job.status)

In [None]:
job.wait()
job.show_logs()

### Approach 2: Train with file

In [None]:
from snowflake.ml.jobs import submit_file

job = submit_file(
    "../src/main.py",
    "E2E_CPU_POOL",
    stage_name="multi_node_payload_stage",
    num_instances=3  # Specify multiple instances
)

In [None]:
print(job.id)
print(job.status)

In [None]:
job.wait()
job.show_logs()