## Many Model Training (MMT) – Example Walkthrough

This notebook demonstrates how to use Snowflake's Many Model Training (MMT) API to train multiple models in parallel.

We’ll:
- Define a custom training function
- Run MMT on a synthetic dataset
- Optionally scale with multiple nodes to speed up MMT.
- Monitor training progress and debug failures
- Inspect model logs and metadata from previous runs (even after the notebook/session is closed)


🛠️ **Callout**:  
There are two ways to invoke the MMT API:
1. **Inside the Snowflake Notebook environment**.

2. **Outside the notebook via ML Jobs** — This option allows you to run the MMT API in a headless setup. Please refer to the [headless setup guide](../ml_jobs)
for details. Note that the latest headless version does not include the MMT feature, so adjustments may be required. For assistance, reach out to Snowflake support.



## Step 1: Basic Setup
Set up training function and generate synthetic training data.

In [None]:
from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Create a stage that will be used to store various training artifacts, including models, logs and etc.
stage_name = "MY_STAGE" 
session.sql(f"CREATE STAGE IF NOT EXISTS {stage_name}").collect()


In [None]:
from typing import Any
import xgboost as xgb
import pandas as pd
from sklearn.metrics import mean_squared_error, r2_score

def user_training_func(data_connector, **kwargs) -> Any:
    """
    User-defined function that takes in a DataConnector object and returns a trained model.
        
    Args:
        data_connector: A Snowflake DataConnector object containing partitioned training data. This is passed in by the 
        Snowflake data ingestion framework, which handles extracting data from the warehouse  and converting it into a 
        DataConnector object which contains only the partitioned data.
            
    Returns:
        Trained model object.
    """
    partition_id = kwargs.get("partition_id", None)
    assert partition_id is not None

    # Load partitioned data.
    pandas_df: pd.DataFrame = data_connector.to_pandas()
    
    # Define feature and label columns
    NUMERICAL_COLUMNS = ["X1", "X2", "X3"]
    LABEL_COLUMNS = "X4"
    
    # Train the model
    model = xgb.XGBRegressor()
    model.fit(pandas_df[NUMERICAL_COLUMNS], pandas_df[LABEL_COLUMNS])
    
    # Evaluate on training data
    preds = model.predict(pandas_df[NUMERICAL_COLUMNS])
    mse = mean_squared_error(pandas_df[LABEL_COLUMNS], preds)
    r2 = r2_score(pandas_df[LABEL_COLUMNS], preds)
    
    # Print metrics
    print(f"[Partition {partition_id}] Training MSE: {mse:.4f}, R²: {r2:.4f}")
    
    return model


In [None]:
from sklearn.datasets import make_regression
import pandas as pd
import datetime
import numpy as np
import uuid

def _init_snowpark_df(curr_session, partition_counts=2):
    """
    Initializes and returns a Snowpark DataFrame containing synthetic regression data.

    This function generates a dataset with 4 numerical features using `make_regression`,
    where the first three columns ("X1", "X2", "X3") are treated as input features and 
    the fourth column ("X4") as the target variable. Each row is also assigned:
      - A LOCATION_ID (Partition key based on modulo of total rows and `partition_counts`)
      - A randomly selected date between 2020-01-01 and 2023-01-01

    The resulting DataFrame is uploaded to Snowflake as a permanent table with a 
    unique name in the current database and schema. The table name is returned as a
    Snowpark DataFrame object.

    Args:
        curr_session: A valid Snowpark session object.
        partition_counts (int, optional): Number of unique partition values 
            for the LOCATION_ID column. Defaults to 2.

    Returns:
        Snowpark DataFrame: A reference to the saved table in Snowflake.
    
    """
    # Generate synthetic data
    cols = ["X1", "X2", "X3", "X4"]
    X, _ = make_regression(n_samples=1000, n_features=4, noise=0.1, random_state=0)
    df = pd.DataFrame(X, columns=cols)
    df["LOCATION_ID"] = np.arange(len(df)) % partition_counts

    # Add random dates between 2020-01-01 and 2023-01-01
    date_range = pd.date_range("2020-01-01", "2023-01-01", freq="D")
    df["DATE"] = np.random.choice(date_range, size=len(df))

    # Create Snowpark DataFrame and save to a uniquely named table
    snowpark_df = curr_session.create_dataframe(df)
    table_name = f"{curr_session.get_current_database()}.{curr_session.get_current_schema()}.mmt_test_{uuid.uuid4().hex.upper()}"
    snowpark_df.write.mode("overwrite").save_as_table(table_name)

    return curr_session.table(table_name)


