In [54]:
import pandas as pd
from datetime import datetime, timedelta
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder


In [50]:
import warnings
warnings.filterwarnings("ignore")

In [42]:
df = pd.read_parquet('sample_txn_data_2022.parquet', engine='pyarrow')

In [195]:
df.head()

Unnamed: 0,dw_gc_header,business_date,fiscal_week_end,fiscal_week,fiscal_year,daypart_name,lineitem_sequence,lineitem_seq_parent,lineitem_description,parent_product_description,...,actmodqty,actdiscqty,actprodprice,actgrosssales,actnetsales,actpromosales,actdiscsales,acttax,dw_gc_header_2,token_primary_account_identifier
0,42188263563,2021-12-31,2022-01-04,Y2022 Q01 P01 W01,Y2022,EVENING,999,999,TAX-LINE,,...,0.0,0,0.0,0.0,0.0,0.0,0.0,0.28,42188263563,880344659239910
1,42188263563,2021-12-31,2022-01-04,Y2022 Q01 P01 W01,Y2022,EVENING,1,1,NON-COMBO-ITEM,Beefy 5-Layer Burrito,...,0.0,0,2.99,2.99,2.99,0.0,0.0,0.0,42188263563,880344659239910
2,42188263563,2021-12-31,2022-01-04,Y2022 Q01 P01 W01,Y2022,EVENING,2,2,NON-COMBO-ITEM,Cheesy Bean and Rice Burrito,...,0.0,0,1.29,1.29,1.29,0.0,0.0,0.0,42188263563,880344659239910
3,42201144646,2021-12-31,2022-01-04,Y2022 Q01 P01 W01,Y2022,BREAKFAST,999,999,TAX-LINE,,...,0.0,0,0.0,0.0,0.0,0.0,0.0,0.31,42201144646,8819419820104328
4,42201144646,2021-12-31,2022-01-04,Y2022 Q01 P01 W01,Y2022,BREAKFAST,1,1,NON-COMBO-ITEM,Breakfast Crunchwrap - Bacon,...,0.0,0,3.49,3.49,3.49,0.0,0.0,0.0,42201144646,8819419820104328


In [193]:
sample = df[["parent_product_description","child_product_description"]]

In [196]:
sample= sample.dropna()

In [198]:
sample.head(100)

Unnamed: 0,parent_product_description,child_product_description
1,Beefy 5-Layer Burrito,Beefy 5-Layer Burrito
2,Cheesy Bean and Rice Burrito,Cheesy Bean and Rice Burrito
4,Breakfast Crunchwrap - Bacon,Breakfast Crunchwrap - Bacon
6,Cheesy Fiesta Potatoes,Cheesy Fiesta Potatoes
8,Chicken Quesadilla,Chicken Quesadilla
...,...,...
115,Bean Burrito,Bean Burrito
116,Bean Burrito,Bean Burrito
117,Chalupa Supreme® - Chicken,Chalupa Supreme® - Chicken
118,Nachos BellGrande®,Nachos BellGrande®


In [200]:
diff = sample[sample["parent_product_description"]!= sample["child_product_description"]]

In [201]:
len(diff) ### Parent description != Child description

592237

In [150]:
transactions = df[["business_date","token_primary_account_identifier","parent_product_description"]]

In [151]:
transactions.rename(columns = {'business_date':'Date',\
                               'token_primary_account_identifier':'CustomerID',\
                               "parent_product_description":'Product'}, inplace = True)

In [152]:
encoder = LabelEncoder()
transactions['Product'] = encoder.fit_transform(transactions['Product'])



In [153]:
# preprocess data
transactions['Date'] = pd.to_datetime(transactions['Date'])
transactions['NextPurchaseDate'] = transactions.groupby('CustomerID')['Date'].shift(-1)
transactions['NextPurchaseProduct'] = transactions.groupby('CustomerID')['Product'].shift(-1)
transactions = transactions.dropna()



In [154]:
transactions.head()

Unnamed: 0,Date,CustomerID,Product,NextPurchaseDate,NextPurchaseProduct
0,2021-12-31,880344659239910,888,2021-12-31,95.0
1,2021-12-31,880344659239910,95,2021-12-31,169.0
2,2021-12-31,880344659239910,169,2022-01-15,888.0
3,2021-12-31,8819419820104328,888,2021-12-31,120.0
4,2021-12-31,8819419820104328,120,2022-01-21,888.0


In [155]:
transactions.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1949588 entries, 0 to 23907
Data columns (total 5 columns):
 #   Column               Dtype         
---  ------               -----         
 0   Date                 datetime64[ns]
 1   CustomerID           object        
 2   Product              int64         
 3   NextPurchaseDate     datetime64[ns]
 4   NextPurchaseProduct  float64       
dtypes: datetime64[ns](2), float64(1), int64(1), object(1)
memory usage: 89.2+ MB


In [156]:
transactions = transactions[-8000:]

In [157]:
transactions["CustomerID"].unique()

array(['8798908705029142', '8353384846088636', '0497941214069789', ...,
       '8147774746369539', '0763647418838986', '9856133995633966'],
      dtype=object)

In [158]:
# train model
features = ['CustomerID', 'Product']
target = 'NextPurchaseProduct'
X = transactions[features]
y = transactions[target]
model = RandomForestClassifier()
model.fit(X, y)



RandomForestClassifier()

In [181]:
customer_id = '9887648135499821'
customer_transactions = transactions[transactions['CustomerID'] == customer_id]
customer_transactions.head() ### 888 means NA

Unnamed: 0,Date,CustomerID,Product,NextPurchaseDate,NextPurchaseProduct
21839,2022-12-27,9887648135499821,888,2022-12-27,91.0
21840,2022-12-27,9887648135499821,91,2022-12-27,494.0
21841,2022-12-27,9887648135499821,494,2022-12-27,571.0
21842,2022-12-27,9887648135499821,571,2022-12-27,226.0
21843,2022-12-27,9887648135499821,226,2022-12-27,739.0


In [185]:
last_product = customer_transactions['Product'].iloc[-1]
next_date = customer_transactions['Date'].iloc[-1] + timedelta(days=1)
next_product = model.predict([[customer_id, last_product]])[0]

print(f"The next product for customer {customer_id} is {next_product} on {next_date}")

The next product for customer 9887648135499821 is 739.0 on 2022-12-28 00:00:00


In [186]:
next_product = encoder.inverse_transform([739])[0]
next_product

'Side of Mild Sauce'

In [183]:
next_product = encoder.inverse_transform([887])[0]
next_product

'Yard Mountain Dew® Baja Blast? Twisted Freeze - Vodka'

In [184]:
next_product = encoder.inverse_transform([888])[0]
next_product

In [177]:
encoder.inverse_transform([int(next_product)])[0]

'Side of Mild Sauce'

In [179]:
# predict next purchase
customer_id = '9887648135499821'
customer_transactions = transactions[transactions['CustomerID'] == customer_id]

if len(customer_transactions) > 0:
    last_product = customer_transactions['Product'].iloc[-1]
    next_date = customer_transactions['Date'].iloc[-1] + timedelta(days=1)
    next_product = model.predict([[customer_id, last_product]])[0]
    next_product = encoder.inverse_transform([int(next_product)])[0]
    print(f"The next product for customer {customer_id} is {next_product} on {next_date}")
else:
    print(f"No transactions found for customer {customer_id}")

The next product for customer 9887648135499821 is Side of Mild Sauce on 2022-12-28 00:00:00
