<a href="https://colab.research.google.com/github/Gokcentunc/atraction/blob/main/Equalities_k_states.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sympy as sp
from functools import lru_cache

# ============================================================
# SETTINGS
# ============================================================
EXCLUDE_STATE_1 = True
LIGHT_SIMPLIFY = True

# ============================================================
# Helpers
# ============================================================
def parse_idx(s: str):
    s = s.strip().replace(" ", "")
    if not s or not all(ch.isdigit() for ch in s):
        raise ValueError("Index must be digits like 2244, 22444, 425544.")
    return tuple(int(ch) for ch in s)

def keep_idx(idx):
    return not (EXCLUDE_STATE_1 and (1 in idx))

# ============================================================
# MAIN
# ============================================================
def main():
    K = int(input("Enter kappa (number of states), e.g. 4: ").strip())
    if K < 2:
        raise ValueError("kappa must be >= 2")

    pi = sp.symbols(" ".join([f"pi{i}" for i in range(1, K + 1)]), nonzero=True)
    PI_SUM = sum(pi)

    def pisum(m):
        return sum(pi[:m])

    def simp_light(expr):
        if not LIGHT_SIMPLIFY:
            return expr
        expr = expr.subs({PI_SUM: 1})
        expr = sp.together(expr)
        expr = sp.cancel(expr)
        return expr

    @lru_cache(maxsize=None)
    def e_u(j, i):
        if i == 1:
            return sp.Integer(1)
        if j + i <= K + 1:
            return sp.Integer(1)
        if j == K + 2 - i:
            return -pisum(K + 1 - i) / pi[(K + 2 - i) - 1]
        return sp.Integer(0)

    @lru_cache(maxsize=None)
    def norm_u(i):
        if i == 1:
            return sp.Integer(1)
        a = K + 1 - i
        b = K + 2 - i
        denom = pi[(K + 2 - i) - 1]
        return simp_light(pisum(a) * pisum(b) / denom)

    @lru_cache(maxsize=None)
    def p0_bar(idx):
        denom = sp.Integer(1)
        for ir in idx:
            denom *= norm_u(ir)

        s = sp.Integer(0)
        for t in range(1, K + 1):
            prod_term = sp.Integer(1)
            for ir in idx:
                prod_term *= e_u(t, ir)
            s += pi[t - 1] * prod_term

        return simp_light(s / denom)

    def make_lam_vec(edge_label: str):
        lam1 = sp.Symbol(f"lambda_1^{edge_label}")
        lam  = sp.Symbol(f"lambda^{edge_label}")
        return [lam1] + [lam] * (K - 1)

    lam_vec6 = make_lam_vec("6")
    lam_vec7 = make_lam_vec("7")
    lam_vec8 = make_lam_vec("8")

    @lru_cache(maxsize=None)
    def barq4(idx4):
        i1, i2, i3, i4 = idx4
        s = sp.Integer(0)
        for y in range(1, K + 1):
            s += norm_u(y) * lam_vec6[y - 1] * p0_bar((i1, i2, y)) * p0_bar((y, i3, i4))
        return simp_light(s)

    @lru_cache(maxsize=None)
    def tildeq4(idx4):
        den = p0_bar(idx4)
        return sp.nan if den == 0 else simp_light(barq4(idx4) / den)

    @lru_cache(maxsize=None)
    def barq5(idx5):
        i1, i2, i3, i4, i5 = idx5
        s = sp.Integer(0)
        for x in range(1, K + 1):
            s += norm_u(x) * lam_vec7[x - 1] * barq4((i1, i2, i3, x)) * p0_bar((x, i4, i5))
        return simp_light(s)

    @lru_cache(maxsize=None)
    def tildeq5(idx5):
        den = p0_bar(idx5)
        return sp.nan if den == 0 else simp_light(barq5(idx5) / den)

    @lru_cache(maxsize=None)
    def barq6(idx6):
        i1, i2, i3, i4, i5, i6 = idx6
        s = sp.Integer(0)
        for x in range(1, K + 1):
            s += norm_u(x) * lam_vec8[x - 1] * barq5((i1, i2, i3, i4, x)) * p0_bar((x, i5, i6))
        return simp_light(s)

    @lru_cache(maxsize=None)
    def tildeq6(idx6):
        den = p0_bar(idx6)
        return sp.nan if den == 0 else simp_light(barq6(idx6) / den)

    def equal_under_simplex(expr):
        expr = simp_light(expr)
        expr = sp.together(expr)
        num, den = sp.fraction(expr)
        num = sp.expand(num)

        G = sp.groebner([PI_SUM - 1], *pi, order="lex", domain="EX")
        rem = G.reduce(num)[1]
        rem = sp.factor(rem)
        return (rem == 0), rem, den

    mode = input("Choose object (quartet / 5 / 6): ").strip().lower()
    if mode not in ["quartet", "5", "6"]:
        raise ValueError("Mode must be one of: quartet, 5, 6.")

    s1 = input("Enter index #1 (e.g. 2244 / 22444 / 425544): ").strip()
    s2 = input("Enter index #2: ").strip()

    idx1 = parse_idx(s1)
    idx2 = parse_idx(s2)

    if not keep_idx(idx1) or not keep_idx(idx2):
        raise ValueError("Index rejected: it contains state 1 while EXCLUDE_STATE_1=True.")

    def tilde_value(idx):
        if mode == "quartet":
            if len(idx) != 4:
                raise ValueError("quartet requires length-4 indices.")
            return tildeq4(idx)
        if mode == "5":
            if len(idx) != 5:
                raise ValueError("5 requires length-5 indices.")
            return tildeq5(idx)
        if mode == "6":
            if len(idx) != 6:
                raise ValueError("6 requires length-6 indices.")
            return tildeq6(idx)
        raise ValueError("Internal error.")

    v1 = tilde_value(idx1)
    v2 = tilde_value(idx2)
    diff = v1 - v2

    naive_is_zero = (simp_light(diff) == 0)
    is_zero, rem, den = equal_under_simplex(diff)

    print("\n--- Results ---")
    print(f"tilde({s1}) =\n{v1}")
    print(f"\ntilde({s2}) =\n{v2}")

    print("\n--- Equality check ---")
    print("Naive (after sum(pi)=1) ->", naive_is_zero)
    print("Under sum(pi)=1 (Groebner remainder) ->", is_zero)

    if not is_zero:
        print("\nRemainder of numerator mod <sum(pi)-1>:")
        print(rem)
        print("\nDenominator (for reference):")
        print(den)

if __name__ == "__main__":
    main()
