<a href="https://colab.research.google.com/github/Gokcentunc/atraction/blob/main/5_leaves.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 5-leaves-equality

In [None]:
import sympy as sp
from itertools import product, permutations
from functools import reduce

# ============================================================
# SETTINGS
# ============================================================
EXCLUDE_STATE_1 = True
SKIP_ZEROS = True
MAX_EQUATIONS = 30                 # print at most this many relations
SYMBOL_TEX = r"\tilde{q}"          # output symbol in TeX

# If True: first checks a small subset of perms, then all S5
USE_FAST_THEN_FULL_PERM_CHECK = True
FAST_PERM_COUNT = 20

# ============================================================
# SYMBOLS
# ============================================================
pi1, pi2, pi3, pi4 = sp.symbols("pi1 pi2 pi3 pi4", nonzero=True)
pi = [pi1, pi2, pi3, pi4]

pi12  = pi1 + pi2
pi123 = pi1 + pi2 + pi3
pisum = pi1 + pi2 + pi3 + pi4

lam1_6 = sp.Symbol("lambda_1^6")
lam6   = sp.Symbol("lambda^6")
lam1_7 = sp.Symbol("lambda_1^7")
lam7   = sp.Symbol("lambda^7")

lam_vec6 = [lam1_6, lam6, lam6, lam6]
lam_vec7 = [lam1_7, lam7, lam7, lam7]

# We eliminate lambdas by collapsing products into A,B,C
A = sp.Symbol("A")  # lambda_1^6 * lambda^7
B = sp.Symbol("B")  # lambda_1^7 * lambda^6
C = sp.Symbol("C")  # lambda^6 * lambda^7

# ============================================================
# LIGHT SIMPLIFICATION
# ============================================================
SUBS_RULES = {pisum: 1, pi123: 1 - pi4, (pi1 + pi2): pi12}

def simp_light(expr):
    expr = sp.sympify(expr).subs(SUBS_RULES)
    expr = sp.together(expr)
    expr = sp.cancel(expr)
    return expr

def keep_idx(idx):
    return not (EXCLUDE_STATE_1 and (1 in idx))

def is_nan_like(expr):
    return (expr is sp.nan) or (expr == sp.nan) or (hasattr(expr, "has") and expr.has(sp.nan))

# ============================================================
# BASIS VECTORS
# ============================================================
u1 = sp.Matrix([pi1, pi2, pi3, pi4])
u2 = sp.Matrix([pi1, pi2, pi3, -pi123])
u3 = sp.Matrix([pi1, pi2, -pi12, 0])
u4 = sp.Matrix([pi1, -pi1, 0, 0])
u  = [u1, u2, u3, u4]

def ip_pi(x, y):
    return sum((x[j] * y[j]) / pi[j] for j in range(4))

def e_u(j, ui):
    return ui[j] / pi[j]

norm = [simp_light(ip_pi(ui, ui)) for ui in u]

# ============================================================
# p0, barq, tildeq
# ============================================================
def p0_n_raw(idx):
    denom = reduce(lambda a, b: a * b, (norm[i - 1] for i in idx), 1)
    s = 0
    for j in range(4):
        prod_term = reduce(lambda a, b: a * b, (e_u(j, u[i - 1]) for i in idx), 1)
        s += pi[j] * prod_term
    return s / denom

def p0_tripod_full():
    return {idx: simp_light(p0_n_raw(idx)) for idx in product([1,2,3,4], repeat=3)}

def p0_5_only(states_out):
    return {idx: simp_light(p0_n_raw(idx)) for idx in product(states_out, repeat=5)}

def barq_quartet_edge6_full(p0T):
    q4 = {}
    for i1,i2,i3,i4 in product([1,2,3,4], repeat=4):
        s = 0
        for y in [1,2,3,4]:
            s += norm[y-1]*lam_vec6[y-1]*p0T[(i1,i2,y)]*p0T[(y,i3,i4)]
        q4[(i1,i2,i3,i4)] = simp_light(s)
    return q4

def barq_5leaves_only(states_out, barq_T1, p0T):
    q5 = {}
    for i1,i2,i3,i4,i5 in product(states_out, repeat=5):
        s = 0
        for x in [1,2,3,4]:
            s += norm[x-1]*lam_vec7[x-1]*barq_T1[(i1,i2,i3,x)]*p0T[(x,i4,i5)]
        q5[(i1,i2,i3,i4,i5)] = simp_light(s)
    return q5

