In [1]:
cd ..

/home/kpgelvan/SymbolicMathematics


In [2]:
CPU_ONLY = True

In [3]:
import os
import numpy as np
import sympy as sp
import torch
from tqdm import tqdm


from src.utils import AttrDict
from src.envs import build_env
from src.model import build_modules

from src.utils import to_cuda
from src.envs.sympy_utils import simplify

In [4]:
OPERATORS = {
        # Elementary functions
        'add': 2,
        'sub': 2,
        'mul': 2,
        'div': 2,
        'pow': 2,
        'rac': 2,
        'inv': 1,
        'pow2': 1,
        'pow3': 1,
        'pow4': 1,
        'pow5': 1,
        'sqrt': 1,
        'exp': 1,
        'ln': 1,
        'abs': 1,
        'sign': 1,
        # Trigonometric Functions
        'sin': 1,
        'cos': 1,
        'tan': 1,
        'cot': 1,
        'sec': 1,
        'csc': 1,
        # Trigonometric Inverses
        'asin': 1,
        'acos': 1,
        'atan': 1,
        'acot': 1,
        'asec': 1,
        'acsc': 1,
        # Hyperbolic Functions
        'sinh': 1,
        'cosh': 1,
        'tanh': 1,
        'coth': 1,
        'sech': 1,
        'csch': 1,
        # Hyperbolic Inverses
        'asinh': 1,
        'acosh': 1,
        'atanh': 1,
        'acoth': 1,
        'asech': 1,
        'acsch': 1,
        # Derivative
        'derivative': 2,
        # custom functions
        'f': 1,
        'g': 2,
        'h': 3,
    }

In [5]:
symbols = ['I', 'INT+', 'INT-', 'INT', 'FLOAT', '-', '.', '10^']

constants = ['pi', 'E']
variables = ['x', 'y', 'z', 't', 'Y', "Y'", "Y''"]
functions = ['f', 'g', 'h']
elements = [str(i) for i in range(-10, 10)]
coefficients = [f'a{i}' for i in range(10)]

no_child_symbols = constants + variables + functions + elements + coefficients

## Build environment / Reload model

In [6]:
def get_params():
    
    params = AttrDict({

        # environment parameters
        'env_name': 'char_sp',
        'int_base': 10,
        'balanced': False,
        'positive': True,
        'precision': 10,
        'n_variables': 1,
        'n_coefficients': 0,
        'leaf_probs': '0.75,0,0.25,0',
        'max_len': 512,
        'max_int': 5,
        'max_ops': 15,
        'max_ops_G': 15,
        'clean_prefix_expr': True,
        'rewrite_functions': '',
        'tasks': 'prim_fwd',
        'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

        'use_pos_embeddings_E':True,
        'use_pos_embeddings_D':True,
        'max_relative_pos':0,
        'use_neg_dist':True,
        'use_encdec_seq_rel_att':False,
        'max_path_width':-1,
        'max_path_depth':-1,
        'use_tree_pos_enc_E':False,
        'use_tree_pos_enc_D':False,
        'use_tree_rel_att':"",
        'tree_rel_vocab_size':0,

        # model parameters
        'cpu': CPU_ONLY,
        'emb_dim': 256,
        'n_enc_layers': 4,
        'n_dec_layers': 4,
        'n_heads': 4,
        'dropout': 0,
        'attention_dropout': 0,
        'sinusoidal_embeddings': False,
        'share_inout_emb': True,
        'reload_model': '',

    })
    
    return params


rel_matrices_batch = None
rel_lens = None
root_path_batch_q, root_path_batch_a = None, None

In [7]:
def init_env_encdec(model_path, use_pos_embeddings_E=True, use_pos_embeddings_D=True,
                    max_relative_pos=0, use_neg_dist=True, use_encdec_seq_rel_att=False,
                    max_path_width=-1, max_path_depth=-1, 
                    use_tree_pos_enc_E=False, use_tree_pos_enc_D=False,
                    use_tree_rel_att="", tree_rel_vocab_size=0):
    
    params = get_params()
    
    assert os.path.isfile(model_path)
    
    params['reload_model']            = model_path
    params['use_pos_embeddings_E']    = use_pos_embeddings_E
    params['use_pos_embeddings_D']    = use_pos_embeddings_D
    params['max_relative_pos']        = max_relative_pos
    params['use_neg_dist']            = use_neg_dist
    params['use_encdec_seq_rel_att']  = use_encdec_seq_rel_att
    params['max_path_width']          = max_path_width
    params['max_path_depth']          = max_path_depth
    params['use_tree_pos_enc_E']      = use_tree_pos_enc_E
    params['use_tree_pos_enc_D']      = use_tree_pos_enc_D
    params['use_tree_rel_att']        = use_tree_rel_att
    params['tree_rel_vocab_size']     = tree_rel_vocab_size

    
    env = build_env(params)
    x = env.local_dict['x']

    modules = build_modules(env, params)
    encoder = modules['encoder']
    decoder = modules['decoder']
    encoder.eval()
    decoder.eval()
    
    return env, x, encoder, decoder

