In [None]:
!pip install calt-x

In [1]:
from typing import Any, List, Tuple, Dict, Union
import random
# Dataset Generation 
from sympy import GF, QQ, RR, ZZ
from sympy.polys.rings import ring, PolyRing, PolyElement
from transformers import BartConfig, BartForConditionalGeneration as Transformer
from transformers import TrainingArguments
from calt import (
    PolynomialSampler,
    DatasetGenerator,
    DatasetWriter,
    BaseStatisticsCalculator,
    PolynomialTrainer,
    data_loader,
)
from calt.data_loader.utils import (
    load_eval_results,
    parse_poly,
    display_with_diff
)
import torch, random, numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

  from .autonotebook import tqdm as notebook_tqdm
  """Run *greedy* or *beam search* generation on the evaluation set.
  """Create dataset, tokenizer and data-collator objects.


<torch._C.Generator at 0x14f5ac0b7350>

In [3]:
class SumProblemGenerator:
    """
    Problem generator for partial sum problems involving polynomials.

    This generator generates problems in which the input is a list of polynomials F = [f_1, f_2, ..., f_n],
    and the output is a list of polynomials G = [g_1, g_2, ..., g_n], where g_i = f_1 + f_2 + ... + f_i.
    """

    def __init__(
        self, sampler: PolynomialSampler, max_polynomials: int, min_polynomials: int
    ):
        """
        Initialize polynomial partial sum sampler.

        Args:
            sampler: Polynomial sampler
            max_polynomials: Maximum number of polynomials in F
            min_polynomials: Minimum number of polynomials in F
        """

        self.sampler = sampler
        self.max_polynomials = max_polynomials
        self.min_polynomials = min_polynomials

    def __call__(self, seed: int) -> Tuple[List[PolyElement], List[PolyElement]]:
        """
        Generate a single sample.

        Each sample consists of:
        - Input polynomial system F
        - Output polynomial system G (partial sums of F)

        Args:
            seed: Seed for random number generator

        Returns:
            Tuple containing (F, G)
        """

        # Set random seed
        random.seed(seed)

        # Choose number of polynomials for this sample
        num_polys = random.randint(self.min_polynomials, self.max_polynomials)

        # Generate input polynomials using sampler
        F = self.sampler.sample(num_samples=num_polys)

        # Generate sums for output
        current_sum = 0
        for f in F:
            current_sum += f

        return F, current_sum
    



In [4]:
class PolyStatisticsCalculator(BaseStatisticsCalculator):
    """
    Statistics calculator for polynomial problems.
    """

    def __init__(self, ring: PolyRing):
        """
        Initialize polynomial statistics calculator.

        Args:
            ring: Polynomial ring
        """
        self.ring = ring
        self.num_vars = ring.ngens
        self.coeff_field = ring.domain

    def __call__(
        self,
        problem_input: Union[List[PolyElement], PolyElement],
        problem_output: Union[List[PolyElement], PolyElement],
    ) -> Dict[str, Any]:
        """
        Calculate statistics for a single generated sample.

        Args:
            problem_input: Input problem (a list of polynomials or a single polynomial)
            problem_output: Output solution (a list of polynomials or a single polynomial)

        Returns:
            Dictionary containing statistics about the sample
        """

        if isinstance(problem_input, list):
            input_stats = self.poly_system_stats(problem_input)
        else:
            input_stats = self.poly_system_stats([problem_input])
        if isinstance(problem_output, list):
            output_stats = self.poly_system_stats(problem_output)
        else:
            output_stats = self.poly_system_stats([problem_output])

        return {
            "input": input_stats,
            "output": output_stats,
        }

    def poly_system_stats(self, polys: List[PolyElement]) -> Dict[str, Any]:
        """
        Calculate statistics for a list of polynomials.

        Args:
            polys: List of polynomials

        Returns:
            Dictionary containing statistical information about the polynomials
        """
        num_polys = len(polys)

        if num_polys == 0:
            return {"num_polynomials": 0, "total_degree": 0, "total_terms": 0}

        degrees = [self.total_degree(p) for p in polys]
        num_terms = [len(p.terms()) for p in polys]

        coeffs = []
        for p in polys:
            if self.coeff_field == QQ:
                # For QQ, consider both numerators(分子) and denominators(分母)
                coeffs.extend([abs(float(c.numerator)) for c in p.coeffs()])
                coeffs.extend([abs(float(c.denominator)) for c in p.coeffs()])
            elif self.coeff_field == RR:
                # For RR, take absolute values
                coeffs.extend([abs(float(c)) for c in p.coeffs()])
            elif self.coeff_field == ZZ:
                # For ZZ, take absolute values
                coeffs.extend([abs(int(c)) for c in p.coeffs()])
            elif self.coeff_field.is_FiniteField:  # GF
                # For finite fields, just take the values
                coeffs.extend([int(c) for c in p.coeffs()])

        stats = {
            # System size statistics
            "num_polynomials": num_polys,
            "total_degree": sum(degrees),
            "total_terms": sum(num_terms),
            # Degree statistics
            "max_degree": max(degrees),
            "min_degree": min(degrees),
            # Term count statistics
            "max_terms": max(num_terms),
            "min_terms": min(num_terms),
            # Coefficient statistics
            "max_coeff": max(coeffs) if coeffs else 0,
            "min_coeff": min(coeffs) if coeffs else 0,
            # Additional system properties
            "density": float(sum(num_terms))
            / (num_polys * (1 + max(degrees)) ** self.num_vars),
        }

        return stats

    def total_degree(self, poly: PolyElement) -> int:
        """Compute total degree of a polynomial"""
        if poly.is_zero:
            return 0
        else:
            return max(list(sum(monom) for monom in poly.monoms()))


