# **calt Demo Notebook**

This notebook shows a minimal end‑to‑end workflow for the **calt** library:

1. **Install and import** the library  
2. **Generate** a dataset of *polynomial‑sum* examples  
3. **Configure** the tokenizer and model  
4. **Train** the Transformer  
5. **Visualize** training result  


Note on Google Colab:
- Change the runtime type to GPU (e.g., T4 GPU) from the Runtime tab -> Change runtime type -> GPU
- The `Sympy` backend to simplify the installation dependencies. For extensive usage, we recommend using the `SageMath` backend, which for example allows parallel sample generations.     

## 1  – Installation & Imports  
Run the next cell to ensure **calt** and its dependencies are installed, then import the required Python packages.  


In [None]:
%%capture
%pip install calt-x

In [1]:
import random
from sympy.polys.orderings import grevlex
from sympy.polys.rings import PolyElement
from transformers import BartConfig, BartForConditionalGeneration as Transformer
from transformers import TrainingArguments
from calt import (
    Trainer,
    load_data,
)
from calt.dataset_generator.sympy import (
    PolynomialSampler,
    DatasetGenerator,
)
from calt.data_loader.utils import (
    load_eval_results,
    parse_poly,
    display_with_diff
)
import torch, random, numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x150d85fdc1b0>

## 2  – Dataset Generation *(Polynomial Addition)*  
This cell uses `calt.generator` utilities to create a synthetic dataset of polynomial‑addition.

In [3]:
def sum_problem_generator(
    seed: int,
) -> tuple[list[PolyElement], list[PolyElement]]:
    """
    Generate a partial sum problem involving polynomials.

    This function creates problems in which the problem is a list of polynomials F = [f_1, f_2, ..., f_n],
    and the solution is a list of polynomials G = [g_1, g_2, ..., g_n], where g_i = f_1 + f_2 + ... + f_i.

    Args:
        seed: Seed for random number generator

    Returns:
        Tuple containing (F, G) where F is the problem and G is the solution
    """
    # Set random seed
    random.seed(seed)

    # Initialize polynomial sampler
    sampler = PolynomialSampler(
        symbols="x0, x1", # "x, y, z, ... " or "x0, x1, x2, ... "
        field_str="GF(7)", # "QQ", "RR", "ZZ", "GF(p)", "GFp", where p is a prime number
        order="grevlex", # "lex", "grevlex", "grlex", "ilex", "igrevlex", "igrlex"
        max_num_terms=2,
        max_degree=2,
        min_degree=1,
    )

    # Generate problem polynomials using sampler
    F = sampler.sample(num_samples=2)

    # Generate solution polynomial g (sum of F)
    g = sum(F)

    return F, g

In [4]:
save_dir = "."

# Initialize dataset generator
dataset_generator = DatasetGenerator(
    backend="multiprocessing",
    n_jobs=-1,  
    verbose=False,
    root_seed=100,
)
# Generate training set with batch processing
dataset_generator.run(
    dataset_sizes={"train": 10000, "test": 1000},
    problem_generator=sum_problem_generator,
    save_dir=save_dir,
)

save_dir: .
Text output: True
JSON output: True

Starting dataset generation for 2 dataset(s)
Dataset sizes: {'train': 10000, 'test': 1000}

---------------------------------- train ----------------------------------
Dataset size: 10000 samples  (Batch size: 100000)

Overall statistics saved for train dataset
Total time: 0.71 seconds


---------------------------------- test ----------------------------------
Dataset size: 1000 samples  (Batch size: 100000)

Overall statistics saved for test dataset
Total time: 0.48 seconds


All datasets generated successfully!



## 3  – Model Configuration  
Here we instantiate the tokenizer, define the Transformer architecture, and prepare the training pipeline.  


In [5]:
# Point to any dataset you like; here we assume the toy Sum dataset from the data‑generation notebook.
TRAIN_PATH = "train_raw.txt"
TEST_PATH = "test_raw.txt"
dataset, tokenizer, data_collator = load_data(
    train_dataset_path=TRAIN_PATH,
    test_dataset_path=TEST_PATH,
    field="GF7",
    num_variables=2,
    max_degree=10,  # Should cover the range of generated samples
    max_coeff=7,   # Should cover the range of generated samples
    max_length=256,
)
train_dataset = dataset["train"]
test_dataset = dataset["test"]

