In [1]:
'''
█ █▀▄▀█ █▀█ █▀█ █▀█ ▀█▀ █▀
█ █░▀░█ █▀▀ █▄█ █▀▄ ░█░ ▄█
'''

import numpy as np
import torch

In [None]:
'''
█▀▄▀█ ▄▀█ █ █▄░█
█░▀░█ █▀█ █ █░▀█
'''

# Create an Expression Tree Improver Network (ETIN)
etin = ETIN()
# Train the ETIN
etin.train()
# ETIN ready to be used. Example:
# etin.improve(example_expression_tree)

In [13]:
class ETIN_model(torch.nn.Module):
    '''
    This class is a wrapper for the Expression Tree Improver Network (ETIN). It is a neural network that takes as input a 
    an expression tree, dataset and, the predicted error and, it outputs a new expression tree with a lower general prediction error on the dataset.
    The structure of the network is as follows:
        - Input: Expression Tree, Dataset, Predicted Error
        - 2 Summarizers: 2 Transformer networks dedicated to create an encoded representation of the dataset and the prediction errors.
        - 1 Transformer Encoder: Transformer network that takes as input the embedding of each token of the expression tree.
        - 1 Linear Layer: Linear layer that takes as input the concatenation of the encoded dataset, encoded prediction error and encoded_embedding.
        - Output: Expression Tree with a lower general prediction error on the dataset.
    '''

    def __init__(self, encoded_size, emb_size, language_size, max_seq_size):
        super(ETIN_model, self).__init__()
        self.dataset_summarizer = Summarizer()
        self.prediction_error_summarizer = Summarizer()
        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=max_seq_size, nhead=8)
        self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=4)
        self.output_layer = torch.nn.Linear(in_features=2*encoded_size + emb_size, out_features=language_size)

    
    def forward(self, expression_tree, dataset, prediction_error):
        '''
        This function is the forward pass of the ETIN. It takes as input an expression tree, dataset and, the predicted error and, it outputs a new expression tree with a lower general prediction error on the dataset.
        '''
        # Create encoded dataset  (B x 100 x 2) -> (B x E)
        encoded_dataset = self.dataset_summarizer(dataset)   # Posibilidad de meter un diccionario para no tener que recalcular.
        # Create encoded prediction error  (B x 100 x 2) -> (B x E)
        encoded_prediction_error = self.prediction_error_summarizer(prediction_error)

        # Create Embedding from each token of the Expression Tree (B x L x 3M) -> (B x L x G)
        embedding = self.create_embedding(expression_tree)
        # Pass the embedding through the Transformer Encoder (B x L x G) -> (B x L x G)
        embedding = self.transformer_encoder(embedding)

        # Concatenate encoded dataset, encoded prediction error and encoded_embedding (B x L x 2E+G) -> (B x L x 2E+G)
        concatenated_input = torch.cat((encoded_dataset, encoded_prediction_error, embedding), dim=2)

        # Pass the concatenated input through the Linear final layer (B x L x 2E+G) -> (B x L x M)
        output = self.output_layer(concatenated_input)

        # Return the output
        return output

In [65]:
a = torch.Tensor([1, 0])
torch.where(a == 1, 0, 1)

tensor([0, 1])

In [86]:
from sympy import *
from scipy.optimize import minimize
import time

class Token():
    def __init__(self, function, arity, symbol, unprotected_function, inv=None):
        self.function = function
        self.arity = arity
        self.symbol = symbol
        self.inv = inv
        self.unprotected_function = unprotected_function




