## Many Model Training (MMT) – Example Walkthrough

This notebook demonstrates how to use Snowflake's Many Model Training (MMT) API to train multiple models in parallel.

We’ll:
- Define a custom training function
- Run MMT on a synthetic dataset
- Optionally scale with multiple nodes to speed up MMT.
- Monitor training progress and debug failures
- Inspect model logs and metadata from previous runs (even after the notebook/session is closed)
- Show ways to run inference on the trained models.


🛠️ **Callout**:  
There are two ways to invoke the MMT API:
1. **Inside the Snowflake Notebook environment**.

2. **Outside the notebook via ML Jobs** — This option allows you to run the MMT API in a headless setup. Please refer to the [headless setup guide](../ml_jobs)
for details.



## Step 1: Basic Setup
Set up training function and generate synthetic training data.

In [None]:
from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Create a stage that will be used to store various training artifacts, including models, logs and etc.
stage_name = "MY_STAGE" 
session.sql(f"CREATE STAGE IF NOT EXISTS {stage_name}").collect()


In [None]:
from typing import Any
import xgboost as xgb
import pandas as pd
from sklearn.metrics import mean_squared_error, r2_score
from snowflake.ml.data import DataConnector
from snowflake.ml.modeling.distributors.distributed_partition_function.partition_context import (
    PartitionContext,
)

def user_training_func(data_connector: DataConnector, context: PartitionContext) -> Any:
    """
    User-defined function that takes in a DataConnector object and returns a trained model.
        
    Args:
        data_connector: A Snowflake DataConnector object containing partitioned training data. This is passed in by the 
        Snowflake data ingestion framework, which handles extracting data from the warehouse  and converting it into a 
        DataConnector object which contains only the partitioned data.
            
    Returns:
        Trained model object.
    """
    partition_id = context.partition_id
    assert partition_id is not None

    # Load partitioned data.
    pandas_df: pd.DataFrame = data_connector.to_pandas()
    
    # Define feature and label columns
    NUMERICAL_COLUMNS = ["X1", "X2", "X3"]
    LABEL_COLUMNS = "X4"
    
    # Train the model
    model = xgb.XGBRegressor()
    model.fit(pandas_df[NUMERICAL_COLUMNS], pandas_df[LABEL_COLUMNS])
    
    # Evaluate on training data
    preds = model.predict(pandas_df[NUMERICAL_COLUMNS])
    mse = mean_squared_error(pandas_df[LABEL_COLUMNS], preds)
    r2 = r2_score(pandas_df[LABEL_COLUMNS], preds)
    
    # Print metrics
    print(f"[Partition {partition_id}] Training MSE: {mse:.4f}, R²: {r2:.4f}")
    
    return model


In [None]:
from sklearn.datasets import make_regression
import pandas as pd
import datetime
import numpy as np
import uuid

def _init_snowpark_df(curr_session, partition_counts=2):
    """
    Initializes and returns a Snowpark DataFrame containing synthetic regression data.

    This function generates a dataset with 4 numerical features using `make_regression`,
    where the first three columns ("X1", "X2", "X3") are treated as input features and 
    the fourth column ("X4") as the target variable. Each row is also assigned:
      - A LOCATION_ID (Partition key based on modulo of total rows and `partition_counts`)
      - A randomly selected date between 2020-01-01 and 2023-01-01

    The resulting DataFrame is uploaded to Snowflake as a permanent table with a 
    unique name in the current database and schema. The table name is returned as a
    Snowpark DataFrame object.

    Args:
        curr_session: A valid Snowpark session object.
        partition_counts (int, optional): Number of unique partition values 
            for the LOCATION_ID column. Defaults to 2.

    Returns:
        Snowpark DataFrame: A reference to the saved table in Snowflake.
    
    """
    # Generate synthetic data
    cols = ["X1", "X2", "X3", "X4"]
    X, _ = make_regression(n_samples=1000, n_features=4, noise=0.1, random_state=0)
    df = pd.DataFrame(X, columns=cols)
    df["LOCATION_ID"] = np.arange(len(df)) % partition_counts

    # Add random dates between 2020-01-01 and 2023-01-01
    date_range = pd.date_range("2020-01-01", "2023-01-01", freq="D")
    df["DATE"] = np.random.choice(date_range, size=len(df))

    # Create Snowpark DataFrame and save to a uniquely named table
    snowpark_df = curr_session.create_dataframe(df)
    table_name = f"{curr_session.get_current_database()}.{curr_session.get_current_schema()}.mmt_test_{uuid.uuid4().hex.upper()}"
    snowpark_df.write.mode("overwrite").save_as_table(table_name)

    return curr_session.table(table_name)


