# Load documents and model

In [None]:
import pandas as pd
train_documents = pd.read_csv("../datasets/documents_train.csv", index_col=0)

from sentence_transformers import SentenceTransformer
sentences = ["This is an example sentence", "Each sentence is converted"]

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
embeddings = model.encode(sentences)
print(embeddings)


  from .autonotebook import tqdm as notebook_tqdm


[[ 6.76568747e-02  6.34958595e-02  4.87131178e-02  7.93049559e-02
   3.74480337e-02  2.65277526e-03  3.93749252e-02 -7.09843542e-03
   5.93614802e-02  3.15370075e-02  6.00980483e-02 -5.29051833e-02
   4.06067744e-02 -2.59308163e-02  2.98427958e-02  1.12689717e-03
   7.35149533e-02 -5.03819808e-02 -1.22386612e-01  2.37028506e-02
   2.97265295e-02  4.24768887e-02  2.56337821e-02  1.99517887e-03
  -5.69191203e-02 -2.71598324e-02 -3.29035446e-02  6.60248548e-02
   1.19007103e-01 -4.58791442e-02 -7.26214871e-02 -3.25839780e-02
   5.23413755e-02  4.50552665e-02  8.25300161e-03  3.67023759e-02
  -1.39415069e-02  6.53919056e-02 -2.64272504e-02  2.06405064e-04
  -1.36643583e-02 -3.62809747e-02 -1.95043441e-02 -2.89738420e-02
   3.94270532e-02 -8.84090886e-02  2.62426329e-03  1.36713851e-02
   4.83063050e-02 -3.11565753e-02 -1.17329188e-01 -5.11689894e-02
  -8.85287374e-02 -2.18962021e-02  1.42986281e-02  4.44167629e-02
  -1.34814661e-02  7.43392557e-02  2.66382862e-02 -1.98762454e-02
   1.79190

# Simple embedding of entire document

In [None]:
from tqdm import tqdm

for i in tqdm(range(100)):
    embeddings = model.encode(sentences)

100%|██████████| 100/100 [00:00<00:00, 115.29it/s]


In [19]:
tqdm.pandas()
train_documents["embedding"] = train_documents.body.progress_apply(lambda x: model.encode(x))  

100%|██████████| 35885/35885 [09:36<00:00, 62.22it/s]


In [20]:
train_documents.to_csv("../datasets/documents_train_embedded.csv")

# Cut docs to chunks and build FAISS index

In [None]:
import pandas as pd
import numpy as np
import torch
import faiss
from transformers import AutoTokenizer, AutoModel
from tqdm.auto import tqdm
import pickle
import json
from typing import List, Dict, Tuple



class VectorSearchPreprocessor:
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'):
        """
        Инициализация модели и токенизатора
        """
        print(f"Загрузка модели {model_name}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()
        
        self.max_tokens = 256
        self.passage_tokens = 128
        self.overlap_tokens = 32
        
    def split_into_passages(self, text: str, doc_id: str, title: str = "") -> List[Dict]:
        """
        Разбивает текст на перекрывающиеся пассажи
        """
        tokens = self.tokenizer.encode(text, add_special_tokens=False)
        
        passages = []
        start = 0
        
        title_tokens = []
        if title:
            title_tokens = self.tokenizer.encode(title, add_special_tokens=False)
        
        while start < len(tokens):
            end = min(start + self.passage_tokens, len(tokens))
            
            if start == 0 and title_tokens:
                available_tokens = self.passage_tokens - len(title_tokens) - 2  # -2 для [CLS] и [SEP]
                passage_tokens = title_tokens + tokens[start:start + available_tokens]
                start += available_tokens
            else:
                passage_tokens = tokens[start:end]
                start = end - self.overlap_tokens  # перекрытие
            
            passage_text = self.tokenizer.decode(passage_tokens)
            char_start = self._find_char_position(text, tokens, start if start > 0 else 0)
            char_end = self._find_char_position(text, tokens, end)
            
            passages.append({
                'passage_tokens': passage_tokens,
                'passage_text': passage_text,
                'char_start': char_start,
                'char_end': char_end,
                'doc_id': doc_id,
                'title': title,
                'token_start': start if start > 0 else 0,
                'token_end': end
            })
            
            if end == len(tokens):
                break
        
        return passages
    
    def _find_char_position(self, text: str, tokens: List[int], token_idx: int) -> int:
        """
        Находит позицию символа по индексу токена
        """
        if token_idx >= len(tokens):
            return len(text)
        
        decoded = self.tokenizer.decode(tokens[:token_idx])
        return len(decoded)
    
    def generate_embeddings(self, passages: List[Dict], batch_size: int = 32) -> np.ndarray:
        """
        Генерирует эмбеддинги для пассажей
        Возвращает нормализованные векторы для косинусного сходства
        """
        all_embeddings = []
        
        print(f"Генерация эмбеддингов для {len(passages)} пассажей...")
        
        for i in tqdm(range(0, len(passages), batch_size)):
            batch = passages[i:i + batch_size]
            batch_texts = [p['passage_text'] for p in batch]
            
            inputs = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=self.passage_tokens + 10,  # небольшой запас
                return_tensors='pt'
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                attention_mask = inputs['attention_mask']
                token_embeddings = outputs.last_hidden_state
                
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
                sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                embeddings = sum_embeddings / sum_mask
                
                embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
                
                all_embeddings.append(embeddings.cpu().numpy())
        
        return np.vstack(all_embeddings)
    
    def process_dataframe(self, df: pd.DataFrame, text_column: str = 'body', 
                         title_column: str = 'title', id_column: str = 'doc_id') -> Tuple[pd.DataFrame, np.ndarray]:
        """
        Обрабатывает весь DataFrame
        Возвращает DataFrame с пассажами и матрицу эмбеддингов
        """
        all_passages = []
        
        print("Разбиение документов на пассажи...")
        for idx, row in tqdm(df.iterrows(), total=len(df)):
            doc_id = row[id_column]
            text = row[text_column]
            title = str(row[title_column])
            
            passages = self.split_into_passages(text, doc_id, title)
            all_passages.extend(passages)
        
        passages_df = pd.DataFrame(all_passages)
        
        embeddings = self.generate_embeddings(all_passages)
        
        passages_df['passage_id'] = range(len(passages_df))
        
        print(f"\nСтатистика:")
        print(f"Всего документов: {len(df)}")
        print(f"Всего пассажей: {len(passages_df)}")
        print(f"Среднее пассажей на документ: {len(passages_df) / len(df):.1f}")
        print(f"Размер эмбеддингов: {embeddings.shape}")
        
        return passages_df, embeddings
    
    def create_faiss_index(self, embeddings: np.ndarray) -> faiss.Index:
        """
        Создает FAISS индекс для поиска
        """
        dimension = embeddings.shape[1]
        
        index = faiss.IndexFlatIP(dimension)
        
        embeddings = embeddings.astype('float32')
        
        index.add(embeddings)
        
        print(f"Создан FAISS индекс с {index.ntotal} векторами")
        return index
    
    def save_data(self, passages_df: pd.DataFrame, embeddings: np.ndarray, 
                 index: faiss.Index, output_dir: str = './vector_search_data'):
        """
        Сохраняет все данные на диск
        """
        import os
        os.makedirs(output_dir, exist_ok=True)
        
        passages_df.to_parquet(f'{output_dir}/passages.parquet')
        np.save(f'{output_dir}/embeddings.npy', embeddings)
        faiss.write_index(index, f'{output_dir}/faiss_index.bin')
        
        config = {
            'model_name': 'sentence-transformers/all-MiniLM-L6-v2',
            'passage_tokens': self.passage_tokens,
            'overlap_tokens': self.overlap_tokens,
            'max_tokens': self.max_tokens,
            'total_passages': len(passages_df),
            'embedding_dim': embeddings.shape[1]
        }
        
        with open(f'{output_dir}/config.json', 'w') as f:
            json.dump(config, f, indent=2)
        
        print(f"Данные сохранены в {output_dir}")
        
        return output_dir



In [None]:
df = pd.DataFrame(train_documents)

preprocessor = VectorSearchPreprocessor()

passages_df, embeddings = preprocessor.process_dataframe(df)

index = preprocessor.create_faiss_index(embeddings)

output_dir = preprocessor.save_data(passages_df, embeddings, index)

print(passages_df[['doc_id', 'passage_text', 'char_start', 'char_end']].head())

# Test

In [3]:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [7]:
index = faiss.read_index(f"../datasets/vector_search_data/faiss_index.bin")

# Загружаем метаданные пассажей
passages_df = pd.read_parquet(f"../datasets/vector_search_data/passages.parquet")

In [30]:
embeddings = model.encode(["test test test"])

In [60]:
index.search(embeddings, k=2)[1][0].tolist()

[530727, 530721]

In [45]:
passages_df.iloc[530721]

passage_tokens    [1997, 1006, 1037, 9415, 1007, 3231, 1006, 120...
passage_text      of ( a substance ) test ( verb ) undergo a tes...
char_start                                                     2293
char_end                                                       2392
doc_id                                                      D391809
title                                     Definitions &Translations
token_start                                                     507
token_end                                                       539
passage_id                                                   530721
Name: 530721, dtype: object