# Concurrent Model Training

This notebook allows multiple models to be trained at the same time, each model is trained on a separate Snowpark Container Services instance. This can reduce the overall elapsed time when training lots of models byt distributing the training to available compute resources.

## Background

The ray framework is used to create training tasks, each task is a ML training workflow. A ray cluster is created and then the tasks are submitted to the cluster, compute instances then are assigned a task to execute, waiting tasks queue for the next available compute instance.

## Usage Notes

Although this might sound complicated to set up and use, it is really simple within the Snowflake environment.

1. The parameters to change are grouped into a single cell, these define values for your environment
2. Training results for each model are saved to a Snowflake Stage table
3. This training example uses AutoGluon as the training package, but other packages can be used
4. The resulting models can be saved to the model registry and deployed


UNSUPPORTED BY SNOWFLAKE - CUSTOMER SUPPORTED ONLY

Copyright (c) 2025 Snowflake Inc. All rights reserved.

In [None]:
# These are the settings that should be reviewed for your environment

# name of snowflake table to use for training.
table_name = 'DEMO_BOSTON_HOUSING_GENERATED_DATA_100000'

# target column name that we will train on 
label = 'MEDV'

# unique / key column name to drop or [] to indicate no drop columns
drop_cols = ['ID']

# list of models in AutoGluon to train
models_to_train = ['NN_TORCH','GBM','CAT','XGB','FASTAI','RF','XT','KNN']

# if training the same model, but with different parameters add the suffix _<number> to the model name
# make sure the hpo dict specifies the parameters. 
#models_to_train = ['CAT','XGB','CAT_1','XGB_1','FASTAI','RF']

 
# model hyperparameter tuning options
hpo={'NN_TORCH': {},
     'GBM': {},
     'CAT': {'iterations': 10000, 'learning_rate': 0.05, 'random_seed': 0, 'allow_writing_files': False, 'eval_metric': 'Accuracy', 'thread_count': 6},
     'XGB': {},
     'CAT_1': {'iterations': 20000, 'learning_rate': 0.07, 'random_seed': 42, },
     'XGB_1': {'num_estimators': 100,'learning_rate':  0.1,'max_depth': 5},
     'FASTAI': {},
     'RF': {},
     'XT': {},
     'KNN': {}
    }

# autogluon training preset:
# Available Presets: ['best_quality', 'high_quality', 'good_quality', 'medium_quality', 
#    'experimental_quality', 'optimize_for_deployment', 'interpretable', 'ignore_text']
preset = 'medium_quality'

# max number of seconds to train model
time_limit = 3600 * 24

# stage where the training output should be saved
result_stage = 'NOTEBOOK_FILES/AutoGluon'

# number of SPCS notebook containers in the cluster (not including this instance )
# either set the number of workers you would like to run concurrently or have one per model
#number_of_workers = 3
number_of_workers = len(models_to_train) - 1

# the resources per SPCS notebook container. Set to auto to configured based on the instance family
# or set to the number of cpus /gpus:
#number_of_cpus = 6
number_of_cpus = 'auto'
number_of_gpus = 0

# if true then notebook will check on training status and wait until training completes
wait_until_completed = True

In [None]:
!pip install autogluon==1.3.1 bokeh==2.0.1 numpy==2.1.3 --quiet

In [None]:
# Import python packages (standard to all container notebooks)
import streamlit as st
import pandas as pd

# Used to scale the cluster
from snowflake.ml.runtime_cluster import scale_cluster

from autogluon.tabular import TabularDataset, TabularPredictor
from autogluon.features.generators import AutoMLPipelineFeatureGenerator

# used to creat train and test datasets
from sklearn.model_selection import train_test_split
#from sklearn.pipeline import Pipeline
#from sklearn.compose import ColumnTransformer
    
# ray cluster package
import ray

# use to setup the environment
import os
import psutil
import shutil
from datetime import datetime

In [None]:
# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


