# Import Modules

In [1]:
import torch
import numpy
import pickle

  from .autonotebook import tqdm as notebook_tqdm


# Utility functions for generating dataset

In [2]:
import random

# Generate a random operator (+ or -)
def get_operator():
    return random.choice(['+'])

# Generate a random coefficient (1-9)
def get_coeff():
    return random.randint(1, 9)

# Generate a random order (0-3) x*x*x
def get_order():
    return random.randint(0, 3)

# Generate a formulation with 5 terms
def get_formulation(n_terms=5):
    formulation = []
    for i in range(n_terms):
        term = (get_operator(), get_coeff(), get_order())
        formulation.append(term)
    return formulation

def term2str(term):
    operator, coeff, order = term
    term_str = ''
    term_str += operator
    term_str += str(coeff)
    order_repr = random.randint(0, order)
    order_repr = (order_repr, order - order_repr)
    pow_first = random.randint(0, 1)
    if pow_first:
        term_str += '*x'
        term_str += '^'
        term_str += str(order_repr[0])
        term_str += '*x' * order_repr[1]
    else:
        term_str += '*x' * order_repr[1]
        term_str += '*x'
        term_str += '^'
        term_str += str(order_repr[0])
    return term_str


def formulation2str(formulation):
    formulation_str = ''
    for term in formulation:
        formulation_str += term2str(term)
    return formulation_str[1:]

def aggregate_formulation(formulation):
    aggregated_formulation = []
    # sort formulation by order
    formulation.sort(key=lambda x: x[2], reverse=True)
    # aggregate formulation
    for term in formulation:
        operator, coeff, order = term
        if aggregated_formulation:
            if aggregated_formulation[-1][2] == order:
                aggregated_formulation[-1] = (operator, aggregated_formulation[-1][1] + coeff, order)
            else:
                aggregated_formulation.append(term)
        else:
            aggregated_formulation.append(term)
    return aggregated_formulation

def formulation2abvstr(aggregated_formulation):
    formulation_str = ''
    for term in aggregated_formulation:
        operator, coeff, order = term
        if coeff == 0:
            continue
        formulation_str += operator
        formulation_str += str(coeff)
        if order > 0:
            formulation_str += '*x'
            formulation_str += '^'
            formulation_str += str(order)
    return formulation_str[1:]

### Debug whether the output of formulations before and after the aggregation is the same

In [24]:
# Code for debugging

formulations = []
for i in range(1000):
    formulation = get_formulation()
    formulation_str = formulation2str(formulation)
    aggregated_formulation = aggregate_formulation(formulation)
    aggregated_formulation_str = formulation2abvstr(aggregated_formulation)
    formulations.append((formulation_str, aggregated_formulation_str))
# for formulation in formulations:
    # print(formulation)

def eval_str(formulation_str, x):
    return eval(formulation_str.replace('x', str(x)).replace('^', '**'))

for formulation in formulations:
    formulation_str, aggregated_formulation_str = formulation
    for x in range(0, 10):
        assert eval_str(formulation_str, x) == eval_str(aggregated_formulation_str, x), f'{formulation_str, aggregated_formulation_str, x}'

# Create dataset

In [8]:
def create_samples(formulation_str):
    return [(x, eval_str(formulation_str, x) % 10) for x in range(1, 10)]

def create_database(n=1000000, n_terms=5):
    # database is a dictionary of formulation_abvstr -> samples
    # this function returns a list of (formulation_str, formulation_abvstr) and a dictionary of formulation_abvstr -> samples
    formulations = []
    database = {}
    for i in range(n):
        formulation = get_formulation(n_terms)
        formulation_str = formulation2str(formulation)
        aggregated_formulation = aggregate_formulation(formulation)
        aggregated_formulation_str = formulation2abvstr(aggregated_formulation)
        formulations.append((formulation_str, aggregated_formulation_str))
        if aggregated_formulation_str not in database:
            database[aggregated_formulation_str] = create_samples(formulation_str)
    return formulations, database

In [26]:
dataset = create_database()
# with open('data/dataset.pkl', 'wb') as f:
    # pickle.dump(dataset, f)

In [None]:
formulations, database = dataset

In [None]:
print(len(formulations), len(database))

1000 63209


# debug all the samples in the saved file

In [23]:
for formulation_str, formulation_abvstr in formulations:
    assert formulation_abvstr in database
    assert (eval_str(formulation_str, 1)%10) == database[formulation_abvstr][0][1]
    assert (eval_str(formulation_str, 2)%10) == database[formulation_abvstr][1][1]
    assert (eval_str(formulation_str, 3)%10) == database[formulation_abvstr][2][1]
    assert (eval_str(formulation_str, 4)%10) == database[formulation_abvstr][3][1]
    assert (eval_str(formulation_str, 5)%10) == database[formulation_abvstr][4][1]
    assert (eval_str(formulation_str, 6)%10) == database[formulation_abvstr][5][1]
    assert (eval_str(formulation_str, 7)%10) == database[formulation_abvstr][6][1]
    assert (eval_str(formulation_str, 8)%10) == database[formulation_abvstr][7][1]
    assert (eval_str(formulation_str, 9)%10) == database[formulation_abvstr][8][1]