# Generate base data

In [1]:
import polars as pl
import numpy as np
from scipy.stats import norm
from scipy.stats import expon

seed = 50

df_base = pl.DataFrame(
    {
        'col1_norm': norm.rvs(loc=0, size=1000, random_state=seed),
        'col2_norm_null': pl.concat([pl.Series(norm.rvs(loc=0, size=950, random_state=seed)), pl.Series(np.full(50, None), dtype=pl.Float64)], how='vertical'),
        'col3_norm_null_default': np.concatenate([norm.rvs(loc=0, size=950, random_state=seed), np.full(50, -1)]),
        'col4_str_abc': ['a'] * 250 + ['b'] * 500 + ['c'] * 250, # categorical column, String type
        'col5_str_abc_null': ['a'] * 490 + ['b'] * 240 + ['c'] * 240 + [None] * 30, # categorical column with nulls
        'col6_binary': [1] * 250 + [0] * 750, # binary categorical
        'col7_binary_null': [1] * 250 + [0] * 715 + [None] * 35, # binary with null
        'col8_stacked_at_0': np.concatenate([[0] * 500, expon.rvs(size=500, random_state=seed)]), # Overlapping bins (highly skewed toward 0)
        'col9_stacked_at_1': np.concatenate([[1] * 500, expon.rvs(size=500, random_state=seed)]), # Overlapping bins (highly skewed toward 0)
        'col10_discrete_numeric': [1] * 400 + [2] * 150 + [3] * 150 + [4] * 250 + [5] * 50, # Numeric that takes on integer values
        'col11_cat_missing_level': [1] * 500 + [0] * 500, # categorical missing levels
        'col12_all_nulls_numeric': [None] * 1000, # all nulls, dtype is f64
        'col13_numeric_nan': pl.concat([pl.Series(norm.rvs(loc=0, size=950, random_state=seed)), pl.Series(np.full(50, float('inf')), dtype=pl.Float64)], how='vertical'),
        'col14_numeric_constant': [1] * 1000, # constant column, no change between base and compare
        'col15_numeric_all_nan': [float('inf')] * 1000,
        'col16_cat_abc': pl.Series(['a']* 250 + ['b'] * 250 + ['c'] * 500, dtype=pl.Categorical), # categorical dtype
        'col17_bool': [True] * 500 + [False] * 500,
        'col18_nan': [float('nan')] * 1000, # nan data
        'col19_constant': [1.5] * 1000, # constant column
        'col20_constant_chng': [1.5] * 1000 # Constant value with distribution change
        # constant categorical column
        # Extremely skewed numeric that shifts. E.g., only 1-2 bins defined for base, compare shifts up
    },
    schema_overrides={'col12_all_nulls_numeric': pl.Float64}
)

df_base.write_csv('./base_data.csv')

In [2]:
df_base.schema

Schema([('col1_norm', Float64),
        ('col2_norm_null', Float64),
        ('col3_norm_null_default', Float64),
        ('col4_str_abc', String),
        ('col5_str_abc_null', String),
        ('col6_binary', Int64),
        ('col7_binary_null', Int64),
        ('col8_stacked_at_0', Float64),
        ('col9_stacked_at_1', Float64),
        ('col10_discrete_numeric', Int64),
        ('col11_cat_missing_level', Int64),
        ('col12_all_nulls_numeric', Float64),
        ('col13_numeric_nan', Float64),
        ('col14_numeric_constant', Int64),
        ('col15_numeric_all_nan', Float64),
        ('col16_cat_abc', Categorical),
        ('col17_bool', Boolean),
        ('col18_nan', Float64),
        ('col19_constant', Float64),
        ('col20_constant_chng', Float64)])

# Generate compare data

In [3]:
seed = 60