In [None]:
notebook_name = os.environ.get('OBJECT_NAME', 'NOTEBOOK')

In [None]:
os.environ["AG_DISTRIBUTED_MODE"] = "True"
os.environ["AG_FORCE_PARALLEL"] = "True"

In [None]:
try:
    result = session.sql("ls @"+result_stage).collect()
    print(f"Training results will be saved to the Snowflake stage @{result_stage}/<model-name>")
except Exception as e:
    print(f"The Snowflake stage @{result_stage} used to save model training results is not accessable.")
    print(f"{e}")
    


In [None]:
try:
    result = session.table(table_name).limit(1)
    print(f"The table {table_name} is available.")
except Exception as e:
    print(f"The table {table_name} is not accessable.")
    print(f"{e}")

try:
    result.select(label).collect()
    print(f"The column label (target) {label} is available.")
except Exception as e:
    print(f"The column label (target) {label} is not accessable.")
    print(f"{e}")


In [None]:
# create an estimated memory size for training. this is a rough guide so that the job does not fail due to memory limits.

rows = session.table(table_name).count()
columns = len(session.table(table_name).columns)
print(f"Rows: {rows}, Columns: {columns}")
dataset_size = rows * columns * 8 # Assuming each value is a float64 (8 bytes)

if rows < 100000 and columns < 100:
    memory_factor = 5
else:
    if rows <= 1000000 and columns <= 1000:
        memory_factor = 10
    else:
        memory_factor = 20

dataset_size = dataset_size * memory_factor / (1024 ** 3)
print(f"The estimated memory size per worker for the data is: {dataset_size:.2f} GB and a total memory requirement: {dataset_size  *(1+number_of_workers):.2f} GB")

memory = psutil.virtual_memory().available / (1024 **3)
print(f"Each worker has {memory:.0f} GB available" )

if dataset_size >= (memory*0.75):
    print("Warning the service available memory might be too small for training, consider changing to a larger compute pool instance family.")

In [None]:
result = session.sql("describe service "+os.environ["SNOWFLAKE_SERVICE_NAME"]).collect()
compute_poolname = pd.DataFrame(result).loc[0, "compute_pool"]

result = session.sql("describe compute pool "+compute_poolname).collect()
compute_maxnodes = pd.DataFrame(result).loc[0, "max_nodes"]

if (number_of_workers > compute_maxnodes ):
    print(f"The number of worker instances is larger than the SPCS compute pool {compute_poolname}")
    print(f"Increasing the max_nodes to {number_of_workers} from {compute_maxnodes}")
    session.sql("alter compute pool "+compute_poolname+" set max_nodes = "+str(number_of_workers)+";").collect()
else:
    print(f"The compute pool {compute_poolname} has enough instances to execute the workers.")

In [None]:
if type(number_of_gpus) == str:
    number_of_gpus = -1

if type(number_of_cpus) == str:
    number_of_cpus = -1

# 2 cores are reserved for internal use
notebook_cpus = os.cpu_count()-2

if number_of_cpus == -1:
    number_of_cpus = notebook_cpus
    print(f"CPU resource set to auto, assigning all available CPU's ({number_of_cpus})")  
else:
    if (number_of_cpus > notebook_cpus):
        print(f"The compute pool {compute_poolname} only has {notebook_cpus} CPU's available but the varaiable number_of_cpus is set to {number_of_cpus}.")
        print("Training will not be able to run. Either restart the notebook on a larger SPCS instance family or reduce the number_of_cpus setting.")
    else:
        print(f"The compute pool {compute_poolname} has enough CPU ({number_of_cpus})")

# check for gpus
ncmd = !nvidia-smi --list-gpus
notebook_gpus = len(ncmd)

if number_of_gpus == -1:
    if "not found" in ncmd[0]:
        print("No GPU's are available")
        number_of_gpus = 0
    else:
        number_of_gpus = notebook_gpus
        print(f"GPU resource set to auto, assigning all available GPU's ({number_of_gpus})")  