def tilde_q_only(barq5, p05):
    tq = {}
    for idx, num in barq5.items():
        den = p05[idx]
        tq[idx] = sp.nan if den == 0 else simp_light(num/den)
    return tq

# ============================================================
# Lambda-elimination: map products -> A,B,C and extract coeff vector
# vector basis: [A, B, C, 1]
# ============================================================
PROD_MAP = {
    lam1_6*lam7: A, lam7*lam1_6: A,
    lam1_7*lam6: B, lam6*lam1_7: B,
    lam6*lam7:   C, lam7*lam6:   C,
}

def to_ABC(expr):
    expr = sp.expand(simp_light(expr))
    # replace products with A,B,C
    expr = expr.xreplace(PROD_MAP).subs(PROD_MAP)
    expr = sp.expand(expr)
    return simp_light(expr)

def coeff_vec_ABC(expr):
    expr = to_ABC(expr)
    P = sp.Poly(expr, A, B, C, domain="EX")
    d = dict(zip(P.monoms(), P.coeffs()))
    cA = simp_light(d.get((1,0,0), 0))
    cB = simp_light(d.get((0,1,0), 0))
    cC = simp_light(d.get((0,0,1), 0))
    c0 = simp_light(d.get((0,0,0), 0))
    return sp.Matrix([cA, cB, cC, c0])

# ============================================================
# Solve a from vL - vB = a (vA - vB) quickly
# ============================================================
def solve_a_fast(vL, vA, vB):
    dAB = vA - vB
    dLB = vL - vB
    for k in range(dAB.rows):
        if simp_light(dAB[k]) != 0:
            a = simp_light(dLB[k] / dAB[k])
            for j in range(dAB.rows):
                if simp_light(dLB[j] - a*dAB[j]) != 0:
                    return None
            return a
    return None

def is_trivial_a(a):
    a = simp_light(a)
    return a == 0 or a == 1

# ============================================================
# Permutation invariance check (same a for all sigma in S5)
# ============================================================
S5_PERMS = list(permutations(range(5)))
FAST_PERMS = S5_PERMS[:max(0, min(FAST_PERM_COUNT, len(S5_PERMS)))]

def apply_perm(perm, idx):
    return tuple(idx[p] for p in perm)

def check_perm_invariance(vec, L, Aidx, Bidx, a, perms):
    z = sp.zeros(vec[L].rows, 1)
    for perm in perms:
        Lp = apply_perm(perm, L)
        Ap = apply_perm(perm, Aidx)
        Bp = apply_perm(perm, Bidx)
        if (Lp not in vec) or (Ap not in vec) or (Bp not in vec):
            return False
        if simp_light(vec[Lp] - vec[Bp] - a*(vec[Ap] - vec[Bp])) != z:
            return False
    return True

def check_perm_invariance_fast_then_full(vec, L, Aidx, Bidx, a):
    if USE_FAST_THEN_FULL_PERM_CHECK and FAST_PERMS:
        if not check_perm_invariance(vec, L, Aidx, Bidx, a, FAST_PERMS):
            return False
    return check_perm_invariance(vec, L, Aidx, Bidx, a, S5_PERMS)

# ============================================================
# TeX helpers
# ============================================================
def tex_idx(idx): return "".join(map(str, idx))
def latex(x): return sp.latex(simp_light(x))

def print_eq(L, Aidx, Bidx, a):
    print(r"\[")
    print(r"\begin{aligned}")
    print(rf"{SYMBOL_TEX}_{{{tex_idx(L)}}} &= {latex(a)}\,{SYMBOL_TEX}_{{{tex_idx(Aidx)}}}"
          rf" + \Big(1-{latex(a)}\Big){SYMBOL_TEX}_{{{tex_idx(Bidx)}}}")
    print(r"\end{aligned}")
    print(r"\]")

