# NAME

> Fill in a module description here

In [1]:
#| default_exp manager

In [2]:
#| export
from __future__ import annotations
from typing import List, Optional, Dict, Union, Set
from dataclasses import dataclass
import uuid
from contextlib import contextmanager
from datetime import datetime, timezone
import networkx as nx
from pathlib import Path
import json

from snowflake.ml.feature_store import (
    FeatureStore, Entity, FeatureView, CreationMode
)
from snowflake.snowpark import DataFrame
import snowflake.snowpark.functions as F

# Import our modules
from snowflake_feature_store.connection import SnowflakeConnection
from snowflake_feature_store.feature_view import (
    FeatureViewBuilder, create_feature_view, 
    FeatureStats, FeatureMonitor
)
from snowflake_feature_store.transforms import (
    Transform, apply_transforms, TransformConfig,
    moving_agg, fill_na
)
from snowflake_feature_store.config import (
    FeatureViewConfig, FeatureConfig, 
    RefreshConfig, FeatureValidationConfig
)
from snowflake_feature_store.exceptions import (
    FeatureStoreException, EntityError, 
    FeatureViewError, ValidationError
)
from snowflake_feature_store.logging import logger




In [3]:
#| export
class FeatureStoreCallback:
    """Protocol for feature store callbacks"""
    def on_feature_view_create(
        self, name: str, df: DataFrame, stats: Dict[str, FeatureStats]
    ) -> None: ...
    
    def on_entity_create(self, name: str, keys: List[str]) -> None: ...
    def on_error(self, error: str) -> None: ...
    def on_drift_detected(
        self, feature_view: str, feature: str, metrics: Dict[str, float]
    ) -> None: ...

In [4]:
#| export
class MetricsCallback(FeatureStoreCallback):
    """Callback that logs metrics and statistics"""
    
    def __init__(self, metrics_path: Optional[Path] = None):
        self.metrics_path = metrics_path
        if metrics_path:
            metrics_path.mkdir(parents=True, exist_ok=True)
    
    def _save_metrics(self, name: str, data: Dict) -> None:
        """Save metrics to JSON file if path specified"""
        if self.metrics_path:
            timestamp = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
            file_path = self.metrics_path / f"{name}_{timestamp}.json"
            with open(file_path, 'w') as f:
                json.dump(data, f, indent=2)
    
    def on_feature_view_create(
        self, name: str, df: DataFrame, stats: Dict[str, FeatureStats]
    ) -> None:
        """Log feature view creation with statistics"""
        logger.info(f"Created feature view: {name} with {len(df.columns)} features")
        
        # Log detailed stats
        stats_data = {
            'name': name,
            'timestamp': datetime.now(timezone.utc).isoformat(),
            'feature_stats': {
                fname: fstats.model_dump()
                for fname, fstats in stats.items()
            }
        }
        self._save_metrics(f"{name}_creation", stats_data)
        
    def on_entity_create(self, name: str, keys: List[str]) -> None:
        """Log entity creation"""
        logger.info(f"Created entity: {name} with keys: {keys}")
        
    def on_error(self, error: str) -> None:
        """Log errors"""
        logger.error(f"Error: {error}")
        
    def on_drift_detected(
        self, feature_view: str, feature: str, metrics: Dict[str, float]
    ) -> None:
        """Log feature drift detection"""
        logger.warning(
            f"Drift detected in {feature_view}.{feature}: {metrics}"
        )
        self._save_metrics(
            f"{feature_view}_{feature}_drift",
            {
                'feature_view': feature_view,
                'feature': feature,
                'timestamp': datetime.now(timezone.utc).isoformat(),
                'metrics': metrics
            }
        )

