# NAME

> Fill in a module description here

In [18]:
#| default_exp transforms

In [19]:
#| export
from __future__ import annotations
from typing import Union, List, Callable, Optional, Protocol, Dict
from dataclasses import dataclass
from fastcore.basics import listify
import snowflake.snowpark.functions as F
from snowflake.snowpark import DataFrame, Window

In [20]:
#| export
class Transform(Protocol):
    "Protocol for feature transformations"
    def __call__(self, df: DataFrame) -> DataFrame: ...

In [21]:
#| hide

from snowflake_feature_store.connection import get_connection, ConnectionConfig

# Method 1: Get connection automatically
conn = get_connection()
session = conn.session


In [22]:
#| export
@dataclass
class WindowSpec:
    "Configuration for window-based transformations"
    partition_by: Optional[Union[str, List[str]]] = None
    order_by: Optional[Union[str, List[str]]] = None
    window_size: Optional[int] = None
    
    def __post_init__(self):
        "Convert string inputs to lists"
        if isinstance(self.partition_by, str):
            self.partition_by = [self.partition_by]
        if isinstance(self.order_by, str):
            self.order_by = [self.order_by]
            
    def to_window(self) -> Window:
        "Convert to Snowpark Window specification"
        window = Window.partition_by(self.partition_by or []) \
                      .order_by(self.order_by or [])
        
        if self.window_size:
            window = window.rows_between(
                -(self.window_size-1),
                Window.CURRENT_ROW
            )
        return window


In [23]:
#| export
def window_agg(
    agg_cols: Dict[str, List[str]],
    window_spec: WindowSpec
) -> Transform:
    """Apply window-based aggregations
    
    Args:
        agg_cols: Dictionary mapping columns to aggregation functions
        window_spec: Window specification for the aggregation
        
    Returns:
        Transform function
        
    Example:
        >>> spec = WindowSpec(partition_by='customer_id', order_by='date')
        >>> aggs = {'amount': ['SUM', 'AVG']}
        >>> window_agg(aggs, spec)(df)
    """
    def _inner(df: DataFrame) -> DataFrame:
        # Use the WindowSpec's to_window method
        window = window_spec.to_window()
            
        for col, aggs in agg_cols.items():
            for agg in aggs:
                agg_func = getattr(F, agg.lower())
                # Snowflake typically uppercases column names
                new_col = f"{agg.upper()}_{col.upper()}"
                df = df.with_column(
                    new_col,
                    agg_func(F.col(col)).over(window)
                )
        return df
    return _inner

In [24]:
#| hide
def test_window_agg():
    "Test window-based aggregations with fixed implementation"
    # Create test dataframe
    data = [
        ['C1', '2024-01-01', 100],
        ['C1', '2024-01-02', 200],
        ['C2', '2024-01-01', 150]
    ]
    df = session.create_dataframe(data, ['customer_id', 'date', 'amount'])
    
    # Test with window size
    spec = WindowSpec(
        partition_by='customer_id',
        order_by='date',
        window_size=2
    )
    
    result = window_agg(
        {'amount': ['SUM', 'AVG']},
        spec
    )(df)
    
    # Show results for debugging
    print("\nResult with window aggregations:")
    result.show()
    
    # Create a dictionary for easy lookup by customer_id and date
    results_dict = {
        (row['CUSTOMER_ID'], row['DATE']): row 
        for row in result.collect()
    }
    
    # Test C1 customer values
    c1_jan1 = results_dict[('C1', '2024-01-01')]
    c1_jan2 = results_dict[('C1', '2024-01-02')]
    assert c1_jan1['SUM_AMOUNT'] == 100, "C1's first day should sum to 100"
    assert c1_jan2['SUM_AMOUNT'] == 300, "C1's second day should sum to 300"
    
    # Test C2 customer values
    c2_jan1 = results_dict[('C2', '2024-01-01')]
    assert c2_jan1['SUM_AMOUNT'] == 150, "C2's only day should sum to 150"
    
    # Test averages
    assert c1_jan1['AVG_AMOUNT'] == 100, "C1's first day average should be 100"
    assert c1_jan2['AVG_AMOUNT'] == 150, "C1's second day average should be 150"
    assert c2_jan1['AVG_AMOUNT'] == 150, "C2's average should be 150"

test_window_agg()


Result with window aggregations:
-----------------------------------------------------------------------
|"CUSTOMER_ID"  |"DATE"      |"AMOUNT"  |"SUM_AMOUNT"  |"AVG_AMOUNT"  |
-----------------------------------------------------------------------
|C2             |2024-01-01  |150       |150           |150.000       |
|C1             |2024-01-01  |100       |100           |100.000       |
|C1             |2024-01-02  |200       |300           |150.000       |
-----------------------------------------------------------------------



