In [1]:
!pip install langchain langchain-community faiss-gpu sentence-transformers

Collecting langchain
  Downloading langchain-0.3.11-py3-none-any.whl.metadata (7.1 kB)
Collecting langchain-community
  Downloading langchain_community-0.3.11-py3-none-any.whl.metadata (2.9 kB)
Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting sentence-transformers
  Downloading sentence_transformers-3.3.1-py3-none-any.whl.metadata (10 kB)
Collecting langchain-core<0.4.0,>=0.3.24 (from langchain)
  Downloading langchain_core-0.3.24-py3-none-any.whl.metadata (6.3 kB)
Collecting langchain-text-splitters<0.4.0,>=0.3.0 (from langchain)
  Downloading langchain_text_splitters-0.3.2-py3-none-any.whl.metadata (2.3 kB)
Collecting langsmith<0.3,>=0.1.17 (from langchain)
  Downloading langsmith-0.2.3-py3-none-any.whl.metadata (14 kB)
Collecting httpx-sse<0.5.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from la

In [2]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/ruwiki-tables-and-lists/wiki_tables_and_lists.jsonl
/kaggle/input/ruwiki-valid-tables/wiki_dump.jsonl
/kaggle/input/bert2bert-4b-300e/results/runs/Dec11_01-51-37_075c355daadc/events.out.tfevents.1733881900.075c355daadc.23.0
/kaggle/input/bert2bert-4b-300e/results/checkpoint-9912/config.json
/kaggle/input/bert2bert-4b-300e/results/checkpoint-9912/trainer_state.json
/kaggle/input/bert2bert-4b-300e/results/checkpoint-9912/training_args.bin
/kaggle/input/bert2bert-4b-300e/results/checkpoint-9912/tokenizer.json
/kaggle/input/bert2bert-4b-300e/results/checkpoint-9912/tokenizer_config.json
/kaggle/input/bert2bert-4b-300e/results/checkpoint-9912/scheduler.pt
/kaggle/input/bert2bert-4b-300e/results/checkpoint-9912/model.safetensors
/kaggle/input/bert2bert-4b-300e/results/checkpoint-9912/special_tokens_map.json
/kaggle/input/bert2bert-4b-300e/results/checkpoint-9912/optimizer.pt
/kaggle/input/bert2bert-4b-300e/results/checkpoint-9912/vocab.txt
/kaggle/input/bert2bert-4b-300e/result

In [4]:
import os
import json
from typing import List
from tqdm import tqdm
import torch
from torch import nn

from langchain.docstore.document import Document as LangchainDocument
from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
from transformers import AutoTokenizer, AutoModel
from pydantic import PrivateAttr
import re

# Параметры
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
INPUT_FILE = "/kaggle/input/ruwiki-valid-tables/wiki_dump.jsonl"
OUTPUT_DIR = "/kaggle/working"
CHECKPOINT_NAME = "faiss_tables_index_checkpoint"
BATCH_SIZE = 1000
CHECKPOINT_INTERVAL = 10
RESUME_FROM_CHECKPOINT = True

