# Pretrain

1) Препроцессинг данных

In [None]:
import re
import json
import hashlib
from pathlib import Path

In [None]:
DATA_PATH = Path("./data/test")
OUT_PATH = Path("./data/pretrain_corpus.jsonl")

MIN_SENT_CHARS = 20
MAX_SENT_CHARS = 5000

CONTEXT_LEN = 1024
MAX_TOKENS_PER_CHUNK = CONTEXT_LEN - 2 

BOS = "<bos>"
EOS = "<eos>"

def normalize_text(text):
    text = text.replace("\r\n", "\n").replace("\r", "\n")
    text = text.replace("«", '"').replace("»", '"').replace("„", '"').replace("“", '"').replace("”", '"')
    text = text.replace("—", " — ")
    text = text.replace("–", " — ")
    text = re.sub(r"[ \t]+", " ", text)
    text = re.sub(r"\n{3,}", "\n\n", text)
    return text.strip()

def normalize_punct(s):
    s = s.strip()
    s = re.sub(r"\.{4,}", "...", s)
    s = re.sub(r"!{2,}", "!", s)
    s = re.sub(r"\?{2,}", "?", s)
    s = re.sub(r"(\?!){2,}", "?!", s)
    s = re.sub(r",{2,}", ",", s)
    s = re.sub(r":{2,}", ":", s)
    s = re.sub(r";{2,}", ";", s)
    s = re.sub(r"\s+([,.;:!?])", r"\1", s)
    s = re.sub(r"([,.;:!?])([^\s])", r"\1 \2", s)
    s = re.sub(r"\s{2,}", " ", s)
    return s.strip()

def split_sentences(text):
    parts = re.split(r"(?<=[.!?])\s+", text)
    return [p for p in parts if p.strip()]

LATIN_RE = re.compile(r"[A-Za-z]")
CYR_RE = re.compile(r"[А-Яа-яЁё]")

def cyrillic_ratio(s):
    letters = re.findall(r"[A-Za-zА-Яа-яЁё]", s)
    if not letters:
        return 0.0
    cyr = sum(1 for ch in letters if CYR_RE.match(ch))
    return cyr / len(letters)

def is_good_sentence(s):
    if not s:
        return False

    if len(s) < MIN_SENT_CHARS:
        return False

    if len(s) > MAX_SENT_CHARS:
        return False

    if LATIN_RE.search(s):
        return False
    
    if cyrillic_ratio(s) < 0.70:
        return False
    
    letters_count = len(CYR_RE.findall(s))
    if letters_count < 5:
        return False
    
    if len(set(s)) <= 3:
        return False

    return True

def sha1(text):
    return hashlib.sha1(text.encode("utf-8")).hexdigest()

def normalize_for_dedup(s):
    s = s.lower()
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"[^0-9а-яё]+", "", s)
    return s.strip()

def count_tokens(text, tokenizer=None):
    if tokenizer is None:
        return len(text.split())
    return len(tokenizer.encode(text))

def chunk_sentences(sentences, max_tokens, tokenizer=None):
    chunks = []
    current = []
    current_tokens = 0

    for s in sentences:
        s_tokens = count_tokens(s, tokenizer)
        if s_tokens > max_tokens:
            words = s.split()
            buf = []
            buf_tokens = 0

            for w in words:
                w_tokens = count_tokens(w, tokenizer)
                if buf_tokens + w_tokens > max_tokens and buf:
                    chunks.append(" ".join(buf))
                    buf = [w]
                    buf_tokens = w_tokens
                else:
                    buf.append(w)
                    buf_tokens += w_tokens

            if buf:
                chunks.append(" ".join(buf))
            continue

        if current_tokens + s_tokens > max_tokens and current:
            chunks.append(" ".join(current))
            current = [s]
            current_tokens = s_tokens
        else:
            current.append(s)
            current_tokens += s_tokens
    if current:
        chunks.append(" ".join(current))

    return chunks

In [None]:
txt_files = sorted(DATA_PATH.glob("*.txt"))
print("Количество файлов:", len(txt_files))

seen_docs = set()
seen_sents = set()
all_chunks = []

stats = {
    "docs_total": 0,
    "docs_unique": 0,
    "sents_total": 0,
    "sents_good": 0,
    "sents_unique": 0,
    "chunks_total": 0
}