class Restrictions():

    permitted = {
        'no_inverse_parent': lambda parent: Restrictions.no_inverse_parent(parent),
        'no_double_exp': lambda parent: Restrictions.no_double_exp(parent),
        'no_sqrt_in_log': lambda parent: Restrictions.no_sqrt_in_log(parent),
        'no_double_division': lambda parent: Restrictions.no_double_division(parent),
        'DistributiveRestriction': lambda not_allowed, parent, n_samples, needed: Restrictions.DistributiveRestriction(not_allowed, parent, n_samples, needed),
        'no_const_after_unary': lambda parent: Restrictions.no_const_after_unary(parent),
        'no_sqrt_in_sqrt': lambda parent: Restrictions.no_sqrt_in_sqrt(parent),
        'no_trigo_in_log': lambda parent: Restrictions.no_trigo_in_log(parent),
        'no_trigo_in_exp': lambda parent: Restrictions.no_trigo_in_exp(parent),
        'no_trigo_offspring': lambda trigo_offspring: Restrictions.no_trigo_offspring(trigo_offspring),
        'no_tan_in_exp': lambda parent: Restrictions.no_tan_in_exp(parent),
    }
    restrictions=None

    def __init__(self, restrictions):
        for restriction in restrictions:
            if restriction not in Restrictions.permitted.keys():
                raise Exception("Restriction", restriction, "not known")
        
        Restrictions.restrictions = restrictions

    
    def no_inverse_parent(parent):
        try: return [Language.symbol_to_idx[parent.inv]]
        except: return []


    def no_double_exp(parent):
        if parent.symbol in ['exp', '^', '^2']:
            'Return symbol to idx only of the operators in Language'
            return [Language.symbol_to_idx[x] for x in ['exp', '^', '^2'] if x in Language.symbol_to_idx.keys()]
        return []

    
    def no_sqrt_in_log(parent):
        if parent.symbol == 'log' and 'sqrt' in Language.symbol_to_idx.keys():
            return [Language.symbol_to_idx['sqrt']]
        return []

    
    def no_double_division(parent):
        if parent.symbol == '/':
            return [Language.symbol_to_idx['/']]
        return []

    def DistributiveRestriction(not_allowed, parent, n_variables, needed):
        not_allowed = list(not_allowed)
        if needed + parent.arity > n_variables + Language.use_constants - len(not_allowed):
            return not_allowed + [Language.token_to_idx[parent]]
        return not_allowed

    def no_const_after_unary(parent):
        if parent.arity == 1 and Language.use_constants:
            return [0]
        return []

    def no_sqrt_in_sqrt(parent):
        if parent.symbol == 'sqrt' and 'sqrt' in Language.symbol_to_idx.keys():
            return [Language.symbol_to_idx['sqrt']]
        return []

    def no_trigo_in_log(parent):
        if parent.symbol == 'log':
            return [Language.symbol_to_idx[x] for x in ['sin', 'cos', 'tan'] if x in Language.symbol_to_idx.keys()]
        return []

    def no_trigo_offspring(trigo_offspring):
        if trigo_offspring:
            return [Language.symbol_to_idx[x] for x in ['sin', 'cos', 'tan'] if x in Language.symbol_to_idx.keys()]
        return []

    def no_trigo_in_exp(parent):
        if parent.symbol == 'exp':
            return [Language.symbol_to_idx[x] for x in ['sin', 'cos', 'tan'] if x in Language.symbol_to_idx.keys()]
        return []

    def no_tan_in_exp(parent):
        if parent.symbol == 'exp':
            return [Language.symbol_to_idx['tan']]
        return []



def _protected_division(x1, x2):
    """Closure of division (x1/x2) for zero denominator."""
    return np.where(np.abs(x2) > 1e-5, np.divide(x1, x2), np.ones(1))

def _protected_sqrt(x1):
    """Closure of square root for negative arguments."""
    return np.sqrt(np.abs(x1))

def _protected_log(x1):
    """Closure of log for zero and negative arguments."""
    return np.where(np.abs(x1) > 1e-5, np.log(np.abs(x1)), np.zeros(1))

def _protected_exp(x1):
    """Closure of exp for overflow"""
    return np.where(x1 <= 30, np.exp(x1), np.ones(1)*10686474581524.463)

    