df_compare = pl.DataFrame(
    {
        'col1_norm': norm.rvs(loc=0, size=1000, random_state=seed),
        'col2_norm_null': pl.concat([pl.Series(norm.rvs(loc=0, size=950, random_state=seed)), pl.Series(np.full(50, None), dtype=pl.Float64)], how='vertical'),
        'col3_norm_null_default': np.concatenate([norm.rvs(loc=0, size=950, random_state=seed), np.full(50, -1)]),
        'col4_str_abc': ['a'] * 225 + ['b'] * 535 + ['c'] * 240, # categorical column
        'col5_str_abc_null': ['a'] * 480 + ['b'] * 230 + ['c'] * 230 + [None] * 60, # categorical column with nulls
        'col6_binary': [1] * 230 + [0] * 770, # binary categorical
        'col7_binary_null': [1] * 270 + [0] * 715 + [None] * 15, # binary with null
        'col8_stacked_at_0': np.concatenate([[0] * 300, expon.rvs(size=700, random_state=seed)]), # Overlapping bins (highly skewed toward 0)
        'col9_stacked_at_1': np.concatenate([[1] * 300, expon.rvs(size=700, random_state=seed)]), # Overlapping bins (highly skewed toward 0)
        'col10_discrete_numeric': [1] * 300 + [2] * 225 + [3] * 125 + [4] * 300 + [5] * 50, # Numeric that takes on integer values
        'col11_cat_missing_level': [1] * 500 + [None] * 500, # categorical missing levels
        'col12_all_nulls_numeric': [None] * 1000,
        'col13_numeric_nan': pl.concat([pl.Series(norm.rvs(loc=0, size=950, random_state=seed)), pl.Series(np.full(50, float('inf')), dtype=pl.Float64)], how='vertical'),
        'col14_numeric_constant': [1] * 1000, # constant column
        'col15_numeric_all_nan': [float('inf')] * 1000,
        'col16_cat_abc': pl.Series(['a']* 250 + ['b'] * 250 + ['c'] * 500, dtype=pl.Categorical), # categorical dtype
        'col17_bool': [True] * 500 + [False] * 500,
        'col18_nan': [float('nan')] * 1000, # nan data
        'col19_constant': [1.5] * 1000, # constant column
        'col20_constant_chng': [2.] * 1000 # Constant value with distribution change
    },
    schema_overrides={'col12_all_nulls_numeric': pl.Float64}
)
df_compare.write_csv('./compare_data.csv')

# Descriptive stats

In [4]:
df_base.describe(percentiles=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])

statistic,col1_norm,col2_norm_null,col3_norm_null_default,col4_str_abc,col5_str_abc_null,col6_binary,col7_binary_null,col8_stacked_at_0,col9_stacked_at_1,col10_discrete_numeric,col11_cat_missing_level,col12_all_nulls_numeric,col13_numeric_nan,col14_numeric_constant,col15_numeric_all_nan,col16_cat_abc,col17_bool,col18_nan,col19_constant,col20_constant_chng
str,f64,f64,f64,str,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,f64,f64,f64,f64
"""count""",1000.0,950.0,1000.0,"""1000""","""970""",1000.0,965.0,1000.0,1000.0,1000.0,1000.0,0.0,1000.0,1000.0,1000.0,"""1000""",1000.0,1000.0,1000.0,1000.0
"""null_count""",0.0,50.0,0.0,"""0""","""30""",0.0,35.0,0.0,0.0,0.0,0.0,1000.0,0.0,0.0,0.0,"""0""",0.0,0.0,0.0,0.0
"""mean""",-0.023608,-0.031435,-0.079863,,,0.25,0.259067,0.470319,0.970319,2.4,0.5,,inf,1.0,inf,,0.5,,1.5,1.5
"""std""",1.004011,1.000938,0.998167,,,0.433229,0.43835,0.791218,0.636778,1.357145,0.50025,,,0.0,,,,,0.0,0.0
"""min""",-3.80989,-3.80989,-3.80989,"""a""","""a""",0.0,0.0,0.0,0.002377,1.0,0.0,,-3.80989,1.0,inf,,0.0,,1.5,1.5
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""60%""",0.235721,0.231626,0.16197,,,0.0,0.0,0.239969,1.0,3.0,1.0,,0.322008,1.0,inf,,,,1.5,1.5
"""70%""",0.49399,0.486428,0.438958,,,0.0,0.0,0.482875,1.0,3.0,1.0,,0.598944,1.0,inf,,,,1.5,1.5
"""80%""",0.805189,0.791393,0.746569,,,1.0,1.0,0.876571,1.0,4.0,1.0,,0.920007,1.0,inf,,,,1.5,1.5
"""90%""",1.262733,1.238124,1.222179,,,1.0,1.0,1.522498,1.522498,4.0,1.0,,1.603033,1.0,inf,,,,1.5,1.5


In [5]:
df_compare.describe(percentiles=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])

statistic,col1_norm,col2_norm_null,col3_norm_null_default,col4_str_abc,col5_str_abc_null,col6_binary,col7_binary_null,col8_stacked_at_0,col9_stacked_at_1,col10_discrete_numeric,col11_cat_missing_level,col12_all_nulls_numeric,col13_numeric_nan,col14_numeric_constant,col15_numeric_all_nan,col16_cat_abc,col17_bool,col18_nan,col19_constant,col20_constant_chng
str,f64,f64,f64,str,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,f64,f64,f64,f64
"""count""",1000.0,950.0,1000.0,"""1000""","""940""",1000.0,985.0,1000.0,1000.0,1000.0,500.0,0.0,1000.0,1000.0,1000.0,"""1000""",1000.0,1000.0,1000.0,1000.0
"""null_count""",0.0,50.0,0.0,"""0""","""60""",0.0,15.0,0.0,0.0,0.0,500.0,1000.0,0.0,0.0,0.0,"""0""",0.0,0.0,0.0,0.0
"""mean""",0.00995,0.009835,-0.040657,,,0.23,0.274112,0.670813,0.970813,2.575,1.0,,inf,1.0,inf,,0.5,,1.5,2.0
"""std""",1.0168,1.024542,1.022564,,,0.421043,0.446292,0.947139,0.839281,1.321409,0.0,,,0.0,,,,,0.0,0.0
"""min""",-2.8425,-2.8425,-2.8425,"""a""","""a""",0.0,0.0,0.0,0.001916,1.0,1.0,,-2.8425,1.0,inf,,0.0,,1.5,2.0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""60%""",0.240395,0.248972,0.192405,,,0.0,0.0,0.544848,1.0,3.0,1.0,,0.334884,1.0,inf,,,,1.5,2.0
"""70%""",0.52876,0.533127,0.478714,,,0.0,0.0,0.783957,1.0,4.0,1.0,,0.693443,1.0,inf,,,,1.5,2.0
"""80%""",0.874532,0.875493,0.857939,,,1.0,1.0,1.146046,1.146046,4.0,1.0,,1.07782,1.0,inf,,,,1.5,2.0
"""90%""",1.344515,1.363768,1.316526,,,1.0,1.0,1.828824,1.828824,4.0,1.0,,1.69357,1.0,inf,,,,1.5,2.0


