# Fine-tuning E5 Embeddings on Wikidata

Этот блокнот выполняет дообучение (fine-tuning) модели эмбеддингов E5 на данных Wikidata для улучшения метрик поиска.

In [1]:
import os
os.environ["WANDB_DISABLED"] = "true"  # Disable W&B logging to avoid API key error
import pickle
import torch
import gc
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, InputExample, losses, models
from tqdm.auto import tqdm
import sys

IS_KAGGLE = os.path.exists('/kaggle/input')
if IS_KAGGLE:
    print("Running on Kaggle")
    DATA_BASE_PATH = '/kaggle/input/wikidata-big-sber'
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
else:
    print("Running Locally")
    DATA_BASE_PATH = '../../wikidata_big'

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {device}")

def clear_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    elif torch.backends.mps.is_available():
        torch.mps.empty_cache()
    print("Memory cleared")

clear_memory()

Using device: mps
Memory cleared


In [2]:
import logging

# Настройка логирования для вывода Loss в поток вывода
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[logging.StreamHandler(sys.stdout)])


In [3]:
DATA_PATH = os.path.join(DATA_BASE_PATH, 'kg/tkbc_processed_data/wikidata_big/')
ENT_TEXT_PATH = os.path.join(DATA_BASE_PATH, 'kg/wd_id2entity_text.txt')
REL_TEXT_PATH = os.path.join(DATA_BASE_PATH, 'kg/wd_id2relation_text.txt')

def load_pickle(name):
    with open(os.path.join(DATA_PATH, name), 'rb') as f:
        return pickle.load(f)

ent_to_id = load_pickle('ent_id')
rel_to_id = load_pickle('rel_id')
ts_to_id = load_pickle('ts_id')

id_to_ent = {v: k for k, v in ent_to_id.items()}
id_to_rel = {v: k for k, v in rel_to_id.items()}
id_to_ts = {v: str(k) for k, v in ts_to_id.items()}

train_data = load_pickle('train.pickle')
print(f"Loaded {len(train_data)} training quadruplets")

# Загрузка текстовых меток
def load_text_mapping(file_path):
    mapping = {}
    if os.path.exists(file_path):
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    mapping[parts[0]] = parts[1]
    return mapping

ent_labels = load_text_mapping(ENT_TEXT_PATH)
rel_labels = load_text_mapping(REL_TEXT_PATH)

print(f"Loaded {len(ent_labels)} entity labels and {len(rel_labels)} relation labels")

def get_label(id_str, mapping, default_val):
    return mapping.get(id_str, default_val)

def format_ts(ts_tuple):
    # (Year, Month, Day) -> "Year" or "Year-Month-Day"
    y, m, d = ts_tuple
    if m == 0 and d == 0:
        return str(y)
    return f"{y}-{m:02d}-{d:02d}"

ts_to_id_inv = {v: k for k, v in ts_to_id.items()}

Loaded 322988 training quadruplets
Loaded 125725 entity labels and 203 relation labels


In [4]:
# Формируем тренировочные пары (Запрос, Положительный ответ)
train_samples = []

limit = 50000  # Увеличиваем лимит для стабильного фин-тюнинга
for row in tqdm(train_data[:limit], desc="Preparing samples"):
    s_id = id_to_ent[row[0]]
    r_id = id_to_rel[row[1]]
    o_id = id_to_ent[row[2]]
    t_tuple = ts_to_id_inv[row[3]]
    
    # Использование текстовых меток (нормальных имен)
    s_label = get_label(s_id, ent_labels, s_id)
    r_label = get_label(r_id, rel_labels, r_id)
    o_label = get_label(o_id, ent_labels, o_id)
    t_label = format_ts(t_tuple)
    
    # Task 1: Predict Object (Standard KG task)
    query_obj = f"query: {t_label}: {s_label} {r_label}"
    target_obj = f"passage: {o_label}"
    train_samples.append(InputExample(texts=[query_obj, target_obj]))
    
    # Task 2: Predict Time (The main goal: Temporal Reasoning)
    # Пример: "When did Albert Einstein live in Berlin?" -> "1914-1932"
    query_time = f"query: When did {s_label} {r_label} {o_label}?"
    target_time = f"passage: {t_label}"
    train_samples.append(InputExample(texts=[query_time, target_time]))

print(f"Created {len(train_samples)} InputExamples with robust temporal tasks")

Preparing samples:   0%|          | 0/20000 [00:00<?, ?it/s]

Created 40000 InputExamples (Object + Time tasks)


In [None]:
model_name = 'intfloat/multilingual-e5-small'
word_embedding_model = models.Transformer(model_name)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=16)
train_loss = losses.MultipleNegativesRankingLoss(model)

num_epochs = 3  # Поставили 3 эпохи для глубокого обучения
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)

clear_memory()
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    checkpoint_path='./checkpoints',
    checkpoint_save_total_limit=1,
    show_progress_bar=True
)

# Сохранение финальной модели
OUTPUT_PATH = '/kaggle/working/wikidata_finetuned' if IS_KAGGLE else '../../models/wikidata_finetuned'
model.save(OUTPUT_PATH)
print(f"Training finished! Model saved to {OUTPUT_PATH}")

2026-01-27 17:37:45 - HTTP Request: HEAD https://huggingface.co/intfloat/multilingual-e5-small/resolve/main/adapter_config.json "HTTP/1.1 404 Not Found"
2026-01-27 17:37:45 - HTTP Request: HEAD https://huggingface.co/intfloat/multilingual-e5-small/resolve/main/config.json "HTTP/1.1 307 Temporary Redirect"
2026-01-27 17:37:45 - HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/intfloat/multilingual-e5-small/c007d7ef6fd86656326059b28395a7a03a7c5846/config.json "HTTP/1.1 200 OK"


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: intfloat/multilingual-e5-small
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


2026-01-27 17:37:46 - HTTP Request: HEAD https://huggingface.co/intfloat/multilingual-e5-small/resolve/main/config.json "HTTP/1.1 307 Temporary Redirect"
2026-01-27 17:37:46 - HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/intfloat/multilingual-e5-small/c007d7ef6fd86656326059b28395a7a03a7c5846/config.json "HTTP/1.1 200 OK"
2026-01-27 17:37:46 - HTTP Request: HEAD https://huggingface.co/intfloat/multilingual-e5-small/resolve/main/tokenizer_config.json "HTTP/1.1 307 Temporary Redirect"
2026-01-27 17:37:46 - HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/intfloat/multilingual-e5-small/c007d7ef6fd86656326059b28395a7a03a7c5846/tokenizer_config.json "HTTP/1.1 200 OK"
2026-01-27 17:37:46 - HTTP Request: HEAD https://huggingface.co/intfloat/multilingual-e5-small/resolve/main/tokenizer_config.json "HTTP/1.1 307 Temporary Redirect"
2026-01-27 17:37:46 - HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/intfloat/multilingual-e5-small/c007d7

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]



Step,Training Loss


KeyboardInterrupt: 