# Quick‑start: Hard‑coded Training Script

This notebook trains a tiny BART‑style model **without any external config files**.
All hyper‑parameters are defined inline for clarity.

## 0. Install calt-x

In [None]:
!pip install calt-x

: 

## 1. Imports

In [1]:
from transformers import BartConfig, BartForConditionalGeneration as Transformer
from transformers import TrainingArguments
from calt import (
    PolynomialTrainer,
    data_loader,
)
import torch, random, numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x14bdad998990>

## 2. Dataset (tiny demo)

In [2]:
# Point to any dataset you like; here we assume the toy GCD dataset from the data‑generation notebook.
# TRAIN_PATH = "../samples/train_raw.txt"
# TEST_PATH = "../samples/test_raw.txt"
TRAIN_PATH = "/home/sato/workspace/calt/notebooks/dataset/partial_sum_problem/GF7_n=2/test_raw.txt"
TEST_PATH = "/home/sato/workspace/calt/notebooks/dataset/partial_sum_problem/GF7_n=2/test_raw.txt"
dataset, tokenizer, data_collator = data_loader(
    train_dataset_path=TRAIN_PATH,
    test_dataset_path=TEST_PATH,
    field="GF7",
    num_variables=2,
    max_degree=10,
    max_coeff=10,
    max_length=256,
)

In [3]:
sample1 = data_collator([dataset["train"][0]])

In [4]:
sample1