snowpark_df = _init_snowpark_df(session)

## Step 2: Invoke MMT API & Monitor MMT Training Run

User can optionally choose to scale up the cluster to multi-nodes prior to run the many model training. 

In [None]:
# Optional step to scale to multiple nodes for speed up overall many model trainings.
# from snowflake.ml.runtime_cluster import cluster_manager
# TOTAL_NODES=5
# cluster_manager.scale_cluster(expected_cluster_size=TOTAL_NODES, notebook_name=NOTEBOOK_NAME, options={"block_until_min_cluster_size": 2})

In [None]:
from snowflake.ml.modeling.distributors.many_model import ManyModelTraining
from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (
    ExecutionOptions,
    RunStatus,
)

trainer = ManyModelTraining(
    user_training_func,    
    stage_name=stage_name,    
)

run_id="my_mmt_model_v1"
training_run = trainer.run(
    snowpark_dataframe=snowpark_df,
    partition_by="LOCATION_ID",
    run_id=run_id,
    on_existing_artifacts="overwrite", # or "error"
    # execution_options is optional. When running in a multi-node setting, it's recommended setting use_head_node=False to exclude head node from doing actual training, this improves overall MMT training reliability.
    # execution_options=ExecutionOptions(use_head_node=False)
)

In [None]:
assert training_run.wait() == RunStatus.SUCCESS

In [None]:
training_run.get_progress()["DONE"][0].logs # inspect result

# To inspect failures
# training_run.get_progress()["FAILED"][0].logs 

In [None]:
# To obtain models trained with each partition
for partition_id in training_run.partition_details.keys():
    model = training_run.get_model(partition_id)
    assert isinstance(model, xgb.XGBRegressor)

In [None]:
# Ideally you will not need to interact with the stage at all. This is more of a FYI how your stage is being used
# to persist the model and other artifacts.
session.sql(f"ls @{stage_name}").collect()

## Step 3: Running Inference on Trained Models


### Step 3.1: Register Models in the Snowflake Model Registry and Run Inference (Warehouse Execution) — GA Feature


In [None]:
models = {
    partition_id: training_run.get_model(partition_id)
    for partition_id in training_run.partition_details
}

In [None]:
from typing import Optional
from snowflake.ml.model import custom_model
from snowflake.ml.registry import registry
import pandas as pd


# Log model to model registry
class PartitionedModel(custom_model.CustomModel):
    def __init__(self, context: Optional[custom_model.ModelContext] = None) -> None:
        super().__init__(context)
        self.partition_id = None
        self.model = None

    @custom_model.partitioned_api
    def predict(self, input: pd.DataFrame) -> pd.DataFrame:
        NUMERICAL_COLUMNS = ["X1", "X2", "X3"]

        model_id = str(input["LOCATION_ID"][0])
        model = self.context.model_ref(model_id)

        model_output = model.predict(input[NUMERICAL_COLUMNS])
        res = pd.DataFrame(model_output)
        return res



In [None]:
from snowflake.ml.model import custom_model

# Models have been fit, and they can now be retrieved and registered to the model registry.
model_context = custom_model.ModelContext(
    models=models
)

my_stateful_model = PartitionedModel(context=model_context)
reg = registry.Registry(session=session)
options = {
    "function_type": "TABLE_FUNCTION",
    "relax_version": False
}
NUMERICAL_COLUMNS = ["X1", "X2", "X3"]
mv = reg.log_model(
    my_stateful_model,
    model_name="partitioned_model",
    options=options,
    conda_dependencies=["pandas", "xgboost"],
    sample_input_data=snowpark_df.limit(1).to_pandas()[NUMERICAL_COLUMNS + ["LOCATION_ID"]],    
)

In [None]:
service_prediction = mv.run(
    snowpark_df,
    partition_column="LOCATION_ID",
)

### Step 3.2: Alternative ManyModelInference Method (Container Execution) — Preview Feature


In [None]:
from snowflake.ml.modeling.distributors.many_model import ManyModelInference

def xgb_inference_func(data_connector: DataConnector, model, context: PartitionContext):
    """Simple inference function."""
    df = data_connector.to_pandas()
    X = df[["X1", "X2", "X3"]].values
    predictions = model.predict(X)

    # Write prediction results to persistent storage
    results = df.copy()
    results['predictions'] = predictions
    
    # Two persistence strategies (choose one or both based on your needs):

    # Strategy 1: Stage artifacts - for framework management and debugging
    # context.upload_to_stage(results, "predictions.csv",
    #     write_function=lambda df, path: df.to_csv(path, index=False))

    # Strategy 2: Snowflake table - for immediate downstream consumption
    # predictions_df = context.session.create_dataframe(results)
    # predictions_df.write.mode("append").save_as_table("sales_predictions")
    
    return predictions