class Language(Restrictions):
    function_set_symbols = None
    function_set_idx = None
    function_set_tokens = None
    max_variables = 5
    use_constants = None
    n_functions = None
    size = None

    # Dicts so we can change fastly between token, idx and symbol -> idx 0 is reserved for constants and 1 to 10 for variables
    symbol_to_idx = {'+': 6, '-': 7, '*': 8, '/': 9, '^2': 10, '^3': 11, '^4': 12,
                     'sin': 13, 'cos': 14, 'exp': 15, 'log': 16, 'sqrt': 17, 
                     'abs': 18, 'max': 19, 'min': 20, 'inv': 21, 'neg': 22,}
    idx_to_symbol = None
    symbol_to_token = {
        '+': Token(function=lambda a, b: a + b, arity=2, symbol='+', inv='-', unprotected_function=lambda a, b: a + b),
        '-': Token(function=lambda a, b: a - b, arity=2, symbol='-', inv='+', unprotected_function=lambda a, b: a - b),
        '*': Token(function=lambda a, b: a * b, arity=2, symbol='*', inv='/', unprotected_function=lambda a, b: a * b),
        '/': Token(function=lambda a, b: _protected_division(a, b), arity=2, symbol='/', inv='*', unprotected_function= lambda a, b: a / b),
        '^2': Token(function=lambda a: np.power(a, 2), arity=1, symbol='^2', inv='sqrt', unprotected_function=lambda a: a ** 2),
        '^3': Token(function=lambda a: np.power(a, 3), arity=1, symbol='^3', inv='cbrt', unprotected_function=lambda a: a ** 3),
        '^4': Token(function=lambda a: np.power(a, 4), arity=1, symbol='^4', inv='sqrt', unprotected_function=lambda a: a ** 4),
        'sin': Token(function=lambda a: np.sin(a), arity=1, symbol='sin', unprotected_function=lambda a: sin(a)),
        'cos': Token(function=lambda a: np.cos(a), arity=1, symbol='cos', unprotected_function=lambda a: cos(a)),
        'exp': Token(function=lambda a: _protected_exp(a), arity=1, symbol='exp', inv='log', unprotected_function=lambda a: exp(a)),
        'log': Token(function=lambda a: _protected_log(a), arity=1, symbol='log', inv='exp', unprotected_function=lambda a: log(Abs(a))),
        'sqrt': Token(function=lambda a: _protected_sqrt(a), arity=1, symbol='sqrt', inv='^2', unprotected_function=lambda a: sqrt(Abs(a))),
        'abs': Token(function=lambda a: np.abs(a), arity=1, symbol='abs', unprotected_function=lambda a: Abs(a)),
        'max': Token(function=lambda a, b: np.maximum(a, b), arity=2, symbol='max', unprotected_function=lambda a, b: Max(a, b)),
        'min': Token(function=lambda a, b: np.minimum(a, b), arity=2, symbol='min', unprotected_function=lambda a, b: Min(a, b)),
        'inv': Token(function=lambda a: np.reciprocal(a), arity=1, symbol='inv', unprotected_function=lambda a: 1 / a),
        'neg': Token(function=lambda a: -a, arity=1, symbol='neg', unprotected_function=lambda a: -a),
    }
    token_to_idx = None
    idx_to_token = None

    type_1_functions, type_2_functions, type_3_functions, type_4_functions = 0, 0, 0, 0

    def __init__(self, function_set_symbols, restrictions, use_constants, max_variables):
        super(Language, self).__init__(restrictions)
        # Assert that all th function symbols are valid
        for function_symbol in function_set_symbols:
            if function_symbol not in self.symbol_to_token:
                raise ValueError('Unknown function symbol: ' + function_symbol)
            if function_symbol in ['+', '-', '*', '/']: Language.type_1_functions += 1
            elif function_symbol in ['^2', '^3', '^4', 'exp', 'log', 'sqrt', 'inv', 'neg']: Language.type_2_functions += 1
            elif function_symbol in ['sin', 'cos']: Language.type_3_functions += 1
            elif function_symbol in ['abs', 'max', 'min']: Language.type_4_functions += 1

        # Complete dictionaries
        Language.idx_to_symbol = {v: k for k, v in Language.symbol_to_idx.items()}
        Language.token_to_idx = {v: Language.symbol_to_idx[k] for k, v in Language.symbol_to_token.items()}
        Language.idx_to_token = {v: k for k, v in Language.token_to_idx.items()}

        # Functions used in the language
        Language.function_set_symbols = function_set_symbols   # List of functions to be used
        Language.function_set_idx = [self.symbol_to_idx[x] for x in function_set_symbols]  # List of idx to be used
        Language.function_set = [Language.symbol_to_token[function] for function in function_set_symbols]  # List of tokens to be used
        Language.use_constants = use_constants   # Boolean to indicate if constants should be used

        # Maximum number of variables
        Language.n_functions = len(Language.symbol_to_token)
        Language.size = Language.max_variables + len(Language.idx_to_symbol) + 1


    def get_possibilities(info, n_variables, to_take='all'):
        parent = info['function']
        distributive = info['distributive']
        needed = info['needed']
        trigo_offspring = info['trigo_offspring']
        arity_one = info['arity_one']
        '''
        This function returns the list of possible tokens_idx that can be added to the expression tree.
        '''
        not_allowed = []
        for restriction in Restrictions.restrictions:
            if restriction == 'DistributiveRestriction':
                not_allowed += Restrictions.permitted[restriction](distributive, parent, n_variables, needed)
            elif restriction in 'no_trigo_offspring':
                not_allowed += Restrictions.permitted[restriction](trigo_offspring)
            else:
                not_allowed += Restrictions.permitted[restriction](parent)

        for token in Language.function_set:
            if token.arity == 1:
                not_allowed.append(Language.token_to_idx[token])

        terminals = [x for x in range(1 - Language.use_constants, n_variables + 1)]
        if to_take == 'all':
            return list(set(Language.function_set_idx + terminals) - set(not_allowed))
        elif to_take == 'terminal':
            return list(set(terminals) - set(not_allowed))