class MultiGPUHuggingFaceEmbeddings(HuggingFaceEmbeddings):
    _tokenizer: AutoTokenizer = PrivateAttr()
    _model: nn.Module = PrivateAttr()
    _device: str = PrivateAttr()
    _half: bool = PrivateAttr()
    _normalize_embeddings: bool = PrivateAttr(default=False)

    def __init__(self, model_name: str = EMBEDDING_MODEL_NAME, device: str = None, half: bool = True, **kwargs):
        super().__init__(model_name=model_name, **kwargs)
        
        encode_kwargs = kwargs.get("encode_kwargs", {})
        self._normalize_embeddings = encode_kwargs.get("normalize_embeddings", True)  # включаем нормализацию

        self._tokenizer = AutoTokenizer.from_pretrained(model_name)
        self._model = AutoModel.from_pretrained(model_name)
        self._model.eval()

        # Проверка на количество GPU
        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs for inference...")
            self._model = nn.DataParallel(self._model)
            self._device = 'cuda'
        else:
            self._device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")

        self._model.to(self._device)

        if half and 'cuda' in self._device:
            self._model.half()

        self._half = half

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        # Используем mean pooling для получения sentence embeddings
        batch_size = 256
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            inputs = self._tokenizer(batch_texts, return_tensors="pt", truncation=True, padding=True, max_length=512)
            inputs = {k: v.to(self._device) for k,v in inputs.items()}

            with torch.no_grad():
                outputs = self._model(**inputs)
                last_hidden_state = outputs.last_hidden_state
                attention_mask = inputs['attention_mask']

                # Mean Pooling
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
                sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
                sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                cls_emb = (sum_embeddings / sum_mask).float().cpu().numpy()

                if self._normalize_embeddings:
                    norm = (cls_emb**2).sum(axis=1, keepdims=True)**0.5
                    cls_emb = cls_emb / norm
                embeddings.extend(cls_emb.tolist())
        return embeddings

    def embed_query(self, text: str) -> List[float]:
        return self.embed_documents([text])[0]

def create_empty_faiss_index(embedding_model, distance_strategy=DistanceStrategy.COSINE):
    import faiss
    from langchain.docstore.in_memory import InMemoryDocstore

    # Получаем размерность через фиктивный текст
    dummy_emb = embedding_model.embed_documents(["hello"])
    dim = len(dummy_emb[0])

    if distance_strategy == DistanceStrategy.COSINE:
        index = faiss.IndexFlatIP(dim)
    else:
        index = faiss.IndexFlatL2(dim)

    docstore = InMemoryDocstore({})
    vectorstore = FAISS(
        index=index,
        docstore=docstore,
        index_to_docstore_id={},
        embedding_function=embedding_model,
        distance_strategy=distance_strategy
    )
    return vectorstore

def create_or_load_faiss_index(embedding_model):
    checkpoint_path = os.path.join(OUTPUT_DIR, CHECKPOINT_NAME)
    if RESUME_FROM_CHECKPOINT and os.path.exists(checkpoint_path):
        print(f"Loading existing FAISS index from {checkpoint_path}")
        vectorstore = FAISS.load_local(checkpoint_path, embedding_model, allow_dangerous_deserialization=True)
    else:
        vectorstore = create_empty_faiss_index(embedding_model, distance_strategy=DistanceStrategy.COSINE)
    return vectorstore

def save_faiss_index(vectorstore):
    checkpoint_path = os.path.join(OUTPUT_DIR, CHECKPOINT_NAME)
    vectorstore.save_local(checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

def main():
    embedding_model = MultiGPUHuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL_NAME,
        half=True,
        model_kwargs={"device": "cuda"},
        encode_kwargs={"normalize_embeddings": True},
        multi_process=False
    )

    vectorstore = create_or_load_faiss_index(embedding_model)

    processed_tables_count = 0
    offset_file = os.path.join(OUTPUT_DIR, "offset.txt")
    if RESUME_FROM_CHECKPOINT and os.path.exists(offset_file):
        with open(offset_file, "r") as f:
            processed_tables_count = int(f.read().strip())
        print(f"Resuming from table #{processed_tables_count}")

    docs_batch = []
    batch_counter = 0

    with open(INPUT_FILE, "r", encoding="utf-8") as f:
        for _ in range(processed_tables_count):
            f.readline()

        for line in tqdm(f, desc="Processing tables"):
            line = line.strip()
            if not line:
                continue
            entry = json.loads(line)
            uuid = entry.get("uuid", "")
            context_before = entry.get("context_before", "") or ""
            caption = entry.get("caption", "") or ""
            header = entry.get("header", [])
            data = entry.get("data", [])
            uuid_text = re.sub(r'[^А-Яа-яЁё_]+', '', uuid)
            uuid_text = uuid_text.replace('_', ' ')
            context_text = '\n'.join('\n'.join(inner_text) for inner_text in context_before)
            header_names = [col["name"] for col in header]
            header_line = " | ".join(header_names)

            table_lines = []
            for row in data:
                row_values = [cell[1] for cell in row] if row and isinstance(row[0], list) else row
                table_lines.append(" | ".join(row_values))
            table_text = "\n".join(table_lines)

            full_text = f"{uuid_text}\n{context_text}\n{caption}\n{header_line}\n{table_text}".strip()
            
            doc = LangchainDocument(
                page_content=full_text,
                metadata={
                    "uuid": uuid,
                    "header": header_names
                }
            )
            docs_batch.append(doc)

            processed_tables_count += 1

            if len(docs_batch) >= BATCH_SIZE:
                vectorstore.add_documents(docs_batch)
                docs_batch = []
                batch_counter += 1

                if batch_counter % CHECKPOINT_INTERVAL == 0:
                    save_faiss_index(vectorstore)
                    with open(offset_file, "w") as f_off:
                        f_off.write(str(processed_tables_count))
                    print(f"Processed {processed_tables_count} tables so far.")
            # if processed_tables_count>=30000:
            #     break

    if docs_batch:
        vectorstore.add_documents(docs_batch)

    save_faiss_index(vectorstore)
    with open(offset_file, "w") as f_off:
        f_off.write(str(processed_tables_count))
    print(f"Final processed tables: {processed_tables_count}")

