В этом ноутбуке продемонстрирован пример инференса модели на примере валидационной выборки из 10 случайных запросов

Валидационная выборка состоит из полей:
1. Product Name - название продукта
2. clean_desc - описание товара от 10 слов
3. search_text - дефолтный текст запроса типа "soda" и тд. Сделано для всех 4к строк
4. abstract_search_text - абстрактный запрос вроде "Rich baking ingredients for traditional recipes", сделано для 20% датафрейма, остальное Nan

У каждого товара есть 1 или 2 запроса. Для валидации мы к каждому запросу из тест сета моделью подберём 20 релевантных товаров с помощью модели

#### **инсталлы**

In [6]:
import torch
import transformers
import pandas as pd
import numpy as np
from tqdm import tqdm
from scipy.spatial.distance import cdist
from transformers import AutoTokenizer, DistilBertModel

<torch._C.Generator at 0x79fca81a6350>

In [20]:
def seed_everything(seed = 42, torch_deterministic = True):
  # TRY NOT TO MODIFY: seeding
  # random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = torch_deterministic

seed_everything(42)

#### **загрузка и подготовка данных**

In [11]:
df_val = pd.read_csv("/content/items_with_search_text.csv")
df_val = df_val[["Uniq Id", "Product Name", "clean_desc", "search_text", "abstract_search_text"]]
df_val.head()

Unnamed: 0,Uniq Id,Product Name,clean_desc,search_text,abstract_search_text
0,019b67ef7f01103d8fb0a53e4c36daa7,"La Costena Chipotle Peppers, 7 OZ (Pack of 12)","la costena chipotle peppers, 7 oz (pack of 12)...",chipotle peppers,Mexican spice essentials
1,e4fab4b6f41eac02d22b421818c8f080,(6 Boxes) Twinings of London Nightly Calm Gree...,enojy one of your favorite tea flavors with th...,green tea,Soothing bedtime teas
2,992c11a1b238eae6a2fabf22c74de5c2,"Tiny White Mighty Mints (16 oz, ZIN: 525424) -...","tiny white mighty mints (16 oz, zin: 525424): ...",mints,Breath freshening essentials
3,290970606fd82ac39534aa32b8b3b149,SweetGourmet Narrow Sesame Sticks | Lightly Sa...,"narrow sesame sticks are fresh, crunchy and li...",sesame sticks,Savory snack ideas
4,21fee4394b9cc53bb6ddbe4235506a5c,"Health Warrior Chia Bar, Acai Berry, 25 G, Pac...","health warrior chia bar, acai berry is packed ...",chia bar,Healthy snack alternatives


#### **подготовка модели**

In [12]:
# модель для проверки
class Model:
    def __init__(self, product_DB: pd.DataFrame, to_set_up: bool = False, path_to_vector_BD: str = None):
        self.model = None
        self.tokenizer = None

        self.vector_DB = None
        if path_to_vector_BD:
            print("Downloading Vector BD")
            self.vector_DB = np.load(path_to_vector_BD)

        self.products_DB = product_DB.sort_values(by = "Uniq Id")


        if to_set_up:
            self.set_up()

    def set_up(self, max_position_embeddings = 512):
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        self.model = DistilBertModel.from_pretrained("distilbert-base-uncased", max_position_embeddings = max_position_embeddings)

    def calc_embeding(self, X: str):
        if not self.model:
            print("U should set_up() nodel first!")
            return
        inputs = self.tokenizer(X[:512], return_tensors="pt")
        return self.model(**inputs).last_hidden_state[:,-1,:].detach().numpy()


    def form_vector_DB(self):
        if self.vector_DB is not None:
            print("Vector DB already exists!")
            return

        self.vector_DB = np.stack(
            tuple( self.calc_embeding(row["Description"]) for _, row in tqdm(self.products_DB.sort_values(by = "Uniq Id").iterrows(), total = len(self.products_DB))
            )
        )

    def get_n_rank(self, query, n = 10):
        if self.vector_DB is None:
            print("No Vector DB exists!")
            return

        if not self.model:
            print("U should set_up() nodel first!")
            return

        query_v = self.calc_embeding(query)
        top_indx = cdist(query_v, self.vector_DB, 'cosine')[0].argsort()[:n] # top most simular (with smallar distance)
        out_uniq_id = self.products_DB.sort_values(by = "Uniq Id").iloc[top_indx]["Uniq Id"]
        return out_uniq_id

In [13]:
# инициализируем модель для инференса
inference_model = Model(df_val, path_to_vector_BD = "vector_BD.npy", to_set_up = True)

Downloading Vector BD


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

#### **инференс на тестовых запросах**

In [25]:
# выберем k тестовых стандартных(не абстрактных) запросов случайно из выборки
seed_everything(42)
k = 10
random_indices = np.random.randint(0, len(df_val), size=k)

test_queries = df_val.loc[random_indices]
#test_products = df_val.loc[random_indices][["Uniq Id", "Product Name", "clean_desc"]]
test_queries

