# NAME

> Fill in a module description here

In [10]:
# | default_exp manager

In [None]:
# | export
from __future__ import annotations
from typing import List, Optional, Dict, Union, Protocol, Callable
from dataclasses import dataclass
import uuid
from contextlib import contextmanager
from typing import Optional

from snowflake.ml.feature_store import FeatureStore, Entity, FeatureView, CreationMode
from snowflake.snowpark import DataFrame

from snowflake_feature_store.connection import SnowflakeConnection
from snowflake_feature_store.feature_view import FeatureViewConfig, create_feature_view
from snowflake_feature_store.transforms import Transform, apply_transforms
from datetime import datetime, timezone 




In [3]:

# | export
class FeatureStoreCallback(Protocol):
    "Protocol for feature store callbacks"
    def on_feature_view_create(self, name: str, df: DataFrame) -> None: ...
    def on_entity_create(self, name: str, keys: List[str]) -> None: ...
    def on_error(self, error: str) -> None: ...


In [4]:

# | export
class PrintCallback:
    "Simple callback that prints events"
    def on_feature_view_create(self, name: str, df: DataFrame) -> None:
        print(f"Created feature view: {name} with {len(df.columns)} features")
    
    def on_entity_create(self, name: str, keys: List[str]) -> None:
        print(f"Created entity: {name} with keys: {keys}")
        
    def on_error(self, error: str) -> None:
        print(f"Error: {error}")


In [None]:

# | export
@dataclass
class FeatureStoreManager:
    """Manages feature store operations
    
    Args:
        connection: Snowflake connection
        callbacks: Optional callbacks for monitoring
        overwrite: Whether to overwrite existing features
    """
    connection: SnowflakeConnection
    callbacks: List[FeatureStoreCallback] = None
    overwrite: bool = False
    
    def __post_init__(self):
        "Initialize feature store"
        self.feature_store = FeatureStore(
            session=self.connection.session,
            database=self.connection.database,
            name=self.connection.schema,
            default_warehouse=self.connection.warehouse,
            creation_mode=CreationMode.CREATE_IF_NOT_EXIST
        )
        self.entities: Dict[str, Entity] = {}
        self.feature_views: Dict[str, FeatureView] = {}
        self.callbacks = self.callbacks or []
        
    def add_entity(self, 
                  name: str, 
                  join_keys: List[str], 
                  description: Optional[str] = None) -> FeatureStoreManager:
        """Add entity to feature store
        
        Args:
            name: Entity name
            join_keys: Keys used for joining
            description: Optional description
            
        Returns:
            Self for method chaining
        """
        entity = Entity(
            name=name,
            join_keys=join_keys,
            desc=description or f"Entity {name}"
        )
        try:
            self.feature_store.register_entity(entity)
            self.entities[name] = entity
            for cb in self.callbacks:
                cb.on_entity_create(name, join_keys)
        except Exception as e:
            for cb in self.callbacks:
                cb.on_error(f"Error creating entity {name}: {str(e)}")
            raise
        return self
    
    def add_feature_view(self,
                        config: FeatureViewConfig,
                        df: DataFrame,
                        entity_name: str,
                        transforms: Optional[List[Transform]] = None) -> FeatureView:
        """Add feature view to feature store"""
        try:
            # Validate schema only (no execution)
            self._validate_schema_only(df)
            
            # Apply transforms if provided
            if transforms:
                df = apply_transforms(df, transforms)
            
            self._validate_sql(df)
                
            # Get entity
            entity = self.entities.get(entity_name)
            if not entity:
                raise ValueError(f"Entity {entity_name} not found")
                
            # Create feature view
            feature_view = create_feature_view(config, df, entity)
            
            # Register feature view (this is where actual execution happens)
            registered_view = self.feature_store.register_feature_view(
                feature_view=feature_view,
                version=config.version,
                block=True,
                overwrite=self.overwrite
            )
            
            # Store for later reference
            self.feature_views[config.full_name] = registered_view
            
            # Notify callbacks
            for cb in self.callbacks:
                cb.on_feature_view_create(config.full_name, df)
                
            return registered_view
            
        except Exception as e:
            for cb in self.callbacks:
                cb.on_error(f"Error creating feature view {config.full_name}: {str(e)}")
            raise

    def get_features(self,
                    spine_df: DataFrame,
                    feature_views: List[Union[str, FeatureViewConfig]],
                    label_cols: Optional[List[str]] = None,
                    dataset_name: Optional[str] = None,
                    spine_timestamp_col: Optional[str] = None,
                    **kwargs) -> DataFrame:
        """Get features for training or inference"""
        views = []
        for fv in feature_views:
            if isinstance(fv, FeatureViewConfig):
                view = self.feature_store.get_feature_view(
                    fv.full_name, 
                    version=fv.version
                )
            else:
                name, version = fv.split('/')
                view = self.feature_store.get_feature_view(name, version)
            views.append(view)
        
        if dataset_name is None:
            # Use datetime.now() with timezone
            timestamp = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
            unique_id = str(uuid.uuid4())[:8]
            dataset_name = f"DATASET_{timestamp}_{unique_id}"
        
        dataset = self.feature_store.generate_dataset(
            name=dataset_name,
            spine_df=spine_df,
            features=views,
            spine_label_cols=label_cols,
            spine_timestamp_col=spine_timestamp_col,
            **kwargs
        )
        
        return dataset.read.to_snowpark_dataframe()

    
    def _validate_sql(self, df: DataFrame) -> None:
        """Validate the SQL that will be generated
        
        Args:
            df: DataFrame to validate
            
        Raises:
            ValueError: If SQL validation fails
        """
        try:
            # Get the SQL query
            sql = df.queries['queries'][0]
            
            # Try executing the query with LIMIT 0 to validate syntax
            test_sql = f"SELECT * FROM ({sql}) LIMIT 0"
            self.connection.session.sql(test_sql).collect()
            
        except Exception as e:
            for cb in self.callbacks:
                cb.on_error(f"SQL validation failed: {str(e)}")
            raise ValueError(f"SQL validation failed: {str(e)}")
        
    def _validate_schema_only(self, df: DataFrame) -> None:
        """Validate DataFrame schema without triggering execution"""
        try:
            from snowflake.snowpark.types import (
                StringType, IntegerType, FloatType, LongType, 
                DateType, TimestampType, DoubleType, BooleanType
            )
            
            valid_types = {
                type(StringType()): 'string',
                type(IntegerType()): 'integer',
                type(FloatType()): 'float',
                type(LongType()): 'long',
                type(DateType()): 'date',
                type(TimestampType()): 'timestamp',
                type(DoubleType()): 'double',
                type(BooleanType()): 'boolean'
            }
            
            schema = df.schema
            if not schema:
                raise ValueError("DataFrame has no schema")
                
            for field in schema.fields:
                field_type = type(field.datatype)
                if field_type not in valid_types:
                    raise ValueError(
                        f"Unsupported data type for field {field.name}: "
                        f"{field.datatype} (type: {field_type})"
                    )
                        
        except Exception as e:
            for cb in self.callbacks:
                cb.on_error(f"Schema validation failed: {str(e)}")
            raise ValueError(f"Schema validation failed: {str(e)}")



