In [None]:
import os
import pandas as pd
import torch
import numpy as np
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import (
    CanineTokenizer,
    CanineForTokenClassification,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset, Dataset, concatenate_datasets
from tqdm.auto import tqdm
from typing import List, Tuple
import shutil

In [None]:
# =======================
# Константы и Настройка
# =======================
MODEL_NAME = "google/canine-c"
DATASET_NAME = "dgramus/synth-ecom-search-queries"
AUGMENTED_DATA_FILE = "augmented_test_data.csv"
AUGMENTED_DATA_FILE_TXT = "augmented_add_train.txt"
BEST_MODEL_DIR = "./final_corrected_model"

LABEL_LIST = ["NO_SPACE", "SPACE"]
LABEL_TO_ID = {label: i for i, label in enumerate(LABEL_LIST)}
ID_TO_LABEL = {v: k for k, v in LABEL_TO_ID.items()}

MAX_LENGTH = 128
CHAR_WINDOW = MAX_LENGTH - 2
BATCH_SIZE = 16
NUM_EPOCHS = 6
LEARNING_RATE = 3e-5
SEED = 42

if os.path.exists(BEST_MODEL_DIR):
    shutil.rmtree(BEST_MODEL_DIR)
os.makedirs(BEST_MODEL_DIR, exist_ok=True)

In [None]:
# =======================
# Фиксируем сиды и устройство
# =======================
def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    import random
    random.seed(seed)
    import numpy as np
    np.random.seed(seed)

set_seed(SEED)

if torch.cuda.is_available(): device = torch.device("cuda")
elif torch.backends.mps.is_available(): device = torch.device("mps")
else: device = torch.device("cpu")
print(f"Используется устройство: {device}")

Используется устройство: mps


In [None]:
# =======================
# ЕДИНАЯ И ПРАВИЛЬНАЯ ЛОГИКА РАСЧЕТА ПОЗИЦИЙ
# =======================
def get_true_positions(no_space_text: str, corrected_text: str) -> List[int]:
    if not isinstance(no_space_text, str) or not isinstance(corrected_text, str):
        return []

    ns_text = "".join(no_space_text.split()).lower()
    corr_text = " ".join(corrected_text.split()).lower()

    positions = []
    ns_idx = 0
    corr_idx = 0

    while ns_idx < len(ns_text) and corr_idx < len(corr_text):
        if ns_text[ns_idx] == corr_text[corr_idx]:
            ns_idx += 1
            corr_idx += 1
        elif corr_text[corr_idx] == ' ':
            # Мы нашли пробел. Это значит, что перед символом ns_text[ns_idx] должен быть пробел.
            if ns_idx > 0: # Пробел не может быть в самом начале
                positions.append(ns_idx)
            corr_idx += 1 # Пропускаем пробел в corrected_text
        else:
            # Символы не совпадают, но это не пробел. Синхронизируемся.
            ns_idx += 1
            corr_idx += 1
    return positions

def compute_f1_metric(all_true_positions: List[List[int]], all_pred_positions: List[List[int]]) -> float:
    f1_scores = []
    for true_pos, pred_pos in zip(all_true_positions, all_pred_positions):
        true_set, pred_set = set(true_pos), set(pred_pos)
        tp = len(true_set & pred_set)
        precision = tp / len(pred_set) if len(pred_set) > 0 else 0.0
        recall = tp / len(true_set) if len(true_set) > 0 else (1.0 if len(pred_set) == 0 else 0.0)
        denom = precision + recall
        f1 = (2 * precision * recall) / denom if denom > 0 else 0.0
        f1_scores.append(f1)
    return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0

def restore_spaces(text: str, model, tokenizer, device) -> List[int]:
    """Возвращает только список предсказанных позиций."""
    model.eval()
    clean_text = "".join((text or "")).replace(" ", "").lower()
    if not clean_text: return []
    space_positions, offset = [], 0
    while offset < len(clean_text):
        chunk = clean_text[offset : offset + CHAR_WINDOW]
        encoding = tokenizer(chunk, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(device)
        with torch.no_grad():
            preds = torch.argmax(model(**encoding).logits, dim=-1).squeeze(0)
        preds_chars = preds[1 : 1 + len(chunk)].cpu().tolist()
        for i, char in enumerate(chunk):
            if (offset + i > 0) and preds_chars[i] == LABEL_TO_ID["SPACE"]:
                space_positions.append(offset + i)
        offset += len(chunk)
    return space_positions

In [None]:
# =======================
# Препроцессинг и Обучение
# =======================
tokenizer = CanineTokenizer.from_pretrained(MODEL_NAME)

def preprocess_function(examples):
    input_texts = examples.get("query", examples.get("text_no_spaces"))
    texts_no_spaces = ["".join(t.split()).lower() if isinstance(t, str) else "" for t in input_texts]
    encodings = tokenizer(texts_no_spaces, padding="max_length", truncation=True, max_length=MAX_LENGTH)
    
    true_positions_list = [get_true_positions(ns, corr) for ns, corr in zip(texts_no_spaces, examples["query_corrected"])]
    processed_labels = []
    for positions, ns in zip(true_positions_list, texts_no_spaces):
        char_labels = [LABEL_TO_ID['NO_SPACE']] * len(ns[:CHAR_WINDOW])
        for pos in positions:
            if pos < len(char_labels):
                char_labels[pos] = LABEL_TO_ID['SPACE']
        labels = [-100] + char_labels + [-100]
        labels.extend([-100] * (MAX_LENGTH - len(labels)))
        processed_labels.append(labels[:MAX_LENGTH])
    encodings["labels"] = processed_labels
    return encodings

def evaluate(model, dataloader, loss_fct, device):
    """Оценка во время обучения, использует ту же правильную логику."""
    model.eval()
    total_loss, all_true_positions, all_pred_positions = 0.0, [], []
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            labels = batch["labels"]
            outputs = model(**{k: v for k, v in batch.items() if k != 'labels'})
            loss = loss_fct(outputs.logits.view(-1, model.config.num_labels), labels.view(-1))
            total_loss += loss.item()
            predictions = torch.argmax(outputs.logits, dim=-1)
            for i in range(labels.shape[0]):
                active_mask = labels[i] != -100
                active_true = labels[i][active_mask].cpu().numpy()
                active_pred = predictions[i][active_mask].cpu().numpy()
                all_true_positions.append([idx for idx, lab in enumerate(active_true) if lab == LABEL_TO_ID["SPACE"]])
                all_pred_positions.append([idx for idx, lab in enumerate(active_pred) if lab == LABEL_TO_ID["SPACE"]])
    return total_loss / max(1, len(dataloader)), compute_f1_metric(all_true_positions, all_pred_positions)

def train_model(train_dataset, eval_dataset, best_model_dir):
    train_dataset.set_format("torch")
    eval_dataset.set_format("torch")
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)
    eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=BATCH_SIZE)
    
    model = CanineForTokenClassification.from_pretrained(MODEL_NAME, num_labels=len(LABEL_LIST), id2label=ID_TO_LABEL, label2id=LABEL_TO_ID).to(device)
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
    num_training_steps = NUM_EPOCHS * len(train_dataloader)
    lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max(1, num_training_steps // 10), num_training_steps=num_training_steps)
    
    all_labels_flat = [lab for ex in train_dataset for lab in ex['labels'] if lab != -100]
    class_counts = np.bincount(all_labels_flat, minlength=len(LABEL_LIST))
    class_weights = class_counts.sum() / (len(LABEL_LIST) * np.clip(class_counts, 1, None))
    loss_fct = CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float, device=device), ignore_index=-100)
    
    best_eval_f1 = -1.0
    best_eval_loss = 2970471274
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [train]")
        for batch in pbar:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**{k: v for k, v in batch.items() if k != 'labels'})
            loss = loss_fct(outputs.logits.view(-1, model.config.num_labels), batch["labels"].view(-1))
            loss.backward()
            optimizer.step(); lr_scheduler.step(); optimizer.zero_grad()
            pbar.set_postfix(loss=loss.item())
        eval_loss, eval_f1 = evaluate(model, eval_dataloader, loss_fct, device)
        print(f"\n--- Эпоха {epoch}/{NUM_EPOCHS} ---")
        print(f"Eval loss: {eval_loss:.4f}, Eval F1: {eval_f1:.4f}")
        if eval_f1 > best_eval_f1:
            print(f"Новый лучший F1. Сохраняем модель в {best_model_dir}...")
            best_eval_f1 = eval_f1
            model.save_pretrained(best_model_dir)
            tokenizer.save_pretrained(best_model_dir)

        if eval_loss < best_eval_loss:
            print(f"Новый лучший Loss eval. Сохраняем модель в {best_model_dir}_loss...")
            best_eval_loss = eval_loss
            model.save_pretrained(best_model_dir + "_loss")
            tokenizer.save_pretrained(best_model_dir + "_loss")

