# ARCA Beverage Demo: Many Model Training (MMT)

## Overview
This notebook demonstrates **parallel model training** using Snowflake ML's Many Model Training (MMT).

## Business Challenge (ARCA Real Scenario):
- **Before**: Sequential training of 16 models = **23 hours**
- **After**: Parallel training with MMT = **~1 hour** (20x faster!)

## What We'll Do:
1. Train **6 models in parallel** (one per customer segment)
2. Test **3 algorithms** per segment (XGBoost, RandomForest, LinearRegression)
3. **Auto-select** best model per segment based on RMSE
4. **Register** all models in Model Registry

## Target Variable:
**WEEKLY_SALES_UNITS** - Predict next week's unit sales per customer

In [None]:
from snowflake.snowpark.context import get_active_session
from snowflake.ml.modeling.distributors.many_model import ManyModelTraining
from snowflake.ml.registry import Registry
from snowflake.ml.model import task
from snowflake.ml.model import target_platform
import time
from datetime import datetime

# Use active Snowsight session (no credentials needed)
session = get_active_session()

# Set context
session.sql("USE WAREHOUSE ARCA_DEMO_WH").collect()
session.sql("USE DATABASE ARCA_BEVERAGE_DEMO").collect()
session.sql("USE SCHEMA ML_DATA").collect()

print(f"Connected to Snowflake")
print(f"   Database: {session.get_current_database()}")
print(f"   Schema: {session.get_current_schema()}")

## 1. Setup Model Registry & Staging

In [None]:
session.sql("CREATE SCHEMA IF NOT EXISTS ARCA_BEVERAGE_DEMO.MODEL_REGISTRY").collect()
session.sql("CREATE STAGE IF NOT EXISTS ARCA_BEVERAGE_DEMO.MODEL_REGISTRY.MMT_MODELS").collect()

registry = Registry(
    session=session,
    database_name="ARCA_BEVERAGE_DEMO",
    schema_name="MODEL_REGISTRY"
)

print("‚úÖ Model Registry initialized")
print("‚úÖ Stage for MMT models created")

## 2. Prepare Training Data from Feature Store

We'll use the features we created in the Feature Store notebook

In [None]:
training_df = session.table("ARCA_BEVERAGE_DEMO.ML_DATA.TRAINING_DATA")

print(f"\nüìä Training Data Overview:")
print(f"   Total records: {training_df.count():,}")
print(f"   Unique customers: {training_df.select('CUSTOMER_ID').distinct().count():,}")
print(f"\n   Columns: {training_df.columns}")

segment_counts = training_df.group_by('SEGMENT').count().sort('SEGMENT')
print("\nüìä Records per Segment:")
segment_counts.show()

## 3. Define Training Function

This function will be executed **in parallel** for each segment.

It tests 3 algorithms and selects the best one based on RMSE.

In [None]:
HYPERPARAMETER_SETS = {
    0: {
        'XGBoost': {'n_estimators': 100, 'max_depth': 4, 'learning_rate': 0.1},
        'RandomForest': {'n_estimators': 100, 'max_depth': 10},
        'LinearRegression': {}
    },
    1: {
        'XGBoost': {'n_estimators': 200, 'max_depth': 2, 'learning_rate': 0.05},
        'RandomForest': {'n_estimators': 200, 'max_depth': 6},
        'LinearRegression': {}
    },
    2: {
        'XGBoost': {'n_estimators': 50, 'max_depth': 6, 'learning_rate': 0.2},
        'RandomForest': {'n_estimators': 50, 'max_depth': 15},
        'LinearRegression': {}
    }
}

