# **calt Demo Notebook**

This notebook shows a minimal end‑to‑end workflow for the **calt** library:

1. **Install and import** the library  
2. **Generate** a dataset of *polynomial‑sum* examples  
3. **Configure** the tokenizer and model  
4. **Train** the Transformer  
5. **Visualize** training result  


## 1  – Installation & Imports  
Run the next cell to ensure **calt** and its dependencies are installed, then import the required Python packages.  


In [None]:
!pip install calt-x

In [3]:
from typing import List, Tuple
import random
from sympy import ZZ
from sympy.polys.rings import ring, PolyElement
from transformers import BartConfig, BartForConditionalGeneration as Transformer
from transformers import TrainingArguments
from calt import (
    PolynomialTrainer,
    data_loader,
)
from calt.generate import (
    PolynomialSampler,
    DatasetGenerator,
    DatasetWriter,
)
from calt.data_loader.utils import (
    load_eval_results,
    parse_poly,
    display_with_diff
)
import torch, random, numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x155418501050>

## 2  – Dataset Generation *(Polynomial Addition)*  
This cell uses `calt.generator` utilities to create a synthetic dataset of polynomial‑addition.

In [4]:
class SumProblemGenerator:
    """
    Problem generator for sum problems involving polynomials.

    This generator generates problems in which the input is a list of polynomials F = [f_1, f_2, ..., f_n],
    and the output is a polynomial g = f_1 + f_2 + ... + f_i.
    """

    def __init__(
        self, sampler: PolynomialSampler, max_polynomials: int, min_polynomials: int
    ):
        """
        Initialize polynomial sum sampler.

        Args:
            sampler: Polynomial sampler
            max_polynomials: Maximum number of polynomials in F
            min_polynomials: Minimum number of polynomials in F
        """

        self.sampler = sampler
        self.max_polynomials = max_polynomials
        self.min_polynomials = min_polynomials

    def __call__(self, seed: int) -> Tuple[List[PolyElement], List[PolyElement]]:
        """
        Generate a single sample.

        Each sample consists of:
        - Input polynomial system F
        - Output polynomial system g

        Args:
            seed: Seed for random number generator

        Returns:
            Tuple containing (F, g)
        """

        # Set random seed
        random.seed(seed)

        # Choose number of polynomials for this sample
        num_polys = random.randint(self.min_polynomials, self.max_polynomials)

        # Generate input polynomials using sampler
        F = self.sampler.sample(num_samples=num_polys)

        # Generate sums for output
        current_sum = 0
        for f in F:
            current_sum += f

        return F, current_sum


In [15]:
save_dir = "."

# set up polynomial ring
R, *gens = ring("x0,x1", ZZ, order="grevlex")
# Initialize polynomial sampler
sampler = PolynomialSampler(
    ring=R,
    max_num_terms=2,
    max_degree=2,
    min_degree=1,
    degree_sampling="uniform",  # "uniform" or "fixed"
    term_sampling="uniform",  # "uniform" or "fixed"
    max_coeff=10,  # Used for RR and ZZ
    num_bound=None,  # Used for QQ
    strictly_conditioned=False,
    nonzero_instance=True,
)
# Initialize problem generator
problem_generator = SumProblemGenerator(
    sampler=sampler,
    max_polynomials=2,
    min_polynomials=2,
)
# Initialize dataset generator
dataset_generator = DatasetGenerator(
    backend="multiprocessing",
    n_jobs=1,  # warning: if you set a value other than 1, the current version may throw an error.
    verbose=True,
    root_seed=100,
)
# Generate training set
train_samples, _ = dataset_generator.run(
    train=True,
    num_samples=10000,
    problem_generator=problem_generator,
)
# Generate test set
test_samples, _ = dataset_generator.run(
    train=False,
    num_samples=1000,
    problem_generator=problem_generator,
)
# Initialize writer
dataset_writer = DatasetWriter(save_dir)
# Save datasets
dataset_writer.save_dataset(train_samples, tag="train")
dataset_writer.save_dataset(test_samples, tag="test")

