In [1]:
import re
import solver
from sympy.solvers import solveset, solve
from sympy import Symbol, FiniteSet, Interval, S, EmptySet
from itertools import combinations, product
import numpy as np

In [2]:
def find_set_points(minmax_terms, var_name, low=0, high=1, left_open=True, right_open=True):
    """Return a list of sorted set points. 
    There is at least one number. """
    pts = set()
    interval = Interval(low, high, left_open=left_open, right_open=right_open)
    for term in minmax_terms:
        match = re.findall(r"\s*(\d+).*,\s*(\d+)\s*", term[3])
        left = float(match[0][0])
        right = float(match[0][1])
        splitted = term[3].split(",")
        if var_name in splitted[0]:
            pts.add(right/left) 
        elif var_name in splitted[1]:
            pts.add(left/right)
    return sorted(pts)

In [3]:
def create_intervals(set_points, low=0, high=1, left_open=True, right_open=True):
    """Create a list of intervals based on set_points. 
    a < x <= b. """
    intervals = []
    i = -1 
    j = 0 
    required_interval = Interval(low, high, left_open=left_open, right_open=right_open)
    while i < len(set_points):
        if i == -1:
            interval = Interval(low, set_points[j], left_open=left_open, right_open=True)
            if interval is not EmptySet: 
                intervals.append(interval)
        elif j == len(set_points):
            interval = Interval(set_points[i], high, left_open=True, right_open=right_open)
            if interval is not EmptySet:
                intervals.append(interval)
        else:
            intervals.append(Interval(set_points[i], set_points[j], left_open=True, right_open=True))
        i += 1 
        j += 1 
    return intervals

In [4]:
def get_parts(eq):
    return re.findall(r"[\+\-][^\+\-=]*", eq)

In [5]:
def random_interval(interval):
    return np.random.uniform(interval.start, interval.end)

In [6]:
def get_left_right(minmax_tuple):
    match = re.findall(r"\s*(\d+.*),\s*(\d+[^\)]*)\s*", minmax_tuple)
    return match[0][0], match[0][1]

In [7]:
def get_cons_var_terms(equation):
    match_list = re.findall(r"(\+|\-)\s*(\d+)\s*\*\s*([a-z]+)", equation)
    i = 0 
    n = len(match_list)
    while i < len(match_list):
        if "min" in match_list[i][2] or "max" in match_list[i][2]:
            match_list.pop(i)
            i -= 1 
            n -= 1
        i += 1 
    return match_list

In [8]:
def get_value_term(equation):
    return re.search(r"=\s*(\d+)", equation).group(1)

In [9]:
def knit_solver(interval, minmax_terms, cons_var_terms, var_name):
    solver = ""
    for term in minmax_terms:
        operator = term[0]
        coef = 1 if term[1] == "" else float(term[1])
        minmax = term[2]
        minmax_tuple = term[3]
        if var_name not in minmax_tuple:
            val = eval(f"{minmax}{minmax_tuple}")
            solver += f"{operator}{coef}*{val}"
        else: 
            rand = random_interval(interval)
            index = term[3].rfind(var_name)
            replaced = term[3][:index] + "rand" + term[3][index+1:]
            replaced = f"{minmax}{replaced}"
            left, right = get_left_right(term[3])
            if var_name not in left: 
                non_var_part = left 
                var_part = right 
            else: 
                non_var_part = right 
                var_part = left 
            try:
                val = eval(replaced)
            except Exception as e:
                print(e) 
                return
            if val != float(non_var_part):
                solver += f"{operator}{coef}*{var_part}" 
            else: 
                solver += f"{operator}{coef}*{non_var_part}"
    for term in cons_var_terms: 
        solver += f"{term[0]}{term[1]}*{term[2]}"
    return solver

In [10]:
def get_minmax_terms(equation):
    """Return a list of tuples. """
    return re.findall(r"(\+|\-)\s*(\d*)\s*\**\s*(min|max)(\([^\)]+\))",
                      equation)

In [11]:
def get_validate_eq(equation):
    return equation.replace("=", "==")

In [12]:
def auto_solve(eq, var_name, low=0, high=1, left_open=True, right_open=True):
    equation = f"+{eq}"
    value_term = get_value_term(equation)
    minmax_terms = get_minmax_terms(equation)
    cons_var_terms = get_cons_var_terms(equation)
    set_points = find_set_points(minmax_terms, var_name)
    print(set_points)
    intervals = create_intervals(set_points)
    print(intervals)
    results = [] 
    for interval in intervals: 
        knitted_solver = knit_solver(intervals[0], minmax_terms, cons_var_terms, "a")
        knitted_solver = f"{knitted_solver} - {value_term}"
        a = Symbol("a")
        result = solveset(eval(knitted_solver), a)
        print(result)
        if result is S.Complexes: 
            return Interval(low, high, left_open=left_open, right_open=right_open)
        elif list(result)[0] in interval: 
            results.append(result)
        elif list(result)[0].evalf() == interval.start or list(result)[0].evalf() == interval.end:
            a = np.random.uniform(interval.start, interval.end)
            validate_eq = get_validate_eq(eq)
            print(validate_eq)
            if eval(validate_eq):
                results.append(interval.union(result))
    return results 

