<a href="https://colab.research.google.com/github/Gokcentunc/atraction/blob/main/Atraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 4leaves 4 states

In [1]:
import sympy as sp
from itertools import product, permutations
from functools import reduce

# ============================================================
# SYMBOLS
# ============================================================
pi1, pi2, pi3, pi4 = sp.symbols('pi1 pi2 pi3 pi4', nonzero=True)
pi = [pi1, pi2, pi3, pi4]

lambda_1 = sp.Symbol('lambda_1')  # \lambda_{1}
lam = sp.Symbol('lambda')         # \lambda
lam_vec = [lambda_1, lam, lam, lam]

pi12  = pi1 + pi2
pi123 = pi1 + pi2 + pi3
pisum = pi1 + pi2 + pi3 + pi4

# ============================================================
# BASIS VECTORS u^1,...,u^4
# ============================================================
u1 = sp.Matrix([pi1, pi2, pi3, pi4])          # u^1 = pi
u2 = sp.Matrix([pi1, pi2, pi3, -pi123])       # u^2
u3 = sp.Matrix([pi1, pi2, -pi12, 0])          # u^3
u4 = sp.Matrix([pi1, -pi1, 0, 0])             # u^4
u = [u1, u2, u3, u4]

def ip_pi(x, y):
    return sum((x[j] * y[j]) / pi[j] for j in range(4))

def e_u(j, ui):
    return ui[j] / pi[j]

norm = [sp.simplify(ip_pi(ui, ui)) for ui in u]

def simp(expr):
    expr = sp.together(sp.simplify(expr))
    expr = expr.subs({pisum: 1})
    expr = expr.subs({pi123: 1 - pi4})
    expr = expr.subs({pi1 + pi2: pi12})
    return sp.simplify(sp.together(expr))

# ============================================================
# p^0
# ============================================================
def p0_n_raw(idx):
    denom = reduce(lambda a, b: a * b, (norm[i - 1] for i in idx), 1)
    s = 0
    for j in range(4):
        prod_term = reduce(lambda a, b: a * b, (e_u(j, u[i - 1]) for i in idx), 1)
        s += pi[j] * prod_term
    return sp.simplify(s / denom)

def p0_4leaves():
    p0 = {}
    for idx in product(range(1, 5), repeat=4):
        p0[idx] = simp(p0_n_raw(idx))
    return p0

def p0_tripod():
    p0T = {}
    for a, b, c in product(range(1, 5), repeat=3):
        p0T[(a, b, c)] = simp(p0_n_raw((a, b, c)))
    return p0T

# ============================================================
# bar q via gluing for quartet topology
# ============================================================
def q_4leaves(topology="12|34"):
    p0T = p0_tripod()
    q = {}

    def left_right_triples(i1, i2, i3, i4, y):
        if topology == "12|34":
            return (i1, i2, y), (y, i3, i4)
        if topology == "13|24":
            return (i1, i3, y), (y, i2, i4)
        if topology == "14|23":
            return (i1, i4, y), (y, i2, i3)
        raise ValueError("Unknown topology. Use '12|34', '13|24', or '14|23'.")

    for i1, i2, i3, i4 in product(range(1, 5), repeat=4):
        s = 0
        for y in range(1, 5):
            L, R = left_right_triples(i1, i2, i3, i4, y)
            s += norm[y - 1] * (lam_vec[y - 1] * p0T[L]) * p0T[R]
        q[(i1, i2, i3, i4)] = simp(s)

    return q

# ============================================================
# tilde q = bar q / p^0
# ============================================================
def tilde_q(q, p0_4):
    tq = {}
    for idx, val in q.items():
        den = sp.simplify(p0_4[idx])
        if den != 0:
            tq[idx] = simp(sp.simplify(val / den))
        else:
            tq[idx] = simp(val)
    return tq

# ============================================================
# Symmetry grouping helpers
# ============================================================
def orbit_S4(idx):
    return sorted(set(permutations(idx, 4)))

def canon_S4(idx):
    return tuple(sorted(idx))

def orbit_swap_positions(idx, swaps):
    res = set()
    base = list(idx)
    k = len(swaps)
    for mask in range(2 ** k):
        v = base[:]
        for t in range(k):
            if (mask >> t) & 1:
                a, b = swaps[t]
                v[a], v[b] = v[b], v[a]
        res.add(tuple(v))
    return sorted(res)