{'input_ids': tensor([[28,  9, 15, 16, 11, 15, 15, 26, 11, 17, 23,  9, 22, 15, 26, 14, 15, 16,
          26,  9, 18, 15, 14, 16, 17, 11, 15, 15, 26, 10, 18, 19,  9, 17, 19, 12,
          16, 18, 11, 15, 17, 29]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
 'decoder_input_ids': tensor([[28,  9, 15, 16, 11, 15, 15, 26, 11, 17, 23,  9, 22, 15,  9, 15, 16, 11,
          15, 15, 26, 11, 17, 23,  9, 22, 15, 11, 15, 15, 26, 11, 17, 23,  9, 22,
          15,  9, 18, 15, 14, 16, 17, 14, 15, 15, 26, 11, 17, 23,  9, 22, 15, 10,
          18, 19,  9, 17, 19, 12, 16, 18,  9, 18, 15, 14, 16, 17, 11, 15, 17, 14,
          15, 15]]),
 'decoder_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [5]:
print(tokenizer.decode(sample1["input_ids"][0]))
print(tokenizer.decode(sample1["decoder_input_ids"][0]))
print(tokenizer.decode(sample1["labels"][0]))

<s> C1 E0 E1 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 [SEP] C6 E0 E1 [SEP] C1 E3 E0 C6 E1 E2 C3 E0 E0 [SEP] C2 E3 E4 C1 E2 E4 C4 E1 E3 C3 E0 E2 </s>
<s> C1 E0 E1 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C1 E0 E1 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C1 E3 E0 C6 E1 E2 C6 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C2 E3 E4 C1 E2 E4 C4 E1 E3 C1 E3 E0 C6 E1 E2 C3 E0 E2 C6 E0 E0
C1 E0 E1 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C1 E0 E1 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C1 E3 E0 C6 E1 E2 C6 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C2 E3 E4 C1 E2 E4 C4 E1 E3 C1 E3 E0 C6 E1 E2 C3 E0 E2 C6 E0 E0 </s>


In [6]:
sample2 = data_collator([dataset["test"][0]])

In [7]:
sample2

{'input_ids': tensor([[28,  9, 15, 16, 11, 15, 15, 26, 11, 17, 23,  9, 22, 15, 26, 14, 15, 16,
          26,  9, 18, 15, 14, 16, 17, 11, 15, 15, 26, 10, 18, 19,  9, 17, 19, 12,
          16, 18, 11, 15, 17, 29]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
 'decoder_input_ids': tensor([[28,  9, 15, 16, 11, 15, 15, 26, 11, 17, 23,  9, 22, 15,  9, 15, 16, 11,
          15, 15, 26, 11, 17, 23,  9, 22, 15, 11, 15, 15, 26, 11, 17, 23,  9, 22,
          15,  9, 18, 15, 14, 16, 17, 14, 15, 15, 26, 11, 17, 23,  9, 22, 15, 10,
          18, 19,  9, 17, 19, 12, 16, 18,  9, 18, 15, 14, 16, 17, 11, 15, 17, 14,
          15, 15]]),
 'decoder_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [8]:
print(tokenizer.decode(sample2["input_ids"][0]))
print(tokenizer.decode(sample2["decoder_input_ids"][0]))
print(tokenizer.decode(sample2["labels"][0]))

<s> C1 E0 E1 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 [SEP] C6 E0 E1 [SEP] C1 E3 E0 C6 E1 E2 C3 E0 E0 [SEP] C2 E3 E4 C1 E2 E4 C4 E1 E3 C3 E0 E2 </s>
<s> C1 E0 E1 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C1 E0 E1 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C1 E3 E0 C6 E1 E2 C6 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C2 E3 E4 C1 E2 E4 C4 E1 E3 C1 E3 E0 C6 E1 E2 C3 E0 E2 C6 E0 E0
C1 E0 E1 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C1 E0 E1 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C3 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C1 E3 E0 C6 E1 E2 C6 E0 E0 [SEP] C3 E2 E8 C1 E7 E0 C2 E3 E4 C1 E2 E4 C4 E1 E3 C1 E3 E0 C6 E1 E2 C3 E0 E2 C6 E0 E0 </s>


## 3. Model

In [9]:
tokenizer.bos_token_id

28

In [10]:
# Minimal architecture — only overriding d_model for speed.
model_cfg = BartConfig(
    d_model=256,
    vocab_size=len(tokenizer.vocab),
    encoder_layers=6,
    decoder_layers=6,
    max_position_embeddings=256,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    max_length=256,
)
model = Transformer(config=model_cfg)

## 4. TrainingArguments (defaults + a few essentials)

In [11]:
args = TrainingArguments(
    output_dir="results/demo",
    num_train_epochs=10,
    logging_steps=50,
    per_device_train_batch_size=int(32),
    per_device_eval_batch_size=int(32),
    save_strategy="no",  # skip checkpoints for the quick demo
    seed=SEED,
    remove_unused_columns=False,
    label_names=["labels"],
    report_to="wandb",
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


## 5. Trainer & run

In [12]:
trainer = PolynomialTrainer(
    args=args,
    model=model,
    processing_class=tokenizer,
    data_collator=data_collator,
    train_dataset=dataset["train"],  # slice for speed
    eval_dataset=dataset["test"],
)

# train
results = trainer.train()
trainer.save_model()
metrics = results.metrics

# eval
eval_metrics = trainer.evaluate()
metrics.update(eval_metrics)
acc = trainer.generate_evaluation(max_length=128)
metrics["test_accuracy"] = acc

# save metrics
trainer.save_metrics("all", metrics)

[34m[1mwandb[0m: Currently logged in as: [33msugarl[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss




In [14]:
from sympy import symbols
from comparison_vis import display_with_diff

# --- sample tokens ---
gold_tokens = "C1 E1 E1 C-1 E2 E1 C-3 E1 E8 C-1 E0 E7"
pred_tokens = "C1 E1 E1 C-2 E2 E1 C-3 E1 E9 C-2 E0 E7"   # 係数・指数をわざと誤る

x, y = symbols("x y")
gold_expr = parse_poly(gold_tokens, [x, y])
pred_expr = parse_poly(pred_tokens, [x, y])

display_with_diff(gold_expr, pred_expr, [x, y])


NameError: name 'parse_poly' is not defined

The code above is all you need for a first experiment.  
Increase `num_train_epochs`, remove the slicing, and enable checkpointing/WandB when you move from a demo to full‑scale training.

In [2]:
from typing import Sequence, Union
from sympy import symbols, Symbol, Integer


def parse_poly(tokens: str, var_names: Sequence[Union[str, Symbol]] | None = None):
    """
    Convert an internal token sequence (e.g. ``"C1 E1 E1 C-3 E0 E7"``)
    into a SymPy polynomial.

    Parameters
    ----------
    tokens : str
        Whitespace-separated string where a token starting with ``C`` indicates
        a coefficient and the following ``E`` tokens indicate exponents.
    var_names : Sequence[str | sympy.Symbol] | None, optional
        Variable names (either strings or pre-created Symbol objects).  If
        ``None`` (default), variables are auto-generated as x0, x1, …

    Returns
    -------
    sympy.Expr
        A SymPy expression corresponding to the polynomial.

    Raises
    ------
    ValueError
        If the token sequence is malformed or the number of variables does not
        match ``var_names``.
    """
    parts = tokens.strip().split()
    if not parts or not parts[0].startswith("C"):
        raise ValueError("Token sequence must start with a 'C' coefficient token.")

    # --- Infer the number of variables from the first term ------------------ #
    try:
        next_c_idx = parts.index(next(p for p in parts[1:] if p.startswith("C")))
        n_vars = next_c_idx - 1
    except StopIteration:  # single-term polynomial
        n_vars = len([p for p in parts[1:] if p.startswith("E")])

    # --- Prepare SymPy symbols --------------------------------------------- #
    if var_names is None:
        vars_ = symbols(" ".join(f"x{i}" for i in range(n_vars)))
    else:
        if len(var_names) != n_vars:
            raise ValueError(
                f"Expected {n_vars} variable name(s), got {len(var_names)}."
            )
        # If elements are strings, create Symbols; if they are already Symbols, reuse them.
        if all(isinstance(v, str) for v in var_names):
            vars_ = symbols(" ".join(var_names))  # -> tuple(Symbol, …)
        elif all(isinstance(v, Symbol) for v in var_names):
            vars_ = tuple(var_names)
        else:
            raise TypeError("var_names must be all str or all sympy.Symbol.")

    expr = Integer(0)
    i = 0
    while i < len(parts):
        # Read coefficient
        coeff_token = parts[i]
        if not coeff_token.startswith("C"):
            raise ValueError(f"Expected 'C' token at position {i}, got {coeff_token}.")
        coeff = Integer(coeff_token[1:])
        i += 1

        # Read exponents
        exps = []
        for _ in range(n_vars):
            if i >= len(parts) or not parts[i].startswith("E"):
                raise ValueError(f"Missing 'E' token at position {i}.")
            exps.append(Integer(parts[i][1:]))
            i += 1

        # Build term
        term = coeff
        for v, e in zip(vars_, exps):
            term *= v**e
        expr += term

    return expr

In [2]:
from sympy import symbols, expand, init_printing

init_printing(use_latex="mathjax")

tokens = "C1 E1 E1 C-1 E2 E1 C-3 E1 E8 C-1 E0 E7"
x, y = symbols("x y")

# Case 1: pass SymPy symbols
poly1 = parse_poly(tokens, [x, y])
print(expand(poly1))
# x*y - x**2*y - 3*x*y**8 - y**7

# Case 2: pass variable names as strings
poly2 = parse_poly(tokens, ["x", "y"])
print(expand(poly2))
# x*y - x**2*y - 3*x*y**8 - y**7

-x**2*y - 3*x*y**8 + x*y - y**7
-x**2*y - 3*x*y**8 + x*y - y**7


In [3]:
poly2

   2          8          7
- x ⋅y - 3⋅x⋅y  + x⋅y - y 

In [4]:
from typing import Mapping
import re
from sympy import Expr, Symbol, latex
from IPython.display import display, Math


def display_poly_highlight(
    expr: Expr,
    highlight_syms: Mapping[Symbol, str] | None = None,
    exp_cmd: str | None = None,
):
    """
    Render a SymPy expression in Colab with optional coloring/formatting.

    Parameters
    ----------
    expr : sympy.Expr
        The polynomial (or any expression) to display.
    highlight_syms : Mapping[Symbol, str] | None
        Mapping from a Symbol to its LaTeX replacement string
        (e.g. {x: r"\\color{red}{x}"}).
    exp_cmd : str | None
        LaTeX command (without backslash) to wrap every **numeric exponent**.
        Examples: "color{blue}", "mathbf", "underline".
        Pass ``None`` to leave exponents unchanged.

    Notes
    -----
    - Requires MathJax (enabled by default in Google Colab).
    - ``highlight_syms`` uses the `symbol_names` feature of SymPy's
      ``latex()`` printer, so the expression is re-LaTeXed only once.
    - Exponent styling is done via a regex post-processing step because
      SymPy does not expose each exponent token separately.
    """
    # 1) custom names for selected symbols ---------------------------------- #
    sym_names = {s: rep for s, rep in (highlight_syms or {}).items()}

    # 2) base LaTeX string
    tex = latex(expr, symbol_names=sym_names)

    # 3) post-process numeric exponents if requested ------------------------ #
    if exp_cmd:
        # Match ^{<digits>} but ignore already formatted ones
        tex = re.sub(
            r"\^\{(\d+)\}",
            lambda m: f"^{{\\{exp_cmd}{{{m.group(1)}}}}}",
            tex,
        )

    display(Math(tex))

In [5]:
from sympy import symbols, Poly

# サンプル多項式
x, y = symbols("x y")
expr = x * y - x**2 * y - 3 * x * y**8 - y**7

# 変数 x を赤、指数を青太字に
display_poly_highlight(
    expr,
    highlight_syms={x: r"\color{red}{x}"},
    exp_cmd="color{blue}\\mathbf",
)

<IPython.core.display.Math object>

In [5]:
from sympy import symbols
from comparison_vis import display_with_diff

# --- sample tokens ---
gold_tokens = "C1 E1 E1 C-1 E2 E1 C-3 E1 E8 C-1 E0 E7"
pred_tokens = "C1 E1 E1 C-2 E2 E1 C-3 E1 E9 C-2 E0 E7"  # 係数・指数をわざと誤る

x, y = symbols("x y")
gold_expr = parse_poly(gold_tokens, [x, y])
pred_expr = parse_poly(pred_tokens, [x, y])

display_with_diff(gold_expr, pred_expr, [x, y])

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [9]:
import re
from sympy import symbols, sympify


def strip_mod(expr_str: str) -> str:
    """
    Remove all substrings like ' mod 7', ' mod 23', … from a string.

    Parameters
    ----------
    expr_str : str
        String representation of a SymPy polynomial such as
        '6 mod 7*x**2*y + 4 mod 7*x*y**8 + x*y + 6 mod 7*y**7'.

    Returns
    -------
    str
        Cleaned string, e.g.
        '6*x**2*y + 4*x*y**8 + x*y + 6*y**7'.
    """
    # pattern: optional spaces + 'mod' + optional spaces + digits
    pattern = r"\s*mod\s*\d+"
    cleaned = re.sub(pattern, "", expr_str)
    # collapse multiple spaces that may remain
    cleaned = re.sub(r"\s{2,}", " ", cleaned)
    return cleaned.strip()

In [10]:
from sympy.polys.domains import GF
from sympy.polys import PolynomialRing

# --- GF(7) 多項式を例として生成 ----------------------------------- #
x, y = symbols("x y")
R = PolynomialRing(GF(7), x, y)
f = R.from_dict({(2, 1): 6, (1, 8): 4, (1, 1): 1, (0, 7): 6})

# ① PolyElement → 文字列
s = str(f)  # '6 mod 7*x**2*y + 4 mod 7*x*y**8 + x*y + 6 mod 7*y**7'

# ② ' mod n' 部分を削除
clean_s = strip_mod(s)  # '6*x**2*y + 4*x*y**8 + x*y + 6*y**7'
print(clean_s)

# ③ 必要なら再度 SymPy 式へ
clean_expr = sympify(clean_s)
print(clean_expr.expand())
# 6*x**2*y + 4*x*y**8 + x*y + 6*y**7

ValueError: supported monomial orderings are 'lex', 'grlex' and 'grevlex', got 'y'