In [None]:
import polars as pl
import numpy as np
import altair as alt
import pandas as pd

In [None]:
kiwi_lf = pl.scan_csv(source="data/kiwibubbles/kiwibubbles_tran.csv",
                      has_header=False,
                      separator=",",
                      schema={'ID': pl.UInt16,
                              'Market': pl.UInt8,
                              'Week': pl.Int16,
                              'Day': pl.Int16,
                              'Units': pl.Int16})

kiwi_lf_m2 = (kiwi_lf.filter(pl.col('Market') == 2).drop('Market'))
num_panellists_m2 = 1499


kiwi_lf_m2 = (
    kiwi_lf_m2
    .sort(by='ID')
    .with_columns((pl.col("ID").cum_count().over("ID") - 1).cast(pl.UInt16).alias("DoR"))    
)

In [None]:
def shift_week(group_df):    
    week_arr = group_df["Week"].sort().to_numpy().copy()  # Sort array to handle duplicates systematically
    for i in range(1, len(week_arr)):
        if week_arr[i] <= week_arr[i - 1]: # If duplicate or less, increment by 1
            week_arr[i] = week_arr[i - 1] + 1
    return group_df.with_columns(pl.Series("shWeek", week_arr))

shifted_lf = (
    kiwi_lf_m2
    .group_by('ID')
    .map_groups(shift_week, schema={'Week': pl.Int16, 
                                    'shWeek':pl.Int16,
                                    'DoR':pl.UInt16,
                                    'Units':pl.Int16,
                                    'Day':pl.Int16,
                                    'ID':pl.UInt16})
)

In [None]:
week_range, dor_range = np.meshgrid(np.arange(1, 53, dtype='int16'), np.arange(0, 12, dtype='uint16'))
dummy_lf = pl.DataFrame({'shWeek': week_range.reshape(-1), 'DoR': dor_range.reshape(-1)})

sh_agg_trans = (
    shifted_lf
    .collect()
    .group_by('shWeek', 'DoR')
    .agg(pl.len().alias('Count'))
)

shweek_total_trans = (
    sh_agg_trans
    .group_by('shWeek')
    .agg(pl.col('Count').sum().alias('Total')) 
)

sh_agg_trans_longform = (
    dummy_lf
    .join(sh_agg_trans, on=['shWeek', 'DoR'], how='left')
    .join(shweek_total_trans, on='shWeek', how='left')
    .fill_null(0)
)

In [None]:
sh_agg_trans_wideform = (
    sh_agg_trans_longform
    .pivot(on='DoR', index='shWeek', values='Count')
    .join(shweek_total_trans, on='shWeek', how='left')
)

col_total = sh_agg_trans_wideform.select(pl.col('*').exclude('shWeek').sum())

display(sh_agg_trans_wideform)
display(col_total)

In [None]:
sh_cum_trans_longform = sh_agg_trans_longform.with_columns(pl.col('Count').cum_sum().over('DoR').alias('Cum DoR'))
sh_cum_trans_wideform = sh_cum_trans_longform.pivot(on='DoR', index='shWeek', values='Cum DoR')

display(sh_cum_trans_wideform)

In [None]:
# Calculate Time Since Last Purchase (in weeks)
test = (
    shifted_lf.collect()
    .sort('ID', 'shWeek')
    .with_columns((pl.col('shWeek') - pl.col('shWeek').shift(1)).over('ID').alias('TSLP'))
    .with_columns(pl.col('shWeek').shift(1).alias('shWeek'))
)

# Aggregate purchases by depth and week
purchase_counts = (
    shifted_lf.collect()
    .group_by(["DoR", "shWeek"])
    .agg(pl.col("ID").n_unique().alias("Count")) 
    .sort(["DoR", "shWeek"])
)

eligibility = (
    test
    .group_by("DoR", 'shWeek', 'TSLP')
    .agg(pl.col("ID").n_unique().alias("Count")) 
    .sort("DoR", 'shWeek')
    .with_columns(pl.col("Count").shift(-1).fill_null(0).alias("Eligible"))
)

week_range, dor_range, tslp_range = np.meshgrid(np.arange(1, 53, dtype='int16'), np.arange(0, 12, dtype='uint16'), np.arange(1, 53, dtype='int16'))
full_dor_week = pl.DataFrame({'DoR': dor_range.reshape(-1), 'shWeek': week_range.reshape(-1), 'TSLP': tslp_range.reshape(-1)})

eligibility = (
    pl.concat([full_dor_week, eligibility], how='align')
    .fill_null(0)
    .with_columns(pl.col("Count").cum_sum().over('DoR', 'shWeek').alias('Cum Count by Week'))
)

eligibility = (
    pl.concat([eligibility, purchase_counts], how='align')
    .filter(pl.col('DoR') == 1)
    .sort('TSLP')
)

eligibility.pivot(on='TSLP', index='shWeek', values='Cum Count by Week').sort('shWeek')

DoR,shWeek,TSLP,Count,Eligible,Cum Count by Week
u16,i16,i16,u32,u32,u32
0,0,0,1,3,1
0,1,0,3,2,3
0,2,0,2,1,2
0,3,0,1,6,1
0,4,0,6,5,6
…,…,…,…,…,…
9,49,1,1,1,1
9,50,2,1,1,1
10,30,15,1,1,1
10,37,13,1,1,1