def orbit_topology(idx, topology):
    if topology == "12|34":
        return orbit_swap_positions(idx, [(0, 1), (2, 3)])
    if topology == "13|24":
        return orbit_swap_positions(idx, [(0, 2), (1, 3)])
    if topology == "14|23":
        return orbit_swap_positions(idx, [(0, 3), (1, 2)])
    raise ValueError("Unknown topology")

# ============================================================
# PRINT HELPERS (NO FILES)
# ============================================================
def print_grouped(title, symbol_name, data_dict, orbit_func, skip_zeros=True, max_items=None):
    """
    Prints grouped equalities as LaTeX strings.
    - symbol_name: e.g. r"\bar{p}^0" or r"\bar{q}" or r"\tilde{q}"
    - orbit_func(idx) returns the orbit list of 4-tuples
    """
    print("\n" + "="*80)
    print(title)
    print("="*80)

    visited = set()
    count = 0
    for idx in sorted(data_dict.keys()):
        if idx in visited:
            continue
        O = orbit_func(idx)
        visited.update(O)
        val = sp.simplify(data_dict[idx])

        if skip_zeros and val == 0:
            continue

        chain = " = ".join([rf"{symbol_name}_{{{a}{b}{c}{d}}}" for (a,b,c,d) in O])
        print(rf"{chain} = {sp.latex(val)}")

        count += 1
        if max_items is not None and count >= max_items:
            print(f"... (stopped after {max_items} printed items)")
            break

def main(skip_zeros=True, max_items=None):
    p0 = p0_4leaves()

    # For p0: group by full S4 on indices, using canonical reps
    # We'll build a dict that has values accessible at canonical tuple too.
    # (p0 already has all 4^4 entries)
    reps = sorted({canon_S4(idx) for idx in p0.keys()})
    p0_rep = {rep: p0[rep] for rep in reps}  # representative dict

    print_grouped(
        title=r"No-evolution point  \bar{p}^0  (n=4, kappa=4)  [grouped by S4]",
        symbol_name=r"\bar{p}^0",
        data_dict=p0_rep,
        orbit_func=orbit_S4,
        skip_zeros=skip_zeros,
        max_items=max_items
    )

    for topo in ["12|34", "13|24", "14|23"]:
        q = q_4leaves(topology=topo)
        tq = tilde_q(q, p0)

        print_grouped(
            title=rf"\bar{{q}} for topology {topo} (with \lambda_2=\lambda_3=\lambda_4=\lambda) "
                  r"[grouped by topology symmetry]",
            symbol_name=r"\bar{q}",
            data_dict=q,
            orbit_func=lambda idx, t=topo: orbit_topology(idx, t),
            skip_zeros=skip_zeros,
            max_items=max_items
        )

        print_grouped(
            title=rf"\tilde{{q}}=\bar{{q}}/\bar{{p}}^0 for topology {topo} "
                  r"[grouped by topology symmetry]",
            symbol_name=r"\tilde{q}",
            data_dict=tq,
            orbit_func=lambda idx, t=topo: orbit_topology(idx, t),
            skip_zeros=skip_zeros,
            max_items=max_items
        )

if __name__ == "__main__":
    # max_items=None => hepsini basar (çok uzun olabilir)
    # örn. max_items=200 diyerek ilk 200 satırla sınırlayabilirsin.
    main(skip_zeros=True, max_items=None)



