In [1]:
import string
from dataclasses import dataclass
from tqdm import tqdm

import numpy as np
import pandas as pd

import hnswlib
import faiss

import torch
from torch import nn

from torch.utils.data import Dataset, DataLoader

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from transformers import AutoTokenizer, AutoModel


In [2]:
dataset = pd.read_parquet("products_with_names.parquet")

In [3]:
@dataclass
class Document:
    document_id: int
    document_name: str

documents = [Document(document_id=doc[1]["product_id"], document_name=doc[1]["name"]) for doc in dataset.iterrows() if doc[1]["name"] != ""]


In [4]:
documents[:5]

[Document(document_id=4036767, document_name='Модуль сменный фильтрующий Аквафор КН, 208731'),
 Document(document_id=4050873, document_name='Водоочиститель Аквафор модель Кристалл Н, 205963 //с краном'),
 Document(document_id=4226160, document_name='Развиваем мышление (2-3 года) | Земцова Ольга'),
 Document(document_id=4644911, document_name='Lacoste Вода парфюмерная Pour Femme 50 мл'),
 Document(document_id=4788809, document_name='Сменные Кассеты Для Мужской Бритвы Gillette Mach3, с 3 лезвиями, прочнее, чем сталь, для точного бритья, 2 шт')]

In [5]:
class SimpleTextProcessor:
    def __init__(self):
        self.symbols_to_replace = {"ё": "е"}

    def lowercase_text(self, text: str) -> str:
        return text.lower()

    def replace_symbols(self, text: str) -> str:
        for old, new in self.symbols_to_replace.items():
            text = text.replace(old, new)
        return text

    def process_punctuation_simple(self, text: str) -> str:
        translation_table = str.maketrans(string.punctuation, ' ' * len(string.punctuation))
        text_without_punc = text.translate(translation_table)
        text_without_double_spaces = ' '.join(text_without_punc.split())
        return text_without_double_spaces

    def process_text(self, text: str) -> str:
        text = self.lowercase_text(text)
        text = self.replace_symbols(text)
        text = self.process_punctuation_simple(text)
        return text
        

In [6]:
text_processor = SimpleTextProcessor()

In [7]:
documents_processed = [
    Document(document.document_id, text_processor.process_text(document.document_name)) 
    for document in tqdm(documents)
]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 238428/238428 [00:03<00:00, 72133.85it/s]


In [8]:
documents_processed[:5]

[Document(document_id=4036767, document_name='модуль сменный фильтрующий аквафор кн 208731'),
 Document(document_id=4050873, document_name='водоочиститель аквафор модель кристалл н 205963 с краном'),
 Document(document_id=4226160, document_name='развиваем мышление 2 3 года земцова ольга'),
 Document(document_id=4644911, document_name='lacoste вода парфюмерная pour femme 50 мл'),
 Document(document_id=4788809, document_name='сменные кассеты для мужской бритвы gillette mach3 с 3 лезвиями прочнее чем сталь для точного бритья 2 шт')]

## Tokenization

In [9]:
corpus = [doc.document_name for doc in documents_processed]

In [10]:
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))

tokenizer.pre_tokenizer = Whitespace()

trainer = BpeTrainer(
    special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
    vocab_size=10_000,
    min_frequency=2,
    show_progress=True,
    continuing_subword_prefix="##",
)

tokenizer.train_from_iterator(corpus, trainer)

tokenizer.save("bpe_tokenizer.json")






In [11]:
tokenizer = Tokenizer.from_file("bpe_tokenizer.json")

In [12]:
encoding = tokenizer.encode("мороженое для собак")
print(f"Tokens: {encoding.tokens}")
print(f"Token ids: {encoding.ids}")

Tokens: ['мороженое', 'для', 'собак']
Token ids: [2326, 195, 566]


In [13]:
encoding = tokenizer.encode("бритва gillette")
print(f"Tokens: {encoding.tokens}")
print(f"Token ids: {encoding.ids}")

Tokens: ['бритва', 'gillette']
Token ids: [3905, 4069]


In [14]:
encoding = tokenizer.encode("шампунь мужской nivea men")
print(f"Tokens: {encoding.tokens}")
print(f"Token ids: {encoding.ids}")

Tokens: ['шампунь', 'мужской', 'nivea', 'men']
Token ids: [815, 1278, 3090, 2877]


In [15]:
encoding = tokenizer.encode("шампунь мужской nivea men охлаждающий")
print(f"Tokens: {encoding.tokens}")
print(f"Token ids: {encoding.ids}")

Tokens: ['шампунь', 'мужской', 'nivea', 'men', 'охлажда', '##ющий']
Token ids: [815, 1278, 3090, 2877, 7492, 300]


