In [1]:
!pip install spacy torch
!pip install keybert[spacy]
!pip install transformers -U 
!python -m spacy download ru_core_news_md

Collecting keybert[spacy]
  Downloading keybert-0.8.5-py3-none-any.whl.metadata (15 kB)
Collecting sentence-transformers>=0.3.8 (from keybert[spacy])
  Downloading sentence_transformers-3.1.0-py3-none-any.whl.metadata (23 kB)
Downloading sentence_transformers-3.1.0-py3-none-any.whl (249 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m249.1/249.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading keybert-0.8.5-py3-none-any.whl (37 kB)
^C
[31mERROR: Operation cancelled by user[0m[31m
Collecting transformers
  Downloading transformers-4.44.2-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m738.0 kB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
Downloading transformers-4.44.2-py3-none-any.whl (9.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: t

In [None]:
import re
import spacy
from keybert import KeyBERT
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from spacy.lang.ru.stop_words import STOP_WORDS
from collections import Counter
from string import punctuation
nlp = spacy.load('ru_core_news_md')
kw_model = KeyBERT(model="vlt5-base-keywords")

In [None]:
def preprocess_text(text):
    """Preprocess the text"""
    doc = nlp(text)
    cleaned_text = " ".join([token.lemma_ for token in doc if token.is_alpha and not token.is_stop])
    return cleaned_text

def extract_names(doc):
    """Extract names from the document"""
    names = []
    for ent in doc.ents:
        if ent.label_ == 'PER':
            names.append(ent.text)
    return list(set(names))  # Remove duplicates

def extract_locations(doc):
    """Extract locations from the document"""
    locations = []
    for ent in doc.ents:
        if ent.label_ == 'LOC':
            locations.append(ent.text)
    return list(set(locations))  # Remove duplicates

def extract_dates(text):
    """Extract dates"""
    date_pattern = r'\b(\d{1,2}[\.\/-]\d{1,2}[\.\/-]\d{2,4}|\d{4}[\.\/-]\d{1,2}[\.\/-]\d{1,2})\b'
    date_with_text = r'(\d+\s?(января|февраля|марта|апреля|мая|июня|июля|августа|сентября|октября|ноября|декабря)\s?\d+)'
    dates = re.findall(date_pattern, text)
    dates_with_text = re.findall(date_with_text, text)
    dates += [x[0] for x in dates_with_text]
    return list(set(dates))  # Remove duplicates

def extract_organisations(doc):
    """Extract organizations from the document"""
    organisations = []
    for ent in doc.ents:
        if ent.label_ == 'ORG':
            organisations.append(ent.text)
    return list(set(organisations))

def extract_keywords_with_keybert(text, top_n=10):
    """Extract keywords using KeyBERT."""
    keywords = kw_model.extract_keywords(preprocess_text(text), top_n=top_n)
    return [keyword[0] for keyword in keywords]

In [None]:
def extract_metadata(text):
    result = {"names": [], "dates": [], "locations": [], "organisations": [], "keywords": []}
    doc = nlp(text)
    result["names"] = extract_names(doc)
    result["dates"] = extract_dates(text)    
    result["locations"] = extract_locations(doc)
    result["organisations"] = extract_organisations(doc)
    result["keywords"] = extract_keywords_with_keybert(text)
    return result

class QueryMetadata(BaseModel):
    names: list = Field(description="Extract and list all personal names, product names, and measurement units mentioned in the user's question. Include both full names and abbreviations.")
    dates: list = Field(description="Extract and list all dates mentioned in the user's question. Include full dates, partial dates, and any date-related numbers (e.g., '01.10.2021', '9 сентября 2021').")
    locations: list = Field(description="Extract and list all locations mentioned in the user's question, including cities, districts, regions, and any repeated mentions. For Russian locations, maintain the original spelling and format (e.g., 'ханты-манскийский округ', 'ХМАО').")
    organisations: list = Field(description="Extract and list all organizations, government bodies, institutions, and official entities mentioned in the user's question. Include both full names and abbreviations.")
    keywords: list = Field(description="Extract and list key terms that represent the main themes, topics, or concepts in the user's question. Focus on nouns and significant words that capture the essence of the query.")

        

model = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key="", base_url="")
structured_llm = model.with_structured_output(QueryMetadata)

def get_query_metadata(query):
    result = structured_llm.invoke(query)
    return result.json()


## intended use

In [None]:
query = "вопрос от пользователя"
metainfo = get_query_metadata(query)
query_filters = {}
for key, value in metainfo.items():
    if value:
        query_filters[key] = {"$eq": value}

query_embeddings = embedding_function(query)

# documents = collection.query(
#     n_results=k,
#     where=query_filters
# )

result = collection.query(
    query_embeddings=[query_embeddings],
    n_results=k,
    where=query_filters
)