In [1]:
%load_ext autoreload
%autoreload 2

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import pandas as pd
import numpy as np
import random
import torch
import transformers
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
# read full data
df = pd.read_csv('../data.csv')
df.head()

Unnamed: 0,item_name,l1,l2,photo,Business,category1_tag_id,category2_tag_id
0,Starbucks Flavored Liquid Coffee Pumpkin Spice...,Drinks,Coffee,http://cdn.doordash.com/media/photos/fb3fd14f-...,Circle K,,
1,Pure Leaf Liquid Herbal Tea Cherry Hibiscus No...,Drinks,Tea,http://cdn.doordash.com/media/photos/18896761-...,Circle K,,
2,Mtn Dew Code Red Soft Drink Citrus Cherry Bott...,Drinks,Soda,http://cdn.doordash.com/media/photos/f90aa531-...,Circle K,,
3,"Anderson Valley Brewing Compny Ale The Kimmie,...",Alcohol,Beer,http://cdn.doordash.com/media/photos/17fc39a2-...,Circle K,,
4,Mtn Dew Citrus Zero Calorie Soft Drink 12 oz C...,Drinks,Soda,http://cdn.doordash.com/media/photos/0b70a92e-...,Circle K,,


In [4]:
df.l1.value_counts()

Personal Care    19394
Pantry           13692
Snacks            8793
Drinks            8582
Household         8162
Alcohol           6163
Vitamins          5967
Frozen            5565
Dairy & Eggs      4311
Medicine          4141
Candy             3990
Bakery            2206
Meat & Fish       1710
Baby & Child      1690
Produce           1253
Pet Care          1214
Baby               921
Beauty             819
Fresh Food         614
Condiments         529
Name: l1, dtype: int64

In [5]:
df.Business.value_counts()

Walgreens               26013
Smart & Final           19991
Fresh Thyme             17708
Cub Foods               10888
Meijer Grocery           9087
CVS                      5801
Circle K                 4119
Kroger                   2827
7-Eleven                 2779
Holiday Stationstore      503
Name: Business, dtype: int64

In [6]:
df[df["Business"] == "CVS"].l1.value_counts()

Personal Care    1430
Medicine          867
Household         821
Drinks            536
Candy             495
Snacks            438
Baby & Child      385
Pantry            317
Vitamins          251
Frozen            124
Dairy & Eggs       75
Pet Care           44
Bakery             15
Meat & Fish         3
Name: l1, dtype: int64

In [7]:
df[df["Business"].isin(["Kroger", "7-Eleven", "CVS"])].l1.value_counts()

Personal Care    1561
Drinks           1441
Snacks           1093
Household        1052
Medicine          975
Pantry            925
Candy             812
Frozen            713
Dairy & Eggs      595
Alcohol           589
Baby & Child      475
Vitamins          257
Bakery            234
Meat & Fish       231
Produce           209
Fresh Food        133
Pet Care          112
Name: l1, dtype: int64

In [8]:
len(df[df["Business"].isin(["Kroger", "7-Eleven", "CVS"])])

11407

In [9]:
mask = df["Business"].isin(["Kroger", "7-Eleven", "CVS"])
train = df[~mask]
test = df[mask]

In [10]:
# find all categories in test that are ONLY in test
test_sizes = test.groupby(["l1", "l2"]).size()
train_sizes = train.groupby(["l1", "l2"]).size()

In [31]:
train_sizes

l1        l2           
Alcohol   Alcohol-Free        2
          Beer             2086
          Liquor            779
          Mixers             46
          Seltzer           197
                           ... 
Vitamins  Liquor              1
          Minerals          299
          Multivitamins     734
          Supplements      2259
          Vitamins A-Z      924
Length: 230, dtype: int64

In [11]:
test_categories = set(test_sizes.index.to_list())
train_categories = set(train_sizes.index.to_list())

test_categories - train_categories

{('Baby & Child', 'Toddler Snacks'),
 ('Dairy & Eggs', 'Dough & Crust'),
 ('Fresh Food', 'Fresh Pizza'),
 ('Fresh Food', 'Wings'),
 ('Household', 'Dog Treats & Toys'),
 ('Household', 'Floral'),
 ('Household', 'Seasonal'),
 ('Meat & Fish', 'Lamb')}

In [12]:
train_categories - test_categories