for fp in txt_files:
    stats["docs_total"] += 1

    raw = fp.read_text(encoding="utf-8", errors="ignore")
    raw = normalize_text(raw)

    doc_key = sha1(normalize_for_dedup(raw))
    if doc_key in seen_docs:
        continue

    seen_docs.add(doc_key)
    stats["docs_unique"] += 1

    sents = split_sentences(raw)
    stats["sents_total"] += len(sents)

    cleaned = []
    for s in sents:
        s = normalize_punct(s)
        if not is_good_sentence(s):
            continue

        stats["sents_good"] += 1

        sent_key = sha1(normalize_for_dedup(s))
        if sent_key in seen_sents:
            continue

        seen_sents.add(sent_key)
        stats["sents_unique"] += 1
        cleaned.append(s)

    chunks = chunk_sentences(cleaned, max_tokens=MAX_TOKENS_PER_CHUNK, tokenizer=None)

    for ch in chunks:
        text = f"{BOS} {ch.strip()} {EOS}"
        all_chunks.append(text)

stats["chunks_total"] = len(all_chunks)

print("=== STATS ===")
for k, v in stats.items():
    print(f"{k}: {v}")

OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
with OUT_PATH.open("w", encoding="utf-8") as f:
    for t in all_chunks:
        f.write(json.dumps({"text": t}, ensure_ascii=False) + "\n")

print("Saved:", OUT_PATH, "chunks:", len(all_chunks))

3) Токенизатор

In [None]:
import json
from pathlib import Path

JSONL_PATH = Path("./data/pretrain_corpus.jsonl")
TXT_TRAIN_PATH = Path("./data/tokenizer_train.txt")

TXT_TRAIN_PATH.parent.mkdir(parents=True, exist_ok=True)

count = 0
with open(JSONL_PATH, "r", encoding="utf-8") as f_in, open(TXT_TRAIN_PATH, "w", encoding="utf-8") as f_out:
    for line in f_in:
        obj = json.loads(line)
        text = obj["text"].strip()
        if not text:
            continue
        f_out.write(text.replace("\n", " ") + "\n")
        count += 1

print("Готово. Строк для обучения токенизатора:", count)
print("Файл:", TXT_TRAIN_PATH)

In [None]:
from tokenizers import ByteLevelBPETokenizer
from pathlib import Path

VOCAB_SIZE = 3000
MIN_FREQUENCY = 2

special_tokens = ["<pad>", "<unk>", "<bos>", "<eos>"]

tokenizer = ByteLevelBPETokenizer()

tokenizer.train(
    files=[str(TXT_TRAIN_PATH)],
    vocab_size=VOCAB_SIZE,
    min_frequency=MIN_FREQUENCY,
    special_tokens=special_tokens
)


OUT_DIR = Path("./tokenizer_bpe_3k")
OUT_DIR.mkdir(parents=True, exist_ok=True)

tokenizer.save_model(str(OUT_DIR))

tokenizer.save(str(OUT_DIR / "tokenizer.json"))
print("Saved tokenizer.json:", OUT_DIR / "tokenizer.json")

In [None]:
from transformers import PreTrainedTokenizerFast
from datasets import load_dataset

tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="./tokenizer_bpe_3k/tokenizer.json",
    unk_token="<unk>",
    pad_token="<pad>",
    bos_token="<bos>",
    eos_token="<eos>",
)

ds = load_dataset("json", data_files={"train": "./data/pretrain_corpus.jsonl"})

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        add_special_tokens=False,
        return_token_type_ids=False
    )

tokenized = ds["train"].map(
    tokenize_function,
    batched=True,
    remove_columns=["text"]
)

BLOCK_SIZE = 512

def group_texts(examples):
    concatenated_input_ids = []
    concatenated_attention = []

    for ids in examples["input_ids"]:
        concatenated_input_ids.extend(ids)

    for am in examples["attention_mask"]:
        concatenated_attention.extend(am)

    total_length = len(concatenated_input_ids)
    total_length = (total_length // BLOCK_SIZE) * BLOCK_SIZE

    input_ids = []
    attention_mask = []
    labels = []

    for i in range(0, total_length, BLOCK_SIZE):
        chunk_ids = concatenated_input_ids[i : i + BLOCK_SIZE]
        chunk_mask = concatenated_attention[i : i + BLOCK_SIZE]

        input_ids.append(chunk_ids)
        attention_mask.append(chunk_mask)
        labels.append(chunk_ids.copy())

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

lm_dataset = tokenized.map(
    group_texts,
    batched=True,
    remove_columns=tokenized.column_names
)

print(lm_dataset)
print("Примеров:", len(lm_dataset))
print("Длина блока:", len(lm_dataset[0]["input_ids"]))