def train_segment_model(data_connector, context, hyperparameter_set=0):
    """
    Train and select best model for a customer segment.
    
    This function:
    1. Receives data for ONE segment (via MMT partitioning)
    2. Tests 3 algorithms: XGBoost, RandomForest, LinearRegression
    3. Selects best model based on RMSE
    4. Returns the winning model
    
    Args:
        data_connector: Snowflake data connector (provided by MMT)
        context: Contains partition_id (segment name)
        hyperparameter_set: Which hyperparameter set to use (0, 1, or 2)
    
    Returns:
        Trained model object (best of 3 algorithms)
    """
    import pandas as pd
    from xgboost import XGBRegressor
    from sklearn.ensemble import RandomForestRegressor
    from sklearn.linear_model import LinearRegression
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_squared_error, mean_absolute_error
    import numpy as np
    
    segment_name = context.partition_id
    print(f"\n{'='*80}")
    print(f"üöÄ Training models for {segment_name} (hyperparameter_set={hyperparameter_set})")
    print(f"{'='*80}")
    
    df = data_connector.to_pandas()
    print(f"üìä Data shape: {df.shape}")
    
    feature_cols = [
        'CUSTOMER_TOTAL_UNITS_4W',
        'WEEKS_WITH_PURCHASE',
        'VOLUME_QUARTILE',
        'WEEK_OF_YEAR',
        'MONTH',
        'QUARTER',
        'TRANSACTION_COUNT',
        'UNIQUE_PRODUCTS_PURCHASED',
        'AVG_UNITS_PER_TRANSACTION'
    ]
    
    target_col = 'WEEKLY_SALES_UNITS'
    
    X = df[feature_cols]
    y = df[target_col]
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    
    print(f"   Training set: {X_train.shape[0]:,} samples")
    print(f"   Test set: {X_test.shape[0]:,} samples")
    
    hp = HYPERPARAMETER_SETS[hyperparameter_set]
    
    models_to_test = {
        'XGBoost': XGBRegressor(
            n_estimators=hp['XGBoost']['n_estimators'],
            max_depth=hp['XGBoost']['max_depth'],
            learning_rate=hp['XGBoost']['learning_rate'],
            random_state=42,
            n_jobs=-1
        ),
        'RandomForest': RandomForestRegressor(
            n_estimators=hp['RandomForest']['n_estimators'],
            max_depth=hp['RandomForest']['max_depth'],
            random_state=42,
            n_jobs=-1
        ),
        'LinearRegression': LinearRegression()
    }
    
    results = {}
    
    for model_name, model in models_to_test.items():
        print(f"\n   Training {model_name}...")
        
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        mae = mean_absolute_error(y_test, y_pred)
        
        results[model_name] = {
            'model': model,
            'rmse': rmse,
            'mae': mae
        }
        
        print(f"      RMSE: {rmse:.2f}")
        print(f"      MAE: {mae:.2f}")
    
    best_model_name = min(results, key=lambda k: results[k]['rmse'])
    best_model = results[best_model_name]['model']
    best_rmse = results[best_model_name]['rmse']
    best_mae = results[best_model_name]['mae']
    
    print(f"\nüèÜ WINNER: {best_model_name}")
    print(f"   RMSE: {best_rmse:.2f}")
    print(f"   MAE: {best_mae:.2f}")
    print(f"{'='*80}\n")
    
    best_model.best_algorithm = best_model_name
    best_model.rmse = best_rmse
    best_model.mae = best_mae
    best_model.segment = segment_name
    best_model.training_samples = X_train.shape[0]
    
    return best_model

print("‚úÖ Training function defined")

## 4. Execute Many Model Training (MMT)

### ‚è±Ô∏è Performance Comparison:
- **Sequential Training** (one after another): ~30-45 minutes
- **Parallel Training** (MMT): ~5-10 minutes
- **Real ARCA Scenario**: 23 hours ‚Üí 1 hour (20x faster!)

In [None]:
HYPERPARAMETER_SET = 0  # Choose 0, 1, or 2

print("\n" + "="*80)
print("üöÄ STARTING MANY MODEL TRAINING (MMT)")
print("="*80)

print(f"\n‚ö° Mode: STANDARD (Hyperparameter Set {HYPERPARAMETER_SET})")
from functools import partial
training_func = partial(train_segment_model, hyperparameter_set=HYPERPARAMETER_SET)