In [None]:
#| export 

@contextmanager
def feature_store_session(connection: SnowflakeConnection, 
                         schema_name: Optional[str] = None,
                         cleanup: bool = True):
    """Context manager for feature store operations"""
    schema = schema_name or f"FEATURE_STORE_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
    original_schema = connection.schema
    
    try:
        # Create schema
        connection.session.sql(f"CREATE SCHEMA IF NOT EXISTS {connection.database}.{schema}").collect()
        
        # Set the schema as current
        connection.session.sql(f"USE SCHEMA {connection.database}.{schema}").collect()
        connection.schema = schema
        
        # Create and yield manager
        manager = FeatureStoreManager(
            connection=connection,
            callbacks=[PrintCallback()],
            overwrite=True
        )
        yield manager
        
    finally:
        if cleanup:
            try:
                # Cleanup schema and all objects
                connection.session.sql(f"DROP SCHEMA IF EXISTS {connection.database}.{schema} CASCADE").collect()
            except Exception as e:
                print(f"Cleanup failed: {str(e)}")
            
            # Restore original schema
            try:
                connection.session.sql(f"USE SCHEMA {connection.database}.{original_schema}").collect()
                connection.schema = original_schema
            except Exception as e:
                print(f"Failed to restore original schema: {str(e)}")


In [None]:
# | hide
# Test setup
from snowflake_feature_store.connection import get_connection
from datetime import datetime, timezone 

def test_feature_store_manager():
    "Test feature store manager functionality"
    try:
        conn = get_connection()
        session = conn.session
        test_schema = "TEST_FEATURE_STORE"
        
        # Create unique identifier for test run to avoid conflicts
        test_id =  datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
        test_table = f"customer_test_data_{test_id}"
        
        # Create test schema and table
        try:
            session.sql(f"CREATE SCHEMA IF NOT EXISTS {test_schema}").collect()
            session.sql(f"""
                CREATE OR REPLACE TABLE {test_schema}.{test_table} (
                    CUSTOMER_ID STRING,
                    SESSION_LENGTH INT,
                    PURCHASES INT
                )
            """).collect()
            
            # Insert test data
            session.sql(f"""
                INSERT INTO {test_schema}.{test_table} VALUES
                ('C1', 10, 100),
                ('C2', 20, 200)
            """).collect()
            
            # Create manager with overwrite=True to handle existing objects
            manager = FeatureStoreManager(
                conn, 
                callbacks=[PrintCallback()],
                overwrite=True  # Add this to handle existing feature views
            )
            
            # Test entity creation
            manager.add_entity("TEST_CUSTOMER", ["CUSTOMER_ID"])
            assert "TEST_CUSTOMER" in manager.entities
            
            # Create DataFrame from permanent table
            df = session.table(f"{test_schema}.{test_table}")
            
            # Create feature view with unique name
            config = FeatureViewConfig(
                name=f"test_behavior_{test_id}",  # Make name unique
                domain="TEST",
                entity="TEST_CUSTOMER",
                feature_descriptions={
                    "SESSION_LENGTH": "Test session length",
                    "PURCHASES": "Test purchases"
                }
            )
            
            try:
                feature_view = manager.add_feature_view(config, df, "TEST_CUSTOMER")
                assert feature_view is not None
                
                # Test feature retrieval
                spine_df = session.sql(f"SELECT CUSTOMER_ID FROM {test_schema}.{test_table}")
                
                features_df = manager.get_features(
                    spine_df=spine_df,
                    feature_views=[f"{config.full_name}/{config.version}"]
                )
                assert features_df is not None
                assert features_df.count() == 2
                
            except Exception as e:
                print(f"Error in feature view creation/retrieval: {str(e)}")
                raise
                
        finally:
            # Cleanup test objects
            session.sql(f"DROP SCHEMA IF EXISTS {test_schema} CASCADE").collect()
            
    except Exception as e:
        print(f"Error in test setup: {str(e)}")
        raise