In [43]:
# | export
def fill_na(cols: Union[str, List[str]], fill_value: Union[int, float, str] = 0) -> Transform:
    """Fill NA values in specified columns, matching the column type
    
    Args:
        cols: Column(s) to fill NA values in
        fill_value: Value to use for filling NAs (will be cast to match column type)
        
    Returns:
        Transform function
        
    Example:
        >>> df = session.create_dataframe([[1, None], [2, 3]], ['A', 'B'])
        >>> fill_na('B', 0)(df).collect()
        [[1, 0], [2, 3]]
    """
    cols = listify(cols)
    def _inner(df: DataFrame) -> DataFrame:
        for col in cols:
            # Find actual column name (case-insensitive)
            actual_col = next(
                (c for c in df.schema.names if c.upper() == col.upper()),
                None
            )
            if actual_col is None:
                raise ValueError(f"Column {col} not found in DataFrame")
            
            # Get column type
            col_type = df.schema[actual_col].datatype
            
            # Cast fill value to match column type
            typed_value = (
                int(fill_value) if str(col_type).startswith(('Long', 'Int')) 
                else float(fill_value) if str(col_type).startswith('Double') 
                else str(fill_value)
            )
            
            df = df.na.fill({actual_col: typed_value})
        return df
    return _inner


In [44]:
#| hide
def test_fill_na():
    "Test fill_na works with different case column names"
    from snowflake.snowpark.types import StructType, StructField, LongType
    
    schema = StructType([
        StructField("COLUMN_A", LongType()),
        StructField("COLUMN_B", LongType())
    ])
    data = [[1, None], [2, 3]]
    df = session.create_dataframe(data, schema)
    
    # Test with different cases
    result1 = fill_na('COLUMN_B', 0)(df).collect()
    result2 = fill_na('column_b', 0)(df).collect()
    
    assert result1[0].COLUMN_B == 0
    assert result2[0].COLUMN_B == 0

# Run all tests
test_fill_na()

In [28]:

#| export

def date_diff(
    col: str,
    new_col: str,
    reference_date: Optional[str] = None,
    date_part: str = 'day'
) -> Transform:
    """Calculate date difference between column and reference
    
    Args:
        col: Date column to calculate difference from
        new_col: Name for the new difference column
        reference_date: Reference date (default: current_date)
        date_part: Part to calculate difference in ('day', 'month', etc.)
        
    Returns:
        Transform function
        
    Example:
        >>> df = session.create_dataframe([['2024-01-01']], ['date'])
        >>> date_diff('date', 'days_ago', '2024-02-01')(df).collect()
        [['2024-01-01', 31]]
    """
    def _inner(df: DataFrame) -> DataFrame:
        # Ensure column exists (case-insensitive)
        col_actual = [c for c in df.columns if c.upper() == col.upper()][0]
        
        reference = (F.to_date(F.lit(reference_date)) 
                    if isinstance(reference_date, str)
                    else F.current_date())
        
        return df.with_column(
            new_col.upper(),  # Standardize output column name
            F.datediff(date_part, F.col(col_actual), reference)
        )
    return _inner

In [29]:
#| hide
def test_date_diff():
    "Test date difference calculation"
    # Create test dataframe with proper schema
    from snowflake.snowpark.types import StructType, StructField, DateType
    
    schema = StructType([StructField("DATE", DateType())])
    data = [['2024-01-01'], ['2024-02-01']]
    df = session.create_dataframe(data, schema)
    
    # Apply transformation
    transformed_df = date_diff('DATE', 'DAYS_AGO', '2024-03-01')(df)
    
    # Show the result for debugging
    print("Transformed DataFrame Schema:")
    print(transformed_df.schema)
    print("\nTransformed DataFrame Data:")
    transformed_df.show()
    
    # Collect results and verify
    result = transformed_df.collect()
    
    # Access columns case-insensitively
    days_ago_col = [col for col in result[0].asDict().keys() if col.upper() == 'DAYS_AGO'][0]
    
    # Verify results
    assert result[0][days_ago_col] == 60
    assert result[1][days_ago_col] == 29

# Run the test
test_date_diff()

Transformed DataFrame Schema:
StructType([StructField('DATE', DateType(), nullable=True), StructField('DAYS_AGO', LongType(), nullable=True)])

Transformed DataFrame Data:
---------------------------
|"DATE"      |"DAYS_AGO"  |
---------------------------
|2024-01-01  |60          |
|2024-02-01  |29          |
---------------------------