class Expression():
    '''
    This class represents an Expression Tree. It is composed by a list of tokens/traversal (ex: [0, 2, 1]) and its one-hot-encoding representation in torch ([[1,0,0],[0,0,1],[0,1,0]]).
    '''
    def __init__(self, traversal=None, seed=None, max_nodes=15, n_variables=None):
        self.max_nodes = max_nodes
        self.n_variables = n_variables if n_variables is not None else Language.max_variables
        if traversal is None:
            self.P = [0 for _ in range(Language.size)]
            for i in range(n_variables + 1):
                self.P[i] = 0.4/(n_variables + 1) # Probability of choosing a terminal
            for i in range(Language.max_variables + 1, Language.size):
                symbol = Language.idx_to_symbol[i]
                if symbol in Language.function_set_symbols:
                    if symbol in ['+', '-', '*', '/']: self.P[i] = 0.3/Language.type_1_functions
                    elif symbol in ['^2', '^3', '^4', 'exp', 'log', 'sqrt', 'inv', 'neg']: self.P[i] = 0.22/Language.type_2_functions
                    elif symbol in ['sin', 'cos']: self.P[i] = 0.06/Language.type_3_functions
                    elif symbol in ['abs', 'max', 'min']: self.P[i] = 0.02/Language.type_4_functions
                else:
                    self.P[i] = 0
            traversal = self.generate_random_expression(seed=seed)
        self.traversal = traversal
        self.one_hot_encoding = torch.nn.functional.one_hot(torch.Tensor(traversal).to(torch.int64), num_classes=Language.size)
        if Language.use_constants:
            self.constants = np.abs(np.random.normal(0, 3, size=(self.traversal.count(0),)))


    def to_sympy(self):
        '''
        This function returns the sympy expression of the expression tree.
        '''
        x1, x2, x3, x4, x5, x6, x7, x8, x9, x10 = symbols('x1, x2, x3, x4, x5, x6, x7, x8, x9, x10')
        x = x1, x2, x3, x4, x5, x6, x7, x8, x9, x10
        stack = []
        const = 0
        for idx in self.traversal[::-1]:
            if idx > Language.max_variables:
                function = Language.idx_to_token[idx]
                first_operand = stack.pop()
                if function.arity == 1:
                    stack.append(function.unprotected_function(first_operand))
                else:
                    second_operand = stack.pop()
                    stack.append(function.unprotected_function(first_operand, second_operand))
                
            elif idx == 0:
                stack.append(self.constants[const])
                const += 1
            else:
                stack.append(x[idx - 1])

        return stack[0].simplify()

    def evaluate(self, X, constants=None):
        '''
        This function evaluates the expression tree at a given point x.
        '''
        if constants is None:
            constants = self.constants

        stack = []
        const = 0
        for idx in self.traversal[::-1]:
            if idx > Language.max_variables:
                function = Language.idx_to_token[idx]
                first_operand = stack.pop()
                if function.arity == 1:
                    stack.append(function.function(first_operand))
                else:
                    second_operand = stack.pop()
                    stack.append(function.function(first_operand, second_operand))
                
            elif idx == 0:
                stack.append(constants[const]*np.ones(X.shape[0]))
                const += 1
            else:
                stack.append(X[:, idx - 1])
        
        return stack[0]

    def loss(self, constants, X, y):
        return ((self.evaluate(X, constants) - y)**2).mean()
    

    def add_node(self, will_be_nodes, arities_stack, function_stack, program, max_nodes, First, P, predefined_choice=None):
        if predefined_choice is not None:
            choice = predefined_choice
        else:
            if First:   # If it is the root
                possibilities = Language.function_set_idx # Take a function to avoid degenerated expressions
            else:
                # First let's check what we can add given the context (parent and its not_allowed operands/operators)
                possibilities = Language.get_possibilities(info=function_stack[-1], n_variables=self.n_variables)
            # Get their correspondent probabilities and normalize them
            prob_possibilites = list(map(P.__getitem__, possibilities))
            prob_possibilites = [p / sum(prob_possibilites) for p in prob_possibilites]
            # Choose a token idx
            choice = np.random.choice(possibilities, p=prob_possibilites)

        # Determine if we are adding a function or terminal -> Add Function if we got a function and there will be enough nodes to add it and add its children
        if  choice > Language.max_variables and will_be_nodes + Language.idx_to_token[choice].arity <= max_nodes:  
            program.append(choice)  # Append to program
            function = Language.idx_to_token[choice]  # Function to add
            arities_stack.append(function.arity)   # Append to arities
            will_be_nodes += function.arity # Update the number of nodes that there will be
            new_set = set()   # If we are adding a function of the same type, we need to inherit its not_allowed variables if DistributiveRestriction is used
            needed = function.arity - 1
            trigo_offspring = False if First else function_stack[-1]['trigo_offspring']
            arity_one = True
            if not First and function.arity == 1 and function_stack[-1]['function'].arity == 1:
                arity_one = False
            if not First and function_stack[-1]['function'].symbol == Language.idx_to_symbol[choice]:
                new_set = function_stack[-1]['distributive']
                needed += function_stack[-1]['needed']
            if function.symbol in ['sin', 'cos', 'tan', 'tanh']:
                trigo_offspring = True
            function_stack.append({'function': function, 'trigo_offspring': trigo_offspring, 'distributive': new_set, 'needed': needed, 'arity_one': arity_one})  # Append to function stack
        else:
            # We need a terminal, add a variable or constant
            if choice > Language.max_variables:  # If we got here because of the lack of space, we need to choose a terminal again
                # Get the possible terminals and choose one in the same way that it was done before
                possibilities = Language.get_possibilities(info=function_stack[-1], n_variables=self.n_variables, to_take='terminal')
                prob_possibilites = list(map(P.__getitem__, possibilities)) if not isinstance(P, list) else list(map(P.__getitem__, possibilities))
                prob_possibilites = [p / sum(prob_possibilites) for p in prob_possibilites] if prob_possibilites is not None else None
                choice = np.random.choice(possibilities, p=prob_possibilites)

            program.append(choice)   # Append to program
            function_stack[-1]['distributive'].add(choice)  # Add to the not_allowed set
            arities_stack[-1] -= 1  # One node added to the parent, so we need to remove one from the arity
            while arities_stack[-1] == 0:  # If completed arity of node 
                arities_stack.pop()        # Remove it from the arity stack
                if not arities_stack:      # If there are no more arities, we are done
                    break
                child_info = function_stack.pop()  # Remove it from the function stack
                child_symbol = child_info['function'].symbol
                child_set = child_info['distributive']  # Pass info of non-allowed variables to parent if DistributiveRestriction is used
                arities_stack[-1] -= 1
                if child_symbol == function_stack[-1]['function'].symbol:
                    function_stack[-1]['distributive'].update(child_set)
            
        return will_be_nodes, arities_stack, function_stack, program

    def generate_random_expression(self, seed=None):
        # Set the seed
        np.random.seed(seed)
        First = True
        max_nodes = np.random.randint(4, self.max_nodes)

        will_be_nodes = 1
        arities_stack, function_stack, program = [], [], []

        while First or arities_stack:   # While there are operands/operators to add
            will_be_nodes, arities_stack, function_stack, program = self.add_node(will_be_nodes, arities_stack, function_stack, program, max_nodes, First, self.P)
            First = False
        return program


    def mutate(self, p_random_tree=0.3):
        '''
        # This function mutates the expression tree.
        '''
        new_traversal = self.traversal
        max_nodes = np.random.randint(4, self.max_nodes)
        new_random_tree = np.random.random() < p_random_tree

        First = True
        will_be_nodes = 1
        arities_stack, function_stack, program = [], [], []

        prob_mutation = min(2/len(new_traversal), 0.5)
        cont = 0

        while First or arities_stack:   # While there are operands/operators to add
            
            # Choose whether to mutate or not
            predefined_choice = None if new_random_tree or np.random.rand() < prob_mutation or cont >= len(new_traversal) else new_traversal[cont]

            will_be_nodes, arities_stack, function_stack, program = self.add_node(will_be_nodes, arities_stack, function_stack, 
                                                                                    program, max_nodes, First, self.P, predefined_choice)
            First = False
            cont += 1

        return Expression(program, self.n_variables)

    
    def optimize_constants(self, X, y):
        '''
        # This function optimizes the constants of the expression tree.
        '''
        if 0 not in self.traversal:
            return self.traversal
        
        x0 = np.abs(np.random.normal(0, 1, len(self.constants)))
        res = minimize(self.loss, x0, args=(X, y), method='BFGS', options={'disp': False})

        