# Interactive testing

In [6]:
import sys

sys.path.append('../')
from psi import psi
import polars as pl

pl.Config.set_tbl_rows(-1)
pl.Config.set_fmt_float('full')

df_base = pl.read_csv('./base_data.csv', schema_overrides=df_base.schema)
df_compare = pl.read_csv('./compare_data.csv', schema_overrides=df_compare.schema)

In [7]:
df_base.schema

Schema([('col1_norm', Float64),
        ('col2_norm_null', Float64),
        ('col3_norm_null_default', Float64),
        ('col4_str_abc', String),
        ('col5_str_abc_null', String),
        ('col6_binary', Int64),
        ('col7_binary_null', Int64),
        ('col8_stacked_at_0', Float64),
        ('col9_stacked_at_1', Float64),
        ('col10_discrete_numeric', Int64),
        ('col11_cat_missing_level', Int64),
        ('col12_all_nulls_numeric', Float64),
        ('col13_numeric_nan', Float64),
        ('col14_numeric_constant', Int64),
        ('col15_numeric_all_nan', Float64),
        ('col16_cat_abc', Categorical),
        ('col17_bool', Boolean),
        ('col18_nan', Float64),
        ('col19_constant', Float64),
        ('col20_constant_chng', Float64)])

In [8]:
numeric_columns = [
    'col1_norm',
    'col2_norm_null',
    'col8_stacked_at_0',
    'col9_stacked_at_1',
    'col10_discrete_numeric',
    'col12_all_nulls_numeric',
    'col13_numeric_nan',
    'col14_numeric_constant',
    'col15_numeric_all_nan',
    'col18_nan',
    'col19_constant',
    'col20_constant_chng'
]

categorical_columns = [
    'col4_str_abc',
    'col5_str_abc_null',
    'col6_binary',
    'col7_binary_null',
    'col11_cat_missing_level',
    'col16_cat_abc',
    'col17_bool'
]

df_psi, df_base_freq, df_compare_freq = psi(
    df_base=df_base,
    df_compare=df_compare,
    bins=10,
    include_nulls=False,
    numeric_columns=numeric_columns,
    categorical_columns=categorical_columns
)

In [9]:
df_psi

attribute,psi
str,f64
"""col1_norm""",0.0084225340421378
"""col2_norm_null""",0.0084478363065453
"""col8_stacked_at_0""",0.1775698373678918
"""col9_stacked_at_1""",0.1807236258182166
"""col10_discrete_numeric""",0.072852207112837
"""col12_all_nulls_numeric""",0.0
"""col13_numeric_nan""",0.0083809042668219
"""col14_numeric_constant""",0.0
"""col15_numeric_all_nan""",0.0
"""col18_nan""",0.0