In [8]:
# pos emb
env, x, encoder, decoder = init_env_encdec(model_path='dumped/pos_emb_0202/239415/checkpoint.pth',
                                           use_pos_embeddings_E=True, use_pos_embeddings_D=True)

In [9]:
# seq rel att
env_2, x_2, encoder_2, decoder_2 = init_env_encdec(model_path='dumped/seq_rel_att_0202/238443/checkpoint.pth',
                                                   use_pos_embeddings_E=False, use_pos_embeddings_D=False,
                                                   max_relative_pos=250, use_neg_dist=True)

In [10]:
from tqdm import tqdm
import queue

def get_ancestors(exp_list, exp_len):
    
    ### НЕПРАВИЛЬНО, НО ДЛЯ TREE REL ATT ТАК

    symbols = ['I', 'INT+', 'INT-', 'INT', 'FLOAT', '-', '.', '10^', 'Y', "Y'", "Y''"]

    constants = ['pi', 'E']
    variables = ['x', 'y', 'z', 't']
    functions = ['f', 'g', 'h']
    elements = [str(i) for i in range(-10, 10)]
    coefficients = [f'a{i}' for i in range(10)]

    no_child_symbols = constants + variables + functions + elements + coefficients
    
    q = queue.LifoQueue()
    q.put(-1)                            # so last element gets this parent but doesn't save it

    ancestors = {0: []}
    node2parent = {}
    levels = {0: -1}

    parent = 0
    for i in range(exp_len):
        op_now = exp_list[i]

        node2parent[i] = parent
        try:
            levels[i] = levels[parent] + 1
        except:
            print('you are in except')
            return {}, {}


        if op_now in OPERATORS or op_now in symbols:   # <=> node has children
            if op_now in OPERATORS and OPERATORS[op_now] == 2:    # <=> node has 2 children
                q.put(i)
            parent = i
        elif op_now in no_child_symbols:
            if op_now.isdigit() and i + 1 < exp_len and exp_list[i + 1].isdigit():   # e.g. 18
                parent = i
            else:
                parent = q.get()
        else:
            print(op_now)
            #raise(NotFound)
            return False
        ancestors[i] = [i] + ancestors[node2parent[i]]

    return ancestors, levels

def get_path(i, j):
    if i == j:
        return "<self>"
    anc_i = set(ancestors[i])
      
    for node in ancestors[j][-(levels[i] + 1) :]:
        if node in anc_i:
            up_n = levels[i] - levels[node]
            down_n = levels[j] - levels[node]
            return str(up_n + 0.001 * down_n)
        
def get_ud_masks(ancestors, levels, exp_len):
    path_rels = []
    for i in range(exp_len):
        path_rels.append(" ".join([get_path(i, j) for j in range(exp_len)]))
    
    return path_rels

In [11]:
TR_VOCAB_SIZE = 2000


def build_relmat_dict(rel_vocab_path, tree_rel_vocab_size):
    with open(rel_vocab_path, 'r') as f:
        temp = f.read().splitlines()
        words = [word for word in temp]

    words = words[:tree_rel_vocab_size + 1]
    words.append('unk')
    rel_id2word = {(i + 2): s for i, s in enumerate(words)}
    rel_id2word[0] = 'EOS'
    rel_id2word[1] = 'PAD'
    rel_word2id = {s: i for i, s in rel_id2word.items()}
    
    return rel_word2id