In [87]:
import numpy as np
import torch
from torch.utils.data import DataLoader

class Dataset():

    def normalize(tensor, lb=None, up=None):
        if lb is None:
            lb = tensor.min()
            up = tensor.max()
            return np.divide(np.subtract(tensor, lb), np.subtract(up, lb)), lb, up
        return np.divide(np.subtract(tensor, lb), np.subtract(up, lb))

    def __init__(self, n_functions, seed=None):

        np.random.seed(seed)
        self.n_functions = n_functions
        seeds = np.random.randint(0, 1e9, n_functions)
        self.dataset = []

        for sub_seed in seeds:
            row = {}
            row['n_obs'] = np.random.randint(10, 100)
            row['n_vars'] = np.random.randint(1, Language.max_variables)
            row['Expression'] = Expression(seed=sub_seed, n_variables=row['n_vars'])
            row['X_lower_bound'] = np.random.uniform(0.05, 6, size=row['n_vars'])
            row['X_upper_bound'] = [np.random.uniform(row['X_lower_bound'][i] + 1, 10) for i in range(row['n_vars'])]
            row['X'] = np.concatenate([np.random.uniform(row['X_lower_bound'][i], row['X_upper_bound'][i], (row['n_obs'], 1)) for i in range(row['n_vars'])], axis=1)
            row['X_norm'] = Dataset.normalize(row['X'], lb=row['X_lower_bound'], up=row['X_upper_bound'])
            row['y'] = row['Expression'].evaluate(row['X'])
            row['y_norm'], row['y_lower_bound'], row['y_upper_bound'] = Dataset.normalize(row['y'])
            self.dataset.append(row)

    def __len__(self):
        return len(self.dataset)
    

    def __getitem__(self, index):
        return self.dataset[index]