In [10]:
df_base_freq

col1_norm,col2_norm_null,col8_stacked_at_0,col9_stacked_at_1,col10_discrete_numeric,col12_all_nulls_numeric,col13_numeric_nan,col14_numeric_constant,col15_numeric_all_nan,col18_nan,col19_constant,col20_constant_chng,col4_str_abc,col5_str_abc_null,col6_binary,col7_binary_null,col11_cat_missing_level,col16_cat_abc,col17_bool
struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2]
"{""[-inf, -1.3013769564701396]"",100}","{""[-inf, -1.2895778994315654]"",95}","{""[-inf, 0]"",500}","{""[-inf, 0.24146240139778277]"",100}","{""[-inf, 1]"",400}","{""[-inf, inf]"",0}","{""[-inf, -1.2823503147828819]"",100}","{""[-inf, 1]"",1000}","{""[-inf, inf]"",1000}","{""[-inf, NaN]"",0}","{""[-inf, 1.5]"",1000}","{""[-inf, 1.5]"",1000}","{""a"",250}","{""a"",490}","{""0"",750}","{""0"",715}","{""0"",500}","{""a"",250}","{""false"",500}"
"{""(-1.3013769564701396, -0.8714829034478315]"",100}","{""(-1.2895778994315654, -0.8761616817498066]"",95}","{""(0, 0.0011886445621633556]"",0}","{""(0.24146240139778277, 0.48452393915442477]"",100}","{""(1, 1.6000000000000227]"",0}",,"{""(-1.2823503147828819, -0.8282133365946046]"",100}","{""(1, inf]"",0}",,"{""(NaN, inf]"",0}","{""(1.5, inf]"",0}","{""(1.5, inf]"",0}","{""b"",500}","{""b"",240}","{""1"",250}","{""1"",250}","{""1"",500}","{""b"",250}","{""true"",500}"
"{""(-0.8714829034478315, -0.5408355044403768]"",100}","{""(-0.8761616817498066, -0.5417661688027671]"",95}","{""(0.0011886445621633556, 0.24063260366920644]"",100}","{""(0.48452393915442477, 0.8771514893399184]"",100}","{""(1.6000000000000227, 2]"",150}",,"{""(-0.8282133365946046, -0.5146956176062832]"",100}",,,,,,"{""c"",250}","{""c"",240}",,"{null,35}",,"{""c"",500}",
"{""(-0.5408355044403768, -0.2706955076429412]"",100}","{""(-0.5417661688027671, -0.2748712829586359]"",95}","{""(0.24063260366920644, 0.4834933934050441]"",100}","{""(0.8771514893399184, 1]"",528}","{""(2, 3]"",150}",,"{""(-0.5146956176062832, -0.24397065347664926]"",100}",,,,,,,"{null,30}",,,,,
"{""(-0.2706955076429412, -0.019756045841092827]"",100}","{""(-0.2748712829586359, -0.03150085831307207]"",95}","{""(0.4834933934050441, 0.8767367768967448]"",100}","{""(1, 1.5235002676568767]"",72}","{""(3, 3.300000000000068]"",0}",,"{""(-0.24397065347664926, 0.03382839844449116]"",100}",,,,,,,,,,,,
"{""(-0.019756045841092827, 0.23611511591281759]"",100}","{""(-0.03150085831307207, 0.2326668924514275]"",95}","{""(0.8767367768967448, 1.5235002676568767]"",100}","{""(1.5235002676568767, inf]"",100}","{""(3.300000000000068, 4]"",250}",,"{""(0.03382839844449116, 0.32203795303852567]"",100}",,,,,,,,,,,,
"{""(0.23611511591281759, 0.49443665687066946]"",100}","{""(0.2326668924514275, 0.48720290157795404]"",95}","{""(1.5235002676568767, inf]"",100}",,"{""(4, inf]"",50}",,"{""(0.32203795303852567, 0.602140354859984]"",100}",,,,,,,,,,,,
"{""(0.49443665687066946, 0.8057441309517726]"",100}","{""(0.48720290157795404, 0.7925975400857329]"",95}",,,,,"{""(0.602140354859984, 0.9200824611139095]"",100}",,,,,,,,,,,,
"{""(0.8057441309517726, 1.2633420474920432]"",100}","{""(0.7925975400857329, 1.2383348516011667]"",95}",,,,,"{""(0.9200824611139095, 1.6038225597738998]"",100}",,,,,,,,,,,,
"{""(1.2633420474920432, inf]"",100}","{""(1.2383348516011667, inf]"",95}",,,,,"{""(1.6038225597738998, inf]"",100}",,,,,,,,,,,,