In [16]:
tokenizer.decode(encoding.ids)

'шампунь мужской nivea men охлажда ##ющий'

In [17]:
VOCAB_SIZE = tokenizer.get_vocab_size()
print(f"Tokenizer vocab size: {VOCAB_SIZE}")

Tokenizer vocab size: 10000


## Building DSSM

![title](dssm.png)

In [28]:
class DSSM(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int = 256, hidden_dims: list[int] = [512, 256, 128], padding_idx: int = 3):
        super().__init__()

        self.padding_idx = padding_idx

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)

        layers = []
        input_dim = embedding_dim
        for dim in hidden_dims:
            layers.append(nn.LayerNorm(input_dim))
            layers.append(nn.Linear(input_dim, dim))
            layers.append(nn.LeakyReLU())
            input_dim = dim

        self.mlp = nn.Sequential(*layers)

        self.ln = nn.LayerNorm(input_dim)
        
    def forward(self, input_ids: torch.LongTensor):
        input_embeddings = self.embedding(input_ids)

        input_embeddings_pooled = torch.sum(input_embeddings, dim=1) / torch.sum(input_ids!=self.padding_idx, dim=1, keepdim=True)

        mlp_embeddings = self.mlp(input_embeddings_pooled)

        return self.ln(mlp_embeddings)


## Prepare data for training

In [19]:
query_product_positive_interactions = pd.read_parquet("query_product_positive_interactions.parquet")

In [20]:
query_product_positive_interactions.shape

(18121972, 4)

In [21]:
query_product_positive_interactions.head()

Unnamed: 0,user_id,timestamp,search_query,product_name
0,9897711,2024-03-16 11:28:10,линзы,"ACUVUE Контактные линзы, -3.75, 8.4, 2 недели"
1,3666669,2024-04-15 13:40:39,линзы -4,"ACUVUE Контактные линзы, -3.75, 8.4, 2 недели"
2,4951147,2024-04-28 06:23:20,линзы acuvue,"ACUVUE Контактные линзы, -3.75, 8.4, 2 недели"
3,972605,2024-04-25 13:00:18,линзы acuvue,"ACUVUE Контактные линзы, -3.75, 8.4, 2 недели"
4,972605,2024-03-10 07:08:38,линзы acuvue,"ACUVUE Контактные линзы, -3.75, 8.4, 2 недели"


In [22]:
data = []

In [23]:
for row in tqdm(query_product_positive_interactions.head(1_000_000).iterrows()):
    data.append(
        (
            text_processor.process_text(row[1].search_query), 
            text_processor.process_text(row[1].product_name),
        )
    )

1000000it [01:04, 15442.35it/s]


In [24]:
data[-5:]

[('зелень свежая', 'укроп 75 г'),
 ('укроп', 'укроп 75 г'),
 ('зелень свежая', 'укроп 75 г'),
 ('зелень свежая', 'укроп 75 г'),
 ('укроп свежий', 'укроп 75 г')]

In [25]:
class QueryDocDataset(Dataset):
    def __init__(self, data: list[tuple[str, str]], tokenizer: Tokenizer, max_len: int = 32):
        self.queries = [row[0] for row in data]
        self.docs = [row[1] for row in data]
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, idx: int):
        query = self.queries[idx]
        doc = self.docs[idx]

        query_ids = self.tokenizer.encode(query).ids[:self.max_len]
        doc_ids = self.tokenizer.encode(doc).ids[:self.max_len]

        return {
            "query_ids": torch.LongTensor(query_ids),
            "doc_ids": torch.LongTensor(doc_ids),
        }

    def __len__(self):
        return len(self.queries)


def collate_fn(batch):
    query_ids = torch.nn.utils.rnn.pad_sequence(
        [x["query_ids"] for x in batch], batch_first=True, padding_value=3
    )
    doc_ids = torch.nn.utils.rnn.pad_sequence(
        [x["doc_ids"] for x in batch], batch_first=True, padding_value=3
    )
    return {"query_ids": query_ids, "doc_ids": doc_ids}


In [26]:
DEVICE = "cuda:1"

In [29]:
model = DSSM(vocab_size=tokenizer.get_vocab_size())
model.to(DEVICE)

