In [29]:
# ============================================
# Итоговый пересинтез (v2.1) — от начала до конца
# Вход:  /tf/konokhova/project/processed/all_problems_ready_for_synth.parquet
# Выход: /tf/konokhova/project/processed/synth_v21/all_problems_balanced_synth_v21.parquet (+ .csv)
# ============================================

# ---------- Cell 0: env (рекомендуется запускать первой) ----------
import os
os.environ["USE_TF"] = "0"
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

# ---------- Cell 1: imports ----------
from __future__ import annotations

import re
import random
import difflib
from pathlib import Path
from typing import Optional, Tuple, List, Dict

import numpy as np
import pandas as pd

# ---------- Cell 2: load ----------
SRC_PATH = Path("./all_problems_ready_for_synth.parquet")
assert SRC_PATH.exists(), f"Not found: {SRC_PATH}"

df = pd.read_parquet(SRC_PATH)
print("Loaded:", df.shape)
print("Columns:", list(df.columns))

def pick_col(candidates: List[str]) -> str:
    for c in candidates:
        if c in df.columns:
            return c
    raise KeyError(f"None of columns exist: {candidates}")

COL_UID   = pick_col(["row_uid", "uid", "key", "id"])
COL_COND  = pick_col(["condition_for_train", "condition_rus_clean", "condition_rus", "condition"])
COL_SOL   = pick_col(["solution_rus_clean", "solution_rus", "solution"])
COL_ANS   = pick_col(["answer_norm", "answer_ref", "answer_rus", "answer_tex"])
COL_SPLIT = pick_col(["split"])

print("Using:", dict(uid=COL_UID, cond=COL_COND, sol=COL_SOL, ans=COL_ANS, split=COL_SPLIT))
assert df[COL_SOL].notna().all()
assert df[COL_COND].notna().all()
assert df[COL_SPLIT].isin(["train", "val"]).all()

# ---------- Cell 3: regex + helpers ----------
RE_ANSWER = re.compile(r"(Ответ\s*:\s*)(.+?)(\s*(?:\.|\n|$))", re.IGNORECASE | re.DOTALL)
RE_NUMBER = re.compile(r"(?<![A-Za-zА-Яа-я_])(-?\d+(?:[.,]\d+)?)(?![A-Za-zА-Яа-я_])")
RE_TRIG = re.compile(r"\b(sin|cos|tan|ctg|tg|cot)\b", re.IGNORECASE)
RE_TRIG_RU = re.compile(r"\b(синус|косинус|тангенс|котангенс)\b", re.IGNORECASE)
RE_INEQ = re.compile(r"(<=|>=|<|>|≤|≥)")

def ensure_diff(a: str, b: str) -> bool:
    return (a or "").strip() != (b or "").strip()

def normalize_decimal(s: str) -> str:
    return (s or "").strip().replace(",", ".")

def safe_float(s: str) -> Optional[float]:
    try:
        return float(normalize_decimal(s))
    except Exception:
        return None

def extract_answer_from_text(solution: str) -> str:
    sol = solution or ""
    m = RE_ANSWER.search(sol)
    if m:
        return (m.group(2) or "").strip()
    # fallback: последняя встреченная числовая константа
    nums = list(RE_NUMBER.finditer(sol))
    return nums[-1].group(1).strip() if nums else ""

def ensure_answer_line(solution: str, answer_norm: str) -> str:
    sol = solution or ""
    ans = (answer_norm or "").strip()
    if not ans:
        # если в тексте уже есть Ответ — оставляем, иначе ничего не добавляем
        return sol if RE_ANSWER.search(sol) else sol
    if RE_ANSWER.search(sol):
        return sol
    sep = "\n" if sol.endswith("\n") else "\n\n"
    return sol + sep + f"Ответ: {ans}."

