In [None]:
import pandas as pd

In [None]:
tx = pd.read_csv('/Users/gursanjjam/Documents/basket-transformer/dunnhumby/transaction_data.csv')
prod = pd.read_csv('/Users/gursanjjam/Documents/basket-transformer/dunnhumby/product.csv')

In [None]:
print("TRANSACTION\n",tx.head())
print("\nUnique products:", tx['PRODUCT_ID'].nunique())
print("\nUnique households:", tx['household_key'].nunique())
print('\n\n\n\nPRODUCT\n',prod.head())

TRANSACTION
    household_key    BASKET_ID  DAY  PRODUCT_ID  QUANTITY  SALES_VALUE  \
0           2375  26984851472    1     1004906         1         1.39   
1           2375  26984851472    1     1033142         1         0.82   
2           2375  26984851472    1     1036325         1         0.99   
3           2375  26984851472    1     1082185         1         1.21   
4           2375  26984851472    1     8160430         1         1.50   

   STORE_ID  RETAIL_DISC  TRANS_TIME  WEEK_NO  COUPON_DISC  COUPON_MATCH_DISC  
0       364        -0.60        1631        1          0.0                0.0  
1       364         0.00        1631        1          0.0                0.0  
2       364        -0.30        1631        1          0.0                0.0  
3       364         0.00        1631        1          0.0                0.0  
4       364        -0.39        1631        1          0.0                0.0  

Unique products: 92339

Unique households: 2500




PRODUCT
    PRO

In [None]:
tx = tx.merge(prod, on="PRODUCT_ID", how="left") #each row gets prod_id, department, category, subcategory for vector embeddings later

## sessionizing data into baskets

In [None]:
#sorting
tx = tx.sort_values(['household_key', 'DAY', 'BASKET_ID']) #in order of importance

baskets = (
    tx.groupby(['household_key', 'BASKET_ID'])['PRODUCT_ID']
      .apply(list)   #turns into list
      .reset_index(name='products')    #turning multi index series into data frame
)

In [None]:
print(baskets.head())

   household_key    BASKET_ID  \
0              1  27601281299   
1              1  27774192959   
2              1  28024266849   
3              1  28106322445   
4              1  28235481967   

                                            products  