In [11]:
df_compare_freq

col1_norm,col2_norm_null,col8_stacked_at_0,col9_stacked_at_1,col10_discrete_numeric,col12_all_nulls_numeric,col13_numeric_nan,col14_numeric_constant,col15_numeric_all_nan,col18_nan,col19_constant,col20_constant_chng,col4_str_abc,col5_str_abc_null,col6_binary,col7_binary_null,col11_cat_missing_level,col16_cat_abc,col17_bool
struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2],struct[2]
"{""[-inf, -1.3013769564701396]"",98}","{""[-inf, -1.2895778994315654]"",95}","{""[-inf, 0]"",300}","{""[-inf, 0.24146240139778277]"",152}","{""[-inf, 1]"",300}","{""[-inf, inf]"",0}","{""[-inf, -1.2823503147828819]"",96}","{""[-inf, 1]"",1000}","{""[-inf, inf]"",1000}","{""[-inf, NaN]"",0}","{""[-inf, 1.5]"",1000}","{""[-inf, 1.5]"",0}","{""a"",225}","{""a"",480}","{""0"",770}","{""0"",715}","{""1"",500}","{""a"",250}","{""false"",500}"
"{""(-1.3013769564701396, -0.8714829034478315]"",97}","{""(-1.2895778994315654, -0.8761616817498066]"",91}","{""(0, 0.0011886445621633556]"",0}","{""(0.24146240139778277, 0.48452393915442477]"",122}","{""(1, 1.6000000000000227]"",0}",,"{""(-1.2823503147828819, -0.8282133365946046]"",106}","{""(1, inf]"",0}",,"{""(NaN, inf]"",0}","{""(1.5, inf]"",0}","{""(1.5, inf]"",1000}","{""b"",535}","{""b"",230}","{""1"",230}","{""1"",270}","{null,500}","{""b"",250}","{""true"",500}"
"{""(-0.8714829034478315, -0.5408355044403768]"",95}","{""(-0.8761616817498066, -0.5417661688027671]"",94}","{""(0.0011886445621633556, 0.24063260366920644]"",152}","{""(0.48452393915442477, 0.8771514893399184]"",164}","{""(1.6000000000000227, 2]"",225}",,"{""(-0.8282133365946046, -0.5146956176062832]"",91}",,,,,,"{""c"",240}","{""c"",230}",,"{null,15}",,"{""c"",500}",
"{""(-0.5408355044403768, -0.2706955076429412]"",103}","{""(-0.5417661688027671, -0.2748712829586359]"",96}","{""(0.24063260366920644, 0.4834933934050441]"",121}","{""(0.8771514893399184, 1]"",323}","{""(2, 3]"",125}",,"{""(-0.5146956176062832, -0.24397065347664926]"",95}",,,,,,,"{null,60}",,,,,
"{""(-0.2706955076429412, -0.019756045841092827]"",105}","{""(-0.2748712829586359, -0.03150085831307207]"",90}","{""(0.4834933934050441, 0.8767367768967448]"",165}","{""(1, 1.5235002676568767]"",103}","{""(3, 3.300000000000068]"",0}",,"{""(-0.24397065347664926, 0.03382839844449116]"",106}",,,,,,,,,,,,
"{""(-0.019756045841092827, 0.23611511591281759]"",100}","{""(-0.03150085831307207, 0.2326668924514275]"",99}","{""(0.8767367768967448, 1.5235002676568767]"",126}","{""(1.5235002676568767, inf]"",136}","{""(3.300000000000068, 4]"",300}",,"{""(0.03382839844449116, 0.32203795303852567]"",102}",,,,,,,,,,,,
"{""(0.23611511591281759, 0.49443665687066946]"",91}","{""(0.2326668924514275, 0.48720290157795404]"",86}","{""(1.5235002676568767, inf]"",136}",,"{""(4, inf]"",50}",,"{""(0.32203795303852567, 0.602140354859984]"",86}",,,,,,,,,,,,
"{""(0.49443665687066946, 0.8057441309517726]"",83}","{""(0.48720290157795404, 0.7925975400857329]"",81}",,,,,"{""(0.602140354859984, 0.9200824611139095]"",92}",,,,,,,,,,,,
"{""(0.8057441309517726, 1.2633420474920432]"",113}","{""(0.7925975400857329, 1.2383348516011667]"",102}",,,,,"{""(0.9200824611139095, 1.6038225597738998]"",118}",,,,,,,,,,,,
"{""(1.2633420474920432, inf]"",115}","{""(1.2383348516011667, inf]"",116}",,,,,"{""(1.6038225597738998, inf]"",108}",,,,,,,,,,,,


