In [1]:
import time
import random
from tqdm import tqdm

import polars as pl

In [2]:
search_events = (
    pl.read_parquet("search_events.parquet")
    .filter(pl.col("search_query").is_not_null())
)

In [3]:
search_events.head(5)

user_id,timestamp,search_query
i32,datetime[ns],str
10904261,2024-05-15 15:42:19,"""ламбер сыр"""
2799838,2024-05-09 19:12:06,"""кетчуп"""
1156024,2024-05-19 21:19:14,"""турецкий айран"""
1969650,2024-05-30 08:42:05,"""набор для творчества"""
5184152,2024-05-13 15:37:06,"""чеснок маринованный"""


In [4]:
search_events.shape

(15712074, 3)

## Собираем запросы для индекса

In [5]:
search_events_grouped = (
    search_events.group_by("search_query").len()
)

In [6]:
search_events_grouped.head()

search_query,len
str,u32
"""амбушюры для airpods""",1
"""цепт""",1
"""блок игрушка""",1
"""орзо паста рис""",1
"""шнековая соковыжималка геймлюк…",1


In [7]:
search_events_grouped.shape

(1487083, 2)

In [8]:
search_events_grouped.filter(pl.col("len") > 1).shape

(538694, 2)

In [9]:
search_queries_filtered = search_events_grouped.filter(pl.col("len") > 1)

In [10]:
search_queries_filtered.head()

search_query,len
str,u32
"""royal canin gastrointestinal f…",6
"""курага братья ореховы""",18
"""килька по гавайски с овощами""",6
"""подарок маме на юбилей 50""",2
"""памперсы сени m""",5


In [11]:
search_queries_filtered.sort(pl.col("len"), descending=True).head()

search_query,len
str,u32
"""молоко""",166186
"""хлеб""",137364
"""сыр""",108392
"""мороженое""",106857
"""яйца""",75208


## Посчитаем популярность запросов

In [12]:
to_cart_events = pl.read_parquet("to_cart_events.parquet")

In [13]:
to_cart_events.head()

user_id,timestamp,search_query,product_id
i32,datetime[ns],str,i64
1423743,2024-05-11 20:54:13,"""икра красная""",180730425
37569,2024-05-02 08:51:53,"""руккола свежая""",357175617
5240641,2024-05-16 01:47:09,"""сырки дружба""",356541992
8450453,2024-05-28 05:55:01,"""фрутоняня""",141861404
11105338,2024-05-02 16:02:37,"""пельмени""",392485080


In [14]:
queries_popularity = (
    to_cart_events.group_by("search_query").len(name="popularity")
)

In [15]:
queries_popularity.head()

search_query,popularity
str,u32
"""rubiscookies""",1
"""тушь вивен сабо""",1
"""магнитные шарики""",1
"""щармель""",2
"""шт""",6


In [16]:
queries_popularity.sort(pl.col("popularity"), descending=True).head()

search_query,popularity
str,u32
"""мороженое""",173872
"""молоко""",148224
"""хлеб""",138253
"""сыр""",112617
"""творог""",73395


In [17]:
index_queries = (
    search_queries_filtered.join(queries_popularity, on="search_query", how="left")
    .fill_null(value=0)
    .select(
        pl.col("search_query"),
        pl.col("popularity")
    )
)

In [18]:
index_queries.head()