In [None]:
# =======================
# ОСНОВНОЙ ПАЙПЛАЙН
# =======================
# --- 1. Подготовка данных ---
print("Загрузка и подготовка данных для обучения...")
raw_dataset = load_dataset(DATASET_NAME, split="train")
split_dataset = raw_dataset.train_test_split(test_size=0.1, seed=SEED)
base_train_dataset = split_dataset["train"]
eval_dataset_hf = split_dataset["test"]

if os.path.exists(AUGMENTED_DATA_FILE_TXT):
    with open(AUGMENTED_DATA_FILE_TXT, 'r', encoding='utf-8') as f:
        query_corrected = [line.strip().lower() for line in f if line.strip()]

    query = ["".join(s.split()) for s in query_corrected]

    dataset_dict = {
        "query" : query,
        "query_corrected" : query_corrected
    }
    print(f"Найден файл с дополнительными данными: {AUGMENTED_DATA_FILE}")
    lyrics_dataset = Dataset.from_dict(dataset_dict)
    train_raw_dataset = concatenate_datasets([base_train_dataset, lyrics_dataset])
else:
    print(f"Файл {AUGMENTED_DATA_FILE_TXT} не найден. Используйте только базовый датасет.")
    train_raw_dataset = base_train_dataset

print(f"Итоговый размер обучающего набора: {len(train_raw_dataset)}")
train_dataset = train_raw_dataset.map(preprocess_function, batched=True, remove_columns=train_raw_dataset.column_names)
eval_dataset = eval_dataset_hf.map(preprocess_function, batched=True, remove_columns=eval_dataset_hf.column_names)