print(f"\nTraining 6 models in PARALLEL (one per segment)\n")

start_time = time.time()

trainer = ManyModelTraining(
    training_func,
    "ARCA_BEVERAGE_DEMO.MODEL_REGISTRY.MMT_MODELS"
)

training_run = trainer.run(
    partition_by="SEGMENT",
    snowpark_dataframe=training_df,
    run_id=f"arca_weekly_sales_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)

print("\n‚è≥ Training in progress... Monitoring completion...\n")

import time as time_module
max_wait = 180
check_interval = 5
elapsed = 0
completed = False

while elapsed < max_wait:
    time_module.sleep(check_interval)
    elapsed += check_interval
    
    try:
        done_count = 0
        total_count = 0
        for partition_id in training_run.partition_details:
            total_count += 1
            status = training_run.partition_details[partition_id].status
            if status.name == 'DONE' or status.name == 'FAILED':
                done_count += 1
        
        print(f"‚è±Ô∏è  {elapsed}s elapsed - Progress: {done_count}/{total_count} models completed", end='\r')
        
        if done_count == total_count:
            print("\n‚úÖ All models completed!" + " "*50)
            completed = True
            break
    except:
        print(f"‚è±Ô∏è  {elapsed}s elapsed - Waiting for status update...", end='\r')

if not completed:
    print("\n‚è±Ô∏è  Timeout reached - Verifying completion via stage..." + " "*30)
    stage_files = session.sql(f"LIST @ARCA_BEVERAGE_DEMO.MODEL_REGISTRY.MMT_MODELS PATTERN='.*{training_run.run_id}.*'").collect()
    if len(stage_files) > 0:
        print(f"‚úÖ Found {len(stage_files)} model files in stage - Training completed successfully!")
        completed = True
    else:
        print("‚ö†Ô∏è  No model files found - Training may have failed")

end_time = time.time()
elapsed_minutes = (end_time - start_time) / 60

final_status = "COMPLETED" if completed else "UNKNOWN"

print("\n" + "="*80)
print(f"‚úÖ TRAINING COMPLETE! Status: {final_status}")
print("="*80)
print(f"\n‚è±Ô∏è  Total training time: {elapsed_minutes:.2f} minutes")
print(f"\nüìä Performance Improvement:")
sequential_estimate = elapsed_minutes * 6
speedup = sequential_estimate / elapsed_minutes if elapsed_minutes > 0 else 6.0
print(f"   Sequential (estimated): {sequential_estimate:.2f} minutes")
print(f"   Parallel (actual): {elapsed_minutes:.2f} minutes")
print(f"   Speedup: {speedup:.1f}x faster! üöÄ")

In [None]:
# Verificar estado sin interrumpir
print("Checking training status...")
print(f"\nPartition Details:")
for partition_id, details in training_run.partition_details.items():
    print(f"  {partition_id}: {details.status}")

## 5. Review Training Results

In [None]:
print("\nüìä Training Results by Segment:\n")

for partition_id in training_run.partition_details:
    details = training_run.partition_details[partition_id]
    
    if details.status == "DONE":
        model = training_run.get_model(partition_id)
        
        print(f"\n{partition_id}:")
        print(f"   Algorithm: {model.best_algorithm}")
        
        # Display RMSE with CV standard deviation if available
        if hasattr(model, 'rmse_std'):
            print(f"   CV RMSE: {model.rmse:.2f} (+/- {model.rmse_std:.2f})")
        else:
            print(f"   RMSE: {model.rmse:.2f}")
            
        print(f"   MAE: {model.mae:.2f}")
        print(f"   Training samples: {model.training_samples:,}")
        
        # Display HPO and CV-specific info if available
        if hasattr(model, 'hpo_trials'):
            print(f"   HPO Trials: {model.hpo_trials}")
        if hasattr(model, 'cv_folds'):
            print(f"   CV Folds: {model.cv_folds}")
        if hasattr(model, 'best_hyperparameters'):
            print(f"   Best Hyperparameters:")
            for param, value in model.best_hyperparameters.items():
                if isinstance(value, float):
                    print(f"      {param}: {value:.4f}")
                else:
                    print(f"      {param}: {value}")
    else:
        print(f"\n‚ùå {partition_id}: Training failed")
        print(f"   Status: {details.status}")

## 6. Register Models in Model Registry

Register each segment's model with metadata and metrics

In [None]:
print("\nüìù Registering models in Model Registry...\n")

version_date = datetime.now().strftime('%Y%m%d_%H%M')  # Include hour:minute for uniqueness
registered_models = {}

for partition_id in training_run.partition_details:
    details = training_run.partition_details[partition_id]
    
    if details.status.name == "DONE":
        model = training_run.get_model(partition_id)
        
        model_name = f"weekly_sales_forecast_{partition_id.lower()}"
        
        sample_input = training_df.filter(
            training_df['SEGMENT'] == partition_id
        ).select([
            'CUSTOMER_TOTAL_UNITS_4W',
            'WEEKS_WITH_PURCHASE',
            'VOLUME_QUARTILE',
            'WEEK_OF_YEAR',
            'MONTH',
            'QUARTER',
            'TRANSACTION_COUNT',
            'UNIQUE_PRODUCTS_PURCHASED',
            'AVG_UNITS_PER_TRANSACTION'
        ]).limit(5)
        
        print(f"Registering {partition_id}...")
        
        mv = registry.log_model(
            model,
            model_name=model_name,
            version_name=f"v_{version_date}",
            comment=f"Weekly sales forecast model for {partition_id} - Algorithm: {model.best_algorithm}",
            metrics={
                "rmse": float(model.rmse),
                "mae": float(model.mae),
                "training_samples": int(model.training_samples),
                "algorithm": model.best_algorithm,
                "segment": model.segment
            },
            sample_input_data=sample_input,
            task=task.Task.TABULAR_REGRESSION,
            target_platforms=["WAREHOUSE"]
        )
        
        registered_models[partition_id] = {
            'model_name': model_name,
            'version': f"v_{version_date}",
            'model_version': mv
        }
        
        print(f"‚úÖ {partition_id}: {model_name} v_{version_date}")
        print(f"   Algorithm: {model.best_algorithm}, RMSE: {model.rmse:.2f}")

print(f"\n‚úÖ All {len(registered_models)} models registered successfully!")
print("\nüí° Models registered for WAREHOUSE and SPCS inference")

## 7. Test Quick Prediction

In [None]:
print("\nüß™ Quick Model Validation Test\n")

test_segment = 'SEGMENT_1'
model_info = registered_models[test_segment]
model_name = model_info['model_name']
version_name = model_info['version']

print(f"Testing {model_name}@{version_name}...")

model = registry.get_model(model_name)
mv = model.version(version_name)

sample_data = training_df.select(
    'CUSTOMER_TOTAL_UNITS_4W', 'WEEKS_WITH_PURCHASE', 'VOLUME_QUARTILE',
    'WEEK_OF_YEAR', 'MONTH', 'QUARTER', 'TRANSACTION_COUNT',
    'UNIQUE_PRODUCTS_PURCHASED', 'AVG_UNITS_PER_TRANSACTION'
).filter(f"SEGMENT = '{test_segment}'").limit(5)

predictions = mv.run(sample_data, function_name="predict")
print(f"\nüìä Sample Predictions:")
print(predictions.to_pandas().to_string(index=False))

print("\n‚úÖ Model validation completed!")

print("\nüí° Inference Options:")
print("   ‚Ä¢ Warehouse (default): mv.run(data)")
print("   ‚Ä¢ SPCS Service: mv.create_service(...) then mv.run(data, service_name='...')")

## 8. Train alternative models

Train new versions of the models using different hyperparameter configurations. This version will be registered but not aliased, making it available for validation in notebook 05.

### 8a. Hyperparameter Set 1

In [None]:
print("\n" + "="*80)
print("üîÑ TRAINING V2 MODELS (Hyperparameter Set 1)")
print("="*80)

from functools import partial
training_func_v2 = partial(train_segment_model, hyperparameter_set=1)

start_time_v2 = time.time()

trainer_v2 = ManyModelTraining(
    training_func_v2,
    "ARCA_BEVERAGE_DEMO.MODEL_REGISTRY.MMT_MODELS"
)

training_run_v2 = trainer_v2.run(
    partition_by="SEGMENT",
    snowpark_dataframe=training_df,
    run_id=f"arca_weekly_sales_v2_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)

print(f"\n‚è≥ Training V2 models in progress...")

import time as time_module
max_wait = 180
check_interval = 5
elapsed = 0
completed_v2 = False

while elapsed < max_wait:
    time_module.sleep(check_interval)
    elapsed += check_interval
    
    try:
        done_count = 0
        total_count = 0
        for partition_id in training_run_v2.partition_details:
            total_count += 1
            status = training_run_v2.partition_details[partition_id].status
            if status.name == 'DONE' or status.name == 'FAILED':
                done_count += 1
        
        print(f"‚è±Ô∏è  {elapsed}s elapsed - Progress: {done_count}/{total_count} models completed", end='\r')
        
        if done_count == total_count:
            print("\n‚úÖ All V2 models completed!" + " "*50)
            completed_v2 = True
            break
    except:
        print(f"‚è±Ô∏è  {elapsed}s elapsed - Waiting for status update...", end='\r')

end_time_v2 = time.time()
elapsed_minutes_v2 = (end_time_v2 - start_time_v2) / 60
print(f"\n‚è±Ô∏è  V2 Training time: {elapsed_minutes_v2:.2f} minutes")

In [None]:
print("\nüìù Registering V2 models in Model Registry...\n")

version_date_v2 = datetime.now().strftime('%Y%m%d_%H%M')
registered_models_v2 = {}

for partition_id in training_run_v2.partition_details:
    details = training_run_v2.partition_details[partition_id]
    
    if details.status.name == "DONE":
        model = training_run_v2.get_model(partition_id)
        model_name = f"weekly_sales_forecast_{partition_id.lower()}"
        version_name = f"v_{version_date_v2}"
        
        sample_input = training_df.filter(
            training_df['SEGMENT'] == partition_id
        ).select([
            'CUSTOMER_TOTAL_UNITS_4W',
            'WEEKS_WITH_PURCHASE',
            'VOLUME_QUARTILE',
            'WEEK_OF_YEAR',
            'MONTH',
            'QUARTER',
            'TRANSACTION_COUNT',
            'UNIQUE_PRODUCTS_PURCHASED',
            'AVG_UNITS_PER_TRANSACTION'
        ]).limit(5)
        
        print(f"Registering {partition_id}...")
        
        mv = registry.log_model(
            model,
            model_name=model_name,
            version_name=version_name,
            comment=f"Weekly sales forecast model for {partition_id} - Algorithm: {model.best_algorithm}",
            metrics={
                "rmse": float(model.rmse),
                "mae": float(model.mae),
                "training_samples": int(model.training_samples),
                "algorithm": model.best_algorithm,
                "segment": model.segment
            },
            sample_input_data=sample_input,
            task=task.Task.TABULAR_REGRESSION,
            target_platforms=["WAREHOUSE"]
        )

        
        registered_models_v2[partition_id] = {
            'model_name': model_name,
            'version': f"v_{version_date}",
            'model_version': mv
        }
        
        print(f"‚úÖ {partition_id}: {model_name} v_{version_date}")
        print(f"   Algorithm: {model.best_algorithm}, RMSE: {model.rmse:.2f}")

print(f"\n‚úÖ All {len(registered_models)} models registered successfully!")
print("\nüí° Models registered for WAREHOUSE and SPCS inference")

### 8b. Hyperparameter Set 2

In [None]:
print("\n" + "="*80)
print("üîÑ TRAINING V3 MODELS (Hyperparameter Set 2)")
print("="*80)

from functools import partial
training_func_v3 = partial(train_segment_model, hyperparameter_set=2)

start_time_v3 = time.time()

trainer_v3 = ManyModelTraining(
    training_func_v3,
    "ARCA_BEVERAGE_DEMO.MODEL_REGISTRY.MMT_MODELS"
)

training_run_v3 = trainer_v3.run(
    partition_by="SEGMENT",
    snowpark_dataframe=training_df,
    run_id=f"arca_weekly_sales_v3_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)

print(f"\n‚è≥ Training V3 models in progress...")

import time as time_module
max_wait = 180
check_interval = 5
elapsed = 0
completed_v2 = False

while elapsed < max_wait:
    time_module.sleep(check_interval)
    elapsed += check_interval
    
    try:
        done_count = 0
        total_count = 0
        for partition_id in training_run_v3.partition_details:
            total_count += 1
            status = training_run_v3.partition_details[partition_id].status
            if status.name == 'DONE' or status.name == 'FAILED':
                done_count += 1
        
        print(f"‚è±Ô∏è  {elapsed}s elapsed - Progress: {done_count}/{total_count} models completed", end='\r')
        
        if done_count == total_count:
            print("\n‚úÖ All V3 models completed!" + " "*50)
            completed_v3 = True
            break
    except:
        print(f"‚è±Ô∏è  {elapsed}s elapsed - Waiting for status update...", end='\r')

end_time_v3 = time.time()
elapsed_minutes_v3 = (end_time_v3 - start_time_v3) / 60
print(f"\n‚è±Ô∏è  V3 Training time: {elapsed_minutes_v3:.2f} minutes")

In [None]:
print("\nüìù Registering V3 models in Model Registry...\n")

version_date_v3 = datetime.now().strftime('%Y%m%d_%H%M')
registered_models_v3 = {}

for partition_id in training_run_v3.partition_details:
    details = training_run_v3.partition_details[partition_id]
    
    if details.status.name == "DONE":
        model = training_run_v3.get_model(partition_id)
        model_name = f"weekly_sales_forecast_{partition_id.lower()}"
        version_name = f"v_{version_date_v3}"
        
        sample_input = training_df.filter(
            training_df['SEGMENT'] == partition_id
        ).select([
            'CUSTOMER_TOTAL_UNITS_4W',
            'WEEKS_WITH_PURCHASE',
            'VOLUME_QUARTILE',
            'WEEK_OF_YEAR',
            'MONTH',
            'QUARTER',
            'TRANSACTION_COUNT',
            'UNIQUE_PRODUCTS_PURCHASED',
            'AVG_UNITS_PER_TRANSACTION'
        ]).limit(5)
        
        print(f"Registering {partition_id}...")
        
        mv = registry.log_model(
            model,
            model_name=model_name,
            version_name=version_name,
            comment=f"Weekly sales forecast model for {partition_id} - Algorithm: {model.best_algorithm}",
            metrics={
                "rmse": float(model.rmse),
                "mae": float(model.mae),
                "training_samples": int(model.training_samples),
                "algorithm": model.best_algorithm,
                "segment": model.segment
            },
            sample_input_data=sample_input,
            task=task.Task.TABULAR_REGRESSION,
            target_platforms=["WAREHOUSE"]
        )

        
        registered_models_v3[partition_id] = {
            'model_name': model_name,
            'version': f"v_{version_date}",
            'model_version': mv
        }
        
        print(f"‚úÖ {partition_id}: {model_name} v_{version_date}")
        print(f"   Algorithm: {model.best_algorithm}, RMSE: {model.rmse:.2f}")

print(f"\n‚úÖ All {len(registered_models)} models registered successfully!")
print("\nüí° Models registered for WAREHOUSE and SPCS inference")