In [12]:
df_manual_psi = pl.read_csv('./manual_psi_results.csv')
df_manual_psi = df_manual_psi.rename({'psi': 'psi_manual'})

df_psi = df_psi.rename({'psi': 'psi_fn'}, strict=False)

df_manual_psi = df_manual_psi.join(df_psi, on=['attribute'], suffix='_fn', how='left')
df_manual_psi = df_manual_psi.with_columns(
    (((pl.col('psi_manual') - pl.col('psi_fn').abs()) <= 1e-8) | (pl.col('psi_manual').is_null() & pl.col('psi_fn').is_null())).alias('is_equal')
)
df_manual_psi

attribute,psi_manual,psi_fn,is_equal
str,f64,f64,bool
"""col1_norm""",0.008422534042,0.0084225340421378,True
"""col2_norm_null""",0.008447836307,0.0084478363065453,True
"""col8_stacked_at_0""",0.1775698374,0.1775698373678918,True
"""col9_stacked_at_1""",0.1807236258,0.1807236258182166,True
"""col10_discrete_numeric""",0.07285220711,0.072852207112837,True
"""col4_str_abc""",0.005410285533,0.0054102855332317,True
"""col5_str_abc_null""",0.02185180058,0.0218518005772016,True
"""col6_binary""",0.002193978345,0.0021939783451284,True
"""col7_binary_null""",0.01848517803,0.0184851780304666,True
"""col11_cat_missing_level""",,,True


# Debug

In [14]:
df_base = pl.read_csv('./base_data.csv', schema_overrides=df_base.schema)
df_compare = pl.read_csv('./compare_data.csv', schema_overrides=df_compare.schema)

numeric_columns = [
    'col1_norm',
    'col2_norm_null',
    'col8_stacked_at_0',
    'col9_stacked_at_1',
    'col10_discrete_numeric',
    'col12_all_nulls_numeric'
]

categorical_columns = [
    'col4_str_abc',
     'col5_str_abc_null',
     'col6_binary',
     'col7_binary_null',
    'col11_cat_missing_level'
]
bins = 10

# Initialize frequency tables and Lazy dataframes
df_base_num_count = pl.DataFrame()
df_compare_num_count = pl.DataFrame()
df_base_cat_freq = pl.DataFrame()
df_compare_cat_freq = pl.DataFrame()
ldf_base = df_base.lazy()
ldf_compare = df_compare.lazy()

In [15]:
# Get bins from base using quantiles. Returning as a Series ensures a consistent datatype of the elements in the Sequence.
quantiles = pl.linear_space(0, 1, bins + 1, eager=True).to_list()
dict_cols_edges = (
    pl.concat(
        [ldf_base.select(pl.col(numeric_columns)).quantile(q, interpolation='linear') for q in quantiles], how='vertical'
    )
    .collect()
    .to_dict(as_series=True)
)

# Preprocess bins
for col, edges in dict_cols_edges.items():

    # Edge case: if first edge is an integer, .hist casts the result column as u64, resulting in a TypeError if the other edges are floats
    # Explicitly cast all values to float
    # TODO: issue casting int to float? e.g., 1.000000001

    # What properties of the bins does .hist require?
        # edges must be strictly monotonic
        # edges must be numeric
        # edges cannot be null
            # when will edges be null? E.g., when col is all null
            # Is it possible for there to be a single null? Or if there's a null, will the entire series be null?
        # How many edges?
        # hist can handle empty sequences

    # .hist requires numeric type
    dict_cols_edges[col] = dict_cols_edges[col].cast(pl.Float64)

    # Slightly redundant since first and last edge is defined earlier
    dict_cols_edges[col][0] = float('-inf')
    dict_cols_edges[col][-1] = float('inf')

    # Ensure strict monotonicity. E.g., [0, 0, 0, 1.5, 2.1, 4.]
    dict_cols_edges[col] = dict_cols_edges[col].unique(maintain_order=True)

    # Remove nulls if they exist
    dict_cols_edges[col] = dict_cols_edges[col].drop_nulls()

