## Data Exploration

In [2]:
# read the data

with open('dataset.txt') as file:
    
    equations = file.read().splitlines()

len(equations)

1000000

In [None]:
# check for repeated equations

len(equations) == len(set(equations))

False

In [3]:
# keep only unique equations

equations = list(set(equations))

len(equations)

732171

In [None]:
# see some examples

equations[:10]

['(-9*h-13)*(h+4)=-9*h**2-49*h-52',
 '(2*i-8)*(5*i-25)=10*i**2-90*i+200',
 '(-8*t-4)*(8*t+13)=-64*t**2-136*t-52',
 '(1-8*y)*(6*y-1)=-48*y**2+14*y-1',
 '(s+11)*(4*s-20)=4*s**2+24*s-220',
 '(24-i)*(8*i+29)=-8*i**2+163*i+696',
 '(z+24)*(7*z-12)=7*z**2+156*z-288',
 '(-7*z-9)*(-2*z-12)=14*z**2+102*z+108',
 '(10-9*j)*(-8*j-28)=72*j**2+172*j-280',
 '(17-2*s)*(5*s-4)=-10*s**2+93*s-68']

In [None]:
# check that the '=' sign is present in ALL equations and just once

for equation in equations:
    
    assert(equation.count('=') == 1)

In [52]:
# split equations into factorized and expanded forms

pairs = [equation.lower().split('=') for equation in equations] # we first lowercase all of the equations

factorized = [pair[0] for pair in pairs]

expanded = [pair[1] for pair in pairs]

# see some examples

print(f"{'input (factored)':>30}  |  {'output (expanded)':30}")
print("-"*(30*2+5))

for i in range(10):

    print(f"{factorized[i]:>30}  =  {expanded[i]:30}")

              input (factored)  |  output (expanded)             
-----------------------------------------------------------------
                  -3*k*(2*k-4)  =  -6*k**2+12*k                  
               (x+16)*(2*x+24)  =  2*x**2+56*x+384               
                   -5*i*(15-i)  =  5*i**2-75*i                   
               (6-2*h)*(2*h+3)  =  -4*h**2+6*h+18                
                 (t-1)*(3*t+5)  =  3*t**2+2*t-5                  
                (6-j)*(8*j+25)  =  -8*j**2+23*j+150              
                 -6*t*(13-7*t)  =  42*t**2-78*t                  
               (30-4*j)*(j+29)  =  -4*j**2-86*j+870              
                   7*k*(7*k+3)  =  49*k**2+21*k                  
              (5*a+8)*(6*a-15)  =  30*a**2-27*a-120              


In [53]:
import re
from collections import Counter

def find_and_count(pattern, string):
    
    """
    Looks for 'pattern' in 'string', counts its occurences and sorts them
    Returns: a sorted list of found patterns and their counts
    """

    return Counter(re.findall(pattern, string)).most_common()

In [None]:
# look for white spaces

print(find_and_count(' ', ''.join(factorized)))
print(find_and_count(' ', ''.join(expanded)))

[]
[]


In [None]:
# look for digits

find_and_count('\d', ''.join(factorized))

[('2', 711436),
 ('1', 582600),
 ('3', 377158),
 ('8', 272743),
 ('5', 272671),
 ('6', 272624),
 ('4', 272374),
 ('7', 271962),
 ('9', 198703),
 ('0', 124914)]

In [None]:
find_and_count('\d', ''.join(expanded))

[('2', 1401762),
 ('1', 687712),
 ('4', 505159),
 ('6', 423616),
 ('3', 397345),
 ('0', 393811),
 ('5', 381294),
 ('8', 377866),
 ('7', 254089),
 ('9', 210588)]

In [None]:
# look for letters sequences

find_and_count('[a-z]+', ''.join(factorized))

[('s', 170150),
 ('i', 169508),
 ('n', 169274),
 ('t', 95734),
 ('k', 95512),
 ('y', 95491),
 ('c', 95461),
 ('a', 95417),
 ('z', 95358),
 ('o', 95166),
 ('j', 95148),
 ('x', 94849),
 ('h', 94694),
 ('cos', 19439),
 ('sin', 19417),
 ('tan', 19361)]

In [None]:
find_and_count('[a-z]+', ''.join(expanded))