def build_relmat_batch(rel_matrix, rel_word2id):
    
    rel_matrices = []
    for i, sample in enumerate(rel_matrix):
        rel_matrices.append(
            torch.stack(
                [torch.LongTensor([rel_word2id[w] if w in rel_word2id else rel_word2id['unk'] for w in row])
                 for row in sample]
            )
        )

    ##############
    lengths = torch.LongTensor([len(s) + 2 for s in rel_matrices])
    rel_matrices_batch = torch.LongTensor(lengths.max().item(), lengths.max().item(), lengths.size(0)).fill_(1)   # 1 is for PAD

    rel_matrices_batch[:, 0, :] = 0  # 0 is for EOS
    rel_matrices_batch[0, :, :] = 0  # 0 is for EOS

    for i, s in enumerate(rel_matrices):
        rel_matrices_batch[1:lengths[i]-1, 1:lengths[i]-1, i].copy_(s)
        rel_matrices_batch[lengths[i]-1, :, i] = 0   # 0 is for EOS
        rel_matrices_batch[:, lengths[i]-1, i] = 0   # 0 is for EOS

    # (SRC_LEN, SRC_LEN, BS)

    ##############

    return rel_matrices_batch, lengths

In [12]:
rel_word2id = build_relmat_dict(rel_vocab_path='data/rel_vocab.txt', tree_rel_vocab_size=TR_VOCAB_SIZE)

In [13]:
# tree rel att 2k (only in Encoder)
env_3, x_3, encoder_3, decoder_3 = init_env_encdec(model_path='dumped/tree_rel_att_0303/236678/checkpoint.pth',
                                                   use_pos_embeddings_E=False, use_pos_embeddings_D=True,
                                                   use_tree_rel_att='mult2', tree_rel_vocab_size=TR_VOCAB_SIZE
                                                  )

In [14]:
from tqdm import tqdm
import queue
from collections import OrderedDict

def get_root_paths(exp_list, exp_len):
    
    symbols = ['I', 'INT+', 'INT-', 'INT', 'FLOAT', '-', '.', '10^']

    constants = ['pi', 'E']
    variables = ['x', 'y', 'z', 't', 'Y', "Y'", "Y''"]
    functions = ['f', 'g', 'h']
    elements = [str(i) for i in range(-10, 10)]
    coefficients = [f'a{i}' for i in range(10)]

    no_child_symbols = constants + variables + functions + elements + coefficients
    
    q = queue.LifoQueue()
    q.put(-1)                            # so last element gets this parent but doesn't save it

    root_paths = OrderedDict([(i, '') for i in range(-1, len(exp_list) + 1)])   # init with empty lines to be able to add
    is_right, is_down = False, False

    parent = 0
    for i in range(exp_len):
        op_now = exp_list[i]
        
        if i != 0:
            if parent == -1:
                print('ohoh')
                return -1
            root_paths[i] += root_paths[parent]
            if is_right:
                last_step = '2'     # right
            elif is_down:
                last_step = '0'     # down
            else:
                last_step = '1'     # left

            root_paths[i] += last_step
            is_right, is_down = False, False
        

        if op_now in OPERATORS or op_now in symbols:   # <=> node has children
            if op_now in OPERATORS and OPERATORS[op_now] == 2:    # <=> node has 2 children
                q.put(i)
            else:
                is_down = True
            parent = i
        elif op_now in no_child_symbols:
            if op_now.isdigit() and i + 1 < exp_len and exp_list[i + 1].isdigit():   # e.g. 18
                parent = i
            else:
                parent = q.get()
                is_right = True
        else:
            print(op_now)
            #raise(NotFound)
            return -1

    return root_paths

In [15]:
from src.misc import generate_positions

def build_root_paths(rps_q, max_path_width=4, max_path_depth=16):
        
    tree_positions_list_q = [generate_positions(root_paths, max_path_width, max_path_depth)
                             for root_paths in rps_q]
    bs = len(tree_positions_list_q)
    max_wd = tree_positions_list_q[0].size(1)
    tree_positions_batch_q = torch.zeros(bs, len(rps_q[0]), max_wd, dtype=torch.float)    
    for i in range(len(tree_positions_list_q)):
        tree_positions_batch_q[i, :tree_positions_list_q[i].size(0), :].copy_(tree_positions_list_q[i])
        
    return tree_positions_batch_q

In [16]:
# tree pos enc ENC_ONLY
env_4, x_4, encoder_4, decoder_4 = init_env_encdec(model_path='dumped/tree_pos_enc_1204/261725/checkpoint.pth',
                                                   use_pos_embeddings_E=False, use_pos_embeddings_D=True,
                                                   use_tree_pos_enc_E=True, use_tree_pos_enc_D=False,
                                                   max_path_depth=4, max_path_width=16
                                                  )

In [17]:
# tree pos enc
#env_5, x_5, encoder_5, decoder_5 = init_env_encdec(model_path='dumped/tree_pos_enc_2004/268762/checkpoint.pth',
#                                                   use_pos_embeddings_E=False, use_pos_embeddings_D=False,
#                                                   use_tree_pos_enc_E=True, use_tree_pos_enc_D=True,
#                                                   max_path_depth=4, max_path_width=16
#                                                  )

