In [1]:
import re
import string
import random
import collections

from tqdm import tqdm

In [2]:
# read input text file into a list of tuples such that:
# tuple at 0 = input seq = factored expression
# tuple at 1 = output seq = expanded expression
pairs = []
with open("data/train.txt") as fi:
    raw_text = fi.read()
    for line in tqdm(raw_text.splitlines()):
        pairs.append(line.split("="))
        assert len(pairs[-1]) == 2

100%|██████████| 1000000/1000000 [00:01<00:00, 940738.87it/s]


In [3]:
# number of examples
len(pairs)

1000000

In [4]:
# max length of input (expected <= 29)
max(list(map(lambda x: len(x[0]), pairs)))

29

In [5]:
# max length of output (expected <= 29)
max(list(map(lambda x: len(x[1]), pairs)))

28

In [6]:
# 10 random examples
print(f"{'input (factored)':>30}  |  {'output (expanded)':30}")
print("-"*(30*2+5))
for factored, expanded in random.sample(pairs, 30):
    print(f"{factored:>30}  =  {expanded:30}")

              input (factored)  |  output (expanded)             
-----------------------------------------------------------------
             (29-9*n)*(3*n+10)  =  -27*n**2-3*n+290              
               (n+17)*(6*n+14)  =  6*n**2+116*n+238              
                       15*y**2  =  15*y**2                       
                 (t-8)*(8*t-4)  =  8*t**2-68*t+32                
                     7*s*(4-s)  =  -7*s**2+28*s                  
                    -i*(6*i-2)  =  -6*i**2+2*i                   
                 (2-4*h)*(h+4)  =  -4*h**2-14*h+8                
                  (k-5)*(k+13)  =  k**2+8*k-65                   
                (21-9*n)*(n+4)  =  -9*n**2-15*n+84               
                (s-32)*(4*s+6)  =  4*s**2-122*s-192              
            (-2*x-21)*(8*x-31)  =  -16*x**2-106*x+651            
                       21*c**2  =  21*c**2                       
             (30-9*x)*(8*x+15)  =  -72*x**2+105*x+450            
          

In [7]:
def freq(pattern, s):
    return collections.Counter(re.findall(pattern, s)).most_common()

In [8]:
# chars
freq(".", raw_text)

[('*', 6296348),
 ('-', 2932951),
 ('2', 2739472),
 ('(', 1717013),
 (')', 1717013),
 ('1', 1551639),
 ('+', 1249605),
 ('=', 1000000),
 ('4', 952516),
 ('3', 950998),
 ('6', 854153),
 ('5', 803947),
 ('8', 799495),
 ('7', 646955),
 ('0', 621755),
 ('s', 568438),
 ('n', 566389),
 ('i', 528183),
 ('9', 500412),
 ('t', 285215),
 ('a', 284688),
 ('c', 284521),
 ('o', 283088),
 ('y', 246024),
 ('z', 245599),
 ('k', 245042),
 ('h', 244132),
 ('j', 244094),
 ('x', 243916)]

In [9]:
# lowercase chars
freq("[a-z]", raw_text)

[('s', 568438),
 ('n', 566389),
 ('i', 528183),
 ('t', 285215),
 ('a', 284688),
 ('c', 284521),
 ('o', 283088),
 ('y', 246024),
 ('z', 245599),
 ('k', 245042),
 ('h', 244132),
 ('j', 244094),
 ('x', 243916)]

In [10]:
# lowercase terms
freq("[a-z]+", raw_text)

[('s', 489862),
 ('i', 488935),
 ('n', 487884),
 ('y', 246024),
 ('t', 245958),
 ('z', 245599),
 ('a', 245431),
 ('c', 245193),
 ('k', 245042),
 ('h', 244132),
 ('j', 244094),
 ('x', 243916),
 ('o', 243760),
 ('cos', 39328),
 ('tan', 39257),
 ('sin', 39248)]

In [11]:
# digit chars
freq("[0-9]", raw_text)

[('2', 2739472),
 ('1', 1551639),
 ('4', 952516),
 ('3', 950998),
 ('6', 854153),
 ('5', 803947),
 ('8', 799495),
 ('7', 646955),
 ('0', 621755),
 ('9', 500412)]

In [104]:
# digit terms (numbers)
numbers = freq("[0-9]+", raw_text)
print(len(numbers))
print(numbers[:5])
print(numbers[-5:])

603
[('2', 1340750), ('6', 321799), ('8', 321040), ('4', 310022), ('3', 300074)]
[('508', 1), ('481', 1), ('492', 1), ('499', 1), ('501', 1)]


In [111]:
# symbol chars
freq("\*|-|\(|\)|\+|=", raw_text)