In [13]:
eq = "50* a + min(200, 150) + min(300, 400* a) + min(300, 400 * a) + 2 * min(500*a, 200) + 200* a + max(100, 100*a)= 300"

In [14]:
validate_eq = get_validate_eq(eq)
validate_eq

'50* a + min(200, 150) + min(300, 400* a) + min(300, 400 * a) + 2 * min(500*a, 200) + 200* a + max(100, 100*a)== 300'

In [15]:
equation = f"+{eq}"
equation

'+50* a + min(200, 150) + min(300, 400* a) + min(300, 400 * a) + 2 * min(500*a, 200) + 200* a + max(100, 100*a)= 300'

In [16]:
value_term = solver.get_value_term(equation)
value_term

'300'

In [17]:
minmax_terms = solver.get_minmax_terms(equation)
minmax_terms

[('+', '', 'min', '(200, 150)'),
 ('+', '', 'min', '(300, 400* a)'),
 ('+', '', 'min', '(300, 400 * a)'),
 ('+', '2', 'min', '(500*a, 200)'),
 ('+', '', 'max', '(100, 100*a)')]

In [18]:
cons_var_terms = get_cons_var_terms(equation)
cons_var_terms

[('+', '50', 'a'), ('+', '200', 'a')]

In [19]:
set_points = find_set_points(minmax_terms, "a")
set_points

[0.4, 0.75, 1.0]

In [20]:
intervals = create_intervals(set_points)
intervals 

[Interval.open(0, 0.400000000000000),
 Interval.open(0.400000000000000, 0.750000000000000),
 Interval.open(0.750000000000000, 1.00000000000000)]

In [21]:
parts = get_parts(equation)
parts

['+50* a ',
 '+ min(200, 150) ',
 '+ min(300, 400* a) ',
 '+ min(300, 400 * a) ',
 '+ 2 * min(500*a, 200) ',
 '+ 200* a ',
 '+ max(100, 100*a)']

In [22]:
knitted_solver = knit_solver(intervals[0], minmax_terms, cons_var_terms, "a")
knitted_solver = f"{knitted_solver} - {value_term}"
knitted_solver

'+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300'

In [23]:
a = Symbol("a")
result = solveset(eval(knitted_solver), a)
list(result)

[0.0243902439024390]

In [24]:
sol = auto_solve(eq, "a")
sol

[0.4, 0.75, 1.0]
[Interval.open(0, 0.400000000000000), Interval.open(0.400000000000000, 0.750000000000000), Interval.open(0.750000000000000, 1.00000000000000)]
FiniteSet(0.024390243902439)
FiniteSet(0.024390243902439)
FiniteSet(0.024390243902439)


[FiniteSet(0.024390243902439)]

In [25]:
%timeit solver.auto_solve(eq, "a")

Interval.open(0, 0.400000000000000)
+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300
Interval.open(0, 0.400000000000000)
+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300
Interval.open(0, 0.400000000000000)
+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300
Interval.open(0, 0.400000000000000)
+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300
Interval.open(0, 0.400000000000000)
+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300
Interval.open(0, 0.400000000000000)
+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300
Interval.open(0, 0.400000000000000)
+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300
Interval.open(0, 0.400000000000000)
+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300
Interval.open(0, 0.400000000000000)
+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300
Interval.open(0, 0.400000000000000)
+1*150+1*400* a+1*400 * a+2.0*500*a+1*100+50*a+200*a - 300
Interval.open(0, 0.400000000000000)
+1*150+1*400* 

## test

In [26]:
eq = 'min(400, 400*a) + min(400, 400*a) + min(0, 100*a) = 200'
auto_solve(eq, "a")

[0.0, 1.0]
[Interval.open(0.0, 1.00000000000000)]
FiniteSet(1/4)


[FiniteSet(1/4)]

In [27]:
eq = "min(400, 400*a) = 400"
auto_solve(eq, "a")

[1.0]
[Interval.open(0, 1.00000000000000)]
FiniteSet(1)
min(400, 400*a) == 400


[]

In [28]:
eq = "min(400, 500*a) = 300"
auto_solve(eq, "a")