## Start from a function F, compute its derivative f = F', and try to recover F from f

In [161]:
# here you can modify the integral function the model has to predict, F
############################F_infix = 'x + sin(x)*x/20 - 8*x**6'
############################F_infix = 'x * tan(exp(x)/x)'
############################F_infix  = 'sin(x)*cos(x)'

######## Длинные числа
############################F_infix = '12345*x**2 + 6789*log(x)'
############################F_infix = '368123/x**4 - 345125*x**(123)'
F_infix = 'sin(cos(x)) - cos(tan(x))'


#F_infix = 'x * cos(x**2) * tan(x)'

#F_infix = 'cos(x**2 * exp(x * cos(x)))'

#F_infix = 'ln(cos(x + exp(x)) * sin(x**2 + 2) * exp(x) / x)'

#F_infix = 'x**2*log(x)/2 + x**2/4' # '-x - log(exp(x*log(x)))' 
#F_infix = 'x*log(x)-x' #'log(x)'

################ IBP
#F_infix = '-x*cos(2*x)/2 + sin(2*x)/4'

#F_infix = '-x^2/2 + log(cos(x)) + x tan(x)'

### EASY
#F_infix = 'log(x)*x**2'


In [162]:
# F (integral, that the model will try to predict)
F = sp.S(F_infix, locals=env.local_dict)
F

sin(cos(x)) - cos(tan(x))

In [163]:
# f (F', that the model will take as input)
f = F.diff(x)
f

(tan(x)**2 + 1)*sin(tan(x)) - sin(x)*cos(cos(x))

In [164]:
#сложная дичь
#a = sp.S("asin(x)*ln(x)", locals=env.local_dict)
#F = sp.S('-2*sqrt(1-x**2) + log(x)*(sqrt(1-x**2)-1) + log(sqrt(1-x**2)+1) + x*(log(x) - 1)*asin(x)',

#???
#a = sp.S('x*tan(x)**2', locals=env.local_dict)
#F = sp.S('-x**2/2 + log(cos(x)) + x*tan(x)', locals=env.local_dict) #


# пример с картинки
#a = sp.S('x*(sin(x) + x**2)', locals=env.local_dict)
#F = a.integrate(x)

#f = a

#print(F)
#print(f)



### Compute prefix representations

In [165]:
F_prefix = env.sympy_to_prefix(F)
f_prefix = env.sympy_to_prefix(f)
print(f"F prefix: {F_prefix}")
print(f"f prefix: {f_prefix}")

F prefix: ['add', 'mul', 'INT-', '1', 'cos', 'tan', 'x', 'sin', 'cos', 'x']
f prefix: ['add', 'mul', 'add', 'INT+', '1', 'pow', 'tan', 'x', 'INT+', '2', 'sin', 'tan', 'x', 'mul', 'INT-', '1', 'mul', 'cos', 'cos', 'x', 'sin', 'x']


### my shit

In [166]:
'''f_prefix = ['mul', 'div', 'INT+', '5', 'INT+', '2', 'pow', 'x', 'div', 'INT+', '3', 'INT+', '2']
F_prefix = ['pow', 'x', 'div', 'INT+', '5', 'INT+', '2']
print(f"f prefix: {f_prefix}")
print(f"F prefix: {F_prefix}")
print()

print('source', sp.S(env.infix_to_sympy(env.prefix_to_infix(f_prefix)), locals=env.local_dict))
print('target', sp.S(env.infix_to_sympy(env.prefix_to_infix(F_prefix)), locals=env.local_dict))'''

'f_prefix = [\'mul\', \'div\', \'INT+\', \'5\', \'INT+\', \'2\', \'pow\', \'x\', \'div\', \'INT+\', \'3\', \'INT+\', \'2\']\nF_prefix = [\'pow\', \'x\', \'div\', \'INT+\', \'5\', \'INT+\', \'2\']\nprint(f"f prefix: {f_prefix}")\nprint(f"F prefix: {F_prefix}")\nprint()\n\nprint(\'source\', sp.S(env.infix_to_sympy(env.prefix_to_infix(f_prefix)), locals=env.local_dict))\nprint(\'target\', sp.S(env.infix_to_sympy(env.prefix_to_infix(F_prefix)), locals=env.local_dict))'

### Encode input

