In [1]:
#https://codegolf.stackexchange.com/questions/34926/write-a-fast-word-equation-solver

In [2]:
def solver(eqn, soln_format = "eqn"):
    #get weights
    weight = -1
    weights = {}
    for char in reversed(eqn):
        if char=='=':
            weight=1
            continue
        if char=='+':
            if weight>0:
                weight=1
            else:
                weight=-1
            continue
        if char not in weights:
            weights[char] = 0
        weights[char]+= weight
        weight*=10
    
    assert len(weights) <= 10
    weights_sorted = list(weights.values()) + [0]*(10-len(weights.values()))
    weights_sorted.sort()
    
    _, order = zip(*sorted(zip(weights_sorted, range(10)), key=lambda x: abs(x[0]), reverse=True)) #abs sorted order
    order_list = [num - sum(i<num for i in order[:idx]) for idx, num in enumerate(order)] #because indexes change
    order_list = order_list[:len(weights.values())] #to ignore zeros
    order_list.reverse() #to pop and append like a stack
        
    for soln in backtrack(0, weights_sorted.copy(), order_list, list(reversed(range(10))), []):
        weights_copy = weights.copy()
        weight_soln_pair = [(weights_sorted[i],num) for i,num in zip(order, soln)]
        assert sum(i*j for i,j in weight_soln_pair) == 0, "Major error"
        final_soln = {}
        for weight_val, soln_val in weight_soln_pair:
            for char, weight in weights_copy.items():
                if weight == weight_val:
                    final_soln[char] = soln_val
                    del weights_copy[char]
                    break
                    
        if soln_format == "eqn":
            #CONVERT TO EQN
            eqn_copy = eqn
            for k, v in final_soln.items():
                eqn_copy = eqn_copy.replace(k, str(v))
            yield eqn_copy
        elif soln_format == "dict":
            yield final_soln
            
def backtrack(val, weights, test_order, remaining, soln):
    #Check if end
    if len(test_order)==0:
        if val==0: 
            yield soln
        return False
    #Check if possible
    if sum([i*j for i,j in zip(remaining, reversed(weights))]) + val < 0: #max
        return False
    if sum([i*j for i,j in zip(remaining, weights)]) + val > 0: #min
        return False
    #Backtrack for each digit
    order = test_order.pop()
    weight = weights.pop(order)
    for i, num in enumerate(remaining):
        del remaining[i]
        yield from backtrack(val+num*weight, weights, test_order, remaining, soln + [num])
        remaining.insert(i, num)
    weights.insert(order, weight)
    test_order.append(order)
    return False

In [3]:
%%timeit -n 1 -r 1
print([*solver('AA+BB=CC')])

['11+88=99', '22+77=99', '33+66=99', '44+55=99', '55+44=99', '66+33=99', '77+22=99', '88+11=99', '11+77=88', '22+66=88', '33+55=88', '55+33=88', '66+22=88', '77+11=88', '11+66=77', '22+55=77', '33+44=77', '44+33=77', '55+22=77', '66+11=77', '11+55=66', '22+44=66', '44+22=66', '55+11=66', '11+44=55', '22+33=55', '33+22=55', '44+11=55', '11+33=44', '33+11=44', '11+22=33', '22+11=33']
1.19 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [4]:
%%timeit -n 1 -r 1
print([*solver('SEND+MORE=MONEY')])

['9567+1085=10652', '8542+0915=09457', '8432+0914=09346', '8324+0913=09237', '7643+0826=08469', '7649+0816=08465', '7534+0825=08359', '7531+0825=08356', '7539+0815=08354', '7429+0814=08243', '7316+0823=08139', '6851+0738=07589', '6853+0728=07581', '6524+0735=07259', '6415+0734=07149', '6419+0724=07143', '5849+0638=06487', '5732+0647=06379', '5731+0647=06378', '3821+0468=04289', '3829+0458=04287', '3712+0467=04179', '3719+0457=04176', '2819+0368=03187', '2817+0368=03185']
1.63 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [5]:
%%timeit -n 1 -r 1
print([*solver('CORRECTS+REJECTED=MATTRESS')])

['85112846+12328420=97441266']
45.1 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [6]:
%%timeit -n 1 -r 1
print([*solver('FIFTY+STATES=AMERICA')])