else:
    if (number_of_gpus > notebook_gpus):
        print(f"The compute pool {compute_poolname} only has {notebook_gpus} GPU's available but the varaiable number_of_gpus is set to {number_of_gpus}.")
        print("Training will not be able to run. Either restart the notebook on a larger SPCS instance family or reduce the number_of_gpus setting.")
    else:
        print(f"The compute pool {compute_poolname} has enough GPU ({number_of_gpus})")


In [None]:
if (scale_cluster(number_of_workers)) == True:
    print(f"Ray cluster is ready with 1 head node and {number_of_workers} worker nodes in compute pool {compute_poolname}")
else:
    print(f"Error: Unable to scale the compute pool {compute_poolname} see logs for additional details.")

In [None]:
runtime_env = {"pip": ["autogluon==1.3.1","numpy==1.26.4"], 
               "log_to_driver":False,
               "env_vars": {"AG_DISTRIBUTED_MODE" :"True",
                            "AG_FORCE_PARALLEL":"True"
                           }
              }

In [None]:
ray.shutdown()

In [None]:
ray_cluster = ray.init(runtime_env=runtime_env )
!ray list nodes

In [None]:
# this defines the remote function that will execute on the ray cluster

@ray.remote(scheduling_strategy="SPREAD")
def train_model(model, preset, table_name, label, hpo, time_limit, result_stage):
    
    # imports that each worker needs in the cluster
    import shutil
    import os
    from datetime import datetime
    import time
    import pandas
    import random

    from snowflake.snowpark import Session

    # record some info for this execution
    tid = ray.get_runtime_context().get_task_id()
    ip = ray.util.get_node_ip_address()
    ts_s = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    logging.info(f'Queued {ip} Model {model} TS {ts_s} tid {tid} ', flush=True)
    print(f'Queued {ip} Model {model} TS {ts_s} tid {tid} ', flush=True)
    print(f'HPO {type(hpo)} {hpo}')
    print(f'preset {preset}')
    
    # read the SPCS token for this session
    with open('/snowflake/session/token', 'r') as f:
                token = f.read()

    # set up connection to to Snowflake
    connection_parameters = {
        "host": os.getenv('SNOWFLAKE_HOST'),
        "account": os.getenv('SNOWFLAKE_ACCOUNT'),
        "token": token,
        "authenticator": 'oauth',
        "warehouse": os.getenv('SNOWFLAKE_WAREHOUSE'),
        "database": os.getenv('SNOWFLAKE_DATABASE'),
        "schema": os.getenv('SNOWFLAKE_SCHEMA')
    }

    # create session from the ray worker to Snowflake
    session = Session.builder.configs(connection_parameters).getOrCreate()
    print("Connection to Snowflake sucessful")

    # retrieve the data from Snowflake
    data = session.table(table_name).limit(100000).to_pandas()
    print("Retrieved data sucessfully ")

    # identify the target the we will be predicting and remove it from the data used from training (inputs)
    target = data[label]
    inputs = data
    inputs.drop(columns=drop_cols, axis=1, inplace=True)

    # create create a train and test dataset
    x_train, x_test, y_train, y_test = train_test_split(inputs, target, test_size=0.2, random_state=42)
    print("Split of data sucessful")
    
    # train one model
    ts_s = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    print(f'Starting {ip} Model {model} TS {ts_s} tid {tid} ', flush=True)
    model_path = '/tmp/autogluon/'+model+'/'
    predictor = TabularPredictor(label=label, path=model_path).fit(x_train, hyperparameters=hpo, presets=preset, time_limit=3600)
    print("Fit finished")

    m = TabularPredictor.load("/tmp/autogluon/"+model)
    print(m.predict(data))
    # , ag_args_fit={"ag.max_memory_usage_ratio": 1.5}
    # verbosity=3

    # show the model results on the training data
    #predictor.evaluate(x_train)
    
    #y_pred = predictor.predict(train_data.drop(columns=[label]))
    #predictor.evaluate(train_data, silent=True)

    predictor.save(model_path)
    #predictor.fit_summary()

    # create arhive of the training output
    print("Creating archive")
    local_file = '/tmp/'+model
    shutil.make_archive(local_file, 'zip', model_path )
    
    # get the current token as the training step could take sometime to complete
    with open('/snowflake/session/token', 'r') as f:
            token = f.read()

    # create session from the ray worker to Snowflake
    session = Session.builder.configs(connection_parameters).getOrCreate()
    print("upload session created")
    
    # stage location
    stage_location = "@"+os.getenv('SNOWFLAKE_DATABASE')+"."+os.getenv('SNOWFLAKE_SCHEMA')+"."+result_stage+"/"+model+"/"
    
    try:
        print("uploading artifacts")
        session.file.put(local_file+'.zip', stage_location, auto_compress=False, overwrite=True)
        session.file.put(model_path+'/models/*/model.pkl', stage_location, auto_compress=False, overwrite=True)
        logging.info(f"File '{local_file}' successfully uploaded to stage '{stage_location}'.")
        print("upload sucessful")
        shutil.rmtree(model_path)
        os.remove(local_file+'.zip')
    except Exception as e:
        logging.error(f"Error uploading file: {e}")
    finally:
        if session:
           session.close()
    
  
    ts_e = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    logging.info(f'Completed {ip} Model {model} TS {ts_e} tid {tid} ', flush=True)
    print(f'Completed {ip} Model {model} TS {ts_e} tid {tid} ', flush=True)
    
    return 
    