def propose_wrong_answer(answer_norm: str, solution_fallback: str, rng: random.Random) -> str:
    ans = (answer_norm or "").strip()
    if not ans:
        ans = extract_answer_from_text(solution_fallback)

    f = safe_float(ans)
    if f is not None:
        candidates = [f + 1, f - 1, -f, f * 2, f / 2 if f != 0 else 1]
        new = rng.choice(candidates)
        out = str(new)
        if "," in ans:
            out = out.replace(".", ",")
        if out.strip() == ans.strip():
            out = (str(new + 1)).replace(".", ",") if "," in ans else str(new + 1)
        return out

    # строковый ответ / нечисловой
    if not ans:
        # совсем нет опоры — подставим "1"
        return "1"
    if ans.endswith("0"):
        return ans[:-1] + "1"
    return ans + "0"

def first_diff_window(a: str, b: str, window: int = 160) -> Tuple[str, str]:
    a = a or ""
    b = b or ""
    sm = difflib.SequenceMatcher(None, a, b)
    for tag, i1, i2, j1, j2 in sm.get_opcodes():
        if tag != "equal":
            a_start = max(0, i1 - window)
            a_end   = min(len(a), i2 + window)
            b_start = max(0, j1 - window)
            b_end   = min(len(b), j2 + window)
            return a[a_start:a_end], b[b_start:b_end]
    return "", ""

# ---------- Cell 4: синтетика (улучшенная) ----------
def tweak_number_str(num_str: str, rng: random.Random) -> str:
    val = safe_float(num_str)
    if val is None:
        return num_str
    if abs(val) >= 1:
        delta = rng.choice([-2, -1, 1, 2])
        new_val = val + delta
    else:
        delta = rng.choice([-0.1, -0.05, 0.05, 0.1])
        new_val = val + delta
    out = f"{new_val}"
    if "," in num_str:
        out = out.replace(".", ",")
    return out

def synth_wrong_final_answer_replace(solution: str, answer_norm: str, rng: random.Random) -> Tuple[str, bool]:
    sol = ensure_answer_line(solution, answer_norm)

    # если answer_norm пустой — берём из текста или last-number
    ref_ans = (answer_norm or "").strip() or extract_answer_from_text(sol)
    if not ref_ans:
        return sol, False

    new_ans = propose_wrong_answer(ref_ans, sol, rng)

    m = RE_ANSWER.search(sol)
    if not m:
        # теоретически не должно быть, но оставим
        sol2 = sol + ("\n" if sol.endswith("\n") else "\n\n") + f"Ответ: {new_ans}."
        return sol2, ensure_diff(solution, sol2)

    start, end = m.span(2)
    out = sol[:start] + new_ans + sol[end:]
    return out, ensure_diff(solution, out)

def synth_intermediate_local_inconsistency(solution: str, answer_norm: str, rng: random.Random, max_tries: int = 12) -> Tuple[str, bool]:
    sol = ensure_answer_line(solution, answer_norm)

    ans_span = None
    m_ans = RE_ANSWER.search(sol)
    if m_ans:
        ans_span = m_ans.span(0)

    for _ in range(max_tries):
        matches = list(RE_NUMBER.finditer(sol))
        if not matches:
            return sol, False

        candidates = []
        for m in matches:
            sp = m.span(1)
            if ans_span and not (sp[1] <= ans_span[0] or sp[0] >= ans_span[1]):
                continue
            candidates.append(m)

        if not candidates:
            return sol, False

        m = rng.choice(candidates)
        num = m.group(1)
        new_num = tweak_number_str(num, rng)
        if new_num == num:
            continue

        out = sol[:m.start(1)] + new_num + sol[m.end(1):]
        if ensure_diff(sol, out):
            return out, True

    return sol, False

def synth_intermediate_propagated(solution: str, answer_norm: str, rng: random.Random) -> Tuple[str, bool]:
    sol1, ok1 = synth_intermediate_local_inconsistency(solution, answer_norm, rng)
    if not ok1:
        return ensure_answer_line(solution, answer_norm), False
    sol2, ok2 = synth_wrong_final_answer_replace(sol1, answer_norm, rng)
    return sol2, ok2

