In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install gradio
!pip install nmslib
!pip install sentence_transformers
# !pip install googletrans==4.0.0-rc1
!pip install -U deep-translator

In [None]:
import random
import gradio as gr
import re
import argparse
import pandas as pd
import string
import secrets
import torch
import nmslib

from typing import List
from sentence_transformers import SentenceTransformer
# from googletrans import Translator
from deep_translator import GoogleTranslator

In [None]:
data_path = '/content/drive/MyDrive/text2rec/top_250_mvp.csv'
img_path = '/content/drive/MyDrive/text2rec_imgs/'
logs_path = "/content/drive/MyDrive/logs_text2rec"

In [None]:
def regex_for_query_without_quotes(name, suffix = "ы"):
    if suffix != "": suffix += "?"
    return f"{name}{suffix}:([\w,.]+)"


def regex_for_query_with_quotes(name, suffix = "ы"):
    if suffix != "": suffix += "?"
    return f"{name}{suffix}:'([\w,. ]+)'"

In [None]:
class Handler:
    def __init__(self, column_name, value_type=str):
        self.column_name = column_name
        self.value_type = value_type

    def __call__(self, df: pd.DataFrame, query: str):
        raise NotImplementedError()


class ComplexHandler(Handler):
    def __init__(self, field_name, column_name, pred, cast_to_type=str, **kwargs):
        super().__init__(column_name, cast_to_type)
        pattern = regex_for_query_without_quotes(field_name, **kwargs)
        pattern_with_quotes = regex_for_query_with_quotes(field_name, **kwargs)
        self.regex = re.compile(pattern)
        self.regex_with_quotes = re.compile(pattern_with_quotes)
        self.pred = pred
        
    def __call__(self, df: pd.DataFrame, query: str):
        search_match = self.regex.search(query)
        search_match_with_quotes = self.regex_with_quotes.search(query)
        if search_match is None and search_match_with_quotes is None: 
            return df, query
        column = df[self.column_name]
        correct_match = search_match if search_match is not None else search_match_with_quotes
        filtered_query = query.replace(correct_match.group(), "")
        first_result: str = correct_match.group(1)
        first_result = first_result.split(',')
        result = pd.Series([True for _ in range(column.size)], index=column.index)
        try:
            for value in first_result:
                value = self.value_type(value)
                result &= self.pred(column, value)
        except:
            return df, filtered_query
        return df[result], filtered_query

In [None]:
class Pipeline:
    def __init__(self, handlers: List[Handler]):
        self.handlers = handlers

    def __call__(self, df, query):
        for handler in self.handlers:
            df, query = handler(df, query)
        return df, query


def val_in_column(column: pd.Series, value: str):
    return column.str.contains(value)

In [None]:
# Set up
df = pd.read_csv(data_path)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
all_mpnet_base_v2 = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
all_mpnet_base_v2 = all_mpnet_base_v2.to(device)

default_embs_en = all_mpnet_base_v2.encode(df['description'].to_list())

index = nmslib.init(method='hnsw', space='cosinesimil')
index.addDataPointBatch(default_embs_en)
index.createIndex({'post': 2}, print_progress=True)

# translator = Translator()
translator = GoogleTranslator(source='ru', target='en')

In [None]:
def prep_data(query: str):
    global df
    year_handler     = ComplexHandler("год",      "Year",      lambda c, v: c == v, int)
    director_handler = ComplexHandler("режиссер", "Director",  val_in_column)
    genre_handler    = ComplexHandler("жанр",     "Genres",    val_in_column)
    actor_handler    = ComplexHandler("актер",    "Actors",    val_in_column)
    country_handler  = ComplexHandler("страна",   "Countries", val_in_column)
    rating_handler   = ComplexHandler("рейтинг",  "Rating",    lambda c, v: c >= v, float)

    pipeline = Pipeline([
        rating_handler, year_handler, director_handler, 
        genre_handler, actor_handler, country_handler
    ])

    filtered_df, query = pipeline(df, query)
    return filtered_df, query

In [None]:
def get_films(data, query: str, films, k: int):
    global df
    res = data.groupby(by=['FilmId']).max().sort_values('Rating', ascending=False)
    res = res.iloc[:k].reset_index()
    ids = res.FilmId.values.tolist()
    films.extend(ids)
    res = res.Title.values.tolist()
    if len(res) < k:
        res.extend(get_similar(df, query, films, k-len(res)))
    return res

In [None]:
def get_similar(data, query: str, films, k: int):
    global df
    query_en = translator.translate(query)
    query_emb_en = all_mpnet_base_v2.encode(query_en)
    ids = data.index
    data['sim'] = (default_embs_en[ids] @ query_emb_en.T).flatten('F')
    res = data.groupby(by=['FilmId']).max().sort_values('sim', ascending=False)
    res = res.iloc[:k].reset_index()
    ids = res.FilmId.values.tolist()
    films.extend(ids)
    res = res.Title.values.tolist()
    if len(res) < k:
        res.extend(get_similar(df, query, films, k-len(res)))
    return res

In [None]:
def get_recs(full_query: str, img_paths, k: int = 10) -> List[str]:
    filtered_df, query = prep_data(full_query)
    if query == ' ':
        result = get_films(filtered_df, full_query, img_paths, k)
    else:
        result = get_similar(filtered_df, query, img_paths, k)
    return result