In [None]:
def ray_output(opt, output, fields):

    js = json.loads(output.s)
    result = ''
    task_add = False
    for i, item in enumerate(json.loads(output.s)):
        for column_name, column_value in item.items():
            if column_name in fields.split():
                if opt == 'list':
                    if fields.index(column_name) == 0:
                        print()
                
                    print(f"{column_name}: {column_value} ", end="")
                if opt == 'value':
                    if len(result) < 2:
                        result = '{"'+column_name+'":"'+str(column_value)+'"'
                    else:
                         result += ',"'+column_name+'":"'+str(column_value)+'"'
                if opt == 'parse':
                    if column_name == "task_id":
                        task_add = True
                        if len(result) < 2:
                            result = '{"'+str(column_value)+'"'
                        else:
                             result += ',"'+str(column_value)+'"' 
                    else:
                        if task_add:
                            result += ':{"'+column_name+'":"'+str(column_value)+'"'
                            task_add = False
                        else:
                             result += ',"'+column_name+'":"'+str(column_value)+'"' 
        if opt == 'parse':
            result += '}'
            
    result += '}'
    if opt == 'list':
        return
        
    if opt == 'parse':
        task = json.loads(result)
        print("{:<50} {:<9} {:<12} {:<14} {:<10} {:<8} {:<10} {:<8}".format("Task id","Status", "Model", "Service", "Start", "Time", "End", "Time"))
        for k, value in task.items():
            
            print(f"{k:<50} {value['state']:<9} " ,end="")
            
            os.system('ray logs task --id '+k+' --tail -1 > /tmp/task.txt ')
            try:
                with open('/tmp/task.txt', 'r') as f:
                    content = f.read()
                    lines = content.splitlines()
                    for line in lines:
                       # print(line)
                        if len(line) >= 8:
                            words = line.split()
                            #print(words[0])
                            if words[0] == "Starting":
                                print(f"{words[3]:<12} {words[1]:<10} {words[5]:<10} {words[6]:<8} ", end="")
                            if words[0] == "Completed":
                                print(f"{words[5]:<10} {words[6]:<10}  ", end="")
                    print()
            except Exception as e:
                print(f"Cannot find log for task {k}  {e}")

    else:
        return result