[0.8]
[Interval.open(0, 0.800000000000000), Interval.open(0.800000000000000, 1)]
FiniteSet(3/5)
FiniteSet(3/5)


[FiniteSet(3/5)]

In [29]:
eq = "min(500*a, 400) = 300"
auto_solve(eq, "a")

[0.8]
[Interval.open(0, 0.800000000000000), Interval.open(0.800000000000000, 1)]
FiniteSet(3/5)
FiniteSet(3/5)


[FiniteSet(3/5)]

In [30]:
eq = "800*a + min(300, 400*a) + min(300, 400*a) = 1000"
auto_solve(eq, "a")

[0.75]
[Interval.open(0, 0.750000000000000), Interval.open(0.750000000000000, 1)]
FiniteSet(5/8)
FiniteSet(5/8)


[FiniteSet(5/8)]

In [31]:
eq = "min(500, 600*a) + max(400, 500*a) = 500"
auto_solve(eq, "a")

[0.8, 0.8333333333333334]
[Interval.open(0, 0.800000000000000), Interval.open(0.800000000000000, 0.833333333333333), Interval.open(0.833333333333333, 1)]
FiniteSet(1/6)
FiniteSet(1/6)
FiniteSet(1/6)


[FiniteSet(1/6)]

In [32]:
eq = "min(500, 600*a) - min(500, 600*a) = 0"
auto_solve(eq, "a")

[0.8333333333333334]
[Interval.open(0, 0.833333333333333), Interval.open(0.833333333333333, 1)]
S.Complexes


Interval.open(0, 1)

In [33]:
eq = "800*a + min(300, 400*a) + min(300, 400*a) = 1000"
auto_solve(eq, "a")

[0.75]
[Interval.open(0, 0.750000000000000), Interval.open(0.750000000000000, 1)]
FiniteSet(5/8)
FiniteSet(5/8)


[FiniteSet(5/8)]

In [34]:
eq = "800*a + 2*min(300, 400*a) = 1000"
auto_solve(eq, "a")

[0.75]
[Interval.open(0, 0.750000000000000), Interval.open(0.750000000000000, 1)]
FiniteSet(0.625)
FiniteSet(0.625)


[FiniteSet(0.625)]

In [35]:
eq = "min(400, 500*a) = 1000"
auto_solve(eq, "a")

[0.8]
[Interval.open(0, 0.800000000000000), Interval.open(0.800000000000000, 1)]
FiniteSet(2)
FiniteSet(2)


[]

In [36]:
eq = "min(400, 600*a) = 400"
auto_solve(eq, "a")

[0.6666666666666666]
[Interval.open(0, 0.666666666666667), Interval.open(0.666666666666667, 1)]
FiniteSet(2/3)
min(400, 600*a) == 400
FiniteSet(2/3)
min(400, 600*a) == 400


[Interval.Ropen(0.666666666666667, 1)]

In [37]:
eq = "max(400*a, 600) = 600"
auto_solve(eq, "a")

[1.5]
[Interval.open(0, 1.50000000000000)]
S.Complexes


Interval.open(0, 1)

In [41]:
eq = "max(600*a, 400) + min(200*a, 500) + min(100, 300*a) + 50*a = 600"

In [42]:
%timeit auto_solve(eq, "a")

[0.3333333333333333, 0.6666666666666666, 2.5]
[Interval.open(0, 0.333333333333333), Interval.open(0.333333333333333, 0.666666666666667), Interval.open(0.666666666666667, 2.50000000000000)]
FiniteSet(4/11)
FiniteSet(4/11)
FiniteSet(4/11)
[0.3333333333333333, 0.6666666666666666, 2.5]
[Interval.open(0, 0.333333333333333), Interval.open(0.333333333333333, 0.666666666666667), Interval.open(0.666666666666667, 2.50000000000000)]
FiniteSet(4/11)
FiniteSet(4/11)
FiniteSet(4/11)
[0.3333333333333333, 0.6666666666666666, 2.5]
[Interval.open(0, 0.333333333333333), Interval.open(0.333333333333333, 0.666666666666667), Interval.open(0.666666666666667, 2.50000000000000)]
FiniteSet(4/11)
FiniteSet(4/11)
FiniteSet(4/11)
[0.3333333333333333, 0.6666666666666666, 2.5]
[Interval.open(0, 0.333333333333333), Interval.open(0.333333333333333, 0.666666666666667), Interval.open(0.666666666666667, 2.50000000000000)]
FiniteSet(4/11)
FiniteSet(4/11)
FiniteSet(4/11)
[0.3333333333333333, 0.6666666666666666, 2.5]
[Inter