# Polynomial GCD Dataset — Minimal Example

This notebook shows how **easy** it is to plug a custom *problem generator* into the
`transformer_algebra` data pipeline.  
Instead of the built‑in `SumProblemGenerator`, we define our own `GCDProblemGenerator`
directly in the notebook, import the rest of the library, and instantly obtain a toy dataset.

## 1. Imports

In [None]:
!pip install calt-x

In [1]:
from typing import Any, List, Tuple, Dict, Union
import random
from sympy import GF, QQ, RR, ZZ
from sympy.polys.rings import ring, PolyRing, PolyElement
from calt import (
    PolynomialSampler,
    DatasetGenerator,
    DatasetWriter,
    BaseStatisticsCalculator,
)

  """Run *greedy* or *beam search* generation on the evaluation set.
  from .autonotebook import tqdm as notebook_tqdm


## 2. Define a Polynomial Ring

In [19]:
ring = PolynomialRing(GF(7), 2, "x", order="degrevlex")  # GF(7) with variables x0, x1
ring

In [20]:
# GF(7) with 2 variables x0, x1
ring = PolynomialRing(GF(7), 2, "x", order="degrevlex")
ring

Multivariate Polynomial Ring in x0, x1 over Finite Field of size 7

## 3. Build a Polynomial Sampler

In [21]:
sampler = PolynomialSampler(
    ring=ring,
    max_num_terms=6,
    max_degree=4,
    min_degree=1,
    degree_sampling="uniform",
    term_sampling="uniform",
    nonzero_instance=True,
)

## 4. Write a **custom** `GCDProblemGenerator`

In [2]:
class PartialSumProblemGenerator:
    """
    Problem generator for partial sum problems involving polynomials.

    This generator generates problems in which the input is a list of polynomials F = [f_1, f_2, ..., f_n],
    and the output is a list of polynomials G = [g_1, g_2, ..., g_n], where g_i = f_1 + f_2 + ... + f_i.
    """

    def __init__(
        self, sampler: PolynomialSampler, max_polynomials: int, min_polynomials: int
    ):
        """
        Initialize polynomial partial sum sampler.

        Args:
            sampler: Polynomial sampler
            max_polynomials: Maximum number of polynomials in F
            min_polynomials: Minimum number of polynomials in F
        """

        self.sampler = sampler
        self.max_polynomials = max_polynomials
        self.min_polynomials = min_polynomials

    def __call__(self, seed: int) -> Tuple[List[PolyElement], List[PolyElement]]:
        """
        Generate a single sample.

        Each sample consists of:
        - Input polynomial system F
        - Output polynomial system G (partial sums of F)

        Args:
            seed: Seed for random number generator

        Returns:
            Tuple containing (F, G)
        """

        # Set random seed
        random.seed(seed)

        # Choose number of polynomials for this sample
        num_polys = random.randint(self.min_polynomials, self.max_polynomials)

        # Generate input polynomials using sampler
        F = self.sampler.sample(num_samples=num_polys)

        # Generate partial sums for output
        G = []
        current_sum = 0
        for f in F:
            current_sum += f
            G.append(current_sum)

        return F, G

In [3]:
class PolyStatisticsCalculator(BaseStatisticsCalculator):
    """
    Statistics calculator for polynomial problems.
    """

    def __init__(self, ring: PolyRing):
        """
        Initialize polynomial statistics calculator.

        Args:
            ring: Polynomial ring
        """
        self.ring = ring
        self.num_vars = ring.ngens
        self.coeff_field = ring.domain

    def __call__(
        self,
        problem_input: Union[List[PolyElement], PolyElement],
        problem_output: Union[List[PolyElement], PolyElement],
    ) -> Dict[str, Any]:
        """
        Calculate statistics for a single generated sample.

        Args:
            problem_input: Input problem (a list of polynomials or a single polynomial)
            problem_output: Output solution (a list of polynomials or a single polynomial)

        Returns:
            Dictionary containing statistics about the sample
        """

        if isinstance(problem_input, list):
            input_stats = self.poly_system_stats(problem_input)
        else:
            input_stats = self.poly_system_stats([problem_input])
        if isinstance(problem_output, list):
            output_stats = self.poly_system_stats(problem_output)
        else:
            output_stats = self.poly_system_stats([problem_output])

        return {
            "input": input_stats,
            "output": output_stats,
        }

    def poly_system_stats(self, polys: List[PolyElement]) -> Dict[str, Any]:
        """
        Calculate statistics for a list of polynomials.

        Args:
            polys: List of polynomials

        Returns:
            Dictionary containing statistical information about the polynomials
        """
        num_polys = len(polys)

        if num_polys == 0:
            return {"num_polynomials": 0, "total_degree": 0, "total_terms": 0}

        degrees = [self.total_degree(p) for p in polys]
        num_terms = [len(p.terms()) for p in polys]

        coeffs = []
        for p in polys:
            if self.coeff_field == QQ:
                # For QQ, consider both numerators(分子) and denominators(分母)
                coeffs.extend([abs(float(c.numerator)) for c in p.coeffs()])
                coeffs.extend([abs(float(c.denominator)) for c in p.coeffs()])
            elif self.coeff_field == RR:
                # For RR, take absolute values
                coeffs.extend([abs(float(c)) for c in p.coeffs()])
            elif self.coeff_field == ZZ:
                # For ZZ, take absolute values
                coeffs.extend([abs(int(c)) for c in p.coeffs()])
            elif self.coeff_field.is_FiniteField:  # GF
                # For finite fields, just take the values
                coeffs.extend([int(c) for c in p.coeffs()])

        stats = {
            # System size statistics
            "num_polynomials": num_polys,
            "total_degree": sum(degrees),
            "total_terms": sum(num_terms),
            # Degree statistics
            "max_degree": max(degrees),
            "min_degree": min(degrees),
            # Term count statistics
            "max_terms": max(num_terms),
            "min_terms": min(num_terms),
            # Coefficient statistics
            "max_coeff": max(coeffs) if coeffs else 0,
            "min_coeff": min(coeffs) if coeffs else 0,
            # Additional system properties
            "density": float(sum(num_terms))
            / (num_polys * (1 + max(degrees)) ** self.num_vars),
        }

        return stats

    def total_degree(self, poly: PolyElement) -> int:
        """Compute total degree of a polynomial"""
        if poly.is_zero:
            return 0
        else:
            return max(list(sum(monom) for monom in poly.monoms()))


**Key idea:** the generator is just a *callable* that returns `(inputs, target)`.  
If it follows that contract, `DatasetGenerator` can parallel‑generate samples automatically.

## 5. Create Data & Inspect a Sample

In [23]:
problem_generator = GCDProblemGenerator(sampler)
dataset_generator = DatasetGenerator(ring=ring, n_jobs=1, verbose=False, root_seed=2025)

# Single sample
F, g = problem_generator(seed=123)
print("Inputs F:", F)
print("Target g:", g)

Input F (polynomials): [-2*x0*x1^2 + 2*x0*x1, x0*x1^2 - x1^3 - x0*x1 + x1^2]
Output G (partial sums): x1^2 - x1


## 6. Generate a Tiny Dataset

In [24]:
samples, stats = dataset_generator.run(
    num_samples=20, train=True, problem_generator=problem_generator
)
stats

{'total_time': 0.009609460830688477,
 'samples_per_second': 2081.2822230492493,
 'num_samples': 20,
 'generation_time': {'mean': 0.00045791864395141604,
  'std': 0.00012987409280385738,
  'min': 0.00032973289489746094,
  'max': 0.0008177757263183594},
 'input_polynomials_overall': {'num_polynomials': {'mean': 2.0,
   'std': 0.0,
   'min': 2.0,
   'max': 2.0},
  'total_degree': {'mean': 7.2,
   'std': 2.4617067250182343,
   'min': 4.0,
   'max': 12.0},
  'total_terms': {'mean': 10.0,
   'std': 7.063993204979744,
   'min': 2.0,
   'max': 27.0},
  'max_degree': {'mean': 3.95,
   'std': 1.2835497652993437,
   'min': 2.0,
   'max': 6.0},
  'min_degree': {'mean': 3.25,
   'std': 1.2599603168354152,
   'min': 2.0,
   'max': 6.0},
  'max_terms': {'mean': 6.1,
   'std': 3.9736632972611057,
   'min': 1.0,
   'max': 14.0},
  'min_terms': {'mean': 3.9,
   'std': 3.3600595232822887,
   'min': 1.0,
   'max': 13.0},
  'max_coeff': {'mean': 5.55,
   'std': 0.5894913061275798,
   'min': 4.0,
   'max': 

In [25]:
# Show first three examples
for i, (F_i, g_i) in enumerate(samples[:3]):
    print(f"--- Sample {i} ---")
    print("F:", F_i)
    print("g:", g_i)

--- Sample 0 ---
F: [-x0*x1^2 + 2*x0*x1, -2*x0*x1^2]
G: x0*x1
--- Sample 1 ---
F: [-3*x0^3 + 3*x0^2 + 3*x0*x1 - 3*x1, 3*x0*x1 - 2*x0 - 3*x1 + 2]
G: x0 - 1
--- Sample 2 ---
F: [-3*x0*x1^3 + x0^2*x1 - 3*x0*x1^2 - x1^3 - x0*x1 - x1^2 - 2*x0 - 2*x1 - 3, -x1^5 - 2*x0*x1^3 - x1^4 - 2*x1^3 - 3*x1^2]
G: x1^3 + 2*x0*x1 + x1^2 + 2*x1 + 3


Change the ring, sampler hyper‑parameters, or swap in a different generator class,
and you immediately get a new task‑specific dataset — **no other code changes needed**.

In [5]:
save_dir = "dataset/partial_sum_problem/GF7_n=2"

# set up polynomial ring
R, *gens = ring("x0,x1", GF(7), order="grevlex")
# Initialize polynomial sampler
sampler = PolynomialSampler(
    ring=R,
    max_num_terms=5,
    max_degree=10,
    min_degree=1,
    degree_sampling="uniform",  # "uniform" or "fixed"
    term_sampling="uniform",  # "uniform" or "fixed"
    max_coeff=None,  # Used for RR and ZZ
    num_bound=None,  # Used for QQ
    strictly_conditioned=False,
    nonzero_instance=True,
)
# Initialize problem generator
problem_generator = PartialSumProblemGenerator(
    sampler=sampler,
    max_polynomials=5,
    min_polynomials=2,
)
# Initialize statistics calculator
statistics_calculator = PolyStatisticsCalculator(ring=R)
# Initialize dataset generator
dataset_generator = DatasetGenerator(
    ring=R,
    backend="multiprocessing",
    n_jobs=1,  # warning 
    verbose=True,
    root_seed=100,
)
# Generate training set
train_samples, train_stats = dataset_generator.run(
    train=True,
    num_samples=100000,
    problem_generator=problem_generator,
    statistics_calculator=statistics_calculator,
)
# Generate test set
test_samples, test_stats = dataset_generator.run(
    train=False,
    num_samples=1000,
    problem_generator=problem_generator,
    statistics_calculator=statistics_calculator,
)
# Initialize writer
dataset_writer = DatasetWriter(save_dir)
# Save datasets
dataset_writer.save_dataset(train_samples, train_stats, "train")
dataset_writer.save_dataset(test_samples, test_stats, "test")

[Parallel(n_jobs=1)]: Done  49 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 199 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 449 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 799 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 1249 tasks      | elapsed:    0.2s
[Parallel(n_jobs=1)]: Done 1799 tasks      | elapsed:    0.3s
[Parallel(n_jobs=1)]: Done 2449 tasks      | elapsed:    0.4s
[Parallel(n_jobs=1)]: Done 3199 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    0.7s
[Parallel(n_jobs=1)]: Done 4999 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 6049 tasks      | elapsed:    1.0s
[Parallel(n_jobs=1)]: Done 7199 tasks      | elapsed:    1.2s
[Parallel(n_jobs=1)]: Done 8449 tasks      | elapsed:    2.8s
[Parallel(n_jobs=1)]: Done 9799 tasks      | elapsed:    3.1s
[Parallel(n_jobs=1)]: Done 11249 tasks      | elapsed:    3.3s
[Parallel(n_jobs=1)]: Done 12799 tasks      | elapsed:    3.5s
[Parallel(