['65682+981849=1047531', '64672+870738=0935410']
6.23 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [7]:
%%timeit -n 1 -r 1
print([*solver('CODE+GOLF=GREAT')])

['9428+1437=10865', '9438+1427=10865', '9265+1278=10543', '9275+1268=10543', '8653+0671=09324', '8673+0651=09324', '8643+0672=09315', '8673+0642=09315', '8612+0635=09247', '8632+0615=09247', '7642+0651=08293', '7652+0641=08293', '6918+0934=07852', '6938+0914=07852', '6918+0925=07843', '6928+0915=07843', '5928+0943=06871', '5948+0923=06871', '3857+0862=04719', '3867+0852=04719', '3612+0685=04297', '3682+0615=04297', '2918+0956=03874', '2958+0916=03874', '2918+0947=03865', '2948+0917=03865', '2846+0851=03697', '2856+0841=03697']
5.6 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [8]:
%%timeit -n 1 -r 1
print([*solver('NINETEEN+NINETEEN'+4*'+TEN'+5*'+NINE'+877*'+ONE'+'=THOUSAND', soln_format='dict')])

[{'N': 4, 'T': 9, 'I': 7, 'H': 5, 'E': 8, 'O': 3, 'U': 1, 'S': 2, 'A': 6, 'D': 0}]
3.92 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [9]:
def Ajax1234_f(a, b, r, c = {}, s = [], k = 0, l = [], A = [], B = []):
    if not a and not b:
        yield f"{''.join(map(str, A))}+{''.join(map(str, B))}={''.join(map(str, l))}"
        return
    for x in c.get(a[-1], range(10)):
        if a[-1] in c or x not in s:
            _c, _s = {**c, a[-1]:[x]}, s+[x]
            for y in _c.get(b[-1], range(10)):
                if b[-1] in _c or y not in _s:
                    n_c, v = {**_c, b[-1]:[y]}, k+x+y
                    n_s = _s+[y]
                    if a[:-1] and b[:-1]:
                        _v, _k = v%10, v//10
                        if [_v] == n_c.get(r[-1], [_v])and (r[-1] in n_c or _v not in n_s):
                            yield from Ajax1234_f(a[:-1], b[:-1], r[:-1], {**n_c, r[-1]:[_v]}, n_s+[_v], _k, [_v]+l, [x]+A, [y]+B)
                    else:
                        if len(str(v)) == len(r):
                            F = 1
                            for J, K in zip(r[::-1], str(v)[::-1]):
                                if n_c.get(J, [int(K)]) == [int(K)] and (J in n_c or int(K) not in n_s):
                                    n_c[J] = [int(K)]
                                    n_s = n_s + [int(K)]
                                else:
                                    F = 0
                                    break
                            if F: yield from Ajax1234_f(a[:-1], b[:-1], '', n_c, n_s, 0, [v]+l, [x]+A, [y]+B) 

In [10]:
%%timeit -n 1 -r 1
print([*Ajax1234_f('AA', 'BB', 'CC')])

['11+22=33', '11+33=44', '11+44=55', '11+55=66', '11+66=77', '11+77=88', '11+88=99', '22+11=33', '22+33=55', '22+44=66', '22+55=77', '22+66=88', '22+77=99', '33+11=44', '33+22=55', '33+44=77', '33+55=88', '33+66=99', '44+11=55', '44+22=66', '44+33=77', '44+55=99', '55+11=66', '55+22=77', '55+33=88', '55+44=99', '66+11=77', '66+22=88', '66+33=99', '77+11=88', '77+22=99', '88+11=99']
657 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [11]:
%%timeit -n 1 -r 1
print([*Ajax1234_f('SEND', 'MORE', 'MONEY')])

['9567+1085=10652']
11.2 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [12]:
%%timeit -n 1 -r 1
print([*Ajax1234_f('CORRECTS', 'REJECTED', 'MATTRESS')])

['85112846+12328420=97441266']
2.22 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [13]:
%%timeit -n 1 -r 1
print([*Ajax1234_f('FIFTY', 'STATES', 'AMERICA')])

[]
23.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [14]:
%%timeit -n 1 -r 1
print([*Ajax1234_f('CODE', 'GOLF', 'GREAT')])

['9265+1278=10543', '9275+1268=10543', '9428+1437=10865', '9438+1427=10865']
27.4 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
