In [2]:
import torch
import numpy as np
import numbers



OPS = "+-*" # avoiding division by default so that Z_n ring works out of the box

def merge_two_items(list_of_expressions, rng, operators=OPS):
    p = rng.permutation(len(list_of_expressions))
    ploe = permuted_list_of_expressions = [list_of_expressions[i] for i in p]
    op = operators[rng.randint(len(operators))]
    a, b = ploe.pop(), ploe.pop()
    ploe.append((a, b, op))
    return ploe

def create_random_tree_from_list(list_of_items, rng, operators=OPS):
    while len(list_of_items) > 1:
        list_of_items = merge_two_items(list_of_items, rng, operators)
    return list_of_items[0]

def render_expression_from_tree(tree,
                                render_expr=lambda a, b, op: f"( {a} {op} {b} )",
                                render_leaf=lambda x: str(x)):

    if isinstance(tree, tuple) and len(tree) == 3:
        a, b, op = tree
        a = render_expression_from_tree(a, render_expr=render_expr, render_leaf=render_leaf)
        b = render_expression_from_tree(b, render_expr=render_expr, render_leaf=render_leaf)
        return render_expr(a, b, op)
    else:

        return render_leaf(tree)


class ExpressionTreeDataset(torch.utils.data.Dataset):

    def __init__(self, n_operands=2, n_samples=1000000, operators="+-*",
                 n_expressions_per_sample=1, random_state=0):
        self.n_operands = n_operands
        self.n_samples = n_samples
        self.operators = operators
        self.n_expressions_per_sample = n_expressions_per_sample
        self.random_state = random_state

        global_rng = np.random.RandomState(random_state)
        self.seeds = global_rng.randint(0, 2**32-1, self.n_samples)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        rng = np.random.RandomState(self.seeds[idx])
        items = []
        numbers = list(range(self.n_operands))
        for _ in range(self.n_expressions_per_sample):
            tree = create_random_tree_from_list(numbers, rng, self.operators)
            items.append(tree)
        return items


class ModuloArithmeticDataset(torch.utils.data.Dataset):

    def __init__(self, Zn=7,
                 n_operands=2, n_samples=1000000,
                 n_items_per_sample=1, operators="+-*",
                 separator=";", random_state=0,
                transform=None, output_length=None,
                return_pre_token_lists=True,
                leading_zeros=True,
                return_pre_tokens=False):

        self.Zn = Zn
        self.n_operands = n_operands
        self.n_samples = n_samples
        self.n_items_per_sample = n_items_per_sample
        self.operators = operators
        self.separator = separator
        self.transform = transform
        self.output_length = output_length
        self.return_pre_token_lists = return_pre_token_lists
        self.leading_zeros = leading_zeros
        self.return_pre_tokens = return_pre_tokens

        self.expression_tree_dataset = ExpressionTreeDataset(n_operands=n_operands,
                                                             n_samples=n_samples,
                                                             operators=operators,
                                                             n_expressions_per_sample=n_items_per_sample,
                                                             random_state=random_state)

    def __len__(self):
        return len(self.expression_tree_dataset)

    def __getitem__(self, idx):

        trees = self.expression_tree_dataset[idx]
        rng = np.random.RandomState(self.expression_tree_dataset.seeds[idx])

        numberss = rng.randint(0, self.Zn, (len(trees), self.n_operands))

        ndigits = int(np.ceil(np.log10(self.Zn)))
        leading = "0" if self.leading_zeros else ""
        format_string = f"{{:{leading}{ndigits}d}}"
        node_rendering = lambda a, b, op: f"({a}{op}{b})"
        node_rendering_for_token_list = lambda a, b, op: ['('] + a + [op] + b + [')']

        rendered_expressions = []
        rendered_expressions_for_token_list = []
        # rendered_for_evals = []
        evaled_expressions = []
        for tree, numbers in zip(trees, numberss):
            leaf_rendering = lambda i: format_string.format(numbers[i])
            leaf_rendering_for_token_list = lambda i: [numbers[i]]

            rendered_expression = render_expression_from_tree(tree,
                                    render_expr=node_rendering,
                                    render_leaf=leaf_rendering)
            rendered_expressions.append(rendered_expression)

            rendered_expression_for_token_list = render_expression_from_tree(tree,
                                    render_expr=node_rendering_for_token_list,
                                    render_leaf=leaf_rendering_for_token_list)
            rendered_expressions_for_token_list.append(rendered_expression_for_token_list)

            rendered_for_eval = render_expression_from_tree(tree,
                                                            render_expr=node_rendering,
                                                            render_leaf=lambda i:f"{numbers[i]:d}")
            # rendered_for_evals.append(rendered_for_eval)
            evaled_expressions.append(eval(rendered_for_eval) % self.Zn)

        eqn_string = f"{{}}={{:0{ndigits}d}}"
        equations = [eqn_string.format(r, e) for r, e in zip(rendered_expressions, evaled_expressions)]
        # equations = [f"{r}={e}" for r, e in zip(rendered_for_evals, evaled_expressions)]

        token_list_equations = [r + ["=", e]
                            for r, e in zip(rendered_expressions_for_token_list, evaled_expressions)]

        joined_string = self.separator.join(equations)

        if self.return_pre_token_lists:
            return joined_string, token_list_equations
        return joined_string