[('*', 6296348),
 ('-', 2932951),
 ('(', 1717013),
 (')', 1717013),
 ('+', 1249605),
 ('=', 1000000)]

In [112]:
# symbol terms
freq("[\*|-|\(|\)|\+|=]+", raw_text)

[('*', 3248644),
 ('+', 1226446),
 ('**', 1008562),
 (')=', 958297),
 ('(', 760670),
 (')*(', 631813),
 ('*(', 317284),
 (')', 43644),
 ('=', 40853),
 (')**', 32287),
 (')+', 23159),
 (')*', 7697),
 ('))*(', 7246),
 ('))*', 1958),
 ('))=', 850),
 ('))**', 4)]

### Language
- **digits:** `0, 1, 2, 3, 4, 5, 6, 7, 8, 9`
- **variables:** `a, c, h, i, j, k, n, o, s, t, x, y, z`
- **parentheses:** `(, )`
- **math operators:** `*, **, +, -`
- **trig functions:** `sin, cos, tan`

Potential experimentation with languages:
- purely character based lang
- replace digits with numbers
- use all symbol terms

In [122]:
# final terms
vocab_pattern = "sin|cos|tan|\d|\w|\(|\)|\+|-|\*+"
vocab = freq(vocab_pattern, raw_text)

In [123]:
len(vocab)

32

In [125]:
vocab

[('*', 4214642),
 ('-', 2932951),
 ('2', 2739472),
 ('(', 1717013),
 (')', 1717013),
 ('1', 1551639),
 ('+', 1249605),
 ('**', 1040853),
 ('4', 952516),
 ('3', 950998),
 ('6', 854153),
 ('5', 803947),
 ('8', 799495),
 ('7', 646955),
 ('0', 621755),
 ('9', 500412),
 ('s', 489862),
 ('i', 488935),
 ('n', 487884),
 ('y', 246024),
 ('t', 245958),
 ('z', 245599),
 ('a', 245431),
 ('c', 245193),
 ('k', 245042),
 ('h', 244132),
 ('j', 244094),
 ('x', 243916),
 ('o', 243760),
 ('cos', 39328),
 ('tan', 39257),
 ('sin', 39248)]

In [139]:
# check if input language and output languages are the same
raw_factored = "\n".join([factored for factored, _ in pairs])
raw_expanded = "\n".join([expanded for _, expanded in pairs])

factored_vocab = freq(vocab_pattern, raw_factored)
expanded_vocab = freq(vocab_pattern, raw_expanded)

In [143]:
assert set([term for term, _ in factored_vocab]) == set([term for term, _ in expanded_vocab]) == set([term for term, _ in vocab])

In [144]:
factored_vocab

[('*', 2351588),
 ('(', 1658131),
 (')', 1658131),
 ('-', 1645524),
 ('2', 915989),
 ('1', 717682),
 ('3', 466374),
 ('+', 452922),
 ('4', 340647),
 ('6', 338982),
 ('5', 338830),
 ('8', 338405),
 ('7', 335316),
 ('s', 245081),
 ('9', 244696),
 ('i', 244626),
 ('n', 244100),
 ('0', 155302),
 ('y', 123092),
 ('t', 123055),
 ('z', 122866),
 ('a', 122808),
 ('c', 122658),
 ('k', 122596),
 ('h', 122155),
 ('j', 122133),
 ('x', 122021),
 ('o', 121956),
 ('**', 40853),
 ('cos', 19675),
 ('sin', 19640),
 ('tan', 19636)]

In [145]:
expanded_vocab

[('*', 1863054),
 ('2', 1823483),
 ('-', 1287427),
 ('**', 1000000),
 ('1', 833957),
 ('+', 796683),
 ('4', 611869),
 ('6', 515171),
 ('3', 484624),
 ('0', 466453),
 ('5', 465117),
 ('8', 461090),
 ('7', 311639),
 ('9', 255716),
 ('s', 244781),
 ('i', 244309),
 ('n', 243784),
 ('y', 122932),
 ('t', 122903),
 ('z', 122733),
 ('a', 122623),
 ('c', 122535),
 ('k', 122446),
 ('h', 121977),
 ('j', 121961),
 ('x', 121895),
 ('o', 121804),
 ('(', 58882),
 (')', 58882),
 ('cos', 19653),
 ('tan', 19621),
 ('sin', 19608)]

In [149]:
# test that regex pattern preserves information
for factored, expanded in tqdm(pairs):
    factored_terms = re.findall(vocab_pattern, factored)
    expanded_terms = re.findall(vocab_pattern, expanded)
    assert "".join(factored_terms) == factored
    assert "".join(expanded_terms) == expanded

100%|██████████| 1000000/1000000 [00:05<00:00, 175464.55it/s]