DSSM(
  (embedding): Embedding(10000, 256, padding_idx=3)
  (mlp): Sequential(
    (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=256, out_features=512, bias=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (7): Linear(in_features=256, out_features=128, bias=True)
    (8): LeakyReLU(negative_slope=0.01)
  )
  (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)

In [30]:
dataset = QueryDocDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)

In [31]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [32]:
loss_fn = nn.CrossEntropyLoss()

In [33]:
def train_step(batch):
    optimizer.zero_grad()
    
    query_ids, doc_ids = batch["query_ids"], batch["doc_ids"]

    query_vectors = model(query_ids.to(DEVICE))
    doc_vectors = model(doc_ids.to(DEVICE))

    query_vectors = torch.nn.functional.normalize(query_vectors, p=2, dim=1)
    doc_vectors = torch.nn.functional.normalize(doc_vectors, p=2, dim=1)

    scores = torch.matmul(query_vectors, doc_vectors.T)  # [batch_size, batch_size]

    labels = torch.arange(len(scores), device=scores.device)

    loss = loss_fn(scores, labels)

    loss.backward()
    optimizer.step()

    return loss.item()


for epoch in range(10):
    train_losses = []
    for i, batch in enumerate(dataloader):
        loss = train_step(batch)
        train_losses.append(loss)
        if i % 1000 == 0:
            print(f"Epoch: {epoch}, Iteration: {i}, Loss: {sum(train_losses) / len(train_losses):.4f}")
            train_losses = []

Epoch: 0, Iteration: 0, Loss: 4.6828
Epoch: 0, Iteration: 1000, Loss: 3.9814
Epoch: 0, Iteration: 2000, Loss: 3.9370
Epoch: 0, Iteration: 3000, Loss: 3.9312
Epoch: 0, Iteration: 4000, Loss: 3.9281
Epoch: 0, Iteration: 5000, Loss: 3.9263
Epoch: 0, Iteration: 6000, Loss: 3.9248
Epoch: 0, Iteration: 7000, Loss: 3.9233
Epoch: 1, Iteration: 0, Loss: 3.9764
Epoch: 1, Iteration: 1000, Loss: 3.9199
Epoch: 1, Iteration: 2000, Loss: 3.9189
Epoch: 1, Iteration: 3000, Loss: 3.9188
Epoch: 1, Iteration: 4000, Loss: 3.9181
Epoch: 1, Iteration: 5000, Loss: 3.9185
Epoch: 1, Iteration: 6000, Loss: 3.9192
Epoch: 1, Iteration: 7000, Loss: 3.9176
Epoch: 2, Iteration: 0, Loss: 3.9177
Epoch: 2, Iteration: 1000, Loss: 3.9154
Epoch: 2, Iteration: 2000, Loss: 3.9162
Epoch: 2, Iteration: 3000, Loss: 3.9164
Epoch: 2, Iteration: 4000, Loss: 3.9155
Epoch: 2, Iteration: 5000, Loss: 3.9156
Epoch: 2, Iteration: 6000, Loss: 3.9144
Epoch: 2, Iteration: 7000, Loss: 3.9152
Epoch: 3, Iteration: 0, Loss: 3.9257


KeyboardInterrupt: 

In [34]:
def cosine_similarity(vec_1, vec_2):
    vec_1_normalized = torch.nn.functional.normalize(vec_1, p=2, dim=1)
    vec_2_normalized = torch.nn.functional.normalize(vec_2, p=2, dim=1)
    scores = torch.sum(vec_1_normalized * vec_2_normalized, dim=1)

    return scores.item()
    

In [39]:
query = "творожный сыр"
doc = "молоко"

In [40]:
query_embedding = model(torch.LongTensor([tokenizer.encode(query).ids]).to(DEVICE))

In [41]:
document_embedding = model(torch.LongTensor([tokenizer.encode(doc).ids]).to(DEVICE))

In [42]:
cosine_similarity(query_embedding, document_embedding)

-0.06171335279941559

## ANN index

In [43]:
document_names = [x.document_name for x in documents_processed]

In [44]:
document_names[:3]

['модуль сменный фильтрующий аквафор кн 208731',
 'водоочиститель аквафор модель кристалл н 205963 с краном',
 'развиваем мышление 2 3 года земцова ольга']

In [46]:
def embed_texts(texts, model, tokenizer, max_len=32):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for text in tqdm(texts):
            text_ids = tokenizer.encode(text).ids[:max_len]
            text_ids = torch.tensor(text_ids, dtype=torch.long).unsqueeze(0).to(DEVICE)
            text_embedding = model(text_ids)
            embeddings.append(text_embedding.squeeze(0).cpu().numpy())
    return np.stack(embeddings)


doc_embeddings = embed_texts(document_names, model, tokenizer)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 238428/238428 [01:32<00:00, 2566.70it/s]


In [47]:
dim = doc_embeddings.shape[1]
index = hnswlib.Index(space="cosine", dim=dim)

index.init_index(max_elements=len(documents), ef_construction=200, M=16)
index.add_items(doc_embeddings)

In [48]:
query = "сыр сливочный"
query_embedding = embed_texts([query], model, tokenizer)[0]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 719.06it/s]