class FloatArithmeticDataset(torch.utils.data.Dataset):

    def __init__(self, distribution='standard_normal',
                  uniform_distribution_range=2,
                    n_decimals=None,
                 n_operands=2, n_samples=1000000,
                 n_items_per_sample=1, operators="+-*",
                 separator="", random_state=0,
                transform=None, output_length=None,
                return_pre_token_lists=True,
                protect_division_by_zero=True):

        self.distribution = distribution
        self.uniform_distribution_range = uniform_distribution_range
        self.n_decimals = n_decimals
        self.n_operands = n_operands
        self.n_samples = n_samples
        self.n_items_per_sample = n_items_per_sample
        self.operators = operators
        self.separator = separator
        self.transform = transform
        self.output_length = output_length
        self.return_pre_token_lists = return_pre_token_lists
        self.protect_division_by_zero = protect_division_by_zero

        self.random_state = random_state

        self.expression_tree_dataset = ExpressionTreeDataset(n_operands=n_operands,
                                                             n_samples=n_samples,
                                                             operators=operators,
                                                             n_expressions_per_sample=n_items_per_sample,
                                                             random_state=random_state)

    def __len__(self):
        return len(self.expression_tree_dataset)

    def __getitem__(self, idx):

        trees = self.expression_tree_dataset[idx]
        rng = np.random.RandomState(self.expression_tree_dataset.seeds[idx])
        if self.distribution == 'standard_normal':
            numberss = rng.randn(len(trees), self.n_operands)
        elif self.distribution == "uniform":
            numberss = rng.rand(len(trees), self.n_operands) * self.uniform_distribution_range - self.uniform_distribution_range/2
        else:
            raise NotImplementedError("Only implemented 'standard_normal' so far")

        ndigits = f"1.{self.n_decimals}" if self.n_decimals is not None else ""
        format_string = f"{{:{ndigits}f}}"
        node_rendering = lambda a, b, op: f"({a}{op}{b})"
        node_rendering_for_token_list = lambda a, b, op: ['('] + a + [op] + b + [')']

        rendered_expressions = []
        rendered_expressions_for_token_list = []
        evaled_expressions = []
        for tree, numbers in zip(trees, numberss):
            leaf_rendering = lambda i: format_string.format(numbers[i])
            leaf_rendering_for_token_list = lambda i: [numbers[i] if self.n_decimals is None
                                                            else np.round(numbers[i], self.n_decimals)]

            rendered_expression = render_expression_from_tree(tree,
                                    render_expr=node_rendering,
                                    render_leaf=leaf_rendering)
            rendered_expressions.append(rendered_expression)

            rendered_expression_for_token_list = render_expression_from_tree(tree,
                                    render_expr=node_rendering_for_token_list,
                                    render_leaf=leaf_rendering_for_token_list)
            rendered_expressions_for_token_list.append(rendered_expression_for_token_list)

            rendered_for_eval = render_expression_from_tree(tree,
                                                            render_expr=node_rendering,
                                                            #render_leaf=lambda i:f"{numbers[i]:f}")
                                                            render_leaf=lambda i: format_string.format(numbers[i]))
            try:
                evaled_expression = eval(rendered_for_eval)
            except ZeroDivisionError:
                if self.protect_division_by_zero:
                    import warnings
                    warnings.warn(f"Found division by 0 at sample {idx}, replacing with the sample at index {idx + 1}")
                    return self[idx + 1]
                else:
                    raise

            evaled_expressions.append(evaled_expression)

        eqn_string = f"{{}}={{:{ndigits}f}}"
        equations = [eqn_string.format(r, e) for r, e in zip(rendered_expressions, evaled_expressions)]

        token_list_equations = [r + ["=", e]
                            for r, e in zip(rendered_expressions_for_token_list, evaled_expressions)]


        if self.return_pre_token_lists:
            if self.transform is not None:
                return self.transform(token_list_equations)
            return token_list_equations

        joined_string = self.separator.join(equations)
        return joined_string


In [3]:
fad_train = FloatArithmeticDataset(n_operands=4, n_decimals=4, n_samples=5_000_000, distribution='uniform')
fad_val = FloatArithmeticDataset(n_operands=4, n_decimals=4, n_samples=300_000, distribution='uniform')
fad_test = FloatArithmeticDataset(n_operands=4, n_decimals=4, n_samples=300_000, distribution='uniform')

In [5]:
def save_dataset_to_file(dataset, filename):
    with open(filename, 'w') as f:
        for i in range(len(dataset)):
            sample = dataset[i][0]
            sample_str = ''.join(map(str, sample))  # Convert tensor to a space-separated string
            f.write(sample_str + '\n')

save_dataset_to_file(fad_train, './data/fad_4_op_uniform_10/train')
save_dataset_to_file(fad_val, './data/fad_4_op_uniform_10/val')
save_dataset_to_file(fad_test, './data/fad_4_op_uniform_10/test')

In [1]:
from datasets import DatasetDict
dataset_path = "./data/tokenized_fad_4/tokenized_ds_xval"
tokenized_ds = DatasetDict.load_from_disk(dataset_path)

In [2]:
tokenized_ds["train"][8]

{'input_ids': [18, 18, 3, 3, 20, 17, 18, 3, 17, 3, 20, 20, 19, 3],
 'numbers': [1.0,
  1.0,
  -0.737,
  -3.146,
  1.0,
  1.0,
  1.0,
  1.223,
  1.0,
  -0.405,
  1.0,
  1.0,
  1.0,
  -3.064]}

In [6]:
text_ds = DatasetDict.from_text("./data/fad/train")

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [7]:
text_ds[0]

{'text': '(0.63-((-0.158+0.426)--0.492))=-0.13'}

In [None]:
from xval.make_tokenizer import make_tokenizer

make_tokenizer(
    encoding="xval",
    save_file=tokenizer_path, 
    efficient_json=True, 
    sample_keys=sample_keys
)