{('Alcohol', 'Alcohol-Free'),
 ('Alcohol', 'Mixers'),
 ('Alcohol', 'Seltzer'),
 ('Baby', 'Baby Food'),
 ('Baby', 'Bath & Skin'),
 ('Baby', 'Diapers'),
 ('Baby', 'Feeding'),
 ('Baby', 'Formula'),
 ('Baby', 'Toys'),
 ('Baby', 'Wipes'),
 ('Bakery', 'Break & Bake'),
 ('Beauty', 'Bath & Body'),
 ('Beauty', 'Facial Care'),
 ('Beauty', 'Makeup'),
 ('Beauty', 'Nails'),
 ('Beauty', 'Skin Care'),
 ('Beauty', 'Sun Care'),
 ('Condiments', 'Dressing'),
 ('Condiments', 'Oil & Vinegar'),
 ('Condiments', 'Sauces'),
 ('Condiments', 'Spices & Seasoning'),
 ('Condiments', 'Syrups'),
 ('Drinks', 'Ice'),
 ('Fresh Food', 'Health'),
 ('Frozen', 'Poultry'),
 ('Frozen', 'Seafood'),
 ('Frozen', 'Sides'),
 ('Household', 'Automotive'),
 ('Household', 'Hand Soap'),
 ('Household', 'Linens & Bedding'),
 ('Household', 'Liquor'),
 ('Personal Care', 'Sun care'),
 ('Pet Care', 'Cat Toys & Treats'),
 ('Produce', 'Canned Specialty'),
 ('Snacks', 'Wings'),
 ('Vitamins', 'Aromatherapy'),
 ('Vitamins', 'Liquor'),
 ('Vitamins

In [13]:
trouble_categories = (test_categories - train_categories) | (train_categories - test_categories)
trouble_categories

{('Alcohol', 'Alcohol-Free'),
 ('Alcohol', 'Mixers'),
 ('Alcohol', 'Seltzer'),
 ('Baby', 'Baby Food'),
 ('Baby', 'Bath & Skin'),
 ('Baby', 'Diapers'),
 ('Baby', 'Feeding'),
 ('Baby', 'Formula'),
 ('Baby', 'Toys'),
 ('Baby', 'Wipes'),
 ('Baby & Child', 'Toddler Snacks'),
 ('Bakery', 'Break & Bake'),
 ('Beauty', 'Bath & Body'),
 ('Beauty', 'Facial Care'),
 ('Beauty', 'Makeup'),
 ('Beauty', 'Nails'),
 ('Beauty', 'Skin Care'),
 ('Beauty', 'Sun Care'),
 ('Condiments', 'Dressing'),
 ('Condiments', 'Oil & Vinegar'),
 ('Condiments', 'Sauces'),
 ('Condiments', 'Spices & Seasoning'),
 ('Condiments', 'Syrups'),
 ('Dairy & Eggs', 'Dough & Crust'),
 ('Drinks', 'Ice'),
 ('Fresh Food', 'Fresh Pizza'),
 ('Fresh Food', 'Health'),
 ('Fresh Food', 'Wings'),
 ('Frozen', 'Poultry'),
 ('Frozen', 'Seafood'),
 ('Frozen', 'Sides'),
 ('Household', 'Automotive'),
 ('Household', 'Dog Treats & Toys'),
 ('Household', 'Floral'),
 ('Household', 'Hand Soap'),
 ('Household', 'Linens & Bedding'),
 ('Household', 'Liquor'

In [51]:
import math

def get_train_size(size):
    if size == 1:
        return 1
    
    train_size = math.ceil(0.9 * size)
    
    if train_size == size:
        return train_size - 1
    else:
        return train_size

In [52]:
df = df.assign(Train=False)
df

Unnamed: 0,item_name,l1,l2,photo,Business,category1_tag_id,category2_tag_id,Train
0,Starbucks Flavored Liquid Coffee Pumpkin Spice...,Drinks,Coffee,http://cdn.doordash.com/media/photos/fb3fd14f-...,Circle K,,,False
1,Pure Leaf Liquid Herbal Tea Cherry Hibiscus No...,Drinks,Tea,http://cdn.doordash.com/media/photos/18896761-...,Circle K,,,False
2,Mtn Dew Code Red Soft Drink Citrus Cherry Bott...,Drinks,Soda,http://cdn.doordash.com/media/photos/f90aa531-...,Circle K,,,False
3,"Anderson Valley Brewing Compny Ale The Kimmie,...",Alcohol,Beer,http://cdn.doordash.com/media/photos/17fc39a2-...,Circle K,,,False
4,Mtn Dew Citrus Zero Calorie Soft Drink 12 oz C...,Drinks,Soda,http://cdn.doordash.com/media/photos/0b70a92e-...,Circle K,,,False
...,...,...,...,...,...,...,...,...
99711,Chobani Pie Yogurt Key Lime Crumble (5.3 oz),Dairy & Eggs,Yogurt,http://cdn.doordash.com/media/photos/f942e709-...,Walgreens,960.0,1008.0,False
99712,Chobani Greek Yogurt Peach (6 oz),Dairy & Eggs,Yogurt,http://cdn.doordash.com/media/photos/246dcb0d-...,Walgreens,960.0,1008.0,False
99713,Yoplait Light Fat Free Yogurt Strawberries 'n ...,Dairy & Eggs,Yogurt,http://cdn.doordash.com/media/photos/5ccb97f2-...,Walgreens,960.0,1008.0,False
99714,Fage Total 2% Lowfat Greek Strained Yogurt Hon...,Dairy & Eggs,Yogurt,http://cdn.doordash.com/media/photos/c2fd4b5f-...,Walgreens,960.0,1008.0,False


In [53]:
for l1, l2 in trouble_categories:
    category_mask = (df["l1"] == l1) & (df["l2"] == l2)
    category_df = df[category_mask]
    splits = np.split(category_df.sample(frac=1), [get_train_size(len(category_df))])
    
    train_index, test_index = splits[0].index, splits[1].index
    
    
    df.loc[train_index, "Train"] = True
    df.loc[test_index, "Train"] = False
    

In [48]:
sum(df["Train"])

3875

1