<a href="https://www.kaggle.com/code/avilashahaldar/apriori-analysis-of-grocery-baskets?scriptVersionId=191206757" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

Hi! Today, we're going to be looking at apriori analysis of people's grocery baskets to find which items people are most likely to buy together. Information on the dataset can be found [here](https://www.kaggle.com/datasets/heeraldedhia/groceries-dataset/data). Let's get started!

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from sklearn.preprocessing import LabelEncoder
from itertools import combinations, chain

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/apriori2-img/Apriori2.png
/kaggle/input/groceries-dataset/Groceries_dataset.csv


# Exploring and Processing the Data

In [2]:
groceries_data = pd.read_csv("/kaggle/input/groceries-dataset/Groceries_dataset.csv")
groceries_data.head()

Unnamed: 0,Member_number,Date,itemDescription
0,1808,21-07-2015,tropical fruit
1,2552,05-01-2015,whole milk
2,2300,19-09-2015,pip fruit
3,1187,12-12-2015,other vegetables
4,3037,01-02-2015,whole milk


Some of the checks I did before using this data were checking for NaNs (of which there were none), checking all the dates were valid (yes, they were), and checking the unique itemDescription values to ensure none of them were misspelled or duplicated with both American and British spellings. The date range is over the course of 2 years, from Jan. 1, 2014, to Dec. 30, 2015. I believe Dec. 31, 2015 was a bank holiday or weekend.

We're going to convert all the item descriptions to integers just so they take up less space in memory. Strings are expensive! And because there are 167 unique categories, we're going to use the int8 dtype instead of the default int64, since int64 is overkill. This is incredibly important the larger the dataset gets, since using unnecessarily large dtypes can cause memory costs to spiral in the worst-case scenarios.

At the same time, it's important to make sure our datatype won't break things if we have a pipeline where data gets added in the future. If we expected lots of categories to be added in the future (enough to take us over 256), we might set to int16 instead of int8, since int8 will break once we reach 256 categories. However, since our data here is static, we can use int8.

We're also going to change the dtype of the Member Number. Currently, Member Number goes from 1000 to 5000 inclusive and is int64 dtype. We can set it to int16.

In [3]:
itemEncoder = LabelEncoder()
itemEncoder.fit(groceries_data["itemDescription"])

groceries_data["itemDescription"] = (itemEncoder.transform(groceries_data["itemDescription"]) - 100).astype('int8')
groceries_data["Date"] = pd.to_datetime(groceries_data["Date"], format="%d-%m-%Y")
groceries_data["Member_number"] = groceries_data["Member_number"].astype('int16')

# We'll need this later to convert back
itemEncoderDict = dict(zip(itemEncoder.transform(itemEncoder.classes_) - 100, itemEncoder.classes_))

Let's now take a look at the transformed data and also its shape. We can see that we have a decent number of rows for a proof of concept - a bit less than 40k rows of data. I doubt a megabusiness (e.g. Walmart) would base a business strategy on this amount of data, but it's definitely god for a smaller business, and definitely works for a proof of concept.

In [4]:
groceries_data.shape

(38765, 3)

In [5]:
groceries_data.head()

Unnamed: 0,Member_number,Date,itemDescription
0,1808,2015-07-21,56
1,2552,2015-01-05,64
2,2300,2015-09-19,9
3,1187,2015-12-12,2
4,3037,2015-02-01,64


Let's check if any people buy multiple things on the same day. For that, we can group by member_number and date and see if any of the counts are > 1.

In [6]:
groceries_data.groupby(["Member_number", "Date"]).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,itemDescription
Member_number,Date,Unnamed: 2_level_1
1000,2014-06-24,3
1000,2015-03-15,4
1000,2015-05-27,2
1000,2015-07-24,2
1000,2015-11-25,2
...,...,...
4999,2015-05-16,2
4999,2015-12-26,2
5000,2014-03-09,2
5000,2014-11-16,2


Yeah, that's expected, that we see lots of people buying more than 1 item on a given date. If that weren't the case, and everyone only bought 1 item a day, I'd think something's wrong with the data and it only records the first item on each person's receipt. It's good to see that's not the case. There is some other data I wish were included, e.g. whether all of this data is from the same store, what country that store is in (this could allow us to look more into regional holidays like 4th of July or Thanksgiving) and what time of day the purchases were made. There could be some interesting insights we could've pulled out from there, e.g. whether lots of people come in around lunchtime to grab a sandwich, or whether we see more folks coming in the evening for a more hefty grocery shop, or what kinds of items see a surge in buyers around certain holidays.

Anyway, I think that does it for data processing and EDA. Let's now talk more about the analysis we're going to do.

# Apriori Algorithm

The **apriori algorithm** is an **unsupervised machine learning algorithm**, meaning there are no right answers or labels. Regression is a **supervised** learning algorithm, meaning it is trained using a set of correct answers, while algorithms designed to play e.g. board or video games use **reinforcement learning**, which seeks to maximize a reward like points or wins. Unsupervised learning algorithms tend to be used for clustering and association tasks. K-means clustering is a common example.

Apriori is a type of **association rule learning**, meaning it identifies frequent patterns and associations. We'll be using apriori to find out which grocery items are often bought together. This knowledge can then be used for designing store layouts. If you want to make things easier for your consumers, you'll put frequently connected items in the same or adjacent aisles. If you want to force your buyer to walk more of the store to get what they need, you'll put these things in different aisles.

There are already libraries in Python that can do apriori for you, but to enhance our own understanding, we're going to compute it from scratch. To do that, let's format our data a bit. Since, at this point, we don't really care who the specific members are or the exact date (only which transactions were made by the same member on the same date), we're going to create a new index where each unique value represent a unique member_number and date combination, just so we can rid ourselves of the cumbersome MultiIndex. We'll also format each unique transaction value in a list so that our functions later on can work with any length of unique transaction, assuming the lengths are the same for all rows.

In [7]:
# We don't care if someone buys multiple of the same item on the same day, so we just remove duplicated rows.
groceries_apriori = groceries_data.copy().drop_duplicates()

# Member_number and Date will be our useful index columns, since we want to group transactions by person and date.
groceries_apriori = groceries_apriori.set_index(["Member_number", "Date"])

groceries_apriori.index = groceries_apriori.index.get_level_values("Member_number").astype(str) + "_" + groceries_apriori.index.get_level_values("Date").astype(str)
groceries_apriori.index.names = ["Index"]
groceries_apriori["itemDescription"] = groceries_apriori["itemDescription"].apply(lambda x: [x])
groceries_apriori = groceries_apriori.rename(columns={"itemDescription": "uniqueTransaction"})

# We have this new format where the items are in lists so we can reuse the same functions for different lengths.
# E.g. When we get to considering combinations of 2, each row will have a list of 2 items bought on a given date by a given member.
groceries_apriori

Unnamed: 0_level_0,uniqueTransaction
Index,Unnamed: 1_level_1
1808_2015-07-21,[56]
2552_2015-01-05,[64]
2300_2015-09-19,[9]
1187_2015-12-12,[2]
3037_2015-02-01,[64]
...,...
4471_2014-10-08,[35]
2022_2014-02-23,[-81]
1097_2014-04-16,[-83]
1510_2014-12-03,[-36]


To implement apriori, we start with **frequent itemset mining**, which is just a fancy way of saying, narrowing our dataset down to items which are bought above a certain frequency threshold. Since we're looking for general trends, we don't want to be focusing on items which have maybe only been bought a few times in the entire two-year period. This is arbitrary, but I'm going to set our threshold at 48 (i.e. items have been bought twice a month in the 2-year period our data covers). Since we didn't have so much data to begin with that our program would be very slow to run, I think this is fine. If we had millions of rows, I'd be a lot more aggressive with this pruning.

If we had data on revenues and costs by item category, it would be interesting to look at the product categories that got removed and see if they are even profitable to sell to begin with.

We are also going to see which combinations of Member_number and Date only yield 1 item, meaning the person only bought a single item on that date. These rows are not useful for determining item associations, so we'll be removing them.

From there, we get all combinations of 2 items sharing a member and date index and keep the combinations that occur more than the common_combo_threshold. We then move up to considering 3 items and keep all transactions where 2 of the 3 items form a common combination of 2. This is best explained with the below diagram:

<img src='https://storage.googleapis.com/kagglesdsdata/datasets/5495722/9105794/Apriori2.png?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=databundle-worker-v2%40kaggle-161607.iam.gserviceaccount.com%2F20240804%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240804T230423Z&X-Goog-Expires=345600&X-Goog-SignedHeaders=host&X-Goog-Signature=49232ba444d241a053430c2218018a6eea2b0ab944ce2b887a37b212e121bb0c3721ffde98fe30095d13e79d5c71260e4648900576820338066443007a6abce0f6b07c4a48d879797c99e0f5bf4579034a4b121e593e4a8b6398d5bd743de22bb1ebeb5a8793332c30b5fc678c6c81c681c25cbc9e877c48a7d86ce5eab1754ebccf7c1bb88035541af824e6e2ee5df33238ce7cf24623b46268ffd7f3c54be5356abf62069c59f4ce0e241619fecddcc1a1a8ce7e286e2370ab1184fae36bca75716213b25428d3421a3d2ed6960ac811ff3774e1fb3ce5bd37104ce82a5a644bc0ac95356934bb4c2976b5158eb26ba3462d3aac471698d3a645a24944b10e' width=600>

If something item or combination of items doesn't meet our minimum threshold, and and all its children are excluded.

In [8]:
def select_items_bought_with_high_frequency(groceries_data: pd.DataFrame, value_col: str, item_frequency_threshold: int = 48) -> pd.DataFrame:
    """
    Frequency itemset mining: Only keeping rows corresponding to items that are bought more than item_frequency_threshold times.
    Each row in groceries_data corresponds to a unique transaction.
    We return the groceries_data dataframe with rows removed corresponding to infrequently bought items.
    """
    if groceries_data.empty:
        return groceries_data
    items_bought_frequently = groceries_data[value_col].value_counts() >= item_frequency_threshold
    groceries_data = groceries_data[groceries_data[value_col].isin(items_bought_frequently[items_bought_frequently].index)]
    return groceries_data

def get_rows_with_num_transactions(groceries_data: pd.DataFrame, value_col: str) -> pd.DataFrame:
    """
    Assuming that each uniqueTransaction in groceries_data has length num_items_per_transaction, and we want to keep indices
    where more than num_items_per_transaction have been bought, we just need to keep the indices that appear more than once.
    Each row in groceries_data corresponds to a unique transaction.
    """
    if groceries_data.empty:
        return groceries_data
    grouped_groceries_data = (groceries_data.groupby("Index").count() >= 2).rename(columns={value_col: f"{value_col}_2"})
    groceries_data = grouped_groceries_data.merge(right=groceries_data, on="Index", how="right")
    groceries_data = groceries_data[groceries_data[f"{value_col}_2"] == True].drop(columns=f"{value_col}_2")
    return groceries_data

def get_most_common_combos(groceries_data: pd.DataFrame, min_item_frequency_threshold: int, common_combo_threshold: int, value_col: str = "uniqueTransaction", print_steps: bool=False) -> (pd.DataFrame, pd.Series):
    """
    We assume that each row of groceries_data is a unique transaction of a given item or set of items.
    Each uniqueTransaction is a list of the same length num_items_per_transaction. So num_items_per_transaction = 2 gives us the following for groceries_data:
    
    index  |  uniqueTransactions
    -----------------------------
    idx1   |  [56, 82]
    idx1   |  [56, 74]
    idx1   |  [74, 82]
    idx2   |  [-4, 34]
    
    We then find all combinations of size num_items_per_transaction+1, and select the most common ones
    (i.e. combos that appear more than common_combo_threshold times). So in the above example, the only combination would be [56, 74, 82].
    
    We also transform groceries_data to have uniqueTransactions with length num_items_per_transaction+1. This would then give us:
    
    index  |  uniqueTransactions
    -----------------------------
    idx1   |  [56, 74, 82]
    """

    # Here, we just select items or sets of items bought more than (or equal to) item_frequency_threshold times.
    groceries_data = select_items_bought_with_high_frequency(groceries_data, value_col, min_item_frequency_threshold)
    
    if groceries_data.empty:
        return groceries_data, pd.Series([], name="count", dtype='int64')
    
    # Transform groceries_data to have unique indices and a long list of all items with that index, rather than one row per unique transaction.
    num_items_per_transaction = len(groceries_data[value_col].iloc[0])
    groceries_data = get_rows_with_num_transactions(groceries_data, value_col)
    groceries_data = pd.DataFrame(groceries_data.groupby("Index")[value_col].apply(lambda x: sorted(list(set(chain(*x)))))).rename(columns={value_col: "allItems"})

    # Get all possible combinations for each index
    groceries_data["combos"] = groceries_data["allItems"].apply(lambda x: list(list(combo) for combo in (combinations(x, num_items_per_transaction+1))))
    if print_steps:
        print(f"\nGetting list of all items and combinations of length {num_items_per_transaction+1} for each unique index")
        print(groceries_data.head())
    
    # Getting the combinations of size num_items_per_transaction + 1 that occur more than or equal to common_combo_threshold times.
    grocery_combos = list(chain(*groceries_data["combos"].values.tolist()))
    combo_counts = pd.Series(grocery_combos).value_counts()
    combo_counts = combo_counts[combo_counts >= common_combo_threshold]
    if print_steps:
        print(f"\nGetting the common combinations of length {num_items_per_transaction+1}")
        print(combo_counts)
    
    # Transform groceries_data to again have uniqueTransactions, but with a higher num_items_per_transaction by 1.
    groceries_data = groceries_data[["combos"]].explode("combos").rename(columns={"combos": value_col})
    
    # Keep only the rows where the combination is one of the common ones.
    groceries_data = groceries_data[groceries_data[value_col].isin(combo_counts.index)]

    return groceries_data, combo_counts

# We get the most common combinations of lengths 2 to 4 inclusive and set the common_combo_threshold lower for larger combination lengths.
# This is arbitrary, but it feels right. I don't expect common combos of 4 to happen as often as common combos of 2.
combinations_dict = dict()
combo_size_and_threshold_dict = {2: 70, 3: 50, 4: 20}

for combination_size, common_combo_threshold in combo_size_and_threshold_dict.items():
    
    # To give an idea of what's going on under the hood, I'm printing out the steps for combination_size 2
    print_steps = False
    if combination_size == 2:
        print_steps = True
    groceries_apriori, combo_counts = get_most_common_combos(groceries_data = groceries_apriori,
                           min_item_frequency_threshold = 48,
                           common_combo_threshold = common_combo_threshold,
                           value_col = "uniqueTransaction",
                           print_steps = print_steps)
    combinations_dict[combination_size] = combo_counts


Getting list of all items and combinations of length 2 for each unique index
                         allItems  \
Index                               
1000_2014-06-24       [5, 28, 64]   
1000_2015-03-15  [30, 32, 64, 65]   
1000_2015-05-27           [8, 38]   
1000_2015-07-24         [-80, -8]   
1000_2015-11-25         [-27, 30]   

                                                            combos  
Index                                                               
1000_2014-06-24                       [[5, 28], [5, 64], [28, 64]]  
1000_2015-03-15  [[30, 32], [30, 64], [30, 65], [32, 64], [32, ...  
1000_2015-05-27                                          [[8, 38]]  
1000_2015-07-24                                        [[-80, -8]]  
1000_2015-11-25                                        [[-27, 30]]  

Getting the common combinations of length 2
[2, 64]      222
[22, 64]     209
[38, 64]     174
[64, 65]     167
[2, 22]      158
[2, 38]      145
[30, 64]     134
[56, 64]     12

Let's take a look at our common combinations. We turn the item numbers back into their original item descriptions, because looking at a bunch of numbers isn't terribly useful in this case.

In [9]:
for combination_size, combos in combinations_dict.items():
    combos_values = combos.values
    combos = pd.DataFrame(combos.index.values.tolist(), columns=["Item1", "Item2"]).replace(itemEncoderDict)
    combos["Count"] = combos_values
    
    print(f"\nCombo size {combination_size}:")
    print(combos)
    
    combinations_dict[combination_size] = combos


Combo size 2:
               Item1             Item2  Count
0   other vegetables        whole milk    222
1         rolls/buns        whole milk    209
2               soda        whole milk    174
3         whole milk            yogurt    167
4   other vegetables        rolls/buns    158
5   other vegetables              soda    145
6            sausage        whole milk    134
7     tropical fruit        whole milk    123
8   other vegetables            yogurt    121
9         rolls/buns              soda    121
10        rolls/buns            yogurt    117
11   root vegetables        whole milk    113
12      bottled beer        whole milk    107
13     bottled water        whole milk    107
14      citrus fruit        whole milk    107
15         pip fruit        whole milk     99
16            pastry        whole milk     97
17     shopping bags        whole milk     95
18  other vegetables    tropical fruit     94
19        rolls/buns    tropical fruit     91
20  other vegetable

Oh, wow. There are no common transactions with more than 2 items. I had a look at the combo_counts at combination_size 3 before the filtering step, and the highest value_count is 18. 18 occurrences of the same combination in 2 years (mind that we started with 14963 unique combinations of date and member in the data) is really nothing.

From the looks of it, whole milk is incredibly popular and tends to be bought with a bunch of other things, and vegetables or rolls/buns are also fairly common to buy with other things, although buying 3 such common items at once is obviously a rarity.