# ============================================================
# MAIN
# ============================================================
def main():
    sL = input("LHS index (e.g. 22333): ").strip()
    if len(sL) != 5 or any(ch not in "1234" for ch in sL):
        print("% Index must be 5 digits using 1..4 (e.g. 22333).")
        return
    L = tuple(int(ch) for ch in sL)

    if EXCLUDE_STATE_1 and (1 in L):
        print("% LHS contains state 1 but EXCLUDE_STATE_1=True.")
        return

    states_out = [2,3,4] if EXCLUDE_STATE_1 else [1,2,3,4]

    # compute tilde q once
    p0T = p0_tripod_full()
    barq_T1 = barq_quartet_edge6_full(p0T)
    p0_5 = p0_5_only(states_out)
    barq_5 = barq_5leaves_only(states_out, barq_T1, p0T)
    tildeq_5 = tilde_q_only(barq_5, p0_5)

    # filter
    usable = {}
    for idx, val in tildeq_5.items():
        if not keep_idx(idx):
            continue
        if is_nan_like(val):
            continue
        val = simp_light(val)
        if SKIP_ZEROS and val == 0:
            continue
        usable[idx] = val

    if L not in usable:
        print("% Requested LHS not usable (zero/NaN/filtered).")
        return

    # coefficient vectors in A,B,C,1
    vec = {idx: coeff_vec_ABC(v) for idx, v in usable.items()}

    # Optional: show direct equalities (same vector)
    equal_list = [idx for idx in vec if idx != L and (vec[idx] == vec[L])]
    if equal_list:
        print("% Direct equalities (same after eliminating lambdas):")
        for idx in equal_list[:20]:
            print(rf"\[{SYMBOL_TEX}_{{{tex_idx(L)}}} = {SYMBOL_TEX}_{{{tex_idx(idx)}}}\]")
        if len(equal_list) > 20:
            print(f"% ... plus {len(equal_list)-20} more.")
        print()

    # search affine 2-term relations, invariant under all permutations
    keys = list(vec.keys())
    keys.remove(L)

    printed = 0
    seen = set()

    for i in range(len(keys)):
        if printed >= MAX_EQUATIONS:
            break
        Aidx = keys[i]
        vA = vec[Aidx]
        for j in range(i+1, len(keys)):
            if printed >= MAX_EQUATIONS:
                break
            Bidx = keys[j]
            vB = vec[Bidx]

            a = solve_a_fast(vec[L], vA, vB)
            if a is None or is_trivial_a(a):
                continue

            # Require SAME a works for all σ in S5
            if not check_perm_invariance_fast_then_full(vec, L, Aidx, Bidx, a):
                continue

            key = (Aidx, Bidx, sp.srepr(a))
            if key in seen:
                continue
            seen.add(key)

            print_eq(L, Aidx, Bidx, a)
            printed += 1

    if printed == 0 and not equal_list:
        print("% No nontrivial S5-invariant 2-term relations found.")
        print("% Try SKIP_ZEROS=False or EXCLUDE_STATE_1=False or increase MAX_EQUATIONS.")

if __name__ == "__main__":
    main()

LHS index (e.g. 22333): 33222


# Model_eq

In [None]:
import sympy as sp
from itertools import product
from functools import reduce

# ============================================================
# SETTINGS
# ============================================================
EXCLUDE_STATE_1 = True
SKIP_ZEROS = True
MAX_EQUATIONS = 30   # kaç ilişki yazdırsın
SYMBOL_TEX = r"\tilde{q}"  # çıktıda hangi sembol yazsın

# ============================================================
# SYMBOLS
# ============================================================
pi1, pi2, pi3, pi4 = sp.symbols("pi1 pi2 pi3 pi4", nonzero=True)
pi = [pi1, pi2, pi3, pi4]
pi12  = pi1 + pi2
pi123 = pi1 + pi2 + pi3
pisum = pi1 + pi2 + pi3 + pi4

lam1_6 = sp.Symbol("lambda_1^6")
lam6   = sp.Symbol("lambda^6")
lam1_7 = sp.Symbol("lambda_1^7")
lam7   = sp.Symbol("lambda^7")

lam_vec6 = [lam1_6, lam6, lam6, lam6]
lam_vec7 = [lam1_7, lam7, lam7, lam7]

# Only products matter; we'll eliminate lambdas by mapping products to A,B,C
A = sp.Symbol("A")  # lam1_6*lam7
B = sp.Symbol("B")  # lam1_7*lam6
C = sp.Symbol("C")  # lam6*lam7

SUBS_RULES = {pisum: 1, pi123: 1 - pi4, (pi1 + pi2): pi12}