[('s', 169013),
 ('i', 168348),
 ('n', 168139),
 ('t', 94946),
 ('y', 94711),
 ('k', 94695),
 ('c', 94658),
 ('a', 94619),
 ('z', 94604),
 ('o', 94369),
 ('j', 94366),
 ('x', 94063),
 ('h', 93900),
 ('cos', 19305),
 ('sin', 19265),
 ('tan', 19226),
 ('ii', 48),
 ('sn', 48),
 ('nn', 47),
 ('so', 47),
 ('in', 46),
 ('ij', 44),
 ('xs', 43),
 ('is', 42),
 ('ns', 40),
 ('cn', 39),
 ('ic', 39),
 ('yi', 37),
 ('ni', 36),
 ('ci', 36),
 ('it', 36),
 ('ss', 35),
 ('ck', 34),
 ('kn', 34),
 ('sa', 34),
 ('si', 34),
 ('nz', 33),
 ('nk', 33),
 ('ti', 33),
 ('ks', 33),
 ('sh', 33),
 ('jn', 33),
 ('sx', 32),
 ('zn', 32),
 ('ix', 32),
 ('ts', 32),
 ('sc', 31),
 ('at', 31),
 ('ka', 31),
 ('hs', 31),
 ('kk', 31),
 ('xi', 31),
 ('ct', 31),
 ('nh', 30),
 ('ny', 30),
 ('sy', 30),
 ('ky', 30),
 ('sz', 30),
 ('js', 29),
 ('no', 29),
 ('ok', 29),
 ('yn', 29),
 ('nj', 29),
 ('ki', 29),
 ('xj', 29),
 ('ji', 29),
 ('xz', 29),
 ('oi', 29),
 ('tn', 28),
 ('jo', 28),
 ('oh', 28),
 ('nx', 28),
 ('os', 28),
 ('xh', 28

In [None]:
# look for non-alphanumeric sequences

find_and_count('\W+', ''.join(factorized))

[('*', 1103643),
 ('-', 930174),
 (')*(', 488639),
 (')(', 403116),
 ('+', 340905),
 (')(-', 175477),
 ('*(', 110451),
 (')-', 100873),
 (')*(-', 82353),
 (')', 73234),
 ('(', 59630),
 ('*(-', 38215),
 (')+', 13440),
 (')*', 7543),
 ('))*(', 5516),
 ('))*', 1940),
 ('))*(-', 1727),
 (')**', 1653),
 ('**', 923),
 ('(-', 618),
 ('))(', 457),
 (')-(', 339),
 ('))(-', 200),
 ('))-', 112),
 (')-(-', 99),
 ('))', 80),
 ('))**', 4),
 ('-(', 1)]

In [None]:
find_and_count('\W+', ''.join(expanded))

[('*', 1416884),
 ('-', 1005050),
 ('**', 702623),
 ('+', 642639),
 ('(', 58148),
 (')**', 29548),
 (')-', 14230),
 (')+', 9630),
 (')', 4740)]

For both factorized and expanded 'languages', we choose vocabs made of:

* **digits**: from 0 to 9
* **letters sequences**: variables and found trigonometric functions 
* **non-alphanumeric sequences**: math operators and the '(' ')'

In [54]:
vocab_pattern = '\d|[a-z]+|\(|\)|\+|-|\*+'

def create_vocab(data, vocab_pattern, threshold):
  """
  Extracts vocab items from data following a vocab_pattern, omitting items with a frequency less than threshold.
  Returns: a vocab list
  """
  vocab = [item for item, freq in find_and_count(vocab_pattern, data) if freq > threshold]

  return vocab

In [55]:
# create vocabs for both factorized and expanded languages

factorized_vocab = create_vocab(''.join(factorized), vocab_pattern, 1000)

factorized_vocab

['*',
 '(',
 ')',
 '-',
 '2',
 '1',
 '3',
 '+',
 '8',
 '5',
 '6',
 '4',
 '7',
 '9',
 's',
 'i',
 'n',
 '0',
 't',
 'k',
 'y',
 'c',
 'a',
 'z',
 'o',
 'j',
 'x',
 'h',
 'cos',
 'sin',
 'tan',
 '**']

In [44]:
len(factorized_vocab)

32

In [56]:
expanded_vocab = create_vocab(''.join(expanded), vocab_pattern, 1000)

expanded_vocab

['*',
 '2',
 '-',
 '**',
 '1',
 '+',
 '4',
 '6',
 '3',
 '0',
 '5',
 '8',
 '7',
 '9',
 's',
 'i',
 'n',
 't',
 'k',
 'c',
 'y',
 'z',
 'a',
 'o',
 'j',
 'x',
 'h',
 '(',
 ')',
 'cos',
 'sin',
 'tan']

In [None]:
len(factorized_vocab)

32

In [46]:
# is it the same vocab?

factorized_vocab.sort() == expanded_vocab.sort()

True

In [57]:
# as it turns out that both type of forms generated the same vocab, we will be using just one

vocab = expanded_vocab

In [58]:
# add a padding item to the vocab to prepare for a padding operation

vocab.append('<pad>')

In [59]:
vocab

['*',
 '2',
 '-',
 '**',
 '1',
 '+',
 '4',
 '6',
 '3',
 '0',
 '5',
 '8',
 '7',
 '9',
 's',
 'i',
 'n',
 't',
 'k',
 'c',
 'y',
 'z',
 'a',
 'o',
 'j',
 'x',
 'h',
 '(',
 ')',
 'cos',
 'sin',
 'tan',
 '<pad>']

In [62]:
# create a vocab dict and an inverted vocab dict to map between vocab items and vocab indexes

vocab_dict = {item:index for index, item in enumerate(vocab)}

inver_vocab_dict = {index:item for index, item in enumerate(vocab)}

## Data Preprocessing

In [12]:
# filter into vocab

vocab_pattern = '\d|[a-z]|cos|sin|tan|\(|\)|\+|-|\*+'

def format_forms(forms, vocab_pattern):
  """
  Transforms each form of forms into items of a vocab defined by a vocab_pattern
  Returns: a list of lists (found vocab items in the form)
  """

  formatted_form = [re.findall(vocab_pattern, form) for form in forms]

  return formatted_form

In [13]:
formatted_factorized = format_forms(factorized, vocab_pattern)

In [None]:
formatted_factorized[0]

['(', '-', '9', '*', 'h', '-', '1', '3', ')', '*', '(', 'h', '+', '4', ')']

In [14]:
formatted_expanded = format_forms(expanded, vocab_pattern)

In [None]:
formatted_expanded[0]

['-', '9', '*', 'h', '**', '2', '-', '4', '9', '*', 'h', '-', '5', '2']

In [15]:
# examine maximum lengths in our formatted forms

print('maximum factorized forms length: ', max([len(form) for form in formatted_factorized]))
print('maximum expanded forms length: ',max([len(form) for form in formatted_expanded]))

maximum factorized forms length:  29
maximum expanded forms length:  27


In [16]:
input_seq_len = max([len(form) for form in formatted_factorized])
output_seq_len = max([len(form) for form in formatted_expanded])
vocab_len = len(vocab)

In [17]:
# pad the empty spots in forms with lengths < maximum length with a special item '<pad>'

formatted_factorized = [ seq + ['<pad>'] * (input_seq_len - len(seq)) for seq in formatted_factorized]

formatted_expanded = [ seq + ['<pad>'] * (output_seq_len - len(seq)) for seq in formatted_expanded]

In [18]:
# convert forms into numerical vectors of indexes in vocab

X_indexed = [list(map(vocab_dict.get, form)) for form in formatted_factorized]
Y_indexed = [list(map(vocab_dict.get, form)) for form in formatted_expanded]

In [19]:
import tensorflow as tf
import numpy as np

In [20]:
# convert indexes into one hot encoded vectors

X = np.array(tf.one_hot(X_indexed, len(vocab)))
Y = np.array(tf.one_hot(Y_indexed, len(vocab)))

In [21]:
# split data into test and train

from sklearn import model_selection 

x_train, x_test, y_train, y_test = model_selection.train_test_split(X, Y, test_size= 0.05)

print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)

(695562, 29, 33)
(36609, 29, 33)
(695562, 27, 33)
(36609, 27, 33)


## Training

In [23]:
from tensorflow.keras.models import Model, load_model

from tensorflow.keras.layers import Dense, LSTM, Input, Bidirectional, Concatenate, Dot, RepeatVector, Activation

import pandas as pd

In [None]:
## build an encoder decoder with attention mechanism

# define layers

repeator = RepeatVector(input_seq_len)
concatenator = Concatenate(axis=-1)
densor1 = Dense(10, activation = "tanh")
densor2 = Dense(1, activation = "relu")
activator = Activation('softmax')
dotor = Dot(axes = 1)

In [None]:
def one_step_attention(a, s_prev):
    """
    Calculates a context vector from the encoder states a and the decoder previous states s_prev
    """
    
    s_prev = repeator(s_prev)

    # concatenate a and s_prev on the last axis
    concat = concatenator([a, s_prev])

    # propagate concat through a small fully-connected neural network to compute the "intermediate energies" variable e
    e = densor1(concat)

    # propagate e through a small fully-connected neural network to compute the "energies" variable energies
    energies = densor2(e)

    # compute the attention weights "alphas"
    alphas = activator(energies)

    # compute the context vector to be given to the decoder LSTM
    context = dotor([alphas, a])
    
    return context

In [22]:
n_a = 32
n_s = 64

In [None]:
# define decoder layers

post_activation_LSTM_cell = LSTM(n_s, return_state = True)
output_layer = Dense(vocab_len, activation= 'softmax')

In [None]:
def model(Tx, Ty, n_a, n_s, vocab_size):
    """
    Arguments:
    Tx -- length of the input sequence
    Ty -- length of the output sequence
    n_a -- hidden state size of the Bi-LSTM
    n_s -- hidden state size of the post-attention LSTM
    
    Returns:
    model instance
    """
    
    # Define inputs
    X = Input(shape=(Tx, vocab_size))
    s0 = Input(shape=(n_s,), name='s0')
    c0 = Input(shape=(n_s,), name='c0')
    s = s0
    c = c0
    
    # Initialize empty list of outputs
    outputs = []
    
    # Define pre-attention Bi-LSTM encoder
    a = Bidirectional(LSTM(n_a, return_sequences= True))(X)
    
    for t in range(Ty):
    
        # Perform one step of the attention mechanism to get back the context vector at step t
        context = one_step_attention(a, s)
        
        # Apply the post-attention LSTM to the "context" vector
        s, _, c = post_activation_LSTM_cell(initial_state = [s,c], inputs= context)
        
        # Apply Dense layer to the hidden state output of the decoder
        out = output_layer(s)
        
        # Append "out" to the "outputs" list
        outputs.append(out)
    
    # Create model instance taking three inputs and returning the list of outputs
    model = Model(inputs= [X, s0, c0], outputs= outputs)
    
    return model

In [None]:

model = model(input_seq_len, output_seq_len, n_a, n_s, vocab_len)

In [None]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 29, 33)]     0           []                               
                                                                                                  
 s0 (InputLayer)                [(None, 64)]         0           []                               
                                                                                                  
 bidirectional_1 (Bidirectional  (None, 29, 64)      16896       ['input_2[0][0]']                
 )                                                                                                
                                                                                                  
 repeat_vector (RepeatVector)   (None, 29, 64)       0           ['s0[0][0]',               

In [66]:
# compile model

opt = tf.keras.optimizers.Adam(learning_rate= 0.005, beta_1= 0.9, beta_2= 0.999, decay= 0.01)

model.compile(optimizer= opt, loss= 'categorical_crossentropy', metrics= ['accuracy'])

In [67]:
# initialize inputs

s0 = np.zeros((x_train.shape[0], n_s))
c0 = np.zeros((x_train.shape[0], n_s))
outputs = list(y_train.swapaxes(0,1))

In [68]:
# fit the model

model.fit([x_train, s0, c0], outputs, epochs= 5, batch_size= 100, validation_split= 0.05)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f19fa45fd50>

In [69]:
model.save('model4.h5')

In [70]:
# check accuracies evolutions through epochs of training

x = [title for title in model.history.history.keys() if 'accuracy' in title]

df = pd.DataFrame(model.history.history)[x]

df

Unnamed: 0,dense_2_accuracy,dense_2_1_accuracy,dense_2_2_accuracy,dense_2_3_accuracy,dense_2_4_accuracy,dense_2_5_accuracy,dense_2_6_accuracy,dense_2_7_accuracy,dense_2_8_accuracy,dense_2_9_accuracy,dense_2_10_accuracy,dense_2_11_accuracy,dense_2_12_accuracy,dense_2_13_accuracy,dense_2_14_accuracy,dense_2_15_accuracy,dense_2_16_accuracy,dense_2_17_accuracy,dense_2_18_accuracy,dense_2_19_accuracy,dense_2_20_accuracy,dense_2_21_accuracy,dense_2_22_accuracy,dense_2_23_accuracy,dense_2_24_accuracy,dense_2_25_accuracy,dense_2_26_accuracy,val_dense_2_accuracy,val_dense_2_1_accuracy,val_dense_2_2_accuracy,val_dense_2_3_accuracy,val_dense_2_4_accuracy,val_dense_2_5_accuracy,val_dense_2_6_accuracy,val_dense_2_7_accuracy,val_dense_2_8_accuracy,val_dense_2_9_accuracy,val_dense_2_10_accuracy,val_dense_2_11_accuracy,val_dense_2_12_accuracy,val_dense_2_13_accuracy,val_dense_2_14_accuracy,val_dense_2_15_accuracy,val_dense_2_16_accuracy,val_dense_2_17_accuracy,val_dense_2_18_accuracy,val_dense_2_19_accuracy,val_dense_2_20_accuracy,val_dense_2_21_accuracy,val_dense_2_22_accuracy,val_dense_2_23_accuracy,val_dense_2_24_accuracy,val_dense_2_25_accuracy,val_dense_2_26_accuracy
0,0.996899,0.996341,0.997406,0.997618,0.994473,0.982725,0.974955,0.881329,0.767544,0.793518,0.919765,0.987666,0.97863,0.972109,0.979627,0.991578,0.9977,0.998455,0.997857,0.997031,0.997019,0.997267,0.996017,0.994236,0.994777,0.997494,0.999564,0.997499,0.99724,0.998275,0.998189,0.996147,0.983783,0.98016,0.890336,0.779321,0.800339,0.921476,0.990368,0.982691,0.975905,0.983841,0.993531,0.99839,0.998879,0.998304,0.997297,0.997067,0.997613,0.996291,0.994278,0.995457,0.99793,0.99954
1,0.997004,0.996929,0.997739,0.997846,0.995351,0.983527,0.979288,0.893115,0.781552,0.801647,0.924175,0.992385,0.984535,0.978441,0.984461,0.994195,0.998552,0.998941,0.998464,0.9978,0.997651,0.997834,0.996484,0.994874,0.995247,0.997757,0.999635,0.997499,0.997297,0.998275,0.998074,0.996118,0.983956,0.980764,0.893211,0.783519,0.803042,0.922051,0.991777,0.984416,0.97832,0.984703,0.994192,0.998505,0.999022,0.99862,0.997613,0.99747,0.9977,0.996521,0.994508,0.995486,0.997872,0.999569
2,0.99704,0.996999,0.99775,0.997825,0.995501,0.983662,0.979795,0.894772,0.783693,0.802742,0.924768,0.99302,0.985316,0.979266,0.985171,0.994544,0.998726,0.999042,0.998756,0.998199,0.997948,0.997995,0.996595,0.994915,0.995334,0.997822,0.999658,0.997527,0.997383,0.998304,0.998102,0.996291,0.984215,0.980937,0.895483,0.784813,0.804135,0.922568,0.992467,0.98479,0.978263,0.985221,0.994508,0.998591,0.999051,0.998677,0.997901,0.997613,0.997872,0.996578,0.994508,0.995515,0.997844,0.999597
3,0.997044,0.997029,0.997756,0.997837,0.995598,0.983722,0.980245,0.89591,0.784987,0.803105,0.925128,0.993337,0.985782,0.979621,0.985502,0.994652,0.998823,0.999087,0.9989,0.998441,0.998187,0.998114,0.996669,0.994974,0.995396,0.997863,0.999676,0.997527,0.997297,0.998217,0.998131,0.996348,0.984387,0.981454,0.895627,0.784985,0.804566,0.922769,0.992553,0.984818,0.978349,0.985393,0.994537,0.998821,0.999022,0.998821,0.998332,0.997844,0.998016,0.996665,0.994594,0.995687,0.997959,0.999597
4,0.997049,0.997072,0.997748,0.997822,0.995628,0.983774,0.980461,0.896337,0.785529,0.803728,0.925293,0.993515,0.986101,0.979934,0.985697,0.994827,0.998915,0.999143,0.998942,0.998558,0.998308,0.998182,0.996687,0.995033,0.995439,0.99788,0.999663,0.997527,0.997297,0.998275,0.998189,0.996348,0.984502,0.981339,0.896259,0.785963,0.804796,0.922597,0.992524,0.984933,0.978924,0.985422,0.994594,0.998735,0.99908,0.998764,0.998476,0.99793,0.998074,0.996722,0.994451,0.995601,0.997987,0.999626


* We can see that for **each position t** of our output sequence, we are predicting the vocab item with an **accuracy > 0.78**.

## Evaluation

In [24]:
model = load_model('model3.h5')

In [78]:
s0 = np.zeros((x_test.shape[0], n_s))
c0 = np.zeros((x_test.shape[0], n_s))

In [79]:
# predict

test_predictions = model.predict([x_test, s0, c0])

In [80]:
# convert from list of lists to array

new_test_predictions = np.argmax(test_predictions, axis = -1)

new_test_predictions = np.array(new_test_predictions).T

new_test_predictions.shape

(36609, 27)

In [74]:
def index_to_string(indexed_forms):
  """
  converts from indexed sequences to strings
  """
  
  forms = []

  for row in indexed_forms: 

    forms.append(''.join([k for index in row for k, v in vocab_dict.items() if v == index]).replace('<pad>', ''))

  return forms

In [81]:
# see some random examples

num_test_seqs = x_test.shape[0]

random_indexes = np.random.choice(num_test_seqs, 10)

x_sample = index_to_string (x_test[random_indexes].argmax(axis= -1))

y_sample = index_to_string (y_test[random_indexes].argmax(axis= -1))

predicted_sample = index_to_string(new_test_predictions[random_indexes])

for i, j, k in zip(x_sample, y_sample, predicted_sample):

  print(i, '=', j, ' || predicted =====>', k)

6*t*(7*t-20) = 42*t**2-120*t  || predicted =====> 42*t**2-120*t
(14-6*h)*(-6*h-26) = 36*h**2+72*h-364  || predicted =====> 36*h**2+70*h-364
(21-2*c)*(c-21) = -2*c**2+63*c-441  || predicted =====> -2*c**2+65*c-441
(a-15)*(a-3) = a**2-18*a+45  || predicted =====> a**2-18*a+45
(i+19)*(i+31) = i**2+50*i+589  || predicted =====> i**2+56*i+589
(7-9*y)*(4*y-10) = -36*y**2+118*y-70  || predicted =====> -36*y**2+118*y-70
(2*i+18)*(4*i-9) = 8*i**2+54*i-162  || predicted =====> 8*i**2+50*i-162
(-4*x-22)*(2*x-29) = -8*x**2+72*x+638  || predicted =====> -8*x**2+70*x+638
4*a*(31-6*a) = -24*a**2+124*a  || predicted =====> -24*a**2+124*a
i*(i-2) = i**2-2*i  || predicted =====> i**2-2*i


In [76]:
def calculate_accuracy(predicted_expanded, real_expanded):

   accuracy = sum([ i == j for i,j in zip(predicted_expanded,real_expanded)]) / len(predicted_expanded)

   return accuracy

In [82]:
calculate_accuracy(predicted_sample, y_sample) #index_to_string (y_test[random_indexes].argmax(axis= -1))

0.5

## Conclusion

* Our model performs fine for all **the expanded form positions** with classification accuracies > **0.97**, except for positions **8 and 9** where we score near **0.8**.

* This may affect the global accuracy score a little bit because it is measured on sequence-level equality which means that individual classifications errors may reach important levels when accumulating.

* We reached this performance by training the model on about **20 epochs** with a high cost of **2h per 5 epochs** that made decreasing the loss even more quite difficult.

* A potential improvement of our model is to pursue its training using GPUs or simply more powerful machines. This will also open the opportunity for more hyperparameters tuning and design improvements.