In [167]:
x1_prefix = env.clean_prefix(['sub', 'derivative', 'f', 'x', 'x'] + f_prefix)
x1 = torch.LongTensor(
    [env.eos_index] +
    [env.word2id[w] for w in x1_prefix] +
    [env.eos_index]
).view(-1, 1)
len1 = torch.LongTensor([len(x1)])

#x2_prefix = env.clean_prefix(F_prefix)
#x2 = torch.LongTensor(
#    [env.eos_index] +
#    [env.word2id[w] for w in x2_prefix] +
#    [env.eos_index]
#).view(-1, 1)
#len2 = torch.LongTensor([len(x2)])

##### rel mat

In [168]:
ancestors, levels = get_ancestors(x1_prefix, len(x1_prefix))
rel_matrix = get_ud_masks(ancestors, levels, len(x1_prefix))              
rel_matrix = [[line.split() for line in rel_matrix]]

rel_matrices_batch, rel_lens = build_relmat_batch(rel_matrix=rel_matrix, rel_word2id=rel_word2id)
print(rel_matrices_batch.size())
print(rel_lens)

torch.Size([26, 26, 1])
tensor([26])


###### root paths

In [169]:
paths_q = get_root_paths(x1_prefix, len(x1_prefix))
print(paths_q)
rps_q = [paths_q[path] if paths_q[path] != '' else 'root' for i, path in enumerate(paths_q)]
print(rps_q)
rps_q = [[int(rp_elem) for rp_elem in list(rp)] if rp != "root" else [] for rp in rps_q]
print(rps_q)

#paths_a = get_root_paths(x2_prefix, len(x2_prefix))
#rps_a = [paths_a[path] if paths_a[path] != '' else 'root' for i, path in enumerate(paths_a)]
#rps_a = [[int(rp_elem) for rp_elem in list(rp)] if rp != "root" else [] for rp in rps_a]

root_path_batch_q = build_root_paths([rps_q])

OrderedDict([(-1, ''), (0, ''), (1, '1'), (2, '2'), (3, '21'), (4, '211'), (5, '2111'), (6, '21110'), (7, '2112'), (8, '21121'), (9, '211210'), (10, '21122'), (11, '211220'), (12, '212'), (13, '2120'), (14, '21200'), (15, '22'), (16, '221'), (17, '2210'), (18, '222'), (19, '2221'), (20, '22210'), (21, '222100'), (22, '2222'), (23, '22220'), (24, '')])
['root', 'root', '1', '2', '21', '211', '2111', '21110', '2112', '21121', '211210', '21122', '211220', '212', '2120', '21200', '22', '221', '2210', '222', '2221', '22210', '222100', '2222', '22220', 'root']
[[], [], [1], [2], [2, 1], [2, 1, 1], [2, 1, 1, 1], [2, 1, 1, 1, 0], [2, 1, 1, 2], [2, 1, 1, 2, 1], [2, 1, 1, 2, 1, 0], [2, 1, 1, 2, 2], [2, 1, 1, 2, 2, 0], [2, 1, 2], [2, 1, 2, 0], [2, 1, 2, 0, 0], [2, 2], [2, 2, 1], [2, 2, 1, 0], [2, 2, 2], [2, 2, 2, 1], [2, 2, 2, 1, 0], [2, 2, 2, 1, 0, 0], [2, 2, 2, 2], [2, 2, 2, 2, 0], []]


In [170]:
with torch.no_grad():
    encoded = encoder('fwd', x=x1, lengths=len1, causal=False).transpose(0, 1)
    
    encoded_2 = encoder_2('fwd', x=x1, lengths=len1, causal=False).transpose(0, 1)
    
    encoded_3 = encoder_3('fwd', x=x1, lengths=len1, causal=False,
                          rel_matrix=rel_matrices_batch, rel_lens=rel_lens).transpose(0, 1)
    
    encoded_4 = encoder_4('fwd', x=x1, lengths=len1, causal=False,
                          root_paths=root_path_batch_q).transpose(0, 1)
    
    #encoded_5 = encoder_5('fwd', x=x1, lengths=len1, causal=False,
    #                      root_paths=root_path_batch_q).transpose(0, 1)

### Decode with beam search

In [172]:
beam_size = 10

print(f"Input function f: {f}")
print(f"Reference function F: {F}")