## Step 2: Invoke MMT API & Monitor MMT Training Run

User can optionally choose to scale up the cluster to multi-nodes prior to run the many model training. 

In [None]:
# Optional step to scale to multiple nodes for speed up overall many model trainings.
# from snowflake.ml.runtime_cluster import cluster_manager
# TOTAL_NODES=5
# cluster_manager.scale_cluster(expected_cluster_size=TOTAL_NODES, notebook_name=NOTEBOOK_NAME, options={"block_until_min_cluster_size": 2})

In [None]:
from snowflake.ml.modeling.distributors.many_model_training.many_model_trainer import (
    ManyModelTrainer,
)
from snowflake.ml.modeling.distributors.many_model_training.entities import (
    ExecutionOptions,
)

snowpark_df = _init_snowpark_df(session)
model_name="my_mmt_model"
model_version_v1 = "v1"
trainer = ManyModelTrainer(
    training_func=user_training_func,
    model_name=model_name,
    model_version=model_version_v1,
    stage_name=stage_name,
    # execution_options is optional. When running in a multi-node setting, it's recommended setting use_head_node=False to exclude head node from doing actual training, this improves overall MMT training reliability.
    # execution_options=ExecutionOptions(use_head_node=False)
)

trainer.run(snowpark_dataframe=snowpark_df, partition_by="LOCATION_ID")


In [None]:
# Depending on the workload size, MMT can take an arbitrarily long time to complete. This call is 
# interruptible—you can cancel the cell, run other commands, and return later. Interrupting this call 
# does not affect the actual MMT run. The show_progress function will automatically reflect the current 
# status of the MMT run.
trainer.show_progress()

In [None]:
# Mapping between training status and corresponding SingleModelTrainingDetails objects.
# trainer.get_progress() returns a dictionary where the keys are training statuses
# and the values are lists of SingleModelTrainingDetails objects associated with that status.
#
# Example output:
# {
#     "PENDING": [SingleModelTrainingDetails],
#     "RUNNING": [SingleModelTrainingDetails],
#     "FAILED": [SingleModelTrainingDetails],
#     "DONE": [SingleModelTrainingDetails, SingleModelTrainingDetails],
#     "INTERNAL_ERROR": [SingleModelTrainingDetails]
# }

trainer.get_progress()

In [None]:
# Mapping between partition_id and corresponding SingleModelTrainingDetails object.
# trainer.model_trainings returns a dictionary where the keys are partition IDs (strings),
# and the values are SingleModelTrainingDetails objects representing the training detail for each partition.
#
# Example output:
# {
#     "partition_id1": SingleModelTrainingDetails,
#     "partition_id2": SingleModelTrainingDetails,
#     "partition_id3": SingleModelTrainingDetails,
#     ...
# }

trainer.model_trainings

In [None]:
for partition_id, training_detail in trainer.model_trainings.items():
    print(trainer.model_trainings[partition_id].logs)
    assert isinstance(trainer.model_trainings[partition_id].model, xgb.XGBRegressor)

In [None]:
# Ideally you will not need to interact with the stage at all. This is more of a FYI how your stage is being used
# to persist the model and other artifacts.
session.sql(f"ls @{stage_name}").collect()