def synth_flip_sign(solution: str, answer_norm: str, rng: random.Random) -> Tuple[str, bool]:
    sol = ensure_answer_line(solution, answer_norm)
    m = re.search(r"(\+|-)\s*(\d+(?:[.,]\d+)?)", sol)
    if not m:
        return sol, False
    sign = m.group(1)
    new_sign = "-" if sign == "+" else "+"
    out = sol[:m.start(1)] + new_sign + sol[m.end(1):]
    return out, ensure_diff(sol, out)

def synth_swap_trig(solution: str, answer_norm: str, rng: random.Random) -> Tuple[str, bool]:
    sol = ensure_answer_line(solution, answer_norm)
    m = RE_TRIG.search(sol)
    if m:
        token = m.group(1).lower()
        mapping = {"sin":"cos","cos":"sin","tan":"cot","tg":"ctg","ctg":"tg","cot":"tan"}
        rep = mapping.get(token, "cos")
        out = sol[:m.start(1)] + rep + sol[m.end(1):]
        return out, ensure_diff(sol, out)

    m2 = RE_TRIG_RU.search(sol)
    if not m2:
        return sol, False
    token = m2.group(1).lower()
    mapping_ru = {"синус":"косинус","косинус":"синус","тангенс":"котангенс","котангенс":"тангенс"}
    rep = mapping_ru.get(token, "косинус")
    out = sol[:m2.start(1)] + rep + sol[m2.end(1):]
    return out, ensure_diff(sol, out)

def synth_swap_inequality(solution: str, answer_norm: str, rng: random.Random) -> Tuple[str, bool]:
    sol = ensure_answer_line(solution, answer_norm)
    m = RE_INEQ.search(sol)
    if not m:
        return sol, False
    op = m.group(1)
    mapping = {"<":">", ">":"<", "<=":">=", ">=":"<=", "≤":"≥", "≥":"≤"}
    rep = mapping.get(op, op)
    out = sol[:m.start(1)] + rep + sol[m.end(1):]
    return out, ensure_diff(sol, out)

def synth_drop_odz(solution: str, answer_norm: str, rng: random.Random) -> Tuple[str, bool]:
    sol = ensure_answer_line(solution, answer_norm)
    lines = sol.splitlines()
    idx = None
    for i, ln in enumerate(lines):
        low = ln.lower()
        if "одз" in low or "допуст" in low:
            idx = i
            break
    if idx is None:
        return sol, False
    del lines[idx]
    out = "\n".join(lines)
    return out, ensure_diff(sol, out)

TAGS = [
    "intermediate_arithmetic_slip_local",
    "intermediate_arithmetic_slip_propagated",
    "wrong_final_answer_replace",
    "flip_sign",
    "swap_trig",
    "swap_inequality",
    "drop_odz",
]

def synth_wrong_by_tag(solution: str, answer_norm: str, tag: str, rng: random.Random) -> Tuple[str, bool, str]:
    if tag == "wrong_final_answer_replace":
        out, ok = synth_wrong_final_answer_replace(solution, answer_norm, rng)
        return out, ok, tag
    if tag == "intermediate_arithmetic_slip_local":
        out, ok = synth_intermediate_local_inconsistency(solution, answer_norm, rng)
        return out, ok, tag
    if tag == "intermediate_arithmetic_slip_propagated":
        out, ok = synth_intermediate_propagated(solution, answer_norm, rng)
        return out, ok, tag
    if tag == "flip_sign":
        out, ok = synth_flip_sign(solution, answer_norm, rng)
        return out, ok, tag
    if tag == "swap_trig":
        out, ok = synth_swap_trig(solution, answer_norm, rng)
        return out, ok, tag
    if tag == "swap_inequality":
        out, ok = synth_swap_inequality(solution, answer_norm, rng)
        return out, ok, tag
    if tag == "drop_odz":
        out, ok = synth_drop_odz(solution, answer_norm, rng)
        return out, ok, tag
    return ensure_answer_line(solution, answer_norm), False, "FAILED"