[Parallel(n_jobs=1)]: Done  49 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 199 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 449 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 799 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 1249 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 1799 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 2449 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 3199 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 4999 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 6049 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 7199 tasks      | elapsed:    0.6s
[Parallel(n_jobs=1)]: Done 8449 tasks      | elapsed:    0.6s
[Parallel(n_jobs=1)]: Done 9799 tasks      | elapsed:    0.7s
[Parallel(n_jobs=1)]: Done 10000 out of 10000 | elapsed:    0.7s finished
[Parallel(n_jobs=1)]: Done  49 tasks      | elapsed:    0.0s
[

## 3  – Model Configuration  
Here we instantiate the tokenizer, define the Transformer architecture, and prepare the training pipeline.  


In [16]:
# Point to any dataset you like; here we assume the toy Sum dataset from the data‑generation notebook.
TRAIN_PATH = "train_raw.txt"
TEST_PATH = "test_raw.txt"
dataset, tokenizer, data_collator = data_loader(
    train_dataset_path=TRAIN_PATH,
    test_dataset_path=TEST_PATH,
    field="ZZ",
    num_variables=2,
    max_degree=10,
    max_coeff=10,
    max_length=256,
)
train_dataset = dataset["train"]
test_dataset = dataset["test"]

In [17]:
# Minimal architecture.
model_cfg = BartConfig(
    d_model=256,
    vocab_size=len(tokenizer.vocab),
    encoder_layers=2,
    decoder_layers=2,
    max_position_embeddings=256,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    max_length=256,
)
model = Transformer(config=model_cfg)

## 4  – Training Hyper‑parameters  
Learning‑rate schedule, batch size, number of epochs, and other trainer options are declared in this cell.  


In [18]:
args = TrainingArguments(
    output_dir="results/",
    num_train_epochs=20,
    logging_steps=50,
    per_device_train_batch_size=int(128),
    per_device_eval_batch_size=int(128),
    save_strategy="no",  # skip checkpoints for the quick demo
    seed=SEED,
    remove_unused_columns=False,
    label_names=["labels"],
    report_to="none",
)

## 5  – Model Training  
Launch the training loop. Progress is typically logged to the console (and optionally to Weights & Biases).  


In [19]:
trainer = PolynomialTrainer(
    args=args,
    model=model,
    processing_class=tokenizer,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

# train
results = trainer.train()
trainer.save_model()
metrics = results.metrics

# eval
eval_metrics = trainer.evaluate()
metrics.update(eval_metrics)
success_rate = trainer.generate_evaluation(max_length=128)
metrics["success_rate"] = success_rate

# save metrics
trainer.save_metrics("all", metrics)



Step,Training Loss
50,2.0546
100,1.2525
150,1.1075
200,1.016




## 6  – Visualizing Training Results  
Finally, we visualize the differences between the mispredicted samples and their correct counterparts. 


In [20]:
gen_texts, ref_texts = load_eval_results("results/demo/eval_results.json")

for i, (gen, ref) in enumerate(zip(gen_texts, ref_texts)):
    gen_expr = parse_poly(gen, ["x", "y"])
    ref_expr = parse_poly(ref, ["x", "y"])
    if gen_expr != ref_expr:
        print(f"===== Failure case: {i+1} =====")
        display_with_diff(ref_expr, gen_expr)

===== Failure case: 160 =====


<IPython.core.display.Math object>

===== Failure case: 178 =====


<IPython.core.display.Math object>

===== Failure case: 181 =====


<IPython.core.display.Math object>

===== Failure case: 195 =====


<IPython.core.display.Math object>

===== Failure case: 205 =====


<IPython.core.display.Math object>

===== Failure case: 235 =====


<IPython.core.display.Math object>

===== Failure case: 295 =====


<IPython.core.display.Math object>

===== Failure case: 299 =====


<IPython.core.display.Math object>

===== Failure case: 358 =====


<IPython.core.display.Math object>

===== Failure case: 530 =====


<IPython.core.display.Math object>

===== Failure case: 580 =====


<IPython.core.display.Math object>

===== Failure case: 614 =====


<IPython.core.display.Math object>

===== Failure case: 639 =====


<IPython.core.display.Math object>

===== Failure case: 694 =====


<IPython.core.display.Math object>

===== Failure case: 733 =====


<IPython.core.display.Math object>

===== Failure case: 772 =====


<IPython.core.display.Math object>

===== Failure case: 884 =====


<IPython.core.display.Math object>

===== Failure case: 939 =====


<IPython.core.display.Math object>