mmi = ManyModelInference(
    inference_func=xgb_inference_func,
    stage_name=stage_name,
    training_run_id=run_id, # run_id from previous many model training run at step 2
)


inference_run = mmi.run(
    partition_by="LOCATION_ID",
    snowpark_dataframe=snowpark_df, # running inference on the same training data mainly for illustration purposes.
    run_id="basic_inference_run",
    on_existing_artifacts="overwrite",
)

In [None]:
assert inference_run.wait() == RunStatus.SUCCESS

In [None]:
inference_run.get_progress()
# inference_run.get_progress()["FAILED"][0].logs

## Step 4: Troubleshooting Failed Runs

Training functions can fail for various reasons. Below are some common causes:

- **User Code Errors**  
  Bugs or issues in the user-defined training function can cause failures.

- **Infrastructure Issues**  
  An *Out-of-Memory (OOM)* error occurs when the training function consumes more memory than the node can provide.

- **Unexpected Node Failures**  
  In some cases, a node might crash unexpectedly.

---

### Handling OOM and Node Failures

When an OOM error or fatal node failure occurs, the **MMT API will not automatically retry** the training function. Instead, it will mark the corresponding partition ID run as **`INTERNAL_ERROR`**. If a worker node crashes, logs might not be captured, making debugging more difficult.

For all other failure scenarios (including OOM errors), MMT provides:
- A **detailed error message**  
- A **stack trace** to help diagnose and fix the issue

---

### Retry Logic for Non-Fatal Errors

If the failure is not considered fatal (e.g., transient issues), MMT will automatically retry the training function with **exponential backoff**. This mechanism allows transient issues to resolve before the function ultimately fails.

**Retry Attempts:**
1. **First retry**: Wait for 2 seconds (`initial_delay`)
2. **Second retry**: Wait for 4 seconds (2 * `initial_delay`)
3. **Third retry**: Wait for 8 seconds (2^2 * `initial_delay`)
4. **Fourth retry**: Wait for 16 seconds (2^3 * `initial_delay`)
5. **Final retry**: No delay — if it fails again, an exception is raised


In [None]:
def user_func_error(data_connector: DataConnector, context: PartitionContext):
    pandas_df = data_connector.to_pandas()

    NUMERICAL_COLUMNS = ["X1", "X2", "X3"]
    LABEL_COLUMNS = ["X4"]
    model = xgb.XGBRegressor()

    # INTENTIONAL USER-CODE FAILURE: fitss function does not exist
    model.fitss(pandas_df[NUMERICAL_COLUMNS], pandas_df[LABEL_COLUMNS])    
    
    return model


model_name="my_mmt_model"
model_version = "v2"
run_id=f"{model_name}_{model_version}"

trainer = ManyModelTraining(
    user_func_error,    
    stage_name=stage_name,
)

failed_trainer_run = trainer.run(
    snowpark_dataframe=snowpark_df,
    partition_by="LOCATION_ID",    
    run_id=run_id,
    on_existing_artifacts="overwrite", # or "error"
)


In [None]:
# MMT will retry the user-function up to five times and then fail.

# Helper function for illustartion purposes of getting failed logs.
import time
while True:
    if "FAILED" in failed_trainer_run.get_progress():
        break
    time.sleep(1)

# User can optionally choose to cancel the entire MMT run when at least one failed run is detected.
failed_trainer_run.cancel()

# Show first failed partition logs
failed_trainer_run.get_progress()["FAILED"][0].logs

## Step 5: Inspecting Models and Logs After Notebook/Session Closure


After investing considerable time training multiple models, you might want to shut down your notebook temporarily to save resources. But how can you recover your trained models and review logs later?

Snowflake offers an API designed precisely for this use case.

Using `DPFRun`, you can easily restore previously trained models and access their logs and metadata—even after your notebook session has ended.

Below, we demonstrate how to retrieve and inspect artifacts from a previous run.





In [None]:
from snowflake.ml.modeling.distributors.distributed_partition_function.dpf_run import (
    DPFRun,
)

restored_run = DPFRun.restore_from(
    run_id=run_id,
    stage_name=stage_name,
)

# Check the status of the trained model
model_status = restored_run.status
print(model_status)

# You can also access other APIs to inspect logs and metadata
restored_run.get_progress()