def synth_wrong_with_retries(solution: str, answer_norm: str, desired_tag: str, rng: random.Random,
                            max_tries_desired: int = 6) -> Tuple[str, str]:
    sol0 = ensure_answer_line(solution, answer_norm)

    # 1) desired tag retries
    for _ in range(max_tries_desired):
        out, ok, used = synth_wrong_by_tag(sol0, answer_norm, desired_tag, rng)
        if ok and ensure_diff(sol0, out):
            return out, used

    # 2) try other tags
    other = [t for t in TAGS if t != desired_tag]
    rng.shuffle(other)
    for t in other:
        for _ in range(3):
            out, ok, used = synth_wrong_by_tag(sol0, answer_norm, t, rng)
            if ok and ensure_diff(sol0, out):
                return out, used

    # 3) final fallback (теперь почти всегда сработает)
    out, ok = synth_wrong_final_answer_replace(sol0, answer_norm, rng)
    return (out, "wrong_final_answer_replace_fallback") if ok else (sol0, "FAILED_FALLBACK")

# ---------- Cell 5: assign target tags per split ----------
def round_robin_tags(n: int, tags: List[str]) -> List[str]:
    return [tags[i % len(tags)] for i in range(n)]

def assign_tags_per_split(df_in: pd.DataFrame, split_col: str, seed: int = 42) -> pd.Series:
    rr = random.Random(seed)
    n_train = int((df_in[split_col] == "train").sum())
    n_val   = int((df_in[split_col] == "val").sum())

    tags_train = round_robin_tags(n_train, TAGS)
    tags_val   = round_robin_tags(n_val, TAGS)

    rr.shuffle(tags_train)
    rr.shuffle(tags_val)

    it_train = iter(tags_train)
    it_val = iter(tags_val)

    out = []
    for sp in df_in[split_col].tolist():
        out.append(next(it_train) if sp == "train" else next(it_val))
    return pd.Series(out, index=df_in.index, name="target_tag")

# ---------- Cell 6: build balanced ----------
SEED = 42
rng = random.Random(SEED)

df2 = df.copy()
df2["target_tag"] = assign_tags_per_split(df2, COL_SPLIT, seed=SEED)

records: List[Dict[str, object]] = []
failed_fallback = 0

# GOLD
for row in df2.itertuples(index=False):
    uid = getattr(row, COL_UID)
    cond = getattr(row, COL_COND)
    sol = getattr(row, COL_SOL)
    ans = getattr(row, COL_ANS)
    sp  = getattr(row, COL_SPLIT)

    sol_ref = ensure_answer_line(str(sol), str(ans))

    records.append({
        "row_uid": str(uid),
        "split": str(sp),
        "label": 1,
        "error_tag": "gold",
        "condition_for_train": str(cond),
        "student_solution": sol_ref,
        "solution_ref": sol_ref,
        "answer_norm": str(ans) if ans is not None else "",
    })

# WRONG
for row in df2.itertuples(index=False):
    uid = getattr(row, COL_UID)
    cond = getattr(row, COL_COND)
    sol = getattr(row, COL_SOL)
    ans = getattr(row, COL_ANS)
    sp  = getattr(row, COL_SPLIT)
    desired = getattr(row, "target_tag")

    sol_ref = ensure_answer_line(str(sol), str(ans))
    wrong_text, used_tag = synth_wrong_with_retries(sol_ref, str(ans), str(desired), rng)

    if used_tag == "FAILED_FALLBACK":
        failed_fallback += 1

    records.append({
        "row_uid": str(uid),
        "split": str(sp),
        "label": 0,
        "error_tag": used_tag,
        "condition_for_train": str(cond),
        "student_solution": str(wrong_text),
        "solution_ref": sol_ref,
        "answer_norm": str(ans) if ans is not None else "",
    })

