In [None]:
#!/usr/bin/env python3

Step 5: Load Pickle and Log to Model Registry==============================================This script demonstrates how to:1. Load a serialized model from .pkl file2. Log the model to Snowflake Model Registry with COMPREHENSIVE parameter documentation3. Understand ALL log_model() parameters and options4. Follow Snowflake best practices for model registration5. Verify the model was logged successfullyThis is the KEY STEP that shows how to package and register models in Snowflake.COMPREHENSIVE PARAMETER DOCUMENTATION:--------------------------------------This script documents EVERY parameter of registry.log_model() with:- Parameter name and type- Purpose and use case- Valid values and constraints- Best practices from Snowflake documentation- Practical examples"""import joblibimport pickleimport pandas as pdimport jsonfrom pathlib import Pathfrom datetime import datetimeimport sysfrom snowflake.ml.registry import Registryfrom connections import SnowflakeConnectiondef load_pickle_model(pickle_path='xgboost_model_joblib.pkl'):    """    Load a serialized model from .pkl file.        Parameters:    -----------    pickle_path : str        Path to the pickle file            Returns:    --------    tuple        (model, metadata)    """    print("=" * 80)    print("LOADING SERIALIZED MODEL")    print("=" * 80)        filepath = Path(pickle_path)        if not filepath.exists():        print(f"\n✗ Error: Pickle file not found: {pickle_path}")        print(f"Please run 03_save_pickle.py first to create the pickle file.")        sys.exit(1)        print(f"\nLoading model from: {filepath}")    print(f"File size: {filepath.stat().st_size / (1024 * 1024):.2f} MB")        # Load with joblib (handles both pickle and joblib formats)    model = joblib.load(filepath)        print(f"✓ Model loaded successfully!")    print(f"  Type: {type(model).__name__}")    print(f"  Module: {model.__class__.__module__}")        # Load metrics if available    metrics = {}    if Path('model_metrics.json').exists():        with open('model_metrics.json', 'r') as f:            metrics_data = json.load(f)            metrics = metrics_data.get('test', {})        print(f"✓ Loaded model metrics")        # Get model metadata    metadata = {        'file_path': str(filepath),        'file_size_mb': filepath.stat().st_size / (1024 * 1024),        'model_type': type(model).__name__,        'metrics': metrics    }        return model, metadatadef load_sample_data(n_samples=5):    """    Load sample data for signature inference.        Parameters:    -----------    n_samples : int        Number of samples to load            Returns:    --------    pd.DataFrame        Sample input data    """    print(f"\n{'=' * 80}")    print("LOADING SAMPLE DATA FOR SIGNATURE INFERENCE")    print("=" * 80)        test_data_path = Path('test_data.csv')        if not test_data_path.exists():        print(f"⚠ Test data not found, using synthetic sample")        # Create minimal sample data        import numpy as np        sample_df = pd.DataFrame(            np.random.randn(n_samples, 20),            columns=[f"FEATURE_{i:02d}" for i in range(20)]        )    else:        print(f"Loading sample data from: {test_data_path}")        full_data = pd.read_csv(test_data_path)        # Remove target column if present        feature_cols = [col for col in full_data.columns if col != 'TARGET']        sample_df = full_data[feature_cols].head(n_samples)        print(f"✓ Loaded {len(sample_df)} sample rows")    print(f"  Features: {len(sample_df.columns)}")    print(f"\nSample data shape: {sample_df.shape}")        return sample_dfdef initialize_registry(connection_name='legalzoom', database='ML_SHOWCASE', schema='MODELS'):    """    Initialize Model Registry connection.        Parameters:    -----------    connection_name : str        Snowflake connection name    database : str        Target database    schema : str        Target schema            Returns:    --------    tuple        (connection, session, registry)    """    print(f"\n{'=' * 80}")    print("INITIALIZING MODEL REGISTRY")    print("=" * 80)        print(f"\nConnecting to Snowflake...")    connection = SnowflakeConnection.from_snow_cli(connection_name)    session = connection.session        print(f"Initializing registry at {database}.{schema}...")    registry = Registry(        session=session,        database_name=database,        schema_name=schema    )        print(f"✓ Registry initialized: {registry.location}")        return connection, session, registrydef log_model_comprehensive(registry, model, sample_data, metrics):    """    Log model to registry with COMPREHENSIVE parameter documentation.        This function demonstrates ALL parameters of registry.log_model() with    extensive documentation following Snowflake best practices.        Parameters:    -----------    registry : Registry        Initialized Model Registry    model : object        Trained model to log    sample_data : pd.DataFrame        Sample input data for signature inference    metrics : dict        Model performance metrics            Returns:    --------    ModelVersion        Logged model version object    """    print(f"\n{'=' * 80}")    print("LOGGING MODEL TO REGISTRY - COMPREHENSIVE PARAMETERS")    print("=" * 80)        print(f"\nThis demonstrates ALL parameters of registry.log_model()")    print(f"with comprehensive documentation and best practices.")        # ========================================================================    # COMPREHENSIVE PARAMETER DOCUMENTATION    # ========================================================================        print(f"\n{'=' * 80}")    print("PARAMETER DOCUMENTATION")    print("=" * 80)        print(f"\n" + "=" * 80)    print("REQUIRED PARAMETERS")    print("=" * 80)        print(f"\n1. model (required)")    print(f"   Type: object (any pickleable Python object)")    print(f"   Purpose: The trained model to register")    print(f"   Valid Values: scikit-learn, XGBoost, LightGBM, PyTorch, TensorFlow,")    print(f"                 Hugging Face, MLflow, Snowpark ML, custom models")    print(f"   Best Practice: Ensure model is trained and validated before logging")    print(f"   Example: xgb.XGBClassifier()")        print(f"\n2. model_name (required)")    print(f"   Type: str")    print(f"   Purpose: Unique identifier for the model in the registry")    print(f"   Valid Values: Must be valid Snowflake identifier (A-Z, 0-9, _)")    print(f"   Best Practice: Use UPPERCASE_WITH_UNDERSCORES naming convention")    print(f"   Example: 'XGBOOST_CLASSIFIER', 'CUSTOMER_CHURN_MODEL'")        print(f"\n" + "=" * 80)    print("OPTIONAL PARAMETERS - VERSIONING")    print("=" * 80)        print(f"\n3. version_name (optional)")    print(f"   Type: str")    print(f"   Purpose: Specific version identifier for this model")    print(f"   Valid Values: Must be valid Snowflake identifier")    print(f"   Best Practice: Use semantic versioning (v1, v2) or date-based")    print(f"                  (v_20250110, v1_production)")    print(f"   Default: Auto-generated human-readable version name")    print(f"   Example: 'v1_production', 'v_20250110_143022'")        print(f"\n" + "=" * 80)    print("OPTIONAL PARAMETERS - METADATA")    print("=" * 80)        print(f"\n4. comment (optional)")    print(f"   Type: str")    print(f"   Purpose: Human-readable description of the model")    print(f"   Valid Values: Any string")    print(f"   Best Practice: Include training date, data source, purpose")    print(f"   Example: 'XGBoost classifier trained on 2025-01-10 for churn prediction'")        print(f"\n5. metrics (optional)")    print(f"   Type: dict")    print(f"   Purpose: Store model performance metrics")    print(f"   Valid Values: Dictionary with numeric or string values")    print(f"   Best Practice: Include accuracy, F1, precision, recall, ROC-AUC")    print(f"   Limit: Maximum 100 KB of metadata")    print(f"   Example: {{'accuracy': 0.95, 'f1_score': 0.94, 'auc': 0.96}}")        print(f"\n" + "=" * 80)    print("OPTIONAL PARAMETERS - DEPENDENCIES")    print("=" * 80)        print(f"\n6. conda_dependencies (optional)")    print(f"   Type: list of str")    print(f"   Purpose: Specify required Conda packages")    print(f"   Valid Values: Package names with optional versions")    print(f"   Format: '[channel::]package [operator version]'")    print(f"   Best Practice: Pin exact versions for reproducibility")    print(f"   Default Channel: Snowflake channel (warehouse), conda-forge (SPCS)")    print(f"   Example: ['xgboost==2.0.0', 'scikit-learn==1.3.0', 'pandas', 'numpy']")        print(f"\n7. pip_requirements (optional)")    print(f"   Type: list of str")    print(f"   Purpose: Specify required PyPI packages (for SPCS)")    print(f"   Valid Values: Package names with optional versions")    print(f"   Best Practice: Use for packages not in Conda")    print(f"   Note: Only works with SPCS target platform")    print(f"   Example: ['custom-package==1.0.0']")        print(f"\n8. python_version (optional)")    print(f"   Type: str")    print(f"   Purpose: Specify Python version for model execution")    print(f"   Valid Values: '3.8', '3.9', '3.10', '3.11'")    print(f"   Best Practice: Match your training environment")    print(f"   Default: Current Python version")    print(f"   Example: '3.10'")        print(f"\n" + "=" * 80)    print("OPTIONAL PARAMETERS - SIGNATURE INFERENCE")    print("=" * 80)        print(f"\n9. sample_input_data (optional but recommended)")    print(f"   Type: pd.DataFrame or Snowpark DataFrame")    print(f"   Purpose: Infer input/output signatures automatically")    print(f"   Valid Values: DataFrame with same schema as training data")    print(f"   Best Practice: Provide 5-10 representative samples")    print(f"   Required: For scikit-learn and XGBoost models")    print(f"   Example: X_test.head(5)")        print(f"\n10. signatures (optional)")    print(f"   Type: dict")    print(f"   Purpose: Manually specify method signatures")    print(f"   Valid Values: Dict mapping method names to ModelSignature objects")    print(f"   Best Practice: Use sample_input_data instead (easier)")    print(f"   When to Use: When automatic inference doesn't work")    print(f"   Example: {{'predict': ModelSignature(...)}}")        print(f"\n" + "=" * 80)    print("OPTIONAL PARAMETERS - DEPLOYMENT")    print("=" * 80)        print(f"\n11. target_platforms (optional)")    print(f"   Type: list of str")    print(f"   Purpose: Specify where model will run")    print(f"   Valid Values:")    print(f"     - ['WAREHOUSE'] - Run in Snowflake warehouse (default)")    print(f"     - ['SNOWPARK_CONTAINER_SERVICES'] - Run in SPCS")    print(f"     - ['WAREHOUSE', 'SNOWPARK_CONTAINER_SERVICES'] - Both")    print(f"   Best Practice:")    print(f"     - WAREHOUSE: Small-medium CPU models, Snowflake packages only")    print(f"     - SPCS: Large models, GPU models, custom dependencies")    print(f"   Default: ['WAREHOUSE']")    print(f"   Example: ['WAREHOUSE']")        print(f"\n" + "=" * 80)    print("OPTIONAL PARAMETERS - ADVANCED OPTIONS")    print("=" * 80)        print(f"\n12. options (optional)")    print(f"   Type: dict")    print(f"   Purpose: Advanced configuration options")    print(f"   Valid Options:")    print(f"")    print(f"   a) enable_explainability (bool, default: True if supported)")    print(f"      Purpose: Enable SHAP-based model explanations")    print(f"      When to Use: For models that support SHAP")    print(f"      When to Disable: For faster inference, unsupported models")    print(f"      Example: {{'enable_explainability': False}}")    print(f"")    print(f"   b) relax_version (bool, default: True)")    print(f"      Purpose: Allow flexible dependency versions")    print(f"      When to Use: For development and testing")    print(f"      When to Disable: For strict production reproducibility")    print(f"      Example: {{'relax_version': True}}")    print(f"")    print(f"   c) embed_local_ml_library (bool, default: False)")    print(f"      Purpose: Embed local snowflake-ml-python in model")    print(f"      When to Use: When using custom/modified snowflake-ml")    print(f"      Best Practice: Keep False to use Snowflake's version")    print(f"      Example: {{'embed_local_ml_library': False}}")    print(f"")    print(f"   d) target_methods (list of str)")    print(f"      Purpose: Specify which model methods to expose")    print(f"      Valid Values: Method names that exist on the model")    print(f"      Default: ['predict', 'predict_proba'] for classifiers")    print(f"      Example: {{'target_methods': ['predict', 'predict_proba']}}")    print(f"")    print(f"   e) method_options (dict)")    print(f"      Purpose: Per-method configuration")    print(f"      Valid Options: {{'case_sensitive': bool}}")    print(f"      Example: {{'method_options': {{'predict': {{'case_sensitive': True}}}}}}")        print(f"\n" + "=" * 80)    print("OPTIONAL PARAMETERS - CODE AND FILES")    print("=" * 80)        print(f"\n13. code_paths (optional)")    print(f"   Type: list of str")    print(f"   Purpose: Include custom Python code directories")    print(f"   Valid Values: Paths to directories containing Python modules")    print(f"   When to Use: When model depends on custom code")    print(f"   Example: ['./custom_preprocessing', './model_utils']")        print(f"\n14. ext_modules (optional)")    print(f"   Type: list of module objects")    print(f"   Purpose: External modules to pickle with the model")    print(f"   Valid Values: Python module objects")    print(f"   Supported: scikit-learn, Snowpark ML, PyTorch, TorchScript, custom")    print(f"   When to Use: For custom preprocessing or utility modules")    print(f"   Example: [my_custom_module]")        print(f"\n15. user_files (optional)")    print(f"   Type: list of str")    print(f"   Purpose: Include additional files (images, configs, etc.)")    print(f"   Valid Values: Paths to files to include")    print(f"   When to Use: For config files, images, or other assets")    print(f"   Example: ['config.yaml', 'logo.png']")        # ========================================================================    # ACTUAL MODEL LOGGING WITH ALL PARAMETERS    # ========================================================================        print(f"\n{'=' * 80}")    print("LOGGING MODEL WITH COMPREHENSIVE PARAMETERS")    print("=" * 80)        # Create version name with timestamp    version_name = f"v_{datetime.now().strftime('%Y%m%d_%H%M%S')}"        print(f"\nModel Name: XGBOOST_CLASSIFIER")    print(f"Version: {version_name}")    print(f"Target Platform: WAREHOUSE")        print(f"\nLogging model to registry...")    print(f"This may take 1-2 minutes...")        # Log the model with comprehensive parameters    model_version = registry.log_model(        # ===================================================================        # REQUIRED PARAMETERS        # ===================================================================        model=model,        model_name="XGBOOST_CLASSIFIER",                # ===================================================================        # VERSIONING        # ===================================================================        version_name=version_name,                # ===================================================================        # METADATA        # ===================================================================        comment=f"XGBoost binary classifier trained on synthetic data. "                f"Logged on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}. "                f"Model demonstrates comprehensive parameter documentation.",                metrics={            'accuracy': metrics.get('accuracy', 0.0),            'precision': metrics.get('precision', 0.0),            'recall': metrics.get('recall', 0.0),            'f1': metrics.get('f1', 0.0),            'roc_auc': metrics.get('roc_auc', 0.0),            'model_type': 'XGBClassifier',            'training_date': datetime.now().isoformat(),            'data_source': 'synthetic_classification_dataset'        },                # ===================================================================        # DEPENDENCIES        # ===================================================================        conda_dependencies=[            "xgboost==2.0.0",      # Pin exact version for reproducibility            "scikit-learn==1.3.0",  # Required for metrics and utilities            "pandas>=2.0.0",        # Flexible version for data handling            "numpy>=1.24.0"         # Flexible version for numerical operations        ],                python_version="3.10",  # Match training environment                # ===================================================================        # SIGNATURE INFERENCE        # ===================================================================        sample_input_data=sample_data,  # Automatic signature inference                # ===================================================================        # DEPLOYMENT TARGET        # ===================================================================        target_platforms=["WAREHOUSE"],  # Run in Snowflake warehouse                # ===================================================================        # ADVANCED OPTIONS        # ===================================================================        options={            # Disable explainability for faster inference            "enable_explainability": False,                        # Allow flexible dependency versions (good for development)            "relax_version": True,                        # Use Snowflake's snowflake-ml-python (not local version)            "embed_local_ml_library": False,                        # Expose both predict and predict_proba methods            "target_methods": ["predict", "predict_proba"],                        # Per-method configuration            "method_options": {                "predict": {                    "case_sensitive": True  # Preserve column name casing                },                "predict_proba": {                    "case_sensitive": True                }            }        },                # ===================================================================        # CODE AND FILES (not used in this example)        # ===================================================================        code_paths=[],      # No custom code directories        ext_modules=[],     # No external modules to pickle        # user_files=[]     # No additional files (commented as not always available)    )        print(f"\n✓ Model logged successfully!")        return model_versiondef verify_model_logging(registry, model_version):    """    Verify the model was logged correctly.        Parameters:    -----------    registry : Registry        Model Registry object    model_version : ModelVersion        Logged model version    """    print(f"\n{'=' * 80}")    print("VERIFYING MODEL REGISTRATION")    print("=" * 80)        print(f"\nModel Version Information:")    print(f"  Fully Qualified Name: {model_version.fully_qualified_model_name}")    print(f"  Version Name: {model_version.version_name}")    print(f"  Model Name: {model_version.model_name}")        # Show available functions    print(f"\nAvailable Functions:")    try:        functions = model_version.show_functions()        print(functions)    except Exception as e:        print(f"  Could not retrieve functions: {e}")        # Show metrics    print(f"\nModel Metrics:")    try:        metrics_df = model_version.show_metrics()        if len(metrics_df) > 0:            for _, row in metrics_df.iterrows():                print(f"  {row['key']}: {row['value']}")        else:            print(f"  No metrics found")    except Exception as e:        print(f"  Could not retrieve metrics: {e}")        # List all models in registry    print(f"\nAll Models in Registry:")    models_df = registry.show_models()    print(models_df[['name', 'default_version_name', 'created_on']])        print(f"\n✓ Verification complete!")def main():    """Main execution function"""    print("\n" + "=" * 80)    print("STEP 5: LOAD PICKLE AND LOG TO MODEL REGISTRY")    print("=" * 80)    print("\nThis script demonstrates comprehensive parameter documentation")    print("for registry.log_model() following Snowflake best practices.")        # Configuration    PICKLE_PATH = 'xgboost_model_joblib.pkl'    CONNECTION_NAME = 'legalzoom'    DATABASE = 'ML_SHOWCASE'    SCHEMA = 'MODELS'        # Step 1: Load pickle model    model, metadata = load_pickle_model(PICKLE_PATH)        # Step 2: Load sample data    sample_data = load_sample_data(n_samples=5)        # Step 3: Initialize registry    connection, session, registry = initialize_registry(        CONNECTION_NAME, DATABASE, SCHEMA    )        # Step 4: Log model with comprehensive parameters    model_version = log_model_comprehensive(        registry, model, sample_data, metadata['metrics']    )        # Step 5: Verify logging    verify_model_logging(registry, model_version)        # Summary    print(f"\n{'=' * 80}")    print("SUMMARY")    print("=" * 80)    print(f"✓ Model loaded from: {metadata['file_path']}")    print(f"✓ Model logged to: {model_version.fully_qualified_model_name}")    print(f"✓ Version: {model_version.version_name}")    print(f"✓ Registry: {registry.location}")    print(f"✓ All parameters documented and demonstrated!")        # Close connection    connection.close()        print(f"\n{'=' * 80}")    print("NEXT STEPS")    print("=" * 80)    print("Run the next script to perform inference on Snowflake Warehouse:")    print("  python 06a_inference_warehouse.py")    print("=" * 80)if __name__ == "__main__":    main()