# Функция для поиска

In [87]:
from dataclasses import dataclass
from typing import List

import polars as pl

from transliterate import translit
from spellchecker import SpellChecker

from catboost import CatBoostRanker, Pool

from lemma import LemmaRecommender
from semantic import SemanticRecommender

russian = SpellChecker(language='ru')


@dataclass
class Model:
    lex_rec: LemmaRecommender
    sem_rec: SemanticRecommender
    features: pl.DataFrame
    model_cb: CatBoostRanker


def search_video(query: str, model: Model) -> List[str]:
    fixed_query = []
    for word in query.split():
        fixed_word = russian.correction(word)
        if fixed_word is not None:
            fixed_query.append(fixed_word)
        else:
            fixed_query.append(word)
    fixed_query = ' '.join(fixed_query)
    if fixed_query != query:
        print('Автоисправление:\n', fixed_query)
        query = fixed_query

    merged = model.lex_rec.predict_one(query, 100)
    if len(merged) == 0:
        query = translit(query, 'ru')
        print('Транслитерация:\n', query)
        merged = model.lex_rec.predict_one(query, 100)
        
    merged = merged.join(model.sem_rec.predict_one(query, 100), on=['video_id'], how='outer')
    merged = merged.fill_null(0)
    merged = merged.with_columns(datetime_ind=pl.lit(5064, pl.Int32))
    merged = merged.join(model.features, on=['video_id'], how='left')
    merged = merged.with_columns(ind_diff=pl.col('datetime_ind') - pl.col('v_pub_datetime_ind'))
    merged = merged.fill_null(-1)
    pool = Pool(
        data=merged.to_pandas().drop(['video_id'], axis=1),
        cat_features=['v_channel_type', 'v_category']
    )
    return (
        merged
        .with_columns(score=pl.Series(model.model_cb.predict(pool)))
        .sort(by='score', descending=True)["video_id"]
        .head(5)
        .to_list()
    )


In [88]:
def get_results(search_result: pl.DataFrame) -> List[str]:
    return videos.filter(pl.col('video_id').is_in(search_result)).to_pandas().set_index('video_id').loc[search_result, 'video_title'].to_list()

# Дополнительные данные

код стоит переписать

In [66]:
import pickle

In [67]:
with open('./lemma_rec_1e6.pickle', 'rb') as f:
    lex_rec: LemmaRecommender = pickle.load(f)

In [68]:
features = pl.read_parquet('./features.parquet', columns = ['video_id', 'v_pub_datetime']).sort('v_pub_datetime')
features = features.unique()
videos = pl.read_parquet('./videos.parquet', columns = ['video_id', 'video_title', 'v_pub_datetime']).sort('v_pub_datetime')
videos = videos.join(features, on='video_id', how='left')
videos = videos.filter((~pl.col('v_pub_datetime').is_null()) & (~pl.col('v_pub_datetime_right').is_null()))
videos = videos.sort('video_id')
automarkup = pl.read_parquet('./automarkup.parquet', columns=['video_id'])
video_ids = videos["video_id"].tail(1_000_000).to_list()
video_ids += automarkup["video_id"].to_list()
video_ids = sorted(list(set(video_ids)))

from transformers import AutoTokenizer, AutoModel
import torch
from tqdm.auto import tqdm

import os

os.environ['TRANSFORMERS_CACHE'] = './cache/'

device = 'cpu'
tokenizer = AutoTokenizer.from_pretrained("cointegrated/LaBSE-en-ru")
model = AutoModel.from_pretrained("cointegrated/LaBSE-en-ru").to(device)


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/521k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/516M [00:00<?, ?B/s]

In [69]:
import faiss

index = faiss.read_index('labse_candidates.index')

sem_rec = SemanticRecommender(videos.filter(pl.col('video_id').is_in(video_ids))["video_id"].to_list(), index,  tokenizer, model)

In [71]:
import pandas as pd

In [72]:
features = pl.read_parquet('./features_nov.parquet')
v_channel_reg_datetime = features["v_channel_reg_datetime"].to_pandas()
v_pub_datetime = features["v_pub_datetime"].to_pandas()
v_channel_reg_datetime_ind = (pd.to_datetime(v_channel_reg_datetime) - pd.Timestamp(year=2010, month=1, day=1)).dt.days
v_pub_datetime_ind = (pd.to_datetime(v_pub_datetime) - pd.Timestamp(year=2010, month=1, day=1)).dt.days
features = features.with_columns(v_channel_reg_datetime_ind=pl.Series(v_channel_reg_datetime_ind), v_pub_datetime_ind=pl.Series(v_pub_datetime_ind))
features = features.drop(['report_date', 'v_channel_reg_datetime', 'v_pub_datetime']).unique(subset='video_id', keep='last')

In [73]:
model_cb = CatBoostRanker().load_model('./model_cb.cbm')

In [74]:
videos = pl.read_parquet('./videos.parquet', columns=['video_id', 'video_title'])

# Тестирование

In [89]:
model = Model(lex_rec=lex_rec, sem_rec=sem_rec, features=features, model_cb=model_cb)

In [124]:
%%time
search_result = search_video('как построить карьеру', model)

CPU times: user 4.57 s, sys: 2.14 s, total: 6.71 s
Wall time: 733 ms


In [123]:
search_result

['video_8823568',
 'video_7914063',
 'video_9636578',
 'video_18727621',
 'video_8731746']

In [114]:
get_results(search_result)

['Собака', 'Собака', 'собака', 'Собачка резвится', 'СОБАКА.']