for i, (encoded_now, decoder_now) in enumerate(zip([encoded, encoded_2, encoded_3, encoded_4], #encoded_5], 
                                                   [decoder, decoder_2, decoder_3, decoder_4])): #decoder_5])):
    if i == 0:
        print('POS EMB:')
    elif i == 1:
        print('SEQ REL ATT:')
    elif i == 2:
        print('TREE REL ATT:')
    elif i == 3:
        print('TREE POS ENC:')
    elif i == 4:
        print('TREE POS ENC ENC+DEC')
    
    with torch.no_grad():
        _, _, beam = decoder_now.generate_beam(encoded_now, len1, beam_size=beam_size, length_penalty=1.0, early_stopping=1, max_len=200)
        assert len(beam) == 1
    hypotheses = beam[0].hyp
    assert len(hypotheses) == beam_size

    
    for score, sent in sorted(hypotheses, key=lambda x: x[0], reverse=True):

        # parse decoded hypothesis
        ids = sent[1:].tolist()                  # decoded token IDs
        tok = [env.id2word[wid] for wid in ids]  # convert to prefix

        try:
            hyp = env.prefix_to_infix(tok)       # convert to infix
            hyp = env.infix_to_sympy(hyp)        # convert to SymPy

            # check whether we recover f if we differentiate the hypothesis
            # note that sometimes, SymPy fails to show that hyp' - f == 0, and the result is considered as invalid, although it may be correct
            res = "OK" if simplify(hyp.diff(x) - f, seconds=1) == 0 else "NO"
            if res == "OK":
                print("%.5f  %s  %s" % (score, res, hyp))
                break

            
        except:
            res = "INVALID PREFIX EXPRESSION"
            hyp = tok

        print("%.5f  %s  %s" % (score, res, hyp))
    
    print()
    
print('end')

Input function f: (tan(x)**2 + 1)*sin(tan(x)) - sin(x)*cos(cos(x))
Reference function F: sin(cos(x)) - cos(tan(x))
POS EMB:
tree pos enc is being used False
-0.23342  NO  sin(x)*cos(x)/2 + sin(x) - sin(x)/cos(x) + cos(cos(x))/2
-0.23499  NO  sin(x)*cos(x)/2 + sin(x) - sin(x)/cos(x) - sin(cos(x))
-0.23689  NO  sin(x)*cos(x)/2 - sin(x)/cos(x) + sin(cos(x)) + 1/cos(x)
-0.23814  NO  sin(x) - sin(x)/cos(x) - sin(cos(x))/2
-0.24111  NO  sin(x)*cos(x)/2 - sin(x)/cos(x) + sin(cos(x)) + cos(cos(x))/2
-0.24996  NO  sin(x)*cos(x)/2 + sin(x) - sin(x)/cos(x) + 1/cos(x)
-0.25567  NO  sin(x)*cos(x)/2 - sin(x)/cos(x) + sin(sin(x)) + cos(cos(x))/2
-0.26286  NO  sin(x)*cos(x)/2 + sin(x) - cos(x)*cos(cos(x))/2 + 1/cos(x)
-0.28884  NO  sin(x) - sin(cos(x)) + cos(cos(x))/2
-0.31883  NO  sin(x) - sin(cos(x)) + 1/cos(x)

SEQ REL ATT:
tree pos enc is being used False
-0.30065  NO  -log(cos(x) + 1)/2 + sin(x)*cos(x)/2 - sin(cos(x)) - cos(tan(x))
-0.30395  NO  -log(cos(x) + 1)/2 + sin(x)*cos(x)/2 - sin(cos(x)) 

In [180]:
evals = ['dumped/eval_pos_emb_nodups/261091/eval.prim_fwd.0',
         'dumped/eval_seq_rel_att_nodups/261092/eval.prim_fwd.0',
         'dumped/eval_tree_rel_att_nodups/261093/eval.prim_fwd.0',
         'dumped/eval_tree_pos_enc_nodups/271822/eval.prim_fwd.0']

solutions = [set(), set(), set(), set()]

for i, file in enumerate(evals):
    with open(file, 'r') as f:
        for j, line in enumerate(f):
            if 'Equation' in line:
                line = line.split()
                eq_num, line_end = int(line[1]), line[2]
                times_solved = line_end.strip('()').split('/')[0]
                is_solved = int(times_solved) > 0
                
                if is_solved:
                    solutions[i].add(eq_num)

In [181]:
for sol in solutions:
    print(len(sol))

6245
6207
6180
6245


In [182]:
only_method = [set(), set(), set(), set()]