Loaded 10000 samples from train_raw.txt
Loaded 1000 samples from test_raw.txt


In [6]:
# Minimal architecture.
model_cfg = BartConfig(
    d_model=256,       # 'width' of the model
    vocab_size=len(tokenizer.vocab),
    encoder_layers=2,  # 'depth' of encoder network
    decoder_layers=2,  # 'depth' of decoder network
    max_position_embeddings=256,  # max length of input/output
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    max_length=256,  # max length of input/output
)
model = Transformer(config=model_cfg)

## 4  – Training Hyper‑parameters  
Learning‑rate schedule, batch size, number of epochs, and other trainer options are declared in this cell.  


In [7]:
args = TrainingArguments(
    output_dir="results/",
    num_train_epochs=20,
    logging_steps=50,
    per_device_train_batch_size=int(128),
    per_device_eval_batch_size=int(128),
    save_strategy="no",  # skip checkpoints for the quick demo
    seed=SEED,
    remove_unused_columns=False,
    label_names=["labels"],
    report_to="none",
)

## 5  – Model Training  
Launch the training loop. Progress is typically logged to the console (and optionally to Weights & Biases).  


In [8]:
trainer = Trainer(
    args=args,
    model=model,
    processing_class=tokenizer,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

# train
results = trainer.train()
trainer.save_model()
metrics = results.metrics

# eval
eval_metrics = trainer.evaluate()
metrics.update(eval_metrics)
success_rate = trainer.evaluate_and_save_generation(max_length=128)
metrics["success_rate"] = success_rate

# save metrics
trainer.save_metrics("all", metrics)

print(f'success rate on test set: {100*metrics["success_rate"]:.1f} %')



Step,Training Loss
50,2.118
100,1.2691
150,1.0691
200,0.9844




success rate on test set: 1.2 %


## 6  – Visualizing Training Results  
Finally, we visualize the differences between the mispredicted samples and their correct counterparts.


In [14]:
gen_texts, ref_texts = load_eval_results("results/eval_results.json")

success_cases = [(i, gen, ref) for i, (gen, ref) in enumerate(zip(gen_texts, ref_texts)) if gen == ref]
failure_cases = [(i, gen, ref) for i, (gen, ref) in enumerate(zip(gen_texts, ref_texts)) if gen != ref]

num_show = 5

print('-------------------------')
print(''' Success cases ''')
print('-------------------------')
for (i, gen, ref) in success_cases[:num_show]:
    gen_expr = test_dataset.preprocessor.to_original(gen)
    ref_expr = test_dataset.preprocessor.to_original(ref)

    print(f"===== sample id: {i+1} =====")
    display_with_diff(ref_expr, gen_expr)



print('\n-------------------------')
print(''' Failure cases ''')
print('-------------------------')
for (i, gen, ref) in failure_cases[:num_show]:
    gen_expr = test_dataset.preprocessor.to_original(gen)
    ref_expr = test_dataset.preprocessor.to_original(ref)

    print(f"===== sample id: {i+1} =====")
    display_with_diff(ref_expr, gen_expr)


-------------------------
 Success cases 
-------------------------
===== sample id: 148 =====


<IPython.core.display.Math object>

===== sample id: 172 =====


<IPython.core.display.Math object>

===== sample id: 217 =====


<IPython.core.display.Math object>

===== sample id: 242 =====


<IPython.core.display.Math object>

===== sample id: 272 =====


<IPython.core.display.Math object>


-------------------------
 Failure cases 
-------------------------
===== sample id: 1 =====


<IPython.core.display.Math object>

===== sample id: 2 =====


<IPython.core.display.Math object>

===== sample id: 3 =====


<IPython.core.display.Math object>

===== sample id: 4 =====


<IPython.core.display.Math object>

===== sample id: 5 =====


<IPython.core.display.Math object>