In [49]:
k = 5
labels, distances = index.knn_query(query_embedding, k=k)

for i, (label, dist) in enumerate(zip(labels[0], distances[0])):
    print(f"{documents[label]} (Score: {1 - dist})")

Document(document_id=149171036, document_name='Сыр творожный сливочный 60% 150 г, Almette') (Score: 0.9999899864196777)
Document(document_id=291544604, document_name='Сыр творожный сливочный 150 г, Almette') (Score: 0.9999895095825195)
Document(document_id=170339265, document_name='Сыр творожный Violette, сливочный, 70 %, 140 г') (Score: 0.9999889135360718)
Document(document_id=194534557, document_name='Сыр Ламбер "Гауда", 45%, кусок, 180 г') (Score: 0.9999889135360718)
Document(document_id=142583602, document_name='Сыр полутвердый Сливочный 250 г, Laplandia, нарезка') (Score: 0.9999886751174927)


In [None]:
class DSSMRetriever:
    def __init__(self, model, tokenizer, documents, document_names):
        self.model = model
        self.tokenizer = tokenizer
        self.documents = documents
        self.document_names = document_names
        
        self.dim = model.output_dim
        self.index = hnswlib.Index(space="cosine", dim=self.dim)
        self.index.init_index(max_elements=len(self.document_names), ef_construction=200, M=16)
        
        doc_embeddings = embed_texts(self.document_names, model, tokenizer)
        self.index.add_items(doc_embeddings)
    
    def search(self, query, k=5):
        query_embedding = embed_texts([query], self.model, self.tokenizer)[0]
        labels, _ = self.index.knn_query(query_embedding, k=k)
        return [self.documents[label] for label in labels[0]]


retriever = DSSMRetriever(model, tokenizer, documents, document_names)


In [None]:
results = retriever.search("корм для собак", k=10)
print(results)

## ColBERT

![title](colbert.png)

In [50]:
class ColBERT:
    def __init__(self, model_name: str = "cointegrated/rubert-tiny"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(DEVICE)
        self.dim = self.model.config.hidden_size
        self.model.eval()
        
    def encode(self, texts):
        inputs = self.tokenizer(
            texts, 
            padding=True, 
            truncation=True, 
            return_tensors="pt",
            max_length=64
        ).to(DEVICE)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.last_hidden_state.cpu()
    
    def compute_scores(self, query_emb, doc_emb):
        query_emb = torch.nn.functional.normalize(query_emb, p=2, dim=-1)
        doc_emb = torch.nn.functional.normalize(doc_emb, p=2, dim=-1)
        
        sim = torch.bmm(query_emb, doc_emb.transpose(1, 2))
        
        return torch.max(sim, dim=-1).values.sum(dim=-1)


In [51]:
class FaissIndex:
    def __init__(self, dim):
        self.index = faiss.IndexFlatIP(dim)
        
    def add_docs(self, embeddings):
        doc_vectors = embeddings.mean(dim=1).numpy()
        self.index.add(doc_vectors)
        
    def search(self, query_emb, k=5):
        query_vector = query_emb.mean(dim=1).numpy()
        distances, labels = self.index.search(query_vector, k=k)
        return labels, distances


In [52]:
colbert = ColBERT()
index = FaissIndex(dim=colbert.dim)



In [53]:
document_names = [x.document_name for x in documents]

In [54]:
doc_embeddings = []
for i in tqdm(range(0, 240_000, 10_000)):
    doc_embeddings.append(colbert.encode(document_names[i:i+10000]))
doc_embeddings = torch.cat(doc_embeddings, dim=0)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:30<00:00,  1.28s/it]


In [55]:
index.add_docs(doc_embeddings)

In [56]:
query = "молоко"
query_embedding = colbert.encode([query])

In [57]:
labels, scores = index.search(query_embedding, k=5)

for i, (label, score) in enumerate(zip(labels[0], scores[0])):
    print(f"{documents[label]} (Score: {score:.3f})")

Document(document_id=441379467, document_name='Авокадо') (Score: 144.330)
Document(document_id=196062760, document_name='Кукла') (Score: 142.064)
Document(document_id=20365946, document_name='Репка') (Score: 141.719)
Document(document_id=1539615715, document_name='Чулки') (Score: 141.320)
Document(document_id=567427685, document_name='Зеркало') (Score: 141.117)


In [58]:
model_name: str = "cointegrated/rubert-tiny"

In [59]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [60]:
tokenizer(
    ["молоко"], 
    padding=True, 
    truncation=True, 
    return_tensors="pt",
    max_length=64
)

{'input_ids': tensor([[    2,   324, 20879,  1597,     3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}