only_method[0] = solutions[0].difference(solutions[1]).difference(solutions[2]).difference(solutions[3])
only_method[1] = solutions[1].difference(solutions[2]).difference(solutions[3]).difference(solutions[0])
only_method[2] = solutions[2].difference(solutions[3]).difference(solutions[0]).difference(solutions[1])
only_method[3] = solutions[3].difference(solutions[0]).difference(solutions[1]).difference(solutions[2])

In [149]:
for method in only_method:
    print(len(method))

only_method_lens = []
for method in only_method:
    only_method_lens.append(len(method))
    
print(only_method_lens)

77
60
62
75
[77, 60, 62, 75]


In [173]:
lens_of_q = [0, 0, 0, 0]
lens_of_a = [0, 0, 0, 0]

for METHOD_NUM in [0, 1, 2, 3]:

    with open('data/prim_fwd.no_dups_valtest', 'r') as test_data:
        for i, line in enumerate(test_data):
            if i in only_method[METHOD_NUM]:
                qa = line.split('|')[1].split('\t')
                q = qa[0].split()[2:]
                a = qa[1].split()
                
                lens_of_q[METHOD_NUM] += len(q)
                lens_of_a[METHOD_NUM] += len(a)
                
                try:
                    q = env.prefix_to_infix(q)       # convert to infix
                    q = env.infix_to_sympy(q)
                    #print('source:\t', q)

                    a = env.prefix_to_infix(a)
                    a = env.infix_to_sympy(a)
                    #print('answer:\t', a)
                except:
                    continue

                #print()
            

  return ConstantNode(func(*[x.value for x in args]))


In [176]:
for i in range(4):
    print(i, ':', round(lens_of_q[i] / only_method_lens[i], 2), round(lens_of_a[i] / only_method_lens[i], 2))

0 : 27.36 95.17
1 : 28.13 94.03
2 : 28.55 82.98
3 : 27.28 88.35


In [217]:
# Trigonometric Functions
trig = {'sin',
        'cos',
        'tan',
        'cot',
        'sec',
        'csc'
       }
trig_inv = {'asin',
        'acos',
        'atan',
        'acot',
        'asec',
        'acsc'
}
hyp = {  'sinh',
        'cosh',
        'tanh',
        'coth',
        'sech',
        'csch'
      }
hyp_inv={
        'asinh',
        'acosh',
        'atanh',
        'acoth',
        'asech',
        'acsch',
}

In [151]:
method_trig = [0, 0, 0, 0]
method_trig_inv = [0, 0, 0, 0]
method_hyp = [0, 0, 0, 0]
method_hyp_inv = [0, 0, 0, 0]
found_trig, found_trig_inv, found_hyp, found_hyp_inv = 0, 0, 0, 0

for METHOD_NUM in [0, 1, 2, 3]:

    with open('data/prim_fwd.no_dups_valtest', 'r') as test_data:
        for i, line in enumerate(test_data):
            if i in only_method[METHOD_NUM]:
                qa = line.split('|')[1].split('\t')
                q = qa[0].split()[2:]
                a = qa[1].split()

                for el in q:
                    if el in trig:
                        found_trig = 1
                    elif el in trig_inv:
                        found_trig_inv = 1
                    elif el in hyp:
                        found_hyp = 1
                    elif el in hyp_inv:
                        found_hyp_inv = 1
                
                method_trig[METHOD_NUM] += found_trig
                method_trig_inv[METHOD_NUM] += found_trig_inv
                method_hyp[METHOD_NUM] += found_hyp
                method_hyp_inv[METHOD_NUM] += found_hyp_inv
                found_trig, found_trig_inv, found_hyp, found_hyp_inv = 0, 0, 0, 0
                
                try:
                    q = env.prefix_to_infix(q)       # convert to infix
                    q = env.infix_to_sympy(q)
                    #print('source:\t', q)

                    a = env.prefix_to_infix(a)
                    a = env.infix_to_sympy(a)
                    #print('answer:\t', a)
                except:
                    continue

                #print()
            

In [157]:
for i in range(4):
    print(i, ':', round(method_trig[i] / only_method_lens[i], 2))

0 : 0.36
1 : 0.35
2 : 0.27
3 : 0.45


In [158]:
for i in range(4):
    print(i, ':', round(method_trig_inv[i] / only_method_lens[i], 2))

0 : 0.17
1 : 0.08
2 : 0.13
3 : 0.11


In [159]:
for i in range(4):
    print(i, ':', round(method_hyp[i] / only_method_lens[i], 2))

0 : 0.16
1 : 0.13
2 : 0.13
3 : 0.11