def simp_light(expr):
    expr = sp.sympify(expr).subs(SUBS_RULES)
    expr = sp.together(expr)
    expr = sp.cancel(expr)
    return expr

def keep_idx(idx):
    return not (EXCLUDE_STATE_1 and (1 in idx))

def is_nan_like(expr):
    return (expr is sp.nan) or (expr == sp.nan) or (hasattr(expr, "has") and expr.has(sp.nan))

# ============================================================
# BASIS VECTORS
# ============================================================
u1 = sp.Matrix([pi1, pi2, pi3, pi4])
u2 = sp.Matrix([pi1, pi2, pi3, -pi123])
u3 = sp.Matrix([pi1, pi2, -pi12, 0])
u4 = sp.Matrix([pi1, -pi1, 0, 0])
u  = [u1, u2, u3, u4]

def ip_pi(x, y):
    return sum((x[j] * y[j]) / pi[j] for j in range(4))

def e_u(j, ui):
    return ui[j] / pi[j]

norm = [simp_light(ip_pi(ui, ui)) for ui in u]

# ============================================================
# p0, barq, tildeq
# ============================================================
def p0_n_raw(idx):
    denom = reduce(lambda a, b: a * b, (norm[i - 1] for i in idx), 1)
    s = 0
    for j in range(4):
        prod_term = reduce(lambda a, b: a * b, (e_u(j, u[i - 1]) for i in idx), 1)
        s += pi[j] * prod_term
    return s / denom

def p0_tripod_full():
    return {idx: simp_light(p0_n_raw(idx)) for idx in product([1,2,3,4], repeat=3)}

def p0_5_only(states_out):
    return {idx: simp_light(p0_n_raw(idx)) for idx in product(states_out, repeat=5)}

def barq_quartet_edge6_full(p0T):
    q4 = {}
    for i1,i2,i3,i4 in product([1,2,3,4], repeat=4):
        s = 0
        for y in [1,2,3,4]:
            s += norm[y-1]*lam_vec6[y-1]*p0T[(i1,i2,y)]*p0T[(y,i3,i4)]
        q4[(i1,i2,i3,i4)] = simp_light(s)
    return q4

def barq_5leaves_only(states_out, barq_T1, p0T):
    q5 = {}
    for i1,i2,i3,i4,i5 in product(states_out, repeat=5):
        s = 0
        for x in [1,2,3,4]:
            s += norm[x-1]*lam_vec7[x-1]*barq_T1[(i1,i2,i3,x)]*p0T[(x,i4,i5)]
        q5[(i1,i2,i3,i4,i5)] = simp_light(s)
    return q5

def tilde_q_only(barq5, p05):
    tq = {}
    for idx, num in barq5.items():
        den = p05[idx]
        tq[idx] = sp.nan if den == 0 else simp_light(num/den)
    return tq

# ============================================================
# Lambda-elimination: map products -> A,B,C and extract coeff vector
# vector basis: [A, B, C, 1]
# ============================================================
PROD_MAP = {
    lam1_6*lam7: A, lam7*lam1_6: A,
    lam1_7*lam6: B, lam6*lam1_7: B,
    lam6*lam7:   C, lam7*lam6:   C,
}

def to_ABC(expr):
    expr = sp.expand(simp_light(expr))
    expr = expr.xreplace(PROD_MAP).subs(PROD_MAP)  # extra safety
    expr = sp.expand(expr)
    # If any naked lambdas remain, keep them; but usually they won't.
    return simp_light(expr)

def coeff_vec_ABC(expr):
    expr = to_ABC(expr)
    # treat as polynomial in A,B,C with rational functions in pi as coefficients
    P = sp.Poly(expr, A, B, C, domain="EX")
    d = dict(zip(P.monoms(), P.coeffs()))
    cA = simp_light(d.get((1,0,0), 0))
    cB = simp_light(d.get((0,1,0), 0))
    cC = simp_light(d.get((0,0,1), 0))
    c0 = simp_light(d.get((0,0,0), 0))
    return sp.Matrix([cA, cB, cC, c0])