In [31]:
#| export
def moving_agg(
    cols: Union[str, List[str]],
    window_sizes: List[int],
    agg_funcs: List[str] = ['SUM', 'AVG'],
    partition_by: Optional[List[str]] = None,
    order_by: Optional[List[str]] = None
) -> Transform:
    """Calculate moving window aggregations
    
    Args:
        cols: Columns to aggregate
        window_sizes: List of window sizes
        agg_funcs: List of aggregation functions
        partition_by: Columns to partition by
        order_by: Columns to order by
        
    Returns:
        Transform function
        
    Example:
        >>> moving_agg('amount', [3, 7], ['SUM'], ['customer_id'], ['date'])(df)
    """
    cols = listify(cols)
    def _inner(df: DataFrame) -> DataFrame:
        for col in cols:
            for size in window_sizes:
                spec = WindowSpec(
                    partition_by=partition_by,
                    order_by=order_by,
                    window_size=size
                )
                aggs = {col: agg_funcs}
                df = window_agg(aggs, spec)(df)
        return df
    return _inner


In [39]:
#| export
def cumulative_agg(
    cols: Union[str, List[str]],
    agg_funcs: List[str] = ['SUM'],
    partition_by: Optional[List[str]] = None,
    order_by: Optional[List[str]] = None
) -> Transform:
    """Calculate cumulative aggregations
    
    Args:
        cols: Columns to aggregate
        agg_funcs: List of aggregation functions
        partition_by: Columns to partition by
        order_by: Columns to order by
        
    Returns:
        Transform function
        
    Example:
        >>> cumulative_agg('amount', ['SUM'], ['customer_id'], ['date'])(df)
    """
    return moving_agg(
        cols=cols,
        window_sizes=[None],  # None means unbounded
        agg_funcs=agg_funcs,
        partition_by=partition_by,
        order_by=order_by
    )

In [40]:
#| export
def apply_transforms(df: DataFrame, transforms: List[Transform]) -> DataFrame:
    """Apply a list of transformations to a DataFrame
    
    Args:
        df: Input DataFrame
        transforms: List of transform functions to apply
        
    Returns:
        Transformed DataFrame
        
    Example:
        >>> transforms = [
        ...     fill_na(['score']),
        ...     date_diff('date', 'days_ago')
        ... ]
        >>> apply_transforms(df, transforms)
    """
    for transform in transforms:
        df = transform(df)
    return df


In [41]:
#| hide
def test_apply_transforms():
    "Test applying multiple transforms"
    # Create test data
    data = [
        ['C1', '2024-01-01', 100, None],
        ['C1', '2024-01-02', 200, 3.5]
    ]
    df = session.create_dataframe(data, ['customer_id', 'date', 'amount', 'score'])
    
    # Define transforms
    transforms = [
        fill_na(['score']),
        date_diff('date', 'days_ago', '2024-02-01')
    ]
    
    # Apply transforms
    result = apply_transforms(df, transforms)
    
    # Verify results
    collected = result.collect()
    assert collected[0]['SCORE'] == 0  # NA was filled
    assert collected[0]['DAYS_AGO'] == 31  # Date diff was calculated

test_apply_transforms()


In [45]:
#| hide
import nbdev; nbdev.nbdev_export()

In [38]:
#| eval: false
# Example usage
from snowflake_feature_store.transforms import *

# Example usage with apply_transforms
df = session.create_dataframe([
    ['C1', '2024-01-01', 100, None],
    ['C1', '2024-01-02', 200, 3.5],
    ['C2', '2024-01-01', 150, None]
], ['customer_id', 'date', 'amount', 'score'])

df_transformed = apply_transforms(df, [
    fill_na(['score']),
    date_diff('date', 'days_ago', '2024-02-01'),
    moving_agg(
        'amount',
        window_sizes=[3],  # Simplified for testing
        partition_by=['customer_id'],
        order_by=['date']
    ),
    cumulative_agg(
        'amount',
        partition_by=['customer_id'],
        order_by=['date']
    )
])

print("\nTransformed DataFrame:")
df_transformed.show()




Transformed DataFrame:
----------------------------------------------------------------------------------------------
|"CUSTOMER_ID"  |"DATE"      |"AMOUNT"  |"SCORE"  |"DAYS_AGO"  |"AVG_AMOUNT"  |"SUM_AMOUNT"  |
----------------------------------------------------------------------------------------------
|C1             |2024-01-01  |100       |0.0      |31          |100.000       |100           |
|C1             |2024-01-02  |200       |3.5      |30          |150.000       |300           |
|C2             |2024-01-01  |150       |0.0      |31          |150.000       |150           |
----------------------------------------------------------------------------------------------