# Get bin counts for each df
# TODO: .hist() can fail if all values are NaN or identical, None — needs explicit handling?
list_ldfs_counts = [
    pl.concat(
    [ldf.select(pl.col(col).hist(bins=edges, include_category=True)) for col, edges in dict_cols_edges.items()],
    how='horizontal'
)
    for ldf in [ldf_base, ldf_compare]
]

df_base_num_count, df_compare_num_count = pl.collect_all(list_ldfs_counts)

In [21]:
# Get null counts
list_dfs_null_counts = [
    df.select(
        [pl.struct(category=pl.lit('missing', dtype=pl.Categorical), count=pl.col(col).null_count()).alias(col)
         for col in numeric_columns]
    )
    for df in [df_base, df_compare]
]

In [22]:
list_dfs_null_counts

[shape: (1, 6)
 ┌───────────────┬────────────────┬────────────────┬────────────────┬───────────────┬───────────────┐
 │ col1_norm     ┆ col2_norm_null ┆ col8_stacked_a ┆ col9_stacked_a ┆ col10_discret ┆ col12_all_nul │
 │ ---           ┆ ---            ┆ t_0            ┆ t_1            ┆ e_numeric     ┆ ls_numeric    │
 │ struct[2]     ┆ struct[2]      ┆ ---            ┆ ---            ┆ ---           ┆ ---           │
 │               ┆                ┆ struct[2]      ┆ struct[2]      ┆ struct[2]     ┆ struct[2]     │
 ╞═══════════════╪════════════════╪════════════════╪════════════════╪═══════════════╪═══════════════╡
 │ {"missing",0} ┆ {"missing",50} ┆ {"missing",0}  ┆ {"missing",0}  ┆ {"missing",0} ┆ {"missing",10 │
 │               ┆                ┆                ┆                ┆               ┆ 00}           │
 └───────────────┴────────────────┴────────────────┴────────────────┴───────────────┴───────────────┘,
 shape: (1, 6)
 ┌───────────────┬────────────────┬────────────────

In [23]:
list_dfs_null_counts[0].schema

Schema([('col1_norm', Struct({'category': Categorical, 'count': UInt32})),
        ('col2_norm_null', Struct({'category': Categorical, 'count': UInt32})),
        ('col8_stacked_at_0',
         Struct({'category': Categorical, 'count': UInt32})),
        ('col9_stacked_at_1',
         Struct({'category': Categorical, 'count': UInt32})),
        ('col10_discrete_numeric',
         Struct({'category': Categorical, 'count': UInt32})),
        ('col12_all_nulls_numeric',
         Struct({'category': Categorical, 'count': UInt32}))])

In [26]:
# Get null counts
list_dfs_null_counts = [
    df.select(
        pl.col(numeric_columns).null_count()
        
    )
    for df in [df_base, df_compare]
]
list_dfs_null_counts

[shape: (1, 6)
 ┌───────────┬────────────────┬─────────────────┬─────────────────┬────────────────┬────────────────┐
 │ col1_norm ┆ col2_norm_null ┆ col8_stacked_at ┆ col9_stacked_at ┆ col10_discrete ┆ col12_all_null │
 │ ---       ┆ ---            ┆ _0              ┆ _1              ┆ _numeric       ┆ s_numeric      │
 │ u32       ┆ u32            ┆ ---             ┆ ---             ┆ ---            ┆ ---            │
 │           ┆                ┆ u32             ┆ u32             ┆ u32            ┆ u32            │
 ╞═══════════╪════════════════╪═════════════════╪═════════════════╪════════════════╪════════════════╡
 │ 0         ┆ 50             ┆ 0               ┆ 0               ┆ 0              ┆ 1000           │
 └───────────┴────────────────┴─────────────────┴─────────────────┴────────────────┴────────────────┘,
 shape: (1, 6)
 ┌───────────┬────────────────┬─────────────────┬─────────────────┬────────────────┬────────────────┐
 │ col1_norm ┆ col2_norm_null ┆ col8_stacked_at ┆ c

# Parameter testing

In [15]:
import sys

sys.path.append('../')
from psi import psi
import polars as pl
import itertools


pl.Config.set_tbl_rows(-1)
pl.Config.set_fmt_float('full')

df_base = pl.read_csv('./base_data.csv')
df_compare = pl.read_csv('./compare_data.csv')

In [16]:
numeric_columns = [
    'col1_norm',
    'col2_norm_null',
    'col8_stacked_at_0',
    'col9_stacked_at_1',
    'col10_discrete_numeric',
    'col12_all_nulls_numeric',
    'col13_numeric_nan',
    'col14_numeric_constant',
    'col15_numeric_all_nan'
]

categorical_columns = [
    'col4_str_abc',
    'col5_str_abc_null',
    'col6_binary',
    'col7_binary_null',
    'col11_cat_missing_level',
    'col16_cat_abc'
]

In [19]:
params_numeric_columns = [
    None,
    ['col1_norm'],
    ['col1_norm', 'col2_norm_null']
]

params_categorical_columns = [
    None,
    ['col4_str_abc'],
    ['col7_binary_null']
]

bins = [1, 10, 20, 100]

params_include_nulls = [True, False]

params_combinations = itertools.product(
    params_numeric_columns,
    params_categorical_columns,
    bins,
    params_include_nulls
)

for num_cols, cat_cols, bins, include_nulls in params_combinations:
    print(num_cols, cat_cols, bins)
    try:
        psi(
            df_base=df_base,
            df_compare=df_compare,
            numeric_columns=num_cols,
            categorical_columns=cat_cols,
            include_nulls=include_nulls,
            bins=bins
        )
    except ValueError:
        print('value error:', num_cols, cat_cols, bins, include_nulls)
        pass

None None 1
value error: None None 1 True
None None 1
value error: None None 1 False
None None 10
value error: None None 10 True
None None 10
value error: None None 10 False
None None 20
value error: None None 20 True
None None 20
value error: None None 20 False
None None 100
value error: None None 100 True
None None 100
value error: None None 100 False
None ['col4_str_abc'] 1
None ['col4_str_abc'] 1
None ['col4_str_abc'] 10
None ['col4_str_abc'] 10
None ['col4_str_abc'] 20
None ['col4_str_abc'] 20
None ['col4_str_abc'] 100
None ['col4_str_abc'] 100
None ['col7_binary_null'] 1
None ['col7_binary_null'] 1
None ['col7_binary_null'] 10
None ['col7_binary_null'] 10
None ['col7_binary_null'] 20
None ['col7_binary_null'] 20
None ['col7_binary_null'] 100
None ['col7_binary_null'] 100
['col1_norm'] None 1
['col1_norm'] None 1
['col1_norm'] None 10
['col1_norm'] None 10
['col1_norm'] None 20
['col1_norm'] None 20
['col1_norm'] None 100
['col1_norm'] None 100
['col1_norm'] ['col4_str_abc'] 1
['c

# Top capture

In [1]:
from scipy.stats import uniform
from scipy.stats import bernoulli
import polars as pl

# No relationship between scores and y
scores = uniform.rvs(size=100)
y = bernoulli.rvs(p=0.25, size=100)

In [2]:
df = pl.DataFrame({'y': y, 'score': scores})

In [3]:
df

y,score
i64,f64
0,0.477905
0,0.45107
0,0.949075
0,0.171476
0,0.913155
…,…
0,0.612581
1,0.318257
1,0.010509
1,0.176536


In [4]:
from perfmetrics import gains_table

In [5]:
gains_table(
    data=df,
    bins=10,
    score_column='score',
    dependent_column='y'
)

breakpoint,category,score_avg,score_perc,y_count,y_avg,y_perc,score_perc_cum_sum,y_perc_cum_sum
f64,cat,f64,f64,u32,f64,f64,f64,f64
0.121432,"""(-inf, 0.12143245061419651]""",0.046169,0.1,10,0.5,0.1,0.1,0.1
0.226116,"""(0.12143245061419651, 0.226116…",0.161663,0.1,10,0.2,0.1,0.2,0.2
0.278269,"""(0.22611603749459938, 0.278268…",0.248913,0.1,10,0.3,0.1,0.3,0.3
0.318117,"""(0.27826891626998146, 0.318116…",0.301239,0.1,10,0.3,0.1,0.4,0.4
0.397606,"""(0.3181167995359986, 0.3976055…",0.358122,0.1,10,0.3,0.1,0.5,0.5
0.47917,"""(0.39760559046030597, 0.479169…",0.438989,0.1,10,0.3,0.1,0.6,0.6
0.615151,"""(0.47916982872491404, 0.615151…",0.545875,0.1,10,0.2,0.1,0.7,0.7
0.740755,"""(0.6151513142865404, 0.7407551…",0.669373,0.1,10,0.2,0.1,0.8,0.8
0.85936,"""(0.7407551902184744, 0.8593595…",0.808192,0.1,10,0.1,0.1,0.9,0.9
inf,"""(0.8593595114947364, inf]""",0.933583,0.1,10,0.3,0.1,1.0,1.0