In [None]:
#| export
class FeatureStoreManager:
    """Manages feature store operations with monitoring and dependency tracking"""
    
    def __init__(
        self,
        connection: SnowflakeConnection,
        callbacks: Optional[List[FeatureStoreCallback]] = None,
        metrics_path: Optional[Union[str, Path]] = None,
        overwrite: bool = False
    ):
        """Initialize feature store manager
        
        Args:
            connection: Snowflake connection
            callbacks: Optional callbacks for monitoring
            metrics_path: Optional path to save metrics
            overwrite: Whether to overwrite existing features
        """
        self.connection = connection
        self.feature_store = FeatureStore(
            session=self.connection.session,
            database=self.connection.database,
            name=self.connection.schema,
            default_warehouse=self.connection.warehouse,
            creation_mode=CreationMode.CREATE_IF_NOT_EXIST
        )
        
        # Initialize storage
        self.entities: Dict[str, Entity] = {}
        self.feature_views: Dict[str, FeatureView] = {}
        self.feature_configs: Dict[str, FeatureViewConfig] = {}
        self.feature_stats: Dict[str, Dict[str, FeatureStats]] = {}
        self.feature_transforms: Dict[str, List[Transform]] = {}  # Add this line
        self.dependencies = nx.DiGraph()
        
        # Setup callbacks
        self.callbacks = callbacks or []
        if metrics_path:
            self.callbacks.append(
                MetricsCallback(Path(metrics_path))
            )
        
        self.overwrite = overwrite
        logger.info("FeatureStoreManager initialized")
            
    def add_entity(
        self, 
        name: str, 
        join_keys: List[str], 
        description: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None
    ) -> FeatureStoreManager:
        """Add entity to feature store
        
        Args:
            name: Entity name
            join_keys: Keys used for joining
            description: Optional description
            tags: Optional metadata tags
            
        Returns:
            Self for method chaining
        """
        try:
            entity = Entity(
                name=name,
                join_keys=join_keys,
                desc=description or f"Entity {name}"
            )
            
            # Register entity
            self.feature_store.register_entity(entity)
            self.entities[name] = entity
            
            # Add tags if provided
            if tags:
                for key, value in tags.items():
                    self.feature_store.set_tag(entity, key, value)
            
            for cb in self.callbacks:
                cb.on_entity_create(name, join_keys)
                
        except Exception as e:
            error_msg = f"Error creating entity {name}: {str(e)}"
            for cb in self.callbacks:
                cb.on_error(error_msg)
            raise EntityError(error_msg)
            
        return self

    
    def add_feature_view(
        self,
        config: FeatureViewConfig,
        df: DataFrame,
        entity_name: str,
        transforms: Optional[List[Transform]] = None,
        collect_stats: bool = True
    ) -> FeatureView:
        """Add feature view to feature store with monitoring
        
        Args:
            config: Feature view configuration
            df: Source DataFrame
            entity_name: Entity name
            transforms: Optional transformations to apply
            collect_stats: Whether to collect feature statistics
        """
        try:
            # Validate schema only (no execution)
            self._validate_schema(df)
            
            # Apply transforms if provided
            if transforms:
                self.feature_transforms[config.name] = transforms
                df = apply_transforms(df, transforms)
                
            # Get entity
            entity = self.entities.get(entity_name)
            if not entity:
                raise EntityError(f"Entity {entity_name} not found")
                
            # Create feature view
            feature_view = create_feature_view(
                config=config,
                feature_df=df,
                entities=entity,
                collect_stats=collect_stats
            )
            
            # Register feature view
            registered_view = self.feature_store.register_feature_view(
                feature_view=feature_view,
                version=config.version,
                block=True,
                overwrite=self.overwrite
            )
            
            # Store view, config, and compute initial stats
            self.feature_views[config.name] = registered_view
            self.feature_configs[config.name] = config

            # Update dependency graph
            self._update_dependencies(config)
            
            # Compute and store statistics
            builder = FeatureViewBuilder(config, df, entity)
            stats = {
                name: monitor.compute_stats(df, name)
                for name, monitor in builder.monitors.items()
            }
            self.feature_stats[config.name] = stats
            
            # Notify callbacks
            for cb in self.callbacks:
                cb.on_feature_view_create(config.name, df, stats)
                
            return registered_view
            
        except Exception as e:
            error_msg = f"Error creating feature view {config.name}: {str(e)}"
            for cb in self.callbacks:
                cb.on_error(error_msg)
            raise FeatureViewError(error_msg)
    
    def check_feature_drift(
        self,
        feature_view_name: str,
        new_data: DataFrame
    ) -> Dict[str, Dict[str, float]]:
        """Check for feature drift in new data"""
        drift_results = {}
        
        try:
            # Get stored stats
            stored_stats = self.feature_stats.get(feature_view_name)
            if not stored_stats:
                raise FeatureViewError(
                    f"No baseline stats for feature view {feature_view_name}"
                )
            
            # Get feature view configuration
            config = self.feature_configs.get(feature_view_name)
            if not config:
                raise FeatureViewError(f"Feature view config {feature_view_name} not found")
                
            # Apply stored transforms to new data
            transforms = self.feature_transforms.get(feature_view_name, [])
            if transforms:
                logger.info(f"Applying {len(transforms)} transforms to new data")
                new_data_with_features = apply_transforms(new_data, transforms)
            else:
                logger.info("No transforms to apply")
                new_data_with_features = new_data
            
            # Check drift for each feature
            for feature_name, baseline_stats in stored_stats.items():
                try:
                    # Get the original feature config if it exists
                    feature_config = config.features.get(feature_name)
                    
                    # Create monitor with existing config or default
                    monitor = FeatureMonitor(
                        feature_config or FeatureConfig(
                            name=feature_name,
                            description=f"Temporary monitor for {feature_name}"
                        ),
                        collect_detailed_stats=True
                    )
                    
                    # Compute current stats and detect drift
                    current_stats = monitor.compute_stats(new_data_with_features, feature_name)
                    monitor.set_baseline(baseline_stats)
                    drift_metrics = monitor.detect_drift(current_stats)
                    
                    # Check if drift is significant
                    if any(abs(v) > 0.1 for v in drift_metrics.values()):
                        drift_results[feature_name] = drift_metrics
                        # Notify callbacks
                        for cb in self.callbacks:
                            cb.on_drift_detected(
                                feature_view_name, feature_name, drift_metrics
                            )
                            
                except Exception as e:
                    logger.warning(f"Skipping drift detection for {feature_name}: {str(e)}")
                    continue
            
            return drift_results
            
        except Exception as e:
            error_msg = f"Error checking drift for {feature_view_name}: {str(e)}"
            for cb in self.callbacks:
                cb.on_error(error_msg)
            raise FeatureViewError(error_msg)
        
    def get_feature_dependencies(self, feature_view_name: str) -> Set[str]:
        """Get dependencies for a feature view"""
        try:
            return nx.descendants(self.dependencies, feature_view_name)
        except Exception as e:
            raise FeatureViewError(
                f"Error getting dependencies for {feature_view_name}: {str(e)}"
            )
    
    def _update_dependencies(self, config: FeatureViewConfig) -> None:
        """Update dependency graph with new feature view"""
        try:
            # Add the feature view as a node
            self.dependencies.add_node(config.name)
            
            # Track dependencies from transforms
            for feature_name, feature_config in config.features.items():
                # Add each feature as a node
                feature_node = f"{config.name}.{feature_name}"
                self.dependencies.add_node(feature_node)
                
                # Add edge from feature view to feature
                self.dependencies.add_edge(config.name, feature_node)
                
                # Add dependencies between features
                if feature_config.dependencies:
                    for dep in feature_config.dependencies:
                        self.dependencies.add_edge(feature_node, dep)
                        
            logger.debug(
                f"Updated dependencies for {config.name}: "
                f"{list(self.dependencies.edges)}"
            )
        except Exception as e:
            logger.error(f"Error updating dependencies: {str(e)}")

    def get_feature_dependencies(self, feature_view_name: str) -> Set[str]:
        """Get dependencies for a feature view
        
        Args:
            feature_view_name: Name of the feature view
            
        Returns:
            Set of dependent feature names
        """
        try:
            # Get all descendants (dependencies) from the graph
            deps = nx.descendants(self.dependencies, feature_view_name)
            
            # Filter out internal feature nodes
            feature_deps = {
                dep.split('.')[0] for dep in deps 
                if '.' in dep  # Only include actual feature views
            }
            
            logger.info(
                f"Dependencies for {feature_view_name}: {feature_deps}"
            )
            return feature_deps
            
        except Exception as e:
            raise FeatureViewError(
                f"Error getting dependencies for {feature_view_name}: {str(e)}"
            )

    
    def _validate_schema(self, df: DataFrame) -> None:
        """Validate DataFrame schema without execution"""
        if not df.schema:
            raise ValidationError("DataFrame has no schema")
            
    def get_features(
        self,
        spine_df: DataFrame,
        feature_views: List[Union[str, FeatureView, FeatureViewConfig]],
        label_cols: Optional[List[str]] = None,
        dataset_name: Optional[str] = None,
        spine_timestamp_col: Optional[str] = None,
        **kwargs
    ) -> DataFrame:
        """Get features for training or inference"""
        try:
            # Debug information
            logger.info(f"Spine DataFrame columns: {spine_df.columns}")
            logger.info(f"Spine DataFrame schema: {spine_df.schema}")

            views = []
            for fv in feature_views:
                if isinstance(fv, FeatureView):
                    # Direct FeatureView object
                    views.append(fv)
                elif isinstance(fv, FeatureViewConfig):
                    # Config object - get view by name/version
                    view = self.feature_store.get_feature_view(
                        fv.name, version=fv.version
                    )
                    views.append(view)
                elif isinstance(fv, str):
                    # String reference "name/version"
                    name, version = fv.split('/')
                    view = self.feature_store.get_feature_view(name, version)
                    views.append(view)
                else:
                    raise ValueError(f"Unsupported feature view type: {type(fv)}")
            
            if dataset_name is None:
                timestamp = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
                unique_id = str(uuid.uuid4())[:8]
                dataset_name = f"DATASET_{timestamp}_{unique_id}"

            # If label_cols are provided, ensure they're properly quoted
            if label_cols:
                label_cols = [f'"{col}"' for col in label_cols]
                
            # Ensure timestamp col is quoted
            if spine_timestamp_col:
                spine_timestamp_col = f'"{spine_timestamp_col}"'

            logger.info(f"Generating dataset with name: {dataset_name}")
            logger.info(f"Label columns: {label_cols}")
            logger.info(f"Timestamp column: {spine_timestamp_col}")
                
            dataset = self.feature_store.generate_dataset(
                name=dataset_name,
                spine_df=spine_df,
                features=views,
                spine_label_cols=label_cols,
                spine_timestamp_col=spine_timestamp_col,
                **kwargs
            )
            
            return dataset.read.to_snowpark_dataframe()
            
        except Exception as e:
            error_msg = f"Error generating dataset: {str(e)}"
            for cb in self.callbacks:
                cb.on_error(error_msg)
            raise FeatureStoreException(error_msg)



