In [1]:
import re
from typing import Dict, List, Set

from pyswip import Prolog

In [2]:
prolog = Prolog()
prolog.assertz("father(michael,john)")
prolog.assertz("father(michael,gina)")
list(prolog.query("father(michael,X)"))  # [{'X': 'john'}, {'X': 'gina'}]

[{'X': 'john'}, {'X': 'gina'}]

In [50]:
cryptarithms = [
    "SEND + MORE #= MONEY",
    "VIOLIN + VIOLA #= TRIO",
    "ODD + ODD #= EVEN",
]

In [4]:
def get_vars(expr: str) -> Set[str]:
    return set(re.findall(r"[A-Z]", expr))

In [5]:
get_vars(cryptarithms[0])

{'D', 'E', 'M', 'N', 'O', 'R', 'S', 'Y'}

In [31]:
rules = {
    ":- use_module(library(clpfd))",
    "digit(X) :- member(X, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])",
    "all_diff([])",
    "all_diff([H|T]) :- \+member(H, T), all_diff(T)",
    """
    solve([O, D, D], [O, D, D], [E, V, E, N]) :-
        % génération des chiffres
        digit(O), digit(D),
        digit(E), digit(V), digit(N),
        % test que O et E sont différents de 0
        O =\= 0, E =\= 0,
        % test de la somme
        all_diff([O, D, E, V, N]),
                   100 * O + 10 * D + D +
                   100 * O + 10 * D + D =:=
        1000 * E + 100 * V + 10 * E + N
    """,
}
for rule in rules:
    prolog.assertz(rule)

In [32]:
query = "solve([O, D, D], [O, D, D], [E, V, E, N])"
list(prolog.query(query))

[{'O': 6, 'D': 5, 'E': 1, 'V': 3, 'N': 0},
 {'O': 8, 'D': 5, 'E': 1, 'V': 7, 'N': 0}]

In [6]:
def all_digit(expr: str) -> Set[str]:
    return {f"digit({char})" for char in get_vars(expr)}

In [7]:
all_digit(cryptarithms[2])

{'digit(D)', 'digit(E)', 'digit(N)', 'digit(O)', 'digit(V)'}

In [8]:
def all_diff(expr: str) -> str:
    return f"all_diff([{', '.join(get_vars(expr))}])"

In [9]:
all_diff(cryptarithms[2])

'all_diff([D, O, N, E, V])'

In [29]:
def generate(
    expr: str, allow_zero: bool = True, allow_leading_zero: bool = False
) -> Set[str]:
    generated = all_digit(expr)
    generated.add(all_diff(expr))
    if not allow_zero:
        for char in get_vars(expr):
            generated.add(f"dif({char}, 0)")
    if not allow_leading_zero:
        for char in set(re.findall(r'\b(\w)', expr)):
            generated.add(f"dif({char}, 0)")
    if allow_leading_zero:
        for char in set(re.findall(r'\b(\w)', expr)):
            if f"dif({char}, 0)" in generated:
                generated.remove(f"dif({char}, 0)")
    return generated

In [30]:
generate(cryptarithms[2], allow_zero=False, allow_leading_zero=True)

{'all_diff([D, O, N, E, V])',
 'dif(D, 0)',
 'dif(N, 0)',
 'dif(V, 0)',
 'digit(D)',
 'digit(E)',
 'digit(N)',
 'digit(O)',
 'digit(V)'}

In [48]:
def test(expr: str) -> str:
    operands = re.findall(r"[A-Z]+", expr)
    operators = re.findall(r"[+\-*/]|mod|#=|=:=", expr)

    rule = ""
    for i in range(len(operators)):
        rule += "("
        for char, coef in zip(operands[i], range(len(operands[i]), 0, -1)):
            rule += f"{10 ** (coef - 1)} * {char} + "
        rule = rule[:-3] + ")"
        rule += f" {operators[i]} "
    rule += "("
    for char, coef in zip(operands[-1], range(len(operands[-1]), 0, -1)):
        rule += f"{10 ** (coef - 1)} * {char} + "
    rule = rule[:-3] + ")"
    
    return rule

In [51]:
test(cryptarithms[2])

'(100 * O + 10 * D + 1 * D) + (100 * O + 10 * D + 1 * D) #= (1000 * E + 100 * V + 10 * E + 1 * N)'