0  [825123, 831447, 840361, 845307, 852014, 85498...  
1  [852662, 856942, 997025, 1030547, 1049998, 105...  
2  [841266, 865178, 953561, 991024, 995242, 99590...  
3  [827656, 831447, 845896, 852662, 856942, 85754...  
4  [852662, 856942, 887375, 909472, 922417, 93113...  


### creating household histories using sessionized baskets

In [None]:
household_histories = (
    baskets.groupby('household_key')['products']
           .apply(list)  # list of baskets, each basket is a list of product IDs
           .reset_index(name='basket_sequence')
)

In [None]:
print(household_histories.head())

   household_key                                    basket_sequence
0              1  [[825123, 831447, 840361, 845307, 852014, 8549...
1              2  [[854852, 930118, 1077555, 1098066, 5567388, 5...
2              3  [[866211, 878996, 882830, 904360, 921345, 9319...
3              4  [[836163, 857849, 877523, 878909, 883932, 8914...
4              5  [[938983, 5980822], [1012352], [825538, 100249...


## mapping (tokenization of household ids)

In [None]:
import json, os
from itertools import chain

#Collecting all unique product IDs
all_products = set(chain.from_iterable(chain.from_iterable(household_histories['basket_sequence'])))

#Building mappings
product2id = {pid: idx+1 for idx, pid in enumerate(sorted(all_products))}  # 1..N
id2product = {idx: pid for pid, idx in product2id.items()}
PAD_ID = 0  # Reserved for padding

# Saving mappings
os.makedirs("data/mappings", exist_ok=True)
with open("data/mappings/product2id.json", "w") as f:
    json.dump(product2id, f)
with open("data/mappings/id2product.json", "w") as f:
    json.dump(id2product, f)

#Applying tokenization to basket sequences
def tokenize_baskets(baskets, mapping):
    return [[mapping[pid] for pid in basket] for basket in baskets]

household_histories['basket_sequence_tok'] = household_histories['basket_sequence'].apply(
    lambda baskets: tokenize_baskets(baskets, product2id)
)

# 5. Quick check
print(household_histories.head(1)['basket_sequence_tok'].iloc[0])
print(f"Vocab size (including PAD): {len(product2id) + 1}")

[[6533, 7270, 8279, 8837, 9648, 9975, 10188, 16475, 19753, 20314, 21664, 23826, 25298, 27034, 29038, 31180, 32531, 32547, 32681, 32714, 34677, 35571, 40964, 52618, 53643, 60255, 60427, 60714, 60905, 62314], [9720, 10188, 25951, 29738, 31958, 32611, 35571, 58531, 60255, 62441, 65627, 67673], [8372, 11131, 21176, 25298, 25749, 25826, 30026, 34726, 36246, 41401, 67570, 67673], [6835, 7270, 8911, 9720, 10188, 10255, 12382, 12606, 15635, 18858, 18926, 20314, 25361, 25749, 27279, 29038, 34677, 35142, 53643, 55050, 56999, 60255, 68513], [9720, 10188, 13575, 16108, 17608, 18610, 18816, 23840, 25749, 38561, 52606, 53643, 60255, 61100, 63463, 67673, 71656], [10188, 10609, 18036, 18973, 21176, 24633, 25749, 32681, 60255, 65627], [9281, 11259, 21963, 22520, 25749, 34519, 58291, 60714, 68513], [7189, 8961, 9152, 9720, 10188, 10447, 13732, 16336, 17867, 18610, 19720, 19753, 22520, 23197, 25749, 30026, 30683, 31958, 32611, 33203, 34677, 34742, 35010, 35137, 35571, 35586, 36170, 37602, 40167, 53643, 5

In [None]:
print(household_histories.head())

   household_key                                    basket_sequence  \
0              1  [[825123, 831447, 840361, 845307, 852014, 8549...   
1              2  [[854852, 930118, 1077555, 1098066, 5567388, 5...   
2              3  [[866211, 878996, 882830, 904360, 921345, 9319...   
3              4  [[836163, 857849, 877523, 878909, 883932, 8914...   
4              5  [[938983, 5980822], [1012352], [825538, 100249...   

                                 basket_sequence_tok  
0  [[6533, 7270, 8279, 8837, 9648, 9975, 10188, 1...  
1  [[9957, 18504, 35032, 37354, 52800, 52837, 529...  
2  [[11256, 12682, 13094, 15500, 17484, 18697, 19...  
3  [[7808, 10290, 12527, 12673, 13226, 14014, 146...  
4  [[19533, 55164], [27688], [6586, 26588, 58732]...  


## flattening household_history, because we want to predict for next product for now and not next basket

In [None]:
from itertools import chain

# Step 6: Flatten basket_sequence_tok into a single sequence per household
def flatten_baskets(baskets_tok):
    """
    baskets_tok: list of baskets, where each basket is a list of token IDs.
    Returns: a flat list of token IDs (chronological order).
    """
    return list(chain.from_iterable(baskets_tok))

household_histories['flat_sequence'] = household_histories['basket_sequence_tok'].apply(flatten_baskets)

# Quick check
print(household_histories[['household_key', 'flat_sequence']].head(3))
print(f"Example sequence length: {len(household_histories['flat_sequence'].iloc[0])}")


   household_key                                      flat_sequence
0              1  [6533, 7270, 8279, 8837, 9648, 9975, 10188, 16...
1              2  [9957, 18504, 35032, 37354, 52800, 52837, 5297...
2              3  [11256, 12682, 13094, 15500, 17484, 18697, 193...
Example sequence length: 1727


In [None]:
from itertools import chain

# Make sure basket_sequence_tok exists before this
def flatten_baskets(baskets_tok):
    return list(chain.from_iterable(baskets_tok))

household_histories['flat_sequence'] = household_histories['basket_sequence_tok'].apply(flatten_baskets)

print(household_histories.columns)  # should now include 'flat_sequence'

Index(['household_key', 'basket_sequence', 'basket_sequence_tok',
       'flat_sequence'],
      dtype='object')


## input target pair

In [None]:
import pandas as pd

PAD_ID = 0
MAX_LEN = 50  # maximum length of input sequence

def make_input_target_pairs(flat_sequence, max_len=MAX_LEN, pad_id=PAD_ID):
    """
    Given a flat list of product IDs, return (input_seq, target, seq_len) tuples.
    """
    examples = []
    for i in range(1, len(flat_sequence)):
        # Input is everything before the target position, limited to max_len tokens
        input_seq = flat_sequence[max(0, i - max_len):i]
        
        # Left-pad with PAD_ID
        if len(input_seq) < max_len:
            input_seq = [pad_id] * (max_len - len(input_seq)) + input_seq
        
        target = flat_sequence[i]
        seq_len = sum(1 for x in input_seq if x != pad_id)
        
        examples.append((input_seq, target, seq_len))
    return examples

# Build training examples
all_examples = []
for _, row in household_histories.iterrows():
    for seq, target, seq_len in make_input_target_pairs(row['flat_sequence']):
        all_examples.append({
            "household_key": row['household_key'],
            "seq": seq,
            "target": target,
            "seq_len": seq_len
        })

examples_df = pd.DataFrame(all_examples)

# Quick check
print(examples_df.head())
print(f"Total examples: {len(examples_df)}")


   household_key                                                seq  target  \
0              1  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...    7270   
1              1  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...    8279   
2              1  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...    8837   
3              1  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...    9648   
4              1  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...    9975   

   seq_len  
0        1  
1        2  
2        3  
3        4  
4        5  
Total examples: 2593232


## temporal split

In [None]:
import pandas as pd

def temporal_split(df, train_ratio=0.8, val_ratio=0.1):
    """
    Split each household's examples into train/val/test by chronological order.
    Assumes df is ordered by sequence position.
    """
    train_parts, val_parts, test_parts = [], [], []
    
    for hh, group in df.groupby("household_key"):
        n = len(group)
        train_end = int(n * train_ratio)
        val_end = int(n * (train_ratio + val_ratio))
        
        train_parts.append(group.iloc[:train_end])
        val_parts.append(group.iloc[train_end:val_end])
        test_parts.append(group.iloc[val_end:])
    
    train_df = pd.concat(train_parts).reset_index(drop=True)
    val_df   = pd.concat(val_parts).reset_index(drop=True)
    test_df  = pd.concat(test_parts).reset_index(drop=True)
    
    return train_df, val_df, test_df

# Apply split
train_df, val_df, test_df = temporal_split(examples_df)

# Save to parquet
import os
os.makedirs("data/splits", exist_ok=True)
train_df.to_parquet("data/splits/train.parquet", index=False)
val_df.to_parquet("data/splits/val.parquet", index=False)
test_df.to_parquet("data/splits/test.parquet", index=False)

print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")


Train: 2073557, Val: 259208, Test: 260467