In [None]:
print("Training starting "+datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
for model in models_to_train:
    suffix = model.split("_",1)
    if len(suffix) ==2 and suffix[1].isdigit():
        hpo_model = {suffix[0]:hpo[model]}
    else:
        hpo_model = {model:hpo[model]}
    print(f'Starting for model {model} with hyperparameters {str(hpo_model)}')
    tid = train_model.options(num_cpus=number_of_cpus, num_gpus=number_of_gpus, name=model, scheduling_strategy="SPREAD").remote(model, preset, table_name, label, hpo_model, time_limit, result_stage) 
    print(tid)
    print()

In [None]:
!ray summary tasks

In [None]:
!ray status 

In [None]:
output = !ray list tasks --format json 
ray_output('list', output, 'task_id node_id name state error_type')

In [None]:
while wait_until_completed == True:
    output = !ray list tasks --format json
    waiting = ray_output('value', output, 'name state')
    
    if "RUNNING" not in waiting and "PENDING" not in waiting:
        print("Training completed "+datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
        break
    else:
        print("Pending: "+str(waiting.count("PENDING"))+" RUNNING: "+str(waiting.count("RUNNING"))+" FAILED: "+str(waiting.count("FAILED"))+" FINISHED: "+str(waiting.count("FINISHED")))
        time.sleep(60)

In [None]:
# output files for each model
session.sql('ls @'+result_stage)

In [None]:
output = !ray list tasks --format json 
ray_output('parse', output, 'task_id name state error_type')

In [None]:
# worker output
task_progress = !ray logs task --id 75b8161497c41428ffffffffffffffffffffffff07000000 --tail -1  

# training output
task_output =   !ray logs task --id 975b61ee4287345bffffffffffffffffffffffff07000000  --err   --tail -1  

try:
    if "Traceback" in task_progress[0]:
        print("Task id is invalid, check taskid in the Execution_summary cell")
    else:
        print("[Task Progress]")
        for line in task_progress:
            print(line)
    
    if "Traceback" in task_output[0]:
        print("Task id is invalid, check taskid in the Execution_summary cell")
    else:
        print("[Task Output]")
        for line in task_output:
            print(line)
except:
    print("No logs was available, check the task status.")

# Whats Next?
The model pickle file is saved into a Snowflake Stage, that was defined in the **Change_as_needed** cell in this Notebook.

The model can be registered in the Model Registry and then deployed. A Notebook that performs those steps is [available](https://docs.google.com/presentation/u/0/d/1JTFTH2a1RgQnubebpz3_oaHYxS3Irunjvktcep4lowc/edit) refer to the **Model Registry** section of that Notebook.

Using the trained models, they could also be ensemabled into one model if needed, this would require loading each model into a fit() and saving the resulting model pickle.

# Get Predictions
The most scalable way to get predictions and persist them into a Snowflake table can be achieved by registering the model, see the Notebook referened in cell *Whats Next*, but you can also execute the Model in a cell, which is handy for testing the model before deployment.

In [None]:
# we are just doing to use one of the models, but this could be a loop to use them all
model =  models_to_train[models_to_train.index('RF')]

try:
    session.file.get(f"@{result_stage}/{model}/{model}.zip", f"/tmp/{model}/")
    print(f"Model {model} downloaded to Notebook")
except Exception as e:
        logging.error(f"Error downloading file: {e}")

In [None]:
# unzip the model files these are the files from the dirstributed trainer
try:
    shutil.unpack_archive(f"/tmp/{model}/{model}.zip", f"/tmp/{model}/", "zip")
    print(f"Model {model} file unpacked")
except Exception as e:
        logging.error(f"Error unpacking model file: {e}")   

In [None]:
data = session.table(table_name).limit(100000).to_pandas()

In [None]:
# load the model and make the prediction
m = TabularPredictor.load(f"/tmp/{model}")
m.predict(data)

In [None]:
m.fit_summary()

In [None]:
# show the details about this model
m.info()

# End of Notebook