## Step 3: Troubleshooting Failed Runs

Training functions can fail for various reasons. Below are some common causes:

- **User Code Errors**  
  Bugs or issues in the user-defined training function can cause failures.

- **Infrastructure Issues**  
  An *Out-of-Memory (OOM)* error occurs when the training function consumes more memory than the node can provide.

- **Unexpected Node Failures**  
  In some cases, a node might crash unexpectedly.

---

### Handling OOM and Node Failures

When an OOM error or fatal node failure occurs, the **MMT API will not automatically retry** the training function. Instead, it will mark the corresponding partition ID run as **`INTERNAL_ERROR`**. If a worker node crashes, logs might not be captured, making debugging more difficult.

For all other failure scenarios (including OOM errors), MMT provides:
- A **detailed error message**  
- A **stack trace** to help diagnose and fix the issue

---

### Retry Logic for Non-Fatal Errors

If the failure is not considered fatal (e.g., transient issues), MMT will automatically retry the training function with **exponential backoff**. This mechanism allows transient issues to resolve before the function ultimately fails.

**Retry Attempts:**
1. **First retry**: Wait for 2 seconds (`initial_delay`)
2. **Second retry**: Wait for 4 seconds (2 * `initial_delay`)
3. **Third retry**: Wait for 8 seconds (2^2 * `initial_delay`)
4. **Fourth retry**: Wait for 16 seconds (2^3 * `initial_delay`)
5. **Final retry**: No delay — if it fails again, an exception is raised


In [None]:
def user_func_error(data_connector, **kwargs):
    pandas_df = data_connector.to_pandas()

    NUMERICAL_COLUMNS = ["X1", "X2", "X3"]
    LABEL_COLUMNS = ["X4"]
    model = xgb.XGBRegressor()

    # INTENTIONAL USER-CODE FAILURE: fitss function does not exist
    model.fitss(pandas_df[NUMERICAL_COLUMNS], pandas_df[LABEL_COLUMNS])    
    
    return model


model_name="my_mmt_model"
model_version = "v2"

trainer = ManyModelTrainer(
    training_func=user_func_error,
    model_name=model_name,
    model_version=model_version,
    stage_name=stage_name,
)

trainer.run(
    snowpark_dataframe=snowpark_df,
    partition_by="LOCATION_ID",    
)

In [None]:
# MMT will retry the user-function up to five times and then fail.

# Helper function for illustartion purposes of getting failed logs.
import time
while True:
    if "FAILED" in trainer.get_progress():
        break
    time.sleep(1)

# User can optionally choose to cancel the entire MMT run when at least one failed run is detected.
# trainer.cancel()

# Show first failed partition logs
trainer.get_progress()["FAILED"][0].logs

## Step 4: Inspecting Models and Logs After Notebook/Session Closure

After spending significant time training multiple models, you may want to temporarily shut down your notebook to save costs. However, you might wonder how to recover trained models and logs later. Snowflake provides an API specifically for this purpose.

With the `ReadOnlyManyModelTrainer`, you can restore previously trained models and access their logs and metadata, even after the notebook session has ended. While you can interact with most APIs, note that the `.run()` method is not available because model names and versions are immutable — once a model has been trained, you cannot re-run the same training job for that model/version.

Below we show how you can retrieve and inspect a previously trained model



In [None]:

from snowflake.ml.modeling.distributors.many_model_training.read_only_many_model_trainer import (
    ReadOnlyManyModelTrainer,
)

# Restore the trained model using the model name, version, and stage name
read_only_trainer = ReadOnlyManyModelTrainer.restore_from(
    model_name=model_name, 
    model_version="v1", 
    stage_name=stage_name
)

# Check the status of the trained model
model_status = read_only_trainer.status
print(model_status)

# You can also access other APIs to inspect logs and metadata (except for .run())
read_only_trainer.get_progress()