No-evolution point  \bar{p}^0  (n=4, kappa=4)  [grouped by S4]
\bar{p}^0_{1111} = 1
\bar{p}^0_{1122} = \bar{p}^0_{1212} = \bar{p}^0_{1221} = \bar{p}^0_{2112} = \bar{p}^0_{2121} = \bar{p}^0_{2211} = - \frac{\pi_{4}}{\pi_{4} - 1}
\bar{p}^0_{1133} = \bar{p}^0_{1313} = \bar{p}^0_{1331} = \bar{p}^0_{3113} = \bar{p}^0_{3131} = \bar{p}^0_{3311} = - \frac{\pi_{3}}{\left(\pi_{1} + \pi_{2}\right) \left(\pi_{4} - 1\right)}
\bar{p}^0_{1144} = \bar{p}^0_{1414} = \bar{p}^0_{1441} = \bar{p}^0_{4114} = \bar{p}^0_{4141} = \bar{p}^0_{4411} = \frac{\pi_{2}}{\pi_{1} \left(\pi_{1} + \pi_{2}\right)}
\bar{p}^0_{1222} = \bar{p}^0_{2122} = \bar{p}^0_{2212} = \bar{p}^0_{2221} = \frac{\pi_{4}^{3}}{\left(\pi_{4} - 1\right)^{2}} - \pi_{4}
\bar{p}^0_{1233} = \bar{p}^0_{1323} = \bar{p}^0_{1332} = \bar{p}^0_{2133} = \bar{p}^0_{2313} = \bar{p}^0_{2331} = \bar{p}^0_{3123} = \bar{p}^0_{3132} = \bar{p}^0_{3213} = \bar{p}^0_{3231} = \bar{p}^0_{3312} = \bar{p}^0_{3321} = \frac{\pi_{3} \pi_{4}}{\left(\pi_{1} + \pi_{2}\righ

# 5leaves 4 states

In [None]:
import sympy as sp
from itertools import product
from functools import reduce

# ============================================================
# SWITCHES
# ============================================================
SKIP_ZEROS = True
EXCLUDE_STATE_1 = True   # drop any index containing state 1

# Printing controls (console only)
MAX_LHS_PER_LINE = 10
USE_SMALL_FOR_LONG = True
SMALL_LHS_THRESHOLD = 18

# Optional: split a long fraction into multiple fractions with same denominator
SPLIT_LONG_FRACTIONS = True
FRAC_TERMS_PER_LINE = 8
FRAC_MAX_FRACTIONS = 10

# ============================================================
# SYMBOLS
# ============================================================
pi1, pi2, pi3, pi4 = sp.symbols("pi1 pi2 pi3 pi4", nonzero=True)
pi = [pi1, pi2, pi3, pi4]

pi12  = pi1 + pi2
pi123 = pi1 + pi2 + pi3
pisum = pi1 + pi2 + pi3 + pi4

# edge 6 (quartet)
lam1_6 = sp.Symbol("lambda_1^6")
lam6   = sp.Symbol("lambda^6")
lam_vec6 = [lam1_6, lam6, lam6, lam6]

# edge 7 (gluing)
lam1_7 = sp.Symbol("lambda_1^7")
lam7   = sp.Symbol("lambda^7")
lam_vec7 = [lam1_7, lam7, lam7, lam7]

# ============================================================
# BASIS VECTORS
# ============================================================
u1 = sp.Matrix([pi1, pi2, pi3, pi4])
u2 = sp.Matrix([pi1, pi2, pi3, -pi123])
u3 = sp.Matrix([pi1, pi2, -pi12, 0])
u4 = sp.Matrix([pi1, -pi1, 0, 0])
u  = [u1, u2, u3, u4]

def ip_pi(x, y):
    return sum((x[j] * y[j]) / pi[j] for j in range(4))

def e_u(j, ui):
    return ui[j] / pi[j]

norm = [sp.simplify(ip_pi(ui, ui)) for ui in u]

# ============================================================
# LIGHT SIMPLIFICATION (constraints only)
# ============================================================
SUBS_RULES = {pisum: 1, pi123: 1 - pi4, (pi1 + pi2): pi12}

def simp_light(expr):
    expr = expr.subs(SUBS_RULES)
    expr = sp.together(expr)
    expr = sp.cancel(expr)
    return expr

# ============================================================
# FILTERS
# ============================================================
def keep_idx(idx):
    return not (EXCLUDE_STATE_1 and (1 in idx))

def is_nan_like(expr):
    return (expr is sp.nan) or (expr == sp.nan) or (hasattr(expr, "has") and expr.has(sp.nan))

# ============================================================
# p^0 (internal)
# ============================================================
def p0_n_raw(idx):
    denom = reduce(lambda a, b: a * b, (norm[i - 1] for i in idx), 1)
    s = 0
    for j in range(4):
        prod_term = reduce(lambda a, b: a * b, (e_u(j, u[i - 1]) for i in idx), 1)
        s += pi[j] * prod_term
    return s / denom

def p0_tripod_full():
    p0T = {}
    for idx in product([1, 2, 3, 4], repeat=3):
        p0T[idx] = simp_light(p0_n_raw(idx))
    return p0T

def p0_5_only(states_out):
    p0 = {}
    for idx in product(states_out, repeat=5):
        p0[idx] = simp_light(p0_n_raw(idx))
    return p0

# ============================================================
# \bar{q} (internal)
# ============================================================
def barq_quartet_edge6_full(p0T):
    q4 = {}
    for i1, i2, i3, i4 in product([1, 2, 3, 4], repeat=4):
        s = 0
        for y in [1, 2, 3, 4]:
            s += norm[y - 1] * lam_vec6[y - 1] * p0T[(i1, i2, y)] * p0T[(y, i3, i4)]
        q4[(i1, i2, i3, i4)] = simp_light(s)
    return q4

def barq_5leaves_only(states_out, barq_T1, p0T):
    q5 = {}
    for i1, i2, i3, i4, i5 in product(states_out, repeat=5):
        s = 0
        for x in [1, 2, 3, 4]:
            s += norm[x - 1] * lam_vec7[x - 1] * barq_T1[(i1, i2, i3, x)] * p0T[(x, i4, i5)]
        q5[(i1, i2, i3, i4, i5)] = simp_light(s)
    return q5

# ============================================================
# \tilde{q} = \bar{q} / p^0 (light)
# ============================================================
def tilde_q_only(barq5, p05):
    tq = {}
    for idx, num in barq5.items():
        den = p05[idx]
        if den == 0:
            tq[idx] = sp.nan
        else:
            tq[idx] = simp_light(num / den)
    return tq

# ============================================================
# GROUP BY EXACT EXPRESSION (after light simp)
# ============================================================
def group_equal_expressions(expr_dict):
    groups = {}
    for idx, val in expr_dict.items():
        if not keep_idx(idx):
            continue
        if is_nan_like(val):
            continue

        val = simp_light(val)
        if SKIP_ZEROS and val == 0:
            continue

        key = sp.srepr(val)
        groups.setdefault(key, {"val": val, "idxs": []})
        groups[key]["idxs"].append(idx)

    out = []
    for key in sorted(groups.keys(), key=lambda k: groups[k]["idxs"][0]):
        out.append((sorted(groups[key]["idxs"]), groups[key]["val"]))
    return out

# ============================================================
# CONSOLE PRINT HELPERS
# ============================================================
def tex_idx(idx):
    return "".join(str(i) for i in idx)

def lhs_chain(symbol_tex, idxs, chunk_size=MAX_LHS_PER_LINE):
    items = [rf"{symbol_tex}_{{{tex_idx(idx)}}}" for idx in idxs]
    chunks = [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
    lines = [" = ".join(ch) for ch in chunks]
    if len(lines) == 1:
        return lines[0], len(items)
    return r"\substack{" + r"\\ ".join(lines) + r"}", len(items)

def split_fraction_to_tex_lines(expr, terms_per_frac=FRAC_TERMS_PER_LINE, max_fracs=FRAC_MAX_FRACTIONS):
    expr = simp_light(expr)
    expr = sp.together(expr)
    num, den = sp.fraction(expr)

    if den == 1 or den == -1:
        return None

    terms = sp.Add.make_args(num) if isinstance(num, sp.Add) else (num,)
    chunks = [terms[i:i + terms_per_frac] for i in range(0, len(terms), terms_per_frac)]
    if len(chunks) <= 1 or len(chunks) > max_fracs:
        return None

    den_tex = sp.latex(den)
    lines = []
    for k, ch in enumerate(chunks):
        chunk_expr = sp.Add(*ch)
        chunk_tex = sp.latex(chunk_expr)
        frac_tex = rf"\frac{{{chunk_tex}}}{{{den_tex}}}"
        lines.append(frac_tex if k == 0 else r"+ " + frac_tex)
    return lines

def rhs_lines(expr):
    expr = simp_light(expr)
    if SPLIT_LONG_FRACTIONS:
        frac_lines = split_fraction_to_tex_lines(expr)
        if frac_lines is not None:
            return frac_lines
    return [sp.latex(expr)]

def print_aligned_block(symbol_tex, idxs, rhs_expr):
    lhs, n_items = lhs_chain(symbol_tex, idxs)
    rhs = rhs_lines(rhs_expr)

    use_small = USE_SMALL_FOR_LONG and (n_items >= SMALL_LHS_THRESHOLD)

    if use_small:
        print(r"\[")
        print(r"\small")
    else:
        print(r"\[")

    print(r"\begin{aligned}")
    print(rf"{lhs} &= {rhs[0]}")
    for j in range(1, len(rhs)):
        print(rf"&\quad {rhs[j]}")
    print(r"\end{aligned}")

    if use_small:
        print(r"\normalsize")
    print(r"\]")

# ============================================================
# MAIN (PRINT ONLY \tilde{q} GROUPS)
# ============================================================
def main():
    states_out = [2, 3, 4] if EXCLUDE_STATE_1 else [1, 2, 3, 4]

    p0T = p0_tripod_full()
    barq_T1 = barq_quartet_edge6_full(p0T)

    p0_5   = p0_5_only(states_out)
    barq_5 = barq_5leaves_only(states_out, barq_T1, p0T)
    tildeq_5 = tilde_q_only(barq_5, p0_5)

    for idxs, val in group_equal_expressions(tildeq_5):
        print_aligned_block(r"\tilde{q}", idxs, val)
        print()

if __name__ == "__main__":
    main()


# Equalities 4-states

In [None]:
import sympy as sp
from functools import reduce, lru_cache

# ============================================================
# SETTINGS
# ============================================================
EXCLUDE_STATE_1 = True
LIGHT_SIMPLIFY = True

# ============================================================
# SYMBOLS (4-states fixed)
# ============================================================
pi1, pi2, pi3, pi4 = sp.symbols("pi1 pi2 pi3 pi4", nonzero=True)
pi = [pi1, pi2, pi3, pi4]

pi12  = pi1 + pi2
pi123 = pi1 + pi2 + pi3
pisum = pi1 + pi2 + pi3 + pi4

lam1_6 = sp.Symbol("lambda_1^6")
lam6   = sp.Symbol("lambda^6")
lam_vec6 = [lam1_6, lam6, lam6, lam6]

lam1_7 = sp.Symbol("lambda_1^7")
lam7   = sp.Symbol("lambda^7")
lam_vec7 = [lam1_7, lam7, lam7, lam7]

lam1_8 = sp.Symbol("lambda_1^8")
lam8   = sp.Symbol("lambda^8")
lam_vec8 = [lam1_8, lam8, lam8, lam8]

# ============================================================
# BASIS (your working choice)
# ============================================================
u1 = sp.Matrix([pi1, pi2, pi3, pi4])
u2 = sp.Matrix([pi1, pi2, pi3, -pi123])
u3 = sp.Matrix([pi1, pi2, -pi12, 0])
u4 = sp.Matrix([pi1, -pi1, 0, 0])
u = [u1, u2, u3, u4]

def ip_pi(x, y):
    return sum((x[j] * y[j]) / pi[j] for j in range(4))

def e_u(j, ui):
    return ui[j] / pi[j]

norm = [sp.simplify(ip_pi(ui, ui)) for ui in u]

# ============================================================
# LIGHT SIMPLIFICATION (your intended rules)
# ============================================================
SUBS_RULES = {
    pisum: 1,
    pi123: 1 - pi4,
    (pi1 + pi2): pi12,
}

def simp_light(expr):
    if not LIGHT_SIMPLIFY:
        return expr
    expr = expr.subs(SUBS_RULES)
    expr = sp.together(expr)
    expr = sp.cancel(expr)
    expr = expr.subs(SUBS_RULES)
    return expr

# ============================================================
# Index helpers
# ============================================================
def parse_idx(s: str):
    s = s.strip().replace(" ", "")
    if not s or not all(ch.isdigit() for ch in s):
        raise ValueError("Index must be digits like 2244, 22444, 425544.")
    return tuple(int(ch) for ch in s)

def keep_idx(idx):
    return not (EXCLUDE_STATE_1 and (1 in idx))

# ============================================================
# p0_bar (cached)
# ============================================================
@lru_cache(maxsize=None)
def p0_bar(idx):
    denom = reduce(lambda a, b: a * b, (norm[i - 1] for i in idx), 1)
    s = 0
    for j in range(4):
        prod_term = reduce(lambda a, b: a * b, (e_u(j, u[i - 1]) for i in idx), 1)
        s += pi[j] * prod_term
    return simp_light(s / denom)

# ============================================================
# barq / tildeq
# ============================================================
@lru_cache(maxsize=None)
def barq4(idx4):
    i1, i2, i3, i4 = idx4
    s = 0
    for y in [1, 2, 3, 4]:
        s += norm[y-1] * lam_vec6[y-1] * p0_bar((i1, i2, y)) * p0_bar((y, i3, i4))
    return simp_light(s)

@lru_cache(maxsize=None)
def tildeq4(idx4):
    den = p0_bar(idx4)
    if den == 0:
        return sp.nan
    return simp_light(barq4(idx4) / den)

@lru_cache(maxsize=None)
def barq5(idx5):
    i1, i2, i3, i4, i5 = idx5
    s = 0
    for x in [1, 2, 3, 4]:
        s += norm[x-1] * lam_vec7[x-1] * barq4((i1, i2, i3, x)) * p0_bar((x, i4, i5))
    return simp_light(s)

@lru_cache(maxsize=None)
def tildeq5(idx5):
    den = p0_bar(idx5)
    if den == 0:
        return sp.nan
    return simp_light(barq5(idx5) / den)

@lru_cache(maxsize=None)
def barq6(idx6):
    i1, i2, i3, i4, i5, i6 = idx6
    s = 0
    for x in [1, 2, 3, 4]:
        s += norm[x-1] * lam_vec8[x-1] * barq5((i1, i2, i3, i4, x)) * p0_bar((x, i5, i6))
    return simp_light(s)

@lru_cache(maxsize=None)
def tildeq6(idx6):
    den = p0_bar(idx6)
    if den == 0:
        return sp.nan
    return simp_light(barq6(idx6) / den)

# ============================================================
# Equality under simplex (FIXED Groebner domain)
# ============================================================
def equal_under_simplex(expr):
    """
    Decide expr == 0 under pisum=1 by reducing the NUMERATOR modulo <pisum-1>.
    IMPORTANT: use domain='EX' so coefficients may involve lambdas.
    """
    expr = simp_light(expr)
    expr = sp.together(expr)
    num, den = sp.fraction(expr)

    num = sp.expand(num)

    # Groebner over pi-variables, coefficients in expression domain (allows lambdas)
    G = sp.groebner([pisum - 1], pi1, pi2, pi3, pi4, order="lex", domain="EX")

    rem = G.reduce(num)[1]
    rem = sp.factor(rem)
    return (rem == 0), rem, den

# ============================================================
# Interactive compare
# ============================================================
def main():
    mode = input("Choose object (quartet / 5 / 6): ").strip().lower()
    if mode not in ["quartet", "5", "6"]:
        raise ValueError("Mode must be one of: quartet, 5, 6.")

    s1 = input("Enter index #1 (e.g. 2244 / 22444 / 425544): ").strip()
    s2 = input("Enter index #2: ").strip()

    idx1 = parse_idx(s1)
    idx2 = parse_idx(s2)

    if not keep_idx(idx1) or not keep_idx(idx2):
        raise ValueError("Index rejected: it contains state 1 while EXCLUDE_STATE_1=True.")

    if mode == "quartet":
        if len(idx1) != 4 or len(idx2) != 4:
            raise ValueError("quartet requires length-4 indices.")
        v1 = tildeq4(idx1)
        v2 = tildeq4(idx2)

    elif mode == "5":
        if len(idx1) != 5 or len(idx2) != 5:
            raise ValueError("5 requires length-5 indices.")
        v1 = tildeq5(idx1)
        v2 = tildeq5(idx2)

    else:  # mode == "6"
        if len(idx1) != 6 or len(idx2) != 6:
            raise ValueError("6 requires length-6 indices.")
        v1 = tildeq6(idx1)
        v2 = tildeq6(idx2)

    diff = v1 - v2

    naive_is_zero = (simp_light(diff) == 0)
    is_zero, rem, den = equal_under_simplex(diff)

    print("\n--- Results ---")
    print(f"tilde({s1}) =\n{v1}")
    print(f"\ntilde({s2}) =\n{v2}")

    print("\n--- Equality check ---")
    print("Naive (diff==0 after light simp) ->", naive_is_zero)
    print("Under constraint pi1+pi2+pi3+pi4=1 ->", is_zero)

    if not is_zero:
        print("\nRemainder of numerator mod <pisum-1>:")
        print(rem)
        print("\nDenominator (for reference):")
        print(den)

if __name__ == "__main__":
    main()


Choose object (quartet / 5 / 6): 5
Enter index #1 (e.g. 2244 / 22444 / 425544): 24444
Enter index #2: 34444

--- Results ---
tilde(24444) =
(-lambda^6*lambda^7*pi1**2*pi2*pi4 + lambda^6*lambda^7*pi1**2*pi4 - lambda^6*lambda^7*pi1**2 - lambda^6*lambda^7*pi1*pi2**2*pi4 - lambda^6*lambda^7*pi1*pi2*pi3 - 2*lambda^6*lambda^7*pi1*pi2*pi4 + 2*lambda^6*lambda^7*pi1*pi2 + lambda^6*lambda^7*pi2**2*pi4 - lambda^6*lambda^7*pi2**2 + lambda^6*lambda_1^7*pi1**2*pi2*pi4 - lambda^6*lambda_1^7*pi1**2*pi2 + lambda^6*lambda_1^7*pi1*pi2**2*pi4 - lambda^6*lambda_1^7*pi1*pi2**2)/(pi1**2*pi4 - pi1**2 - pi1*pi2*pi4 + pi1*pi2 + pi2**2*pi4 - pi2**2)

tilde(34444) =
(-lambda^6*lambda^7*pi1**2*pi2*pi4 + lambda^6*lambda^7*pi1**2*pi4 - lambda^6*lambda^7*pi1**2 - lambda^6*lambda^7*pi1*pi2**2*pi4 - lambda^6*lambda^7*pi1*pi2*pi3 - 2*lambda^6*lambda^7*pi1*pi2*pi4 + 2*lambda^6*lambda^7*pi1*pi2 + lambda^6*lambda^7*pi2**2*pi4 - lambda^6*lambda^7*pi2**2 + lambda^6*lambda_1^7*pi1**2*pi2*pi4 - lambda^6*lambda_1^7*pi1**2*pi2 +

# Equalities k-states

In [None]:
import sympy as sp
from functools import lru_cache

# ============================================================
# SETTINGS
# ============================================================
EXCLUDE_STATE_1 = True
LIGHT_SIMPLIFY = True

# ============================================================
# Helpers
# ============================================================
def parse_idx(s: str):
    s = s.strip().replace(" ", "")
    if not s or not all(ch.isdigit() for ch in s):
        raise ValueError("Index must be digits like 2244, 22444, 425544.")
    return tuple(int(ch) for ch in s)

def keep_idx(idx):
    return not (EXCLUDE_STATE_1 and (1 in idx))

# ============================================================
# MAIN
# ============================================================
def main():
    K = int(input("Enter kappa (number of states), e.g. 4: ").strip())
    if K < 2:
        raise ValueError("kappa must be >= 2")

    pi = sp.symbols(" ".join([f"pi{i}" for i in range(1, K + 1)]), nonzero=True)
    PI_SUM = sum(pi)

    def pisum(m):
        return sum(pi[:m])

    def simp_light(expr):
        if not LIGHT_SIMPLIFY:
            return expr
        expr = expr.subs({PI_SUM: 1})
        expr = sp.together(expr)
        expr = sp.cancel(expr)
        return expr

    @lru_cache(maxsize=None)
    def e_u(j, i):
        if i == 1:
            return sp.Integer(1)
        if j + i <= K + 1:
            return sp.Integer(1)
        if j == K + 2 - i:
            return -pisum(K + 1 - i) / pi[(K + 2 - i) - 1]
        return sp.Integer(0)

    @lru_cache(maxsize=None)
    def norm_u(i):
        if i == 1:
            return sp.Integer(1)
        a = K + 1 - i
        b = K + 2 - i
        denom = pi[(K + 2 - i) - 1]
        return simp_light(pisum(a) * pisum(b) / denom)

    @lru_cache(maxsize=None)
    def p0_bar(idx):
        denom = sp.Integer(1)
        for ir in idx:
            denom *= norm_u(ir)

        s = sp.Integer(0)
        for t in range(1, K + 1):
            prod_term = sp.Integer(1)
            for ir in idx:
                prod_term *= e_u(t, ir)
            s += pi[t - 1] * prod_term

        return simp_light(s / denom)

    def make_lam_vec(edge_label: str):
        lam1 = sp.Symbol(f"lambda_1^{edge_label}")
        lam  = sp.Symbol(f"lambda^{edge_label}")
        return [lam1] + [lam] * (K - 1)

    lam_vec6 = make_lam_vec("6")
    lam_vec7 = make_lam_vec("7")
    lam_vec8 = make_lam_vec("8")

    @lru_cache(maxsize=None)
    def barq4(idx4):
        i1, i2, i3, i4 = idx4
        s = sp.Integer(0)
        for y in range(1, K + 1):
            s += norm_u(y) * lam_vec6[y - 1] * p0_bar((i1, i2, y)) * p0_bar((y, i3, i4))
        return simp_light(s)

    @lru_cache(maxsize=None)
    def tildeq4(idx4):
        den = p0_bar(idx4)
        return sp.nan if den == 0 else simp_light(barq4(idx4) / den)

    @lru_cache(maxsize=None)
    def barq5(idx5):
        i1, i2, i3, i4, i5 = idx5
        s = sp.Integer(0)
        for x in range(1, K + 1):
            s += norm_u(x) * lam_vec7[x - 1] * barq4((i1, i2, i3, x)) * p0_bar((x, i4, i5))
        return simp_light(s)

    @lru_cache(maxsize=None)
    def tildeq5(idx5):
        den = p0_bar(idx5)
        return sp.nan if den == 0 else simp_light(barq5(idx5) / den)

    @lru_cache(maxsize=None)
    def barq6(idx6):
        i1, i2, i3, i4, i5, i6 = idx6
        s = sp.Integer(0)
        for x in range(1, K + 1):
            s += norm_u(x) * lam_vec8[x - 1] * barq5((i1, i2, i3, i4, x)) * p0_bar((x, i5, i6))
        return simp_light(s)

    @lru_cache(maxsize=None)
    def tildeq6(idx6):
        den = p0_bar(idx6)
        return sp.nan if den == 0 else simp_light(barq6(idx6) / den)

    def equal_under_simplex(expr):
        expr = simp_light(expr)
        expr = sp.together(expr)
        num, den = sp.fraction(expr)
        num = sp.expand(num)

        G = sp.groebner([PI_SUM - 1], *pi, order="lex", domain="EX")
        rem = G.reduce(num)[1]
        rem = sp.factor(rem)
        return (rem == 0), rem, den

    mode = input("Choose object (quartet / 5 / 6): ").strip().lower()
    if mode not in ["quartet", "5", "6"]:
        raise ValueError("Mode must be one of: quartet, 5, 6.")

    s1 = input("Enter index #1 (e.g. 2244 / 22444 / 425544): ").strip()
    s2 = input("Enter index #2: ").strip()

    idx1 = parse_idx(s1)
    idx2 = parse_idx(s2)

    if not keep_idx(idx1) or not keep_idx(idx2):
        raise ValueError("Index rejected: it contains state 1 while EXCLUDE_STATE_1=True.")

    def tilde_value(idx):
        if mode == "quartet":
            if len(idx) != 4:
                raise ValueError("quartet requires length-4 indices.")
            return tildeq4(idx)
        if mode == "5":
            if len(idx) != 5:
                raise ValueError("5 requires length-5 indices.")
            return tildeq5(idx)
        if mode == "6":
            if len(idx) != 6:
                raise ValueError("6 requires length-6 indices.")
            return tildeq6(idx)
        raise ValueError("Internal error.")

    v1 = tilde_value(idx1)
    v2 = tilde_value(idx2)
    diff = v1 - v2

    naive_is_zero = (simp_light(diff) == 0)
    is_zero, rem, den = equal_under_simplex(diff)

    print("\n--- Results ---")
    print(f"tilde({s1}) =\n{v1}")
    print(f"\ntilde({s2}) =\n{v2}")

    print("\n--- Equality check ---")
    print("Naive (after sum(pi)=1) ->", naive_is_zero)
    print("Under sum(pi)=1 (Groebner remainder) ->", is_zero)

    if not is_zero:
        print("\nRemainder of numerator mod <sum(pi)-1>:")
        print(rem)
        print("\nDenominator (for reference):")
        print(den)

if __name__ == "__main__":
    main()


Enter kappa (number of states), e.g. 4: 5
Choose object (quartet / 5 / 6): 6
Enter index #1 (e.g. 2244 / 22444 / 425544): 442355
Enter index #2: 553244

--- Results ---
tilde(442355) =
(-lambda^6*lambda^7*lambda^8*pi1**4 - 4*lambda^6*lambda^7*lambda^8*pi1**3*pi2 - 3*lambda^6*lambda^7*lambda^8*pi1**3*pi3 - 2*lambda^6*lambda^7*lambda^8*pi1**3*pi4 - 6*lambda^6*lambda^7*lambda^8*pi1**2*pi2**2 - 9*lambda^6*lambda^7*lambda^8*pi1**2*pi2*pi3 - 6*lambda^6*lambda^7*lambda^8*pi1**2*pi2*pi4 - 3*lambda^6*lambda^7*lambda^8*pi1**2*pi3**2 - 4*lambda^6*lambda^7*lambda^8*pi1**2*pi3*pi4 - lambda^6*lambda^7*lambda^8*pi1**2*pi4**2 + lambda^6*lambda^7*lambda^8*pi1**2*pi5**2 + lambda^6*lambda^7*lambda^8*pi1**2*pi5 - 2*lambda^6*lambda^7*lambda^8*pi1**2 - 4*lambda^6*lambda^7*lambda^8*pi1*pi2**3 - 9*lambda^6*lambda^7*lambda^8*pi1*pi2**2*pi3 - 6*lambda^6*lambda^7*lambda^8*pi1*pi2**2*pi4 - 6*lambda^6*lambda^7*lambda^8*pi1*pi2*pi3**2 - 8*lambda^6*lambda^7*lambda^8*pi1*pi2*pi3*pi4 - 2*lambda^6*lambda^7*lambda^8*pi1