Загрузка и подготовка данных для обучения...
Найден файл с дополнительными данными: augmented_test_data.csv
Итоговый размер обучающего набора: 30069


In [None]:
# --- 2. Переобучение модели с правильной логикой ---
train_model(train_dataset, eval_dataset, BEST_MODEL_DIR)

# --- 3. Оценка на тестовых данных ---
print("\n" + "="*50)
print("--- ОЦЕНКА ОБУЧЕННОЙ МОДЕЛИ НА РАЗМЕЧЕННОМ ТЕСТЕ ---")

print(f"Загрузка лучшей модели из '{BEST_MODEL_DIR}'...")
model = CanineForTokenClassification.from_pretrained(BEST_MODEL_DIR).to(device)
tokenizer = CanineTokenizer.from_pretrained(BEST_MODEL_DIR)

if os.path.exists(AUGMENTED_DATA_FILE):
    df_test = pd.read_csv(AUGMENTED_DATA_FILE)
    all_true_positions, all_pred_positions = [], []

    for _, row in tqdm(df_test.iterrows(), total=len(df_test), desc="Финальная оценка"):
        pred_positions = restore_spaces(str(row['text_no_spaces']), model, tokenizer, device)
        true_positions = get_true_positions(str(row['text_no_spaces']), str(row['query_corrected']))
        all_pred_positions.append(pred_positions)
        all_true_positions.append(true_positions)

    final_f1_score = compute_f1_metric(all_true_positions, all_pred_positions)
    print("\n" + "="*50)
    print(f" ✅ Итоговый средний F1-score: {final_f1_score:.4f}")
    print("="*50)
else:
    print(f"Файл {AUGMENTED_DATA_FILE} не найден для финальной оценки.")


In [None]:
print("\n" + "="*50)
print("--- ОЦЕНКА ОБУЧЕННОЙ МОДЕЛИ НА ЛОКАЛЬНОМ ТЕСТЕ ---")
model = CanineForTokenClassification.from_pretrained(BEST_MODEL_DIR).to(device)
tokenizer = CanineTokenizer.from_pretrained(BEST_MODEL_DIR)

if os.path.exists(AUGMENTED_DATA_FILE):
    df_test = pd.read_csv(AUGMENTED_DATA_FILE)
    all_true_pos, all_pred_pos = [], []
    for _, row in tqdm(df_test.iterrows(), total=len(df_test), desc="Локальная оценка"):
        all_pred_pos.append(restore_spaces(str(row['text_no_spaces']), model, tokenizer, device))
        all_true_pos.append(get_true_positions(str(row['text_no_spaces']), str(row['query_corrected'])))
    final_f1_score = compute_f1_metric(all_true_pos, all_pred_pos)
    print(f" ✅ Итоговый средний F1-score на локальном тесте: {final_f1_score:.4f}")

print("\n" + "="*50)
print("--- ГЕНЕРАЦИЯ ФАЙЛА ДЛЯ КОНКУРСА ---")

CONTEST_FILE = "test_data.txt" 
SUBMISSION_FILE = "submission.csv"

def read_contest_file(path: str) -> pd.DataFrame:
    if not os.path.exists(path): raise FileNotFoundError(f"Файл не найден: {path}")
    try:
        df = pd.read_csv(path, sep=None, engine="python", encoding="utf-8-sig")
    except Exception:
        data_rows = []
        with open(path, 'r', encoding='utf-8-sig') as f:
            header = [h.strip() for h in f.readline().strip().split(',')]
            for line in f:
                if line.strip(): data_rows.append(line.strip().split(',', 1))
        df = pd.DataFrame(data_rows, columns=header[:len(data_rows[0])])
    df.columns = [c.strip() for c in df.columns]
    if "text_no_spaces" not in df.columns: raise ValueError("Нужна колонка 'text_no_spaces'.")
    return df

task_data = read_contest_file(CONTEST_FILE)

all_positions_str = []
for text in tqdm(task_data["text_no_spaces"].astype(str), desc="Генерация сабмита"):
    positions = restore_spaces(text, model, tokenizer, device)
    # Преобразуем список [5, 13] в строку "[5, 13]"
    all_positions_str.append(str(positions))

submission_df = task_data.copy()
submission_df["predicted_positions"] = all_positions_str
submission_df = submission_df[["id","predicted_positions"]]
submission_df.to_csv(SUBMISSION_FILE)
print(f"✅ Файл для отправки сохранен: {SUBMISSION_FILE}")
print("Пример первых 5 строк:")
print(submission_df.head())