balanced = pd.DataFrame.from_records(records)

print("Balanced shape:", balanced.shape)
print("label:\n", balanced["label"].value_counts())
print("split:\n", balanced["split"].value_counts())
print("Top error_tag (wrong):\n", balanced[balanced["label"]==0]["error_tag"].value_counts().head(25))
print("FAILED_FALLBACK:", failed_fallback)

# ---------- Cell 7: diff snippets ----------
diff_correct = []
diff_student = []
for ref, stud in zip(balanced["solution_ref"].astype(str), balanced["student_solution"].astype(str)):
    a, b = first_diff_window(ref, stud, window=160)
    diff_correct.append(a)
    diff_student.append(b)

balanced["diff_correct"] = diff_correct
balanced["diff_student"] = diff_student

# Для gold diff будет пустым (это нормально — эталон == ученик)
gold_empty = float((balanced.loc[balanced["label"]==1, "diff_student"].str.len() == 0).mean())
print("Gold diff empty %:", gold_empty)

# sanity: WRONG реально отличается от REF
wrong_diff_ratio = float(np.mean(
    (balanced.loc[balanced["label"]==0, "student_solution"].astype(str).values !=
     balanced.loc[balanced["label"]==0, "solution_ref"].astype(str).values)
))
print("Wrong differs from correct %:", wrong_diff_ratio * 100)

# ---------- Cell 8: save ----------
OUT_DIR = Path("./synth_v21")
OUT_DIR.mkdir(parents=True, exist_ok=True)

OUT_PARQUET = OUT_DIR / "all_problems_balanced_synth_v21.parquet"
OUT_CSV     = OUT_DIR / "all_problems_balanced_synth_v21.csv"

balanced.to_parquet(OUT_PARQUET, index=False)
balanced.to_csv(OUT_CSV, index=False, encoding="utf-8")

print("Saved:", OUT_PARQUET)
print("Saved:", OUT_CSV)

# ---------- Cell 9: quick preview ----------
# несколько примеров WRONG разных тегов
sample = (balanced[balanced["label"]==0]
          .groupby("error_tag", as_index=False)
          .head(1)
          .sample(n=min(8, balanced[balanced["label"]==0]["error_tag"].nunique()), random_state=SEED))

for r in sample.itertuples(index=False):
    print(f"\n=== UID: {r.row_uid} | tag: {r.error_tag} ===")
    print("[DIFF student]\n", (r.diff_student[:300] if r.diff_student else ""))
    print("[DIFF correct]\n", (r.diff_correct[:300] if r.diff_correct else ""))


Loaded: (10178, 39)
Columns: ['id', 'condition', 'images_condition', 'solution', 'condition_rus', 'solution_rus', 'images_solution', 'category', 'subcategory', 'link', 'profile', 'file', 'images_condition_list', 'images_condition_count', 'images_condition_has', 'images_solution_list', 'images_solution_count', 'images_solution_has', 'condition_clean', 'condition_rus_clean', 'solution_clean', 'solution_rus_clean', 'answer_rus', 'answer_tex', 'answer_ref', 'condition_final', 'solution_ref_final', 'condition_chars', 'solution_chars', 'content_hash', 'row_uid', 'is_long_solution', 'is_long_condition', 'condition_for_train', 'solution_for_train', 'split', 'task_hash', 'answer_norm', 'answer_num']
Using: {'uid': 'row_uid', 'cond': 'condition_for_train', 'sol': 'solution_rus_clean', 'ans': 'answer_norm', 'split': 'split'}
Balanced shape: (20356, 8)
label:
 label
1    10178
0    10178
Name: count, dtype: int64
split:
 split
train    20036
val        320
Name: count, dtype: int64
Top error_tag (