# Run tests with error handling
try:
    test_feature_store_manager()
except Exception as e:
    print(f"Test failed: {str(e)}")


In [16]:
#| hide
import nbdev; nbdev.nbdev_export()

In [9]:
# | eval: false
from snowflake_feature_store.manager import *
from snowflake_feature_store.transforms import fill_na, date_diff
from snowflake_feature_store.connection import get_connection

conn = get_connection()
session = conn.session

# Use context manager for automatic cleanup
with feature_store_session(conn) as manager:
    current_schema = manager.connection.schema  # Get the current schema
    
    # Add entity
    manager.add_entity(
        name="CUSTOMER",
        join_keys=["CUSTOMER_ID"],
        description="Customer entity"
    )
    
    # Create source table with fully qualified name
    manager.connection.session.sql(f"""
        CREATE OR REPLACE TABLE {manager.connection.database}.{current_schema}.CUSTOMER_SOURCE (
            CUSTOMER_ID STRING,
            SESSION_LENGTH INT,
            PURCHASES INT,
            SIGNUP_DATE DATE
        )
    """).collect()
    
    # Insert test data using fully qualified name
    manager.connection.session.sql(f"""
        INSERT INTO {manager.connection.database}.{current_schema}.CUSTOMER_SOURCE VALUES
        ('C1', 10, 100, '2024-01-01'),
        ('C2', 20, 200, '2024-01-02')
    """).collect()
    
    # Create DataFrame using fully qualified name
    df = manager.connection.session.table(
        f"{manager.connection.database}.{current_schema}.CUSTOMER_SOURCE"
    )
    
    # Create feature view config
    config = FeatureViewConfig(
        name="customer_behavior",
        domain="RETAIL",
        feature_type="BEHAVIOR",
        feature_descriptions={
            "SESSION_LENGTH": "Session length in minutes",
            "PURCHASES": "Number of purchases",
            "DAYS_SINCE_SIGNUP": "Days since customer signup"
        }
    )
    
    # Define transforms
    transforms = [
        fill_na(['SESSION_LENGTH']),
        date_diff('SIGNUP_DATE', 'DAYS_SINCE_SIGNUP')
    ]
    
    # Create feature view
    feature_view = manager.add_feature_view(
        config=config,
        df=df,
        entity_name="CUSTOMER",
        transforms=transforms
    )
    
    # Create spine DataFrame for feature retrieval using fully qualified name
    spine_df = manager.connection.session.sql(f"""
        SELECT CUSTOMER_ID, PURCHASES as target 
        FROM {manager.connection.database}.{current_schema}.CUSTOMER_SOURCE
    """)
    
    # Get features for training
    features_df = manager.get_features(
        spine_df=spine_df,
        feature_views=[config],
        label_cols=['target'],
        dataset_name="TRAINING_DATA"
    )
    
    # Show results
    print("\nFeature view created successfully:")
    print(f"Name: {feature_view.name}")
    print("\nRetrieved features:")
    features_df.show()


FeatureStore.get_entity() is in private preview since 1.0.8. Do not use it in production. 


Created entity: CUSTOMER with keys: ['CUSTOMER_ID']


  self._check_dynamic_table_refresh_mode(feature_view_name)


Created feature view: FV_RETAIL_CUSTOMER_BEHAVIOR with 5 features





Feature view created successfully:
Name: FV_RETAIL_CUSTOMER_BEHAVIOR

Retrieved features:
---------------------------------------------------------------------------------------------------
|"CUSTOMER_ID"  |"TARGET"  |"SESSION_LENGTH"  |"PURCHASES"  |"SIGNUP_DATE"  |"DAYS_SINCE_SIGNUP"  |
---------------------------------------------------------------------------------------------------
|C1             |100       |10                |100          |2024-01-01     |406                  |
|C2             |200       |20                |200          |2024-01-02     |405                  |
---------------------------------------------------------------------------------------------------