# ============================================================
# Solve a from vL - vB = a (vA - vB) quickly
# ============================================================
def solve_a_fast(vL, vA, vB):
    dAB = vA - vB
    dLB = vL - vB
    # pick first nonzero component of dAB
    for k in range(dAB.rows):
        if simp_light(dAB[k]) != 0:
            a = simp_light(dLB[k] / dAB[k])
            # verify all components
            for j in range(dAB.rows):
                if simp_light(dLB[j] - a*dAB[j]) != 0:
                    return None
            return a
    return None  # vA == vB, cannot determine a

def is_trivial_a(a):
    a = simp_light(a)
    return a == 0 or a == 1

def tex_idx(idx): return "".join(map(str, idx))
def latex(x): return sp.latex(simp_light(x))

def print_eq(L, Aidx, Bidx, a):
    print(r"\[")
    print(r"\begin{aligned}")
    print(rf"{SYMBOL_TEX}_{{{tex_idx(L)}}} &= {latex(a)}\,{SYMBOL_TEX}_{{{tex_idx(Aidx)}}}"
          rf" + \Big(1-{latex(a)}\Big){SYMBOL_TEX}_{{{tex_idx(Bidx)}}}")
    print(r"\end{aligned}")
    print(r"\]")

# ============================================================
# MAIN: user picks LHS, we output lambda-free nontrivial relations
# ============================================================
def main():
    sL = input("LHS index (e.g. 22333): ").strip()
    if len(sL) != 5 or any(ch not in "1234" for ch in sL):
        print("% Index must be 5 digits using 1..4.")
        return
    L = tuple(int(ch) for ch in sL)

    if EXCLUDE_STATE_1 and (1 in L):
        print("% LHS contains state 1 but EXCLUDE_STATE_1=True.")
        return

    states_out = [2,3,4] if EXCLUDE_STATE_1 else [1,2,3,4]

    # compute tildeq once
    p0T = p0_tripod_full()
    barq_T1 = barq_quartet_edge6_full(p0T)
    p0_5 = p0_5_only(states_out)
    barq_5 = barq_5leaves_only(states_out, barq_T1, p0T)
    tildeq_5 = tilde_q_only(barq_5, p0_5)

    usable = {}
    for idx, val in tildeq_5.items():
        if not keep_idx(idx):
            continue
        if is_nan_like(val):
            continue
        val = simp_light(val)
        if SKIP_ZEROS and val == 0:
            continue
        usable[idx] = val

    if L not in usable:
        print("% Requested LHS not usable (zero/NaN/filtered).")
        return

    # Build coefficient vectors in A,B,C,1
    vec = {idx: coeff_vec_ABC(v) for idx, v in usable.items()}

    # 1) First: detect direct equalities (same vector) quickly
    # (these correspond to fully identical expressions after lambda-elimination)
    equal_list = [idx for idx in vec if idx != L and (vec[idx] == vec[L])]
    if equal_list:
        print("% Direct equalities (same after eliminating lambdas):")
        for idx in equal_list[:20]:
            print(rf"\[{SYMBOL_TEX}_{{{tex_idx(L)}}} = {SYMBOL_TEX}_{{{tex_idx(idx)}}}\]")
        if len(equal_list) > 20:
            print(f"% ... plus {len(equal_list)-20} more.")
        print()

    # 2) Now search affine 2-term relations: L = a*A + (1-a)*B
    keys = list(vec.keys())
    keys.remove(L)

    printed = 0
    seen = set()

    # brute force over pairs (fast: <= ~30k when states_out=[2,3,4])
    for i in range(len(keys)):
        if printed >= MAX_EQUATIONS:
            break
        Aidx = keys[i]
        vA = vec[Aidx]
        for j in range(i+1, len(keys)):
            if printed >= MAX_EQUATIONS:
                break
            Bidx = keys[j]
            vB = vec[Bidx]

            a = solve_a_fast(vec[L], vA, vB)
            if a is None or is_trivial_a(a):
                continue

            # skip if a simplifies to 0/1 hidden
            a_s = sp.srepr(a)
            key = (Aidx, Bidx, a_s)
            if key in seen:
                continue
            seen.add(key)

            print_eq(L, Aidx, Bidx, a)
            printed += 1

    if printed == 0 and not equal_list:
        print("% No nontrivial relations found for this LHS with current filters.")
        print("% Try setting SKIP_ZEROS=False or EXCLUDE_STATE_1=False.")

if __name__ == "__main__":
    main()