class ETIN():
    def __init__(self, functions=['+', '*'], restrictions=None, use_constants=True, max_variables=4):
        # Create Language of the the Expressions with the given restrictions
        Language(functions, restrictions, use_constants, max_variables)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Create an Expression Tree Improver Network (ETIN)
        # self.etin_model = ETIN_model(encoded_size=10, emb_size=3, language_size=..., max_seq_size=20)


    def train(self, n_functions=1, epochs=10, seed=None):

        for epoch in range(epochs):
            dataset = Dataset(n_functions, seed=seed)
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=self.preprocessing)

            for batch, data in enumerate(dataloader):
                pass
                

            return dataset

    
    def preprocessing(self, data):
        input_dataset, input_expression, input_prediction, bounds = [], [], [], []
        for row in data:
            # Input Dataset
            padding = torch.nn.ConstantPad1d((0, Language.max_variables - row['n_vars']), 0)
            input_dataset.append(padding(torch.cat([torch.Tensor(row['y_norm']).unsqueeze(1), torch.Tensor(row['X_norm'])], dim=1)))

            # Input prediction
            new_expression = row['Expression'].mutate()
            # Optimize constant values in the expression
            new_expression.optimize_constants(row['X'], row['y'])
            y = new_expression.evaluate(row['X'])
            y_norm = Dataset.normalize(y, lb=row['y_lower_bound'], up=row['y_upper_bound'])
            input_prediction.append(padding(torch.cat([torch.Tensor(y_norm).unsqueeze(1), torch.Tensor(row['X_norm'])], dim=1)))

            # # Input Expression
            # max_length = row['Expression'].max_nodes
            # input_expression.append(torch.cat([new_expression.one_hot_encoding, torch.zeros(max_length, Language.size)], dim=0))

            # bounds.append((row['X_lower_bound'], row['X_upper_bound'], row['y_lower_bound'], row['y_upper_bound']))

        return (input_dataset, input_expression, input_prediction, bounds)



etin = ETIN(functions=['+', '-', '*', '/', '^2', '^3', '^4', 'log', 'exp', 'sqrt', 'sin', 'cos'], 
            restrictions=['no_inverse_parent', 'no_double_exp', 'no_sqrt_in_log', 'no_double_division', 'DistributiveRestriction', 
                          'no_const_after_unary', 'no_sqrt_in_sqrt', 'no_trigo_in_log', 'no_trigo_offspring'], 
            use_constants=True, 
            max_variables=4)
dataset = etin.train(n_functions=100, seed=1)