if __name__ == "__main__":
    main()


Using 2 GPUs for inference...
Loading existing FAISS index from /kaggle/working/faiss_tables_index_checkpoint
Resuming from table #30000


Processing tables: 10905it [00:29, 414.06it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 40000 tables so far.


Processing tables: 20927it [00:58, 415.64it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 50000 tables so far.


Processing tables: 30804it [01:27, 403.46it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 60000 tables so far.


Processing tables: 40678it [01:57, 352.71it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 70000 tables so far.


Processing tables: 50914it [02:28, 341.48it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 80000 tables so far.


Processing tables: 60581it [02:59, 341.33it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 90000 tables so far.


Processing tables: 70665it [03:30, 332.62it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 100000 tables so far.


Processing tables: 80749it [04:01, 346.95it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 110000 tables so far.


Processing tables: 90907it [04:34, 324.05it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 120000 tables so far.


Processing tables: 100858it [05:05, 313.33it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 130000 tables so far.


Processing tables: 110354it [05:38, 198.95it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 140000 tables so far.


Processing tables: 120339it [06:09, 227.51it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 150000 tables so far.


Processing tables: 130598it [06:42, 292.96it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 160000 tables so far.


Processing tables: 140794it [07:14, 283.28it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 170000 tables so far.


Processing tables: 150632it [07:48, 254.95it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 180000 tables so far.


Processing tables: 160689it [08:20, 265.65it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 190000 tables so far.


Processing tables: 170850it [08:52, 272.36it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 200000 tables so far.


Processing tables: 180611it [09:24, 261.26it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 210000 tables so far.


Processing tables: 190463it [09:58, 230.15it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 220000 tables so far.


Processing tables: 200147it [10:32, 182.36it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 230000 tables so far.


Processing tables: 210638it [11:06, 251.12it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 240000 tables so far.


Processing tables: 220345it [11:41, 149.98it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 250000 tables so far.


Processing tables: 230541it [12:15, 213.97it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 260000 tables so far.


Processing tables: 240614it [12:48, 241.59it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 270000 tables so far.


Processing tables: 250860it [13:25, 224.82it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 280000 tables so far.


Processing tables: 260706it [14:02, 235.40it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 290000 tables so far.


Processing tables: 270638it [14:40, 207.28it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 300000 tables so far.


Processing tables: 280682it [15:16, 212.32it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 310000 tables so far.


Processing tables: 290514it [15:53, 186.60it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 320000 tables so far.


Processing tables: 300308it [16:30, 118.34it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 330000 tables so far.


Processing tables: 310835it [17:06, 210.30it/s]

Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Processed 340000 tables so far.


Processing tables: 312658it [17:12, 302.77it/s]


Checkpoint saved at /kaggle/working/faiss_tables_index_checkpoint
Final processed tables: 342658