In [6]:
#| export
@contextmanager
def feature_store_session(
    connection: SnowflakeConnection, 
    *,  # Force keyword arguments
    schema_name: Optional[str] = None,
    metrics_path: Optional[Union[str, Path]] = None,  # Changed type hint
    cleanup: bool = True
):
    """Context manager for feature store operations
    
    Args:
        connection: Snowflake connection
        schema_name: Optional schema name (keyword only)
        metrics_path: Optional path to save metrics (keyword only)
        cleanup: Whether to cleanup schema after use (keyword only)
    """
    schema = schema_name or (
        f"FEATURE_STORE_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}"
        f"_{uuid.uuid4().hex[:8]}"
    )
    original_schema = connection.schema
    
    try:
        # Create schema
        connection.session.sql(
            f"CREATE SCHEMA IF NOT EXISTS {connection.database}.{schema}"
        ).collect()
        
        # Set schema as current
        connection.session.sql(
            f"USE SCHEMA {connection.database}.{schema}"
        ).collect()
        connection.schema = schema
        
        # Create and yield manager with metrics path
        manager = FeatureStoreManager(
            connection=connection,
            metrics_path=metrics_path,
            overwrite=True
        )
        yield manager
        
    finally:
        if cleanup:
            try:
                # Cleanup schema and all objects
                connection.session.sql(
                    f"DROP SCHEMA IF EXISTS {connection.database}.{schema} CASCADE"
                ).collect()
                logger.info(f"Cleaned up schema {schema}")
            except Exception as e:
                logger.error(f"Cleanup failed: {str(e)}")
            
            # Restore original schema
            try:
                connection.session.sql(
                    f"USE SCHEMA {connection.database}.{original_schema}"
                ).collect()
                connection.schema = original_schema
            except Exception as e:
                logger.error(f"Failed to restore original schema: {str(e)}")