In [160]:
for i in range(4):
    print(i, ':', round(method_hyp_inv[i] / only_method_lens[i], 2))

0 : 0.18
1 : 0.08
2 : 0.15
3 : 0.17


In [195]:
failed = [set(), set(), set(), set()]

failed[0] = (solutions[1].intersection(solutions[2]).intersection(solutions[3])).difference(solutions[0])
failed[1] = (solutions[2].intersection(solutions[3]).intersection(solutions[0])).difference(solutions[1])
failed[2] = (solutions[3].intersection(solutions[0]).intersection(solutions[1])).difference(solutions[2])
failed[3] = (solutions[0].intersection(solutions[1]).intersection(solutions[2])).difference(solutions[3])

In [198]:
for fail in failed:
    print(len(fail))
    
failed_lens = []
for fail in failed:
    failed_lens.append(len(fail))
    
print(failed_lens)

95
113
143
116
[95, 113, 143, 116]


In [206]:
lens_of_q = [0, 0, 0, 0]
lens_of_a = [0, 0, 0, 0]

for METHOD_NUM in [0, 1, 2, 3]:

    with open('data/prim_fwd.no_dups_valtest', 'r') as test_data:
        for i, line in enumerate(test_data):
            if i in failed[METHOD_NUM]:
                qa = line.split('|')[1].split('\t')
                q = qa[0].split()[2:]
                a = qa[1].split()
                
                lens_of_q[METHOD_NUM] += len(q)
                lens_of_a[METHOD_NUM] += len(a)
                
                try:
                    q = env.prefix_to_infix(q)       # convert to infix
                    q = env.infix_to_sympy(q)
                    #print('source:\t', q)

                    a = env.prefix_to_infix(a)
                    a = env.infix_to_sympy(a)
                    #print('answer:\t', a)
                except:
                    continue

                #print()
            

In [207]:
for i in range(4):
    print(i, ':', round(lens_of_q[i] / failed_lens[i], 2), round(lens_of_a[i] / failed_lens[i], 2))

0 : 27.49 78.94
1 : 26.21 77.09
2 : 26.73 76.38
3 : 26.09 80.27


In [219]:
method_trig = [0, 0, 0, 0]
method_trig_inv = [0, 0, 0, 0]
method_hyp = [0, 0, 0, 0]
method_hyp_inv = [0, 0, 0, 0]
found_trig, found_trig_inv, found_hyp, found_hyp_inv = 0, 0, 0, 0

for METHOD_NUM in [0, 1, 2, 3]:

    with open('data/prim_fwd.no_dups_valtest', 'r') as test_data:
        for i, line in enumerate(test_data):
            if i in failed[METHOD_NUM]:
                qa = line.split('|')[1].split('\t')
                q = qa[0].split()[2:]
                a = qa[1].split()
                
                
                for el in q:
                    if el in trig:
                        found_trig = 1
                    elif el in trig_inv:
                        found_trig_inv = 1
                    elif el in hyp:
                        found_hyp = 1
                    elif el in hyp_inv:
                        found_hyp_inv = 1
                
                method_trig[METHOD_NUM] += found_trig
                method_trig_inv[METHOD_NUM] += found_trig_inv
                method_hyp[METHOD_NUM] += found_hyp
                method_hyp_inv[METHOD_NUM] += found_hyp_inv
                found_trig, found_trig_inv, found_hyp, found_hyp_inv = 0, 0, 0, 0
                
                try:
                    q = env.prefix_to_infix(q)       # convert to infix
                    q = env.infix_to_sympy(q)
                    #print('source:\t', q)

                    a = env.prefix_to_infix(a)
                    a = env.infix_to_sympy(a)
                    #print('answer:\t', a)
                except:
                    continue

                #print()
            

In [220]:
for i in range(4):
    print(i, ':', round(method_trig[i] / failed_lens[i], 2))

0 : 0.55
1 : 0.51
2 : 0.43
3 : 0.43


In [221]:
for i in range(4):
    print(i, ':', round(method_trig_inv[i] / failed_lens[i], 2))

0 : 0.16
1 : 0.13
2 : 0.17
3 : 0.14


In [222]:
for i in range(4):
    print(i, ':', round(method_hyp[i] / failed_lens[i], 2))

0 : 0.14
1 : 0.19
2 : 0.15
3 : 0.15


In [223]:
for i in range(4):
    print(i, ':', round(method_hyp_inv[i] / failed_lens[i], 2))

0 : 0.19
1 : 0.16
2 : 0.13
3 : 0.12
