In [None]:
#!/usr/bin/env python3

Step 6a: Inference on Snowflake Warehouse==========================================This script demonstrates how to:1. Load a model from Snowflake Model Registry2. Run inference using Python API (model_version.run())3. Run inference using SQL API (SELECT model!predict())4. Compare performance and results5. Handle different input data formats6. Best practices for production inferenceWAREHOUSE INFERENCE:-------------------- Default deployment target for Model Registry- Runs on Snowflake virtual warehouse (CPU-based)- Best for: Small-medium models, batch inference- Limitations: Snowflake Conda channel packages only- Advantages: Simple, no container management, integrated with Snowflake"""import pandas as pdfrom pathlib import Pathimport timeimport sysfrom snowflake.ml.registry import Registryfrom connections import SnowflakeConnectiondef initialize_registry(connection_name='legalzoom', database='ML_SHOWCASE', schema='MODELS'):    """    Initialize Model Registry connection.        Parameters:    -----------    connection_name : str        Snowflake connection name    database : str        Target database    schema : str        Target schema            Returns:    --------    tuple        (connection, session, registry)    """    print("=" * 80)    print("INITIALIZING MODEL REGISTRY")    print("=" * 80)        print(f"\nConnecting to Snowflake...")    connection = SnowflakeConnection.from_snow_cli(connection_name)    session = connection.session        print(f"Initializing registry at {database}.{schema}...")    registry = Registry(        session=session,        database_name=database,        schema_name=schema    )        print(f"✓ Registry initialized: {registry.location}")        return connection, session, registrydef load_model_from_registry(registry, model_name='XGBOOST_CLASSIFIER'):    """    Load a model from the registry.        Parameters:    -----------    registry : Registry        Model Registry object    model_name : str        Name of the model to load            Returns:    --------    tuple        (Model, ModelVersion)    """    print(f"\n{'=' * 80}")    print("LOADING MODEL FROM REGISTRY")    print("=" * 80)        print(f"\nModel name: {model_name}")        # Get the model object    print(f"Retrieving model from registry...")    model = registry.get_model(model_name)        print(f"✓ Model retrieved: {model.fully_qualified_name}")        # Get the default version    print(f"\nGetting default model version...")    model_version = model.default        if model_version is None:        print(f"⚠ No default version set, using latest version...")        versions = model.versions()        if not versions:            print(f"✗ No versions found for model {model_name}")            sys.exit(1)        model_version = versions[-1]        print(f"✓ Model version: {model_version.version_name}")    print(f"  Fully qualified: {model_version.fully_qualified_model_name}")        # Display model information    print(f"\nModel Information:")    try:        functions = model_version.show_functions()        print(f"  Available functions: {len(functions)}")        for _, func in functions.iterrows():            print(f"    - {func['name']}")    except Exception as e:        print(f"  Could not retrieve functions: {e}")        return model, model_versiondef load_test_data(n_samples=100):    """    Load test data for inference.        Parameters:    -----------    n_samples : int        Number of samples to load            Returns:    --------    pd.DataFrame        Test data    """    print(f"\n{'=' * 80}")    print("LOADING TEST DATA")    print("=" * 80)        test_data_path = Path('test_data.csv')        if not test_data_path.exists():        print(f"✗ Test data not found: {test_data_path}")        print(f"Please run 02_train_model.py first to generate test data.")        sys.exit(1)        print(f"\nLoading data from: {test_data_path}")    df = pd.read_csv(test_data_path)        # Remove target column if present    if 'TARGET' in df.columns:        y_true = df['TARGET']        X = df.drop(columns=['TARGET'])    else:        y_true = None        X = df        # Limit to n_samples    X = X.head(n_samples)    if y_true is not None:        y_true = y_true.head(n_samples)        print(f"✓ Loaded {len(X)} samples")    print(f"  Features: {len(X.columns)}")        return X, y_truedef python_api_inference(model_version, X_test, y_test=None):    """    Run inference using Python API (model_version.run()).        This is the recommended method for programmatic inference.        Parameters:    -----------    model_version : ModelVersion        Model version object    X_test : pd.DataFrame        Test features    y_test : pd.Series, optional        True labels for evaluation            Returns:    --------    tuple        (predictions, probabilities, inference_time)    """    print(f"\n{'=' * 80}")    print("PYTHON API INFERENCE")    print("=" * 80)        print(f"\nMethod: model_version.run()")    print(f"Input: pandas DataFrame with {len(X_test)} rows")        # ========================================================================    # METHOD 1: predict() - Get class predictions    # ========================================================================        print(f"\n1. Running predict() method...")    start_time = time.time()        predictions_df = model_version.run(        X_test,        function_name='"predict"'    )        predict_time = time.time() - start_time        # Extract predictions (handle different output formats)    if 'OUTPUT_0' in predictions_df.columns:        predictions = predictions_df['OUTPUT_0'].values    elif 'PREDICT' in predictions_df.columns:        predictions = predictions_df['PREDICT'].values    else:        # Take the last column as predictions        predictions = predictions_df.iloc[:, -1].values        print(f"✓ Predictions completed in {predict_time:.4f} seconds")    print(f"  Throughput: {len(X_test) / predict_time:.2f} samples/second")    print(f"  Output shape: {predictions.shape}")    print(f"  Sample predictions: {predictions[:5]}")        # ========================================================================    # METHOD 2: predict_proba() - Get class probabilities    # ========================================================================        print(f"\n2. Running predict_proba() method...")    start_time = time.time()        probabilities_df = model_version.run(        X_test,        function_name='"predict_proba"'    )        proba_time = time.time() - start_time        print(f"✓ Probabilities completed in {proba_time:.4f} seconds")    print(f"  Throughput: {len(X_test) / proba_time:.2f} samples/second")    print(f"  Output shape: {probabilities_df.shape}")    print(f"  Output columns: {list(probabilities_df.columns)}")    print(f"\n  Sample probabilities:")    print(probabilities_df.head())        # ========================================================================    # EVALUATION (if true labels provided)    # ========================================================================        if y_test is not None:        print(f"\n3. Evaluating predictions...")        from sklearn.metrics import accuracy_score, f1_score, classification_report                accuracy = accuracy_score(y_test, predictions)        f1 = f1_score(y_test, predictions, average='macro')                print(f"  Accuracy: {accuracy:.4f}")        print(f"  F1 Score: {f1:.4f}")                print(f"\n  Classification Report:")        print(classification_report(y_test, predictions, target_names=['Class 0', 'Class 1']))        return predictions, probabilities_df, predict_timedef sql_api_inference(session, model_version, table_name='ML_SHOWCASE.DATA.SYNTHETIC_DATA', n_samples=100):    """    Run inference using SQL API (SELECT model!predict()).        This method is useful for:    - Ad-hoc queries in Snowflake UI    - Integration with existing SQL workflows    - Batch processing of large tables        Parameters:    -----------    session : Session        Snowpark session    model_version : ModelVersion        Model version object    table_name : str        Fully qualified table name    n_samples : int        Number of samples to process            Returns:    --------    pd.DataFrame        Results with predictions    """    print(f"\n{'=' * 80}")    print("SQL API INFERENCE")    print("=" * 80)        model_name = model_version.model_name    version_name = model_version.version_name        print(f"\nModel: {model_name}")    print(f"Version: {version_name}")    print(f"Table: {table_name}")        # ========================================================================    # METHOD 1: Use DEFAULT version    # ========================================================================        print(f"\n1. Using DEFAULT version (simplest method):")        sql_default = f"""    SELECT         ID,        {model_name}!predict(*) AS PREDICTION    FROM {table_name}    LIMIT {n_samples}    """        print(f"\nSQL Query:")    print(f"{sql_default}")        print(f"\nExecuting query...")    start_time = time.time()        try:        result_df = session.sql(sql_default).to_pandas()        sql_time = time.time() - start_time                print(f"✓ Query completed in {sql_time:.4f} seconds")        print(f"  Throughput: {len(result_df) / sql_time:.2f} samples/second")        print(f"  Results: {len(result_df)} rows")        print(f"\nSample results:")        print(result_df.head())            except Exception as e:        print(f"✗ Query failed: {e}")        result_df = None        sql_time = None        # ========================================================================    # METHOD 2: Use SPECIFIC version    # ========================================================================        print(f"\n2. Using SPECIFIC version (for version control):")        sql_specific = f"""    WITH model_version AS MODEL {model_name} VERSION {version_name}    SELECT         ID,        model_version!predict(*) AS PREDICTION    FROM {table_name}    LIMIT {n_samples}    """        print(f"\nSQL Query:")    print(f"{sql_specific}")        print(f"\nExecuting query...")    start_time = time.time()        try:        result_specific_df = session.sql(sql_specific).to_pandas()        sql_specific_time = time.time() - start_time                print(f"✓ Query completed in {sql_specific_time:.4f} seconds")        print(f"  Results: {len(result_specific_df)} rows")            except Exception as e:        print(f"✗ Query failed: {e}")        result_specific_df = None        # ========================================================================    # METHOD 3: With all features and probabilities    # ========================================================================        print(f"\n3. Getting predictions AND probabilities:")        sql_full = f"""    SELECT         ID,        {model_name}!predict(*) AS PREDICTION,        {model_name}!predict_proba(*) AS PROBABILITIES    FROM {table_name}    LIMIT 10    """        print(f"\nSQL Query:")    print(f"{sql_full}")        print(f"\nExecuting query...")        try:        result_full_df = session.sql(sql_full).to_pandas()        print(f"✓ Query completed")        print(f"\nFull results with probabilities:")        print(result_full_df.head())            except Exception as e:        print(f"✗ Query failed: {e}")        return result_df, sql_timedef compare_python_vs_sql(python_time, sql_time, n_samples):    """    Compare Python API vs SQL API performance.        Parameters:    -----------    python_time : float        Python API inference time    sql_time : float        SQL API inference time    n_samples : int        Number of samples processed    """    print(f"\n{'=' * 80}")    print("PYTHON API VS SQL API COMPARISON")    print("=" * 80)        print(f"\n{'Metric':<30} {'Python API':<20} {'SQL API':<20}")    print("-" * 70)        print(f"{'Total Time (seconds)':<30} {python_time:<20.4f} {sql_time:<20.4f}")    print(f"{'Throughput (samples/sec)':<30} {n_samples/python_time:<20.2f} {n_samples/sql_time:<20.2f}")        faster = 'Python API' if python_time < sql_time else 'SQL API'    speedup = max(python_time, sql_time) / min(python_time, sql_time)    print(f"{'Faster Method':<30} {faster:<20} ({speedup:.2f}x)")        print(f"\nRecommendations:")    print(f"  • Python API: Best for programmatic access, notebooks, applications")    print(f"  • SQL API: Best for ad-hoc queries, SQL workflows, BI tools")    print(f"  • Both methods use the same underlying model and produce identical results")def display_best_practices():    """Display best practices for warehouse inference."""    print(f"\n{'=' * 80}")    print("BEST PRACTICES FOR WAREHOUSE INFERENCE")    print("=" * 80)        print(f"\n1. Model Size:")    print(f"   • Keep models < 100 MB for best performance")    print(f"   • Larger models may have longer cold-start times")    print(f"   • Consider SPCS for models > 500 MB")        print(f"\n2. Dependencies:")    print(f"   • Use packages from Snowflake Conda channel")    print(f"   • Pin versions for reproducibility")    print(f"   • Test dependencies before production deployment")        print(f"\n3. Batch Size:")    print(f"   • Process 100-10,000 rows per batch for optimal performance")    print(f"   • Larger batches amortize overhead costs")    print(f"   • Monitor warehouse credit usage")        print(f"\n4. Warehouse Sizing:")    print(f"   • Start with X-Small or Small warehouse")    print(f"   • Scale up if inference is slow")    print(f"   • Use auto-suspend to minimize costs")        print(f"\n5. Caching:")    print(f"   • First inference has cold-start overhead")    print(f"   • Subsequent inferences are faster (warm cache)")    print(f"   • Keep warehouse running for real-time use cases")        print(f"\n6. Monitoring:")    print(f"   • Track inference latency and throughput")    print(f"   • Monitor warehouse credit usage")    print(f"   • Set up alerts for failures")def main():    """Main execution function"""    print("\n" + "=" * 80)    print("STEP 6A: INFERENCE ON SNOWFLAKE WAREHOUSE")    print("=" * 80)    print("\nThis script demonstrates model inference on Snowflake Warehouse")    print("using both Python API and SQL API.")        # Configuration    CONNECTION_NAME = 'legalzoom'    DATABASE = 'ML_SHOWCASE'    SCHEMA = 'MODELS'    MODEL_NAME = 'XGBOOST_CLASSIFIER'    N_SAMPLES = 100        # Step 1: Initialize registry    connection, session, registry = initialize_registry(        CONNECTION_NAME, DATABASE, SCHEMA    )        # Step 2: Load model from registry    model, model_version = load_model_from_registry(registry, MODEL_NAME)        # Step 3: Load test data    X_test, y_test = load_test_data(n_samples=N_SAMPLES)        # Step 4: Python API inference    predictions, probabilities, python_time = python_api_inference(        model_version, X_test, y_test    )        # Step 5: SQL API inference    sql_results, sql_time = sql_api_inference(        session, model_version, n_samples=N_SAMPLES    )        # Step 6: Compare methods    if sql_time:        compare_python_vs_sql(python_time, sql_time, N_SAMPLES)        # Step 7: Best practices    display_best_practices()        # Summary    print(f"\n{'=' * 80}")    print("SUMMARY")    print("=" * 80)    print(f"✓ Model loaded from registry: {MODEL_NAME}")    print(f"✓ Python API inference: {N_SAMPLES} samples in {python_time:.4f}s")    if sql_time:        print(f"✓ SQL API inference: {N_SAMPLES} samples in {sql_time:.4f}s")    print(f"✓ Both methods produce identical results")    print(f"✓ Ready for production deployment!")        # Close connection    connection.close()        print(f"\n{'=' * 80}")    print("NEXT STEPS")    print("=" * 80)    print("Optional: Run the next script to see SPCS deployment:")    print("  python 06b_inference_spcs.py")    print("\nOr explore the complete workflow in the Jupyter notebook:")    print("  jupyter notebook 07_complete_notebook.ipynb")    print("=" * 80)if __name__ == "__main__":    main()