Skip to content

Commit

Permalink
bround function added with tests (#16)
Browse files Browse the repository at this point in the history
* bround function addes with tests

* Function f.bround included in the Readme.md
  • Loading branch information
jroseroMN committed Jan 25, 2023
1 parent f44efbd commit 7a0009a
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ df.group_by("ID").applyInPandas(
| functions.sort_array | sorts the input array in ascending or descending order according to the natural ordering of the array elements. Null elements will be placed at the beginning of the returned array in ascending order or at the end of the returned array in descending order
| functions.map_values | Returns an unordered array containing the values of the map. |
| functions.struct | Returns an object built with the given columns |
| functions.bround | This function receives a column with a number and rounds it to scale decimal places with HALF_EVEN round mode, often called as "Banker's rounding" . This means that if the number is at the same distance from an even or odd number, it will round to the even number. |


### Examples:
Expand Down
13 changes: 12 additions & 1 deletion snowpark_extensions/functions_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ def _struct(*cols):
new_cols.append(c)
return object_construct_keep_null(*new_cols)

def _bround(col: snowflake.snowpark.Column, scale: int = 0):
power = pow(F.lit(10), F.lit(scale))
elevatedColumn = F.when(F.lit(0) == F.lit(scale), col).otherwise(col * power)
columnFloor = F.floor(elevatedColumn)
return F.when(
elevatedColumn - columnFloor == F.lit(0.5)
, F.when(columnFloor % F.lit(2) == F.lit(0), columnFloor).otherwise(columnFloor + F.lit(1))
).otherwise(F.round(elevatedColumn)) / F.when(F.lit(0) == F.lit(scale), F.lit(1)).otherwise(power)


F.array = _array
F.array_max = _array_max
F.array_min = _array_min
Expand All @@ -291,4 +301,5 @@ def _struct(*cols):
F.desc_nulls_first = lambda col: _to_col_if_str(col, "desc_nulls_first").asc()
F.sort_array = _sort_array
F.array_sort = _array_sort
F.struct = _struct
F.struct = _struct
F.bround = _bround
75 changes: 73 additions & 2 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import snowpark_extensions
from snowflake.snowpark import Session
from snowflake.snowpark.types import *
from snowflake.snowpark.functions import col,lit, array_sort,sort_array, array_max, array_min, map_values, struct,object_construct, array_agg
from snowflake.snowpark.functions import col,lit, array_sort,sort_array, array_max, array_min, map_values, struct,object_construct, array_agg, bround
from snowflake.snowpark import functions as F
import re

Expand Down Expand Up @@ -222,4 +222,75 @@ def test_struct():
res = df.select(struct(df.age.alias("A"), df.name.alias("B")).alias("struct")).collect()
assert len(res)==2
assert re.sub(r"\s","",res[0].STRUCT) == '{"A":80,"B":"Bob"}'
assert re.sub(r"\s","",res[1].STRUCT) == '{"A":null,"B":"Alice"}'
assert re.sub(r"\s","",res[1].STRUCT) == '{"A":null,"B":"Alice"}'

def test_bround():
session = Session.builder.from_snowsql().getOrCreate()
data0 = [(1.5,0),
(2.5,0),
(0.00,0),
(0.5,0),
(-1.5,0),
(-2.5,0)]

data1 = [
(2.25,1),
(2.65,1),
(0.00,1),
(1.05,1),
(1.15,1),
(-2.25,1),
(-2.35,1),
(None,1),
(1.5,1),
(1.5,-1) ]

data_null = [
(0.5,None),
(1.5,None),
(2.5,None),
(-1.5,None),
(-2.5,None),
(None,None)]
schema_df = StructType([
StructField('value', FloatType(), True),
StructField('scale', IntegerType(), True)
])

df_0 = session.createDataFrame(data0, schema_df)
df_1 = session.createDataFrame(data1, schema_df)
df_null = session.createDataFrame(data_null, schema_df)

res0 = df_0.withColumn("rounding",bround_udf(f.col('value')) ).collect()
assert len(res0) == 6
assert res0[0].ROUNDING == 2.0
assert res0[1].ROUNDING == 2.0
assert res0[2].ROUNDING == 0.0
assert res0[3].ROUNDING == 0.0
assert res0[4].ROUNDING == -2.0
assert res0[5].ROUNDING == -2.0


res1 = df_1.withColumn("rounding",bround_udf(f.col('value'),1) ).collect()
assert len(res1) == 10
assert res1[0].ROUNDING == 2.2
assert res1[1].ROUNDING == 2.6
assert res1[2].ROUNDING == 0.0
assert res1[3].ROUNDING == 1.0
assert res1[4].ROUNDING == 1.2
assert res1[5].ROUNDING == -2.2
assert res1[6].ROUNDING == -2.4
assert res1[7].ROUNDING == None
assert res1[8].ROUNDING == 1.5
assert res1[9].ROUNDING == 1.5


resNull = df_null.withColumn("rounding",bround_udf(f.col('value'),None) ).collect()
assert len(resNull) == 6
assert resNull[0].ROUNDING == None
assert resNull[1].ROUNDING == None
assert resNull[2].ROUNDING == None
assert resNull[3].ROUNDING == None
assert resNull[4].ROUNDING == None
assert resNull[5].ROUNDING == None

0 comments on commit 7a0009a

Please sign in to comment.