In [7]:
#| hide
import nbdev; nbdev.nbdev_export()

In [8]:
#| eval: false
from snowflake_feature_store.connection import get_connection
from snowflake_feature_store.config import (
    FeatureViewConfig, FeatureConfig, FeatureValidationConfig, RefreshConfig
)
from snowflake_feature_store.transforms import TransformConfig, moving_agg, fill_na
from snowflake_feature_store.manager import feature_store_session
import snowflake.snowpark.functions as F
from datetime import datetime
import tempfile
from pathlib import Path

# Create a temporary directory for metrics
metrics_dir = Path(tempfile.mkdtemp()) / "feature_store_metrics"

# Get connection
conn = get_connection()

# Use the feature store session context manager
with feature_store_session(conn, metrics_path=str(metrics_dir)) as manager:
    # Create sample data
    # First, create a regular table to store our data
    conn.session.sql("""
        CREATE OR REPLACE TABLE TEMP_CUSTOMER_DATA (
            CUSTOMER_ID STRING,
            DATE DATE,
            AMOUNT FLOAT,
            TRANSACTIONS INT,
            SESSION_LENGTH FLOAT
        )
    """).collect()
    
    # Insert the data using SQL
    conn.session.sql("""
        INSERT INTO TEMP_CUSTOMER_DATA VALUES
        ('C1', '2024-01-01', 100.0, 2, NULL),
        ('C1', '2024-01-02', 150.0, 3, 30.5),
        ('C2', '2024-01-01', 75.0, 1, NULL),
        ('C2', '2024-01-02', 200.0, 4, 45.2)
    """).collect()
    
    # Create feature DataFrame from the table
    df = conn.session.table("TEMP_CUSTOMER_DATA")
    
    print("\nInitial DataFrame Schema:")
    for field in df.schema.fields:
        print(f"{field.name}: {field.datatype}")
    
    print("\nSample Data:")
    df.show()
    
    # 1. Add Customer Entity
    manager.add_entity(
        name="CUSTOMER",
        join_keys=["CUSTOMER_ID"],
        description="Customer entity for retail domain"
    )
    
    # 2. Create Feature Configurations with dependencies
    feature_configs = {
        "AMOUNT": FeatureConfig(
            name="AMOUNT",
            description="Transaction amount",
            validation=FeatureValidationConfig(
                null_threshold=0.1,
                range_check=True,
                min_value=0
            ),
            dependencies=[]  # Base feature, no dependencies
        ),
        "TRANSACTIONS": FeatureConfig(
            name="TRANSACTIONS",
            description="Number of transactions",
            validation=FeatureValidationConfig(
                null_threshold=0.05,
                range_check=True,
                min_value=0
            ),
            dependencies=[]  # Base feature, no dependencies
        ),
        "SESSION_LENGTH": FeatureConfig(
            name="SESSION_LENGTH",
            description="Session length in minutes",
            validation=FeatureValidationConfig(
                null_threshold=0.3,
                range_check=True,
                min_value=0
            ),
            dependencies=[]  # Base feature, no dependencies
        ),
        "SUM_AMOUNT_2": FeatureConfig(
            name="SUM_AMOUNT_2",
            description="2-day rolling sum of amount",
            validation=FeatureValidationConfig(
                null_threshold=0.05,
                range_check=True,
                min_value=0
            ),
            dependencies=["AMOUNT"]  # Depends on AMOUNT
        ),
        "AVG_AMOUNT_2": FeatureConfig(
            name="AVG_AMOUNT_2",
            description="2-day rolling average of amount",
            validation=FeatureValidationConfig(
                null_threshold=0.05,
                range_check=True,
                min_value=0
            ),
            dependencies=["AMOUNT"]  # Depends on AMOUNT
        )
    }

    
    # 3. Create Feature View Config
    config = FeatureViewConfig(
        name="customer_behavior",
        domain="RETAIL",
        entity="CUSTOMER",
        feature_type="BEHAVIOR",
        refresh=RefreshConfig(frequency="1 day"),
        features=feature_configs,
        description="Customer behavior features",
        timestamp_col="DATE"
    )
    
    # 4. Create transforms
    transform_config = TransformConfig(
        name="amount_metrics",
        null_threshold=0.05,
        expected_types=['DECIMAL', 'DOUBLE', 'NUMBER']
    )
    
    transforms = [
        # Fill NA values in session length
        fill_na(['SESSION_LENGTH'], fill_value=0),
        
        # Calculate moving aggregations for amount
        moving_agg(
            cols='AMOUNT',
            window_sizes=[2],  # 2-day window
            agg_funcs=['SUM', 'AVG'],
            partition_by=['CUSTOMER_ID'],
            order_by=['DATE'],
            config=transform_config
        )
    ]
    
    # 5. Create Feature View
    feature_view = manager.add_feature_view(
        config=config,
        df=df,
        entity_name="CUSTOMER",
        transforms=transforms,
        collect_stats=True
    )
    
    # 6. Check for feature drift with new data
    # Create new table for drift detection
    conn.session.sql("""
        CREATE OR REPLACE TABLE TEMP_NEW_DATA (
            CUSTOMER_ID STRING,
            DATE DATE,
            AMOUNT FLOAT,
            TRANSACTIONS INT,
            SESSION_LENGTH FLOAT
        )
    """).collect()
    
    conn.session.sql("""
        INSERT INTO TEMP_NEW_DATA VALUES
        ('C1', '2024-01-03', 300.0, 5, 60.0),
        ('C2', '2024-01-03', 80.0, 1, 15.5)
    """).collect()
    
    new_df = conn.session.table("TEMP_NEW_DATA")
    
    drift_results = manager.check_feature_drift(config.name, new_df)
    print("\nDrift Detection Results:")
    for feature, metrics in drift_results.items():
        print(f"\n{feature}:")
        for metric, value in metrics.items():
            print(f"  {metric}: {value:.3f}")
    
    # 7. Get Feature Dependencies
    deps = manager.get_feature_dependencies(config.name)
    print("\nFeature Dependencies:", deps)
    
    # 8. Generate Training Dataset
    print("\nEntity join keys:")
    for entity in feature_view.entities:
        print(f"Entity {entity.name}: {entity.join_keys}")

    # Create spine DataFrame with explicit quoting
    spine_df = df.select([
        F.col('CUSTOMER_ID').alias('"CUSTOMER_ID"'),
        F.col('DATE').alias('"DATE"')
    ])

    print("\nSpine DataFrame columns:")
    print(spine_df.columns)
    print("\nSpine DataFrame schema:")
    for field in spine_df.schema.fields:
        print(f"{field.name}: {field.datatype}")

    training_data = manager.get_features(
        spine_df=spine_df,
        feature_views=[config],
        label_cols=["TRANSACTIONS"],
        spine_timestamp_col="DATE"
    )
    print("\nTraining Data Sample:")
    training_data.show(2)

    # After creating the feature view:
    print("\nFeature Statistics:")
    for feature_name, stats in manager.feature_stats[config.name].items():
        print(f"\n{feature_name}:")
        print(stats)
    
    print("\nTraining Data Schema:")
    for field in training_data.schema.fields:
        print(f"{field.name}: {field.datatype}")
    
    # Cleanup temporary tables
    conn.session.sql("DROP TABLE IF EXISTS TEMP_CUSTOMER_DATA").collect()
    conn.session.sql("DROP TABLE IF EXISTS TEMP_NEW_DATA").collect()


2025-02-17 18:13:57,583 - snowflake_feature_store - INFO - No active session found, creating new connection from environment
2025-02-17 18:13:58,557 - snowflake_feature_store - INFO - Initialized connection to "DATASCIENCE"."FEATURE_STORE_DEMO"
2025-02-17 18:14:00,880 - snowflake_feature_store - INFO - FeatureStoreManager initialized

Initial DataFrame Schema:
CUSTOMER_ID: StringType()
DATE: DateType()
AMOUNT: DoubleType()
TRANSACTIONS: LongType()
SESSION_LENGTH: DoubleType()

Sample Data:
-----------------------------------------------------------------------------
|"CUSTOMER_ID"  |"DATE"      |"AMOUNT"  |"TRANSACTIONS"  |"SESSION_LENGTH"  |
-----------------------------------------------------------------------------
|C1             |2024-01-01  |100.0     |2               |NULL              |
|C1             |2024-01-02  |150.0     |3               |30.5              |
|C2             |2024-01-01  |75.0      |1               |NULL              |
|C2             |2024-01-02  |200.0  