search_query,popularity
str,u32
"""royal canin gastrointestinal f…",6
"""курага братья ореховы""",17
"""килька по гавайски с овощами""",0
"""подарок маме на юбилей 50""",0
"""памперсы сени m""",2


## Напишем префиксное дерево поиска

In [27]:
class TrieNode:
    def __init__(self):
        self._children: dict[str, "TrieNode"] = {}
        self._is_end = False
        self._popularity = 0

    def get_popularity(self) -> int:
        return self._popularity

    def set_popularity(self, popularity: int) -> None:
        self._popularity = popularity

    def is_end(self) -> bool:
        return self._is_end

    def mark_as_end(self) -> None:
        self._is_end = True

    def add_child(self, char: str) -> None:
        self._children[char] = TrieNode()

    def get_children(self) -> dict[str, "TrieNode"]:
        return self._children


In [36]:
class Trie:
    def __init__(self):
        self._root = TrieNode()

    def insert(self, query: str, popularity: int) -> None:
        node = self._root
        for char in query:
            children = node.get_children()
            if char not in children:
                node.add_child(char)
            node = children[char]
        node.mark_as_end()
        node.set_popularity(popularity)

    def get_completions(self, prefix: str, limit: int = -1) -> list[tuple[str, int]]:
        completions = []
        node_starts_with_prefix = self._get_node_starts_with_prefix(prefix)
        if node_starts_with_prefix is None:
            return completions

        self._collect_completions(node_starts_with_prefix, prefix, completions)
        return sorted(completions, key=lambda x: -x[1])[:limit if limit > 0 else None]

    def _get_node_starts_with_prefix(self, prefix: str) -> TrieNode:
        node = self._root
        for char in prefix:
            children = node.get_children()
            if char not in children:
                return None
            node = children[char]
        return node

    def _collect_completions(self, node: TrieNode, current_prefix, completions) -> None:
        if node.is_end():
            completions.append((current_prefix, node.get_popularity()))
        for char, child_node in node.get_children().items():
            self._collect_completions(child_node, current_prefix + char, completions)


In [37]:
trie = Trie()

In [38]:
trie.insert("молоко", 100)
trie.insert("машина", 6)
trie.insert("молоток", 9)
trie.insert("молоко пастеризованное", 13)


In [41]:
trie.get_completions("мо")

[('молоко', 100), ('молоко пастеризованное', 13), ('молоток', 9)]

## Добавим в дерево наши запросы

In [42]:
trie = Trie()

In [43]:
for query, popularity in tqdm(index_queries.iter_rows()):
    trie.insert(query, popularity)

538694it [00:05, 96217.86it/s] 


In [46]:
trie.get_completions("ма", limit=100)

[('макароны', 50646),
 ('масло сливочное', 42951),
 ('майонез', 42640),
 ('масло подсолнечное', 16380),
 ('мармелад', 15554),
 ('масло сливочное 82,5', 10178),
 ('масло', 9277),
 ('макароны barilla', 5658),
 ('масло растительное', 5352),
 ('маслины', 4368),
 ('масло оливковое', 4213),
 ('манго', 4064),
 ('мандарины', 3935),
 ('майонез слобода', 3698),
 ('мармелад жевательный', 3172),
 ('маслины без косточки', 2424),
 ('маскарпоне', 2168),
 ('малина', 1807),
 ('манка', 1767),
 ('магнат', 1660),
 ('манго сушеный', 1610),
 ('майонез провансаль', 1573),
 ('мандарины свежие', 1514),
 ('маринованные огурцы', 1475),
 ('мажитэль', 1437),
 ('масло подсолнечное рафинированное', 1375),
 ('мацони', 1366),
 ('масло гхи', 1284),
 ('маска для волос', 1241),
 ('макароны без глютена', 1211),
 ('маасдам', 1210),
 ('макароны макфа', 1205),
 ('манная крупа', 1112),
 ('маска для лица', 1064),
 ('макароны из твердых сортов пшеницы', 1059),
 ('мармеладки', 1009),
 ('майонез легкий', 1006),
 ('магги', 997),
 

## Добавим возможность нечеткого (fuzzy) поиска

In [47]:
class FuzzyTrie:
    def __init__(self):
        self._root = TrieNode()

    def insert(self, query: str, popularity: int) -> None:
        node = self._root
        for char in query:
            children = node.get_children()
            if char not in children:
                node.add_child(char)
            node = children[char]
        node.mark_as_end()
        node.set_popularity(popularity)

    def get_completions(self, prefix: str, fuzziness: int = 0, limit: int = -1) -> list[tuple[str, int]]:
        completions = []
        
        def dfs(node: TrieNode, current_prefix: str, i: int, edits: int):
            if edits > fuzziness:
                return

            if i == len(prefix):
                self._collect_completions(node, current_prefix, completions)
                return

            current_char = prefix[i]

            children = node.get_children()

            # точный поиск
            if current_char in children:
                dfs(children[current_char], current_prefix + current_char, i + 1, edits)

            # замена последнего символа
            for char, child_node in children.items():
                if current_char != char:
                    dfs(child_node, current_prefix + char, i + 1, edits + 1)

            # вставка символа (например, "молко" вместо "молоко")
            for char, child_node in children.items():
                dfs(child_node, current_prefix + char, i, edits + 1)

            # удаление символа
            dfs(node, current_prefix, i + 1, edits + 1)
            

        dfs(self._root, "", 0, 0)
        unique_completions = {}
        for query, popularity in completions:
            if query not in unique_completions:
                unique_completions[query] = popularity

        return sorted(unique_completions.items(), key=lambda x: -x[1])

    def _collect_completions(self, node: TrieNode, current_prefix, completions) -> None:
        if node.is_end():
            completions.append((current_prefix, node.get_popularity()))
        for char, child_node in node.get_children().items():
            self._collect_completions(child_node, current_prefix + char, completions)


In [48]:
fuzzy_trie = FuzzyTrie()

In [49]:
fuzzy_trie.insert("молоко", 100)
fuzzy_trie.insert("машина", 6)
fuzzy_trie.insert("молоток", 9)
fuzzy_trie.insert("молоко пастеризованное", 13)


In [53]:
fuzzy_trie.get_completions("мош", fuzziness=1, limit=100)

[('молоко', 100),
 ('молоко пастеризованное', 13),
 ('молоток', 9),
 ('машина', 6)]

In [54]:
fuzzy_trie = FuzzyTrie()

In [55]:
for query, popularity in tqdm(index_queries.iter_rows()):
    fuzzy_trie.insert(query, popularity)

538694it [00:05, 90979.86it/s] 


In [56]:
fuzzy_trie.get_completions("мало", fuzziness=1, limit=100)

[('молоко', 148224),
 ('масло сливочное', 42951),
 ('майонез', 42640),
 ('молоко 3,2', 17658),
 ('масло подсолнечное', 16380),
 ('молоко безлактозное', 14275),
 ('масло сливочное 82,5', 10178),
 ('масло', 9277),
 ('масло растительное', 5352),
 ('молоко ультрапастеризованное', 4436),
 ('масло оливковое', 4213),
 ('майонез слобода', 3698),
 ('мыло жидкое для рук', 3692),
 ('молоко 2,5', 3444),
 ('мыло', 3275),
 ('молоко детское', 2942),
 ('молоко кокосовое', 2739),
 ('молоко растительное', 2543),
 ('молочный коктейль', 2352),
 ('молоко агуша', 2241),
 ('малина', 1807),
 ('сало', 1786),
 ('мыло жидкое', 1604),
 ('молоко топленое', 1592),
 ('майонез провансаль', 1573),
 ('масло подсолнечное рафинированное', 1375),
 ('мацони', 1366),
 ('масло гхи', 1284),
 ('молоко миндальное', 1257),
 ('молоко эконива', 1218),
 ('майонез легкий', 1006),
 ('мыло твердое', 976),
 ('мыло для рук', 957),
 ('молоко parmalat', 930),
 ('молоко пармалат', 909),
 ('молочный ломтик', 870),
 ('масло оливковое extra v

# Генеративные подсказки

In [57]:
user_sequences = (
    search_events.group_by("user_id", maintain_order=True)
    .agg(
        pl.struct(pl.col("timestamp"), pl.col("search_query")).alias("user_actions")
    )
)

In [58]:
user_sequences.head()

user_id,user_actions
i32,list[struct[2]]
10904261,"[{2024-05-15 15:42:19,""ламбер сыр""}, {2024-05-17 14:09:09,""експонента exponenta напиток""}, … {2024-05-17 14:10:38,""лечо""}]"
2799838,"[{2024-05-09 19:12:06,""кетчуп""}, {2024-05-09 19:13:01,""яйца""}, … {2024-05-26 13:15:35,""энергетический напиток""}]"
1156024,"[{2024-05-19 21:19:14,""турецкий айран""}, {2024-05-19 21:18:51,""мёд турецкий""}, … {2024-05-19 21:21:32,""возбуждающий мёд""}]"
1969650,"[{2024-05-30 08:42:05,""набор для творчества""}, {2024-05-31 06:24:47,""дарц""}, … {2024-05-30 08:41:40,""пазлы""}]"
5184152,"[{2024-05-13 15:37:06,""чеснок маринованный""}, {2024-05-13 15:39:22,""армянские продукты бутень""}, … {2024-05-13 15:38:02,""чеснок маринованный зелений""}]"


In [59]:
user_sequences_list = user_sequences[["user_actions"]].rows()

In [60]:
user_sequences_list_sorted = [sorted(x[0], key=lambda x: x["timestamp"]) for x in user_sequences_list]

In [61]:
user_sequences_list_sorted[0]

[{'timestamp': datetime.datetime(2024, 5, 15, 15, 40, 1),
  'search_query': 'шницель куриный'},
 {'timestamp': datetime.datetime(2024, 5, 15, 15, 40, 35),
  'search_query': 'експонента exponenta напиток'},
 {'timestamp': datetime.datetime(2024, 5, 15, 15, 41, 5),
  'search_query': 'кофе в капсулах'},
 {'timestamp': datetime.datetime(2024, 5, 15, 15, 42, 19),
  'search_query': 'ламбер сыр'},
 {'timestamp': datetime.datetime(2024, 5, 17, 14, 8, 43),
  'search_query': 'шницель куриный'},
 {'timestamp': datetime.datetime(2024, 5, 17, 14, 9, 9),
  'search_query': 'експонента exponenta напиток'},
 {'timestamp': datetime.datetime(2024, 5, 17, 14, 10, 38),
  'search_query': 'лечо'},
 {'timestamp': datetime.datetime(2024, 5, 27, 13, 18, 4),
  'search_query': 'гранола настин сластин'},
 {'timestamp': datetime.datetime(2024, 5, 27, 13, 18, 13),
  'search_query': 'гранола'},
 {'timestamp': datetime.datetime(2024, 5, 27, 13, 20, 15),
  'search_query': 'гранола настин сластин'}]

In [62]:
user_texts = [[y["search_query"] for y in x] for x in user_sequences_list_sorted]

## Train tokenizer

In [None]:
user_texts_joined = [" ".join(x) for x in user_texts]

with open("tokenizer_train_data.txt", "w") as file:
    file.write("\n".join(user_texts_joined) + "\n")

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace


tokenizer = Tokenizer(WordPiece())

trainer = WordPieceTrainer(
    vocab_size=8_192,
    special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "[SEARCH]", "[PREFIX]", "[EOQ]"]
)

tokenizer.pre_tokenizer = Whitespace()

tokenizer.train(files=["tokenizer_train_data.txt"], trainer=trainer)

In [None]:
tokenizer.save("WPC_tokenizer.json")

In [63]:
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("WPC_tokenizer.json")

encoded = tokenizer.encode(
    "машина на пульте управления"
)
encoded.tokens

['машина', 'на', 'пульт', '##е', 'уп', '##ра', '##в', '##ления']

In [64]:
tokenizer.encode("[SEARCH]").ids

[5]

## Построим обучающую выборку для языковой модели

In [65]:
MAX_LM_LENGTH = 256

In [None]:
[PREFIX] мол [SEARCH] молоко [EOQ] [PREFIX] с [SEARCH] сыр плавленный [EOQ]

In [66]:
def select_random_prefix(text: str) -> str:
    rnd = random.randint(1, len(text))
    return text[:rnd]


def make_training_example(search_queries: list[str]) -> tuple[list[int], list[str], list[bool]]:
    prompt_token_ids = []
    label_masks = []

    for q in search_queries:

        prefix = select_random_prefix(q)
        tokenized_prefix = tokenizer.encode(prefix)
        tokenized_query = tokenizer.encode(q)
        if len(prompt_token_ids) + len(tokenized_prefix.ids) + len(tokenized_query.ids) + 3 > MAX_LM_LENGTH:
            break

        prompt_token_ids.extend(tokenizer.encode("[PREFIX]").ids)
        label_masks.append(False)

        prompt_token_ids.extend(tokenized_prefix.ids)
        label_masks.extend([False for _ in tokenized_prefix.ids])
        
        prompt_token_ids.extend(tokenizer.encode("[SEARCH]").ids)
        label_masks.append(False)

        prompt_token_ids.extend(tokenized_query.ids)
        label_masks.extend([True for _ in tokenized_query.ids])

        prompt_token_ids.extend(tokenizer.encode("[EOQ]").ids)
        label_masks.append(True)

    prompt_token_ids = [tokenizer.encode("[PAD]").ids[0] for _ in range(MAX_LM_LENGTH - len(prompt_token_ids))] + prompt_token_ids
    label_masks = [False for _ in range(MAX_LM_LENGTH - len(label_masks))] + label_masks

    return prompt_token_ids, label_masks
    
    

In [67]:
training_example = make_training_example(user_texts[0])

In [70]:
training_example[1]

[False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,


In [68]:
len(training_example[0])

256

In [None]:
lm_training_data = []
for user_chain in tqdm(user_texts):
    training_sample = make_training_example(user_chain)
    lm_training_data.append((training_sample[0], training_sample[2]))

In [71]:
import pickle

with open("lm_training_data.pickle", "rb") as f:
    lm_training_data = pickle.load(f)

In [72]:
lm_training_data[0]

([3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  6,
  6756,
  4391,
  5,
  6756,
  4391,
  7,
  6,
  173,
  3408,
  4691,
  5,
  173,
  3408,
  4691,
  3827,
  3812,
  7,
  6,
 

In [74]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import LlamaForCausalLM, LlamaConfig, AdamW

In [75]:
class CustomDataset(Dataset):
    def __init__(self, data: tuple[list[int], list[bool]]):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int) -> tuple[torch.LongTensor, torch.BoolTensor]:
        token_ids, label_masks = self.data[idx]
        return (
            torch.LongTensor(token_ids).unsqueeze(0),
            torch.BoolTensor(label_masks).unsqueeze(0),
        )


def collate_fn(data):
    token_ids, label_masks = zip(*data)
    return (
        torch.concat(token_ids, dim=0),
        torch.concat(label_masks, dim=0),
    )


In [76]:
dataset = CustomDataset(lm_training_data)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [77]:
model = LlamaForCausalLM(
    config=LlamaConfig(
        vocab_size=8192,
        hidden_size=256,
        intermediate_size=768,
        num_hidden_layers=4,
        num_attention_heads=4,
        num_key_value_heads=2,
        hidden_act="silu",
        max_position_embeddings=512,
        pad_token_id=3,
        initializer_range=0.02,
        rope_theta=1000.,
    )
)

In [78]:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = AdamW(model.parameters(), lr=1e-4)

model.train()



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(8192, 256, padding_idx=3)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=256, out_features=256, bias=False)
          (k_proj): Linear(in_features=256, out_features=128, bias=False)
          (v_proj): Linear(in_features=256, out_features=128, bias=False)
          (o_proj): Linear(in_features=256, out_features=256, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=256, out_features=768, bias=False)
          (up_proj): Linear(in_features=256, out_features=768, bias=False)
          (down_proj): Linear(in_features=768, out_features=256, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((256,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((256,), eps=1e-06)
      )
    )
    (norm): Lla

In [79]:
start_time = time.time()
losses = []


for i, batch in enumerate(dataloader, 1):
    input_ids, label_masks = batch
    labels = input_ids.clone()

    outputs = model(input_ids=input_ids, labels=labels)

    logits = outputs.logits
    shift_logits = logits[:, :-1, :]
    shift_labels = labels[:, 1:]
    shift_label_masks = label_masks[:, 1:]

    active_loss = shift_label_masks.reshape(-1)
    active_logits = shift_logits.reshape(-1, shift_logits.size(-1))[active_loss]
    active_labels = shift_labels.reshape(-1)[active_loss]
    
    loss = loss_fct(active_logits, active_labels)
    losses.append(loss.item())

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 100 == 0:
        print(f"Iteration {i}, time taken = {time.time() - start_time}: loss={sum(losses) / len(losses)}")
        start_time = time.time()
        losses = []

Iteration 100, time taken = 8.334563970565796: loss=7.4609723567962645
Iteration 200, time taken = 8.098393201828003: loss=6.544859580993652
Iteration 300, time taken = 8.050681829452515: loss=6.211402144432068
Iteration 400, time taken = 8.130712032318115: loss=6.052275166511536
Iteration 500, time taken = 8.404176235198975: loss=6.075403165817261
Iteration 600, time taken = 8.28242301940918: loss=5.804515557289124
Iteration 700, time taken = 8.225300073623657: loss=5.781954045295715
Iteration 800, time taken = 8.615727186203003: loss=5.748478360176087
Iteration 900, time taken = 8.221994161605835: loss=5.484619045257569
Iteration 1000, time taken = 8.181087970733643: loss=5.54943244934082
Iteration 1100, time taken = 8.307843208312988: loss=5.46842031955719


KeyboardInterrupt: 

In [None]:
model.eval()

In [80]:
model_loaded = LlamaForCausalLM.from_pretrained("small_suggest_lm")
model_loaded.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(8192, 256, padding_idx=3)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=256, out_features=256, bias=False)
          (k_proj): Linear(in_features=256, out_features=128, bias=False)
          (v_proj): Linear(in_features=256, out_features=128, bias=False)
          (o_proj): Linear(in_features=256, out_features=256, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=256, out_features=768, bias=False)
          (up_proj): Linear(in_features=256, out_features=768, bias=False)
          (down_proj): Linear(in_features=768, out_features=256, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((256,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((256,), eps=1e-06)
      )
    )
    (norm): Lla

In [110]:
prompt = "[PREFIX] мо [SEARCH] молоко [PREFIX] вод [SEARCH] вода 5л [PREFIX] сыр плав [SEARCH] сыр плавленный [PREFIX] бр [SEARCH]"

prompt_token_ids = tokenizer.encode(prompt).ids
input_ids = torch.LongTensor(prompt_token_ids).unsqueeze(0)

In [111]:
num_beams = 5
max_length = 5

In [112]:
with torch.no_grad():
    output_ids = model_loaded.generate(
        input_ids,
        max_length=len(prompt_token_ids) + max_length,
        num_beams=num_beams,
        early_stopping=True,
        pad_token_id=3,
        eos_token_id=7,
        return_dict_in_generate=True,
        output_scores=True,
        num_return_sequences=5,
    ).sequences

In [113]:
for i, generated_ids in enumerate(output_ids.tolist()):
    text = tokenizer.decode(generated_ids[-max_length:], skip_special_tokens=False)
    print(f"{text}")

бр [SEARCH] брокколи [EOQ] [PAD]
бр [SEARCH] брынза [EOQ] [PAD]
бр [SEARCH] брокколи замороженные [EOQ]
бр [SEARCH] брокколи свежая [EOQ]
бр [SEARCH] брокколи пюре [EOQ]