In [5]:
save_dir = "."

# set up polynomial ring
R, *gens = ring("x0,x1", GF(7), order="grevlex")
# Initialize polynomial sampler
sampler = PolynomialSampler(
    ring=R,
    max_num_terms=2,
    max_degree=2,
    min_degree=1,
    degree_sampling="uniform",  # "uniform" or "fixed"
    term_sampling="uniform",  # "uniform" or "fixed"
    max_coeff=None,  # Used for RR and ZZ
    num_bound=None,  # Used for QQ
    strictly_conditioned=False,
    nonzero_instance=True,
)
# Initialize problem generator
problem_generator = SumProblemGenerator(
    sampler=sampler,
    max_polynomials=2,
    min_polynomials=2,
)
# Initialize statistics calculator
statistics_calculator = PolyStatisticsCalculator(ring=R)
# Initialize dataset generator
dataset_generator = DatasetGenerator(
    ring=R,
    backend="multiprocessing",
    n_jobs=1,  # warning 
    verbose=True,
    root_seed=100,
)
# Generate training set
train_samples, train_stats = dataset_generator.run(
    train=True,
    num_samples=10000,
    problem_generator=problem_generator,
    statistics_calculator=statistics_calculator,
)
# Generate test set
test_samples, test_stats = dataset_generator.run(
    train=False,
    num_samples=1000,
    problem_generator=problem_generator,
    statistics_calculator=statistics_calculator,
)
# Initialize writer
dataset_writer = DatasetWriter(save_dir)
# Save datasets
dataset_writer.save_dataset(train_samples, train_stats, "train")
dataset_writer.save_dataset(test_samples, test_stats, "test")

[Parallel(n_jobs=1)]: Done  49 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 199 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 449 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 799 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 1249 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 1799 tasks      | elapsed:    0.2s
[Parallel(n_jobs=1)]: Done 2449 tasks      | elapsed:    0.4s
[Parallel(n_jobs=1)]: Done 3199 tasks      | elapsed:    0.4s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 4999 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 6049 tasks      | elapsed:    0.6s
[Parallel(n_jobs=1)]: Done 7199 tasks      | elapsed:    0.7s
[Parallel(n_jobs=1)]: Done 8449 tasks      | elapsed:    0.7s
[Parallel(n_jobs=1)]: Done 9799 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 10000 out of 10000 | elapsed:    0.8s finished
[Parallel(n_jobs=1)]: Done  49 tasks      | elapsed:    0.0s
[

In [6]:
# Point to any dataset you like; here we assume the toy GCD dataset from the data‑generation notebook.
TRAIN_PATH = "train_raw.txt"
TEST_PATH = "test_raw.txt"
dataset, tokenizer, data_collator = data_loader(
    train_dataset_path=TRAIN_PATH,
    test_dataset_path=TEST_PATH,
    field="GF7",
    num_variables=2,
    max_degree=10,
    max_coeff=10,
    max_length=256,
)

In [33]:
# Minimal architecture — only overriding d_model for speed.
model_cfg = BartConfig(
    d_model=256,
    vocab_size=len(tokenizer.vocab),
    encoder_layers=2,
    decoder_layers=2,
    max_position_embeddings=256,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    max_length=256,
)
model = Transformer(config=model_cfg)

In [26]:
args = TrainingArguments(
    output_dir="results/demo",
    num_train_epochs=20,
    logging_steps=50,
    per_device_train_batch_size=int(128),
    per_device_eval_batch_size=int(128),
    save_strategy="no",  # skip checkpoints for the quick demo
    seed=SEED,
    remove_unused_columns=False,
    label_names=["labels"],
    report_to="none",
)

In [27]:
trainer = PolynomialTrainer(
    args=args,
    model=model,
    processing_class=tokenizer,
    data_collator=data_collator,
    train_dataset=dataset["train"],  # slice for speed
    eval_dataset=dataset["test"],
)

# train
results = trainer.train()
trainer.save_model()
metrics = results.metrics

# eval
eval_metrics = trainer.evaluate()
metrics.update(eval_metrics)
acc = trainer.generate_evaluation(max_length=128)
metrics["test_accuracy"] = acc

# save metrics
trainer.save_metrics("all", metrics)



Step,Training Loss
50,0.958
100,0.655
150,0.5385
200,0.4526
250,0.3844
300,0.3367
350,0.2964
400,0.248
450,0.1798
500,0.1267




In [4]:
gen_texts, ref_texts = load_eval_results("results/demo/eval_results.json")

for i, (gen, ref) in enumerate(zip(gen_texts, ref_texts)):
    gen_expr = parse_poly(gen, ["x", "y"])
    ref_expr = parse_poly(ref, ["x", "y"])
    if gen_expr != ref_expr:
        print(f"===== Failure case: {i+1} =====")
        display_with_diff(ref_expr, gen_expr)

===== Failure case: 160 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 178 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 181 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 195 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 205 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 235 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 295 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 299 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 358 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 530 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 580 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 614 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 639 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 694 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 733 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 772 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 884 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>

===== Failure case: 939 =====


<IPython.core.display.Math object>

<IPython.core.display.Math object>