Unnamed: 0,Uniq Id,Product Name,clean_desc,search_text,abstract_search_text
3174,49211ae8276d458310dd06f536653fb8,Yumbutter Superfood Almond Butter - Case Of 6 ...,yumbutter superfood almond butter - case of 6 ...,Almond butter,
3507,48f177144587bb50f2298a14fdc02b86,"Setton Farms Omega-3 Mix, 18 Oz",setton farms mixes are created with unique ble...,omega-3 mix,
860,e84cb9e22be2bdf90e1e01b0a273c263,"Sanpellegrino Orange Sparkling Fruit Beverage,...","inspired by homemade italian aranciata, sanpel...",Sparkling beverage,
1294,325db3ffed5c3f15b53530920433bfa0,Lily Of The Desert Organic Aloe Vera Juice Who...,lily of the desert organic aloe vera juice who...,Aloe vera juice,
1130,cf37eea3b3566ea9e32b3f305875cbb3,Raymundos No Added Sugar Strawberry Fruit & Ge...,fruit & gel 7.5oz raymundos straw frt & gel 7.5oz,Strawberry snack,
1095,5a76983a268315991ec00c07c8033d85,"Bulletproof Glutathione Force, Master Antioxid...",strengthen your body against the toxins and st...,antioxidant capsules,
3772,54386abe067dc221416bfb2e99b791b9,"LaCroix Sparkling Water - PinaFraise ""Strawber...",lacroix sparkling water is the #1 sparkling wa...,Sparkling water,
3092,9a359f5429f285a357ce897f04581950,"Cafe Bustelo Espresso Ground Coffee, 10 Ounce","cafe bustelo for flavor, aroma and quality is ...",Espresso coffee,
1638,236841b9b5a1638b3efdec7773142252,"Peytons Thick Sliced Double Smoked Bacon, 12 Oz.",u.s. inspected and passed by department of agr...,Smoked bacon,
2169,094bb5316592e045a38c0bb51addc056,Vibrant health green vibrance superfood powder...,our vitality is the sum total of the health of...,superfood powder,


In [89]:
# проинферим модель на запросах
top_n = 20
model_output = dict()
for row in test_queries.iterrows():
  product_ids = inference_model.get_n_rank(row[1]["search_text"], top_n)
  model_output[row[1]["Uniq Id"]] = product_ids

In [90]:
# проитерируемся по выходам модели и посмотрим на запросы и то, что модель выдала на них
for uniq_id in model_output.keys():
  row = test_queries[test_queries["Uniq Id"] == uniq_id]
  predicted_product_ids = model_output[uniq_id]

  print(f"query was: {row['search_text'].item()} \ncorrect product is: {row['Product Name'].item()}")
  print(f"predicted products are:\n {df_val[df_val['Uniq Id'].isin(predicted_product_ids)]['Product Name'].values}\n\n")

query was: Almond butter 
correct product is: Yumbutter Superfood Almond Butter - Case Of 6 - 6.2 Oz
predicted products are:
 ['Corn Starch (8 oz, ZIN: 526021) - 3-Pack'
 '4 Pack - Humco Cola Syrup 4 oz Each'
 'Two Leaves and a Bud, Inc., Organic Earl Grey Black Tea, 15 Count'
 'Vienna Beef - Smoked Beef Stix (2.5Lbs)'
 '(2 Pack) Great Value 4 in 1 Decorating Sprinkles, All Occasion, 5.2 oz'
 'Folgers Brazilian Blend Ground Coffee, Medium Roast, 24.2 Ounce'
 'Judees Gluten Free Guar Gum Powder, 10 oz'
 'Kind Plus Bars, 1.4 oz bars, Cranberry Almond + Antioxidants 12 bars'
 '(3 Boxes) Bigelow, Red Raspberry, Tea Bags, 20 Ct'
 'Pacific Natural Foods Coconut Original - Non Dairy - Case of 12 - 32 Fl oz.'
 'Boost Kid Oral Supplement Esentials Vanilla 8.25 oz. Ready to Use CS|16 PK|2'
 'Ganoderma Coffe Reyshen Red 200gms. Cafe Organico con Hongo Ganoderma Rojo 200grms'
 '(2 Pack) Smuckers Natural Concord Grape Fruit Spread, 12.75 oz'
 'Adolphs Meat Loaf Mix, 2.11 oz (Pack of 6)'
 'Cafe Rey 

#### **вывод по инференсу**

На 10 тестовых запросах по топ 20 продуктов на каждый видим, что необученный берт пока не может подтянуть товар, по которому создавался запрос, однако в выдаче есть релевантные товары, например, для запроса Espresso coffee в выдаче есть другие виды кофе, а на запрос Strawberry snack выдаются десерты/чай

Для дальнейшей работы по улучшению качества модели можно выделить следующие варианты:

1. взять модель побольше и возможно предобученную на доменах ретейл/ecom
2. попробовать другой датасет
3. добавить препроцессинг описания товаров/запросов