def get_imgs(films):
    return [img_path + f'{id}.jpg' for id in films[-10:]]

In [None]:
description = '# Проект Text2Rec\n \
Сервис предоставляет возможность поиска фильмов по произвольному запросу. \
Запрос можно уточнить информацией о фильме: год, режиссер, жанр, актер, страна, рейтинг. \n\n\
Примеры запросов:  \n \
"жанр:приключения страна:США актер:\'Джонни Депп\'";  \n \
"жанр:приключения фильмы с животными";  \n \
"фильм где у мужика плохо росла кукуруза и он полетел в черную дыру";  \n \
"Фильмы про путешествия во времени" \n\n \
Пожалуйста, оцените релевантность выдачи(Релевантно/Не релевантно), \
это поможет улучшить работу алгоритма.'

In [None]:
def get_id():
    global session_id
    alphabet = string.ascii_letters + string.digits
    session_id = ''.join(secrets.choice(alphabet) for i in range(10))

In [None]:
def reset_radio():
    return 10*[gr.Radio.update(choices=['Релевантно', 'Не релевантно'], value=[])]

In [None]:
session_id = None
callback = gr.CSVLogger()

with gr.Blocks() as demo:
    films = gr.State([])
    gr.Markdown(description)
    query = gr.Textbox(label="Запрос")
    search_btn = gr.Button("Поиск")
    with gr.Row():
        with gr.Column():
            img1 = gr.Image(show_label=False, shape=(150, 210))
            name1 = gr.Text(show_label=False, interactive=False)
            like1 = gr.Radio(show_label=False, value=None, choices=['Релевантно', 'Не релевантно'])
        with gr.Column():
            img2 = gr.Image(show_label=False, shape=(150, 210))
            name2 = gr.Text(show_label=False, interactive=False)
            like2 = gr.Radio(show_label=False, value=None, choices=['Релевантно', 'Не релевантно'])
        with gr.Column():
            img3 = gr.Image(show_label=False, shape=(150, 210))
            name3 = gr.Text(show_label=False, interactive=False)
            like3 = gr.Radio(show_label=False, value=None, choices=['Релевантно', 'Не релевантно'])
        with gr.Column():
            img4 = gr.Image(show_label=False, shape=(150, 210))
            name4 = gr.Text(show_label=False, interactive=False)
            like4 = gr.Radio(show_label=False, value=None, choices=['Релевантно', 'Не релевантно'])
    with gr.Row():
        with gr.Column():
            img5 = gr.Image(show_label=False, shape=(150, 210))
            name5 = gr.Text(show_label=False, interactive=False)
            like5 = gr.Radio(show_label=False, value=None, choices=['Релевантно', 'Не релевантно'])
        with gr.Column():
            img6 = gr.Image(show_label=False, shape=(150, 210))
            name6 = gr.Text(show_label=False, interactive=False)
            like6 = gr.Radio(show_label=False, value=None, choices=['Релевантно', 'Не релевантно'])
        with gr.Column():
            img7 = gr.Image(show_label=False, shape=(150, 210))
            name7 = gr.Text(show_label=False, interactive=False)
            like7 = gr.Radio(show_label=False, value=None, choices=['Релевантно', 'Не релевантно'])
        with gr.Column():
            img8 = gr.Image(show_label=False, shape=(150, 210))
            name8 = gr.Text(show_label=False, interactive=False)
            like8 = gr.Radio(show_label=False, value=None, choices=['Релевантно', 'Не релевантно'])
    with gr.Row():
        with gr.Column():
            pass
        with gr.Column():
            img9 = gr.Image(show_label=False, shape=(150, 210))
            name9 = gr.Text(show_label=False, interactive=False)
            like9 = gr.Radio(show_label=False, value=None, choices=['Релевантно', 'Не релевантно'])
        with gr.Column():
            img10 = gr.Image(show_label=False, shape=(150, 210))
            name10 = gr.Text(show_label=False, interactive=False)
            like10 = gr.Radio(show_label=False, value=None, choices=['Релевантно', 'Не релевантно'])
        with gr.Column():
            pass

    name_list = [name1, name2, name3, name4, name5, name6, name7, name8, name9, name10]
    img_list = [img1, img2, img3, img4, img5, img6, img7, img8, img9, img10]
    like_list = [like1, like2, like3, like4, like5, like6, like7, like8, like9, like10]
    callback.setup([query, *name_list, *like_list], logs_path)

    search_btn.click(fn=get_id, inputs=None, outputs=None)
    search_btn.click(lambda *args: callback.flag(args, username=session_id), [query, *name_list, *like_list], None, preprocess=False)
    search_btn.click(fn=get_recs, inputs=[query, films], outputs=name_list)
    search_btn.click(fn=get_imgs, inputs=films, outputs=img_list)
    search_btn.click(fn=reset_radio, inputs=None, outputs=like_list)

In [None]:
demo.queue(concurrency_count=4)
demo.launch(share=True)

In [27]:
demo.close()

Closing server running on port: 7860


In [None]:
data = pd.read_csv(logs_path + '/log.csv')
data