In [544]:
from z3 import *
class f2Poly:
    def __init__(self, terms=None):
        self.terms = {}
        if terms:
            for (i, j), coef in terms.items():
                if coef % 2 == 1:  # 模 2
                    self.terms[(i, j)] = 1  # 只存 1
    def __add__(self, other):
        result = self.terms.copy()
        for (i, j), coef in other.terms.items():
            if (i, j) in result:
                del result[(i, j)]  # 1+1=0
            else:
                result[(i, j)] = 1
        return f2Poly(result)
    def __mul__(self, other):
        if isinstance(other, int):
            if other % 2 == 0:
                return f2Poly()  # 偶数倍 → 0
            else:
                return self  # 奇数倍 → 自身（在 F2 中 1*f = f, 3*f = f）
        else:
            result = {}
            for (i1, j1) in self.terms:
                for (i2, j2) in other.terms:
                    i_new, j_new = i1 + i2, j1 + j2
                    key = (i_new, j_new)
                    if key in result:
                        del result[key]
                    else:
                        result[key] = 1
            return f2Poly(result)
    def __pow__(self, n):
        if n == 0:
            return f2Poly({(0, 0): 1})  # f^0 = 1
        if len(self.terms) == 0:
            if n >= 0:
                return f2Poly()  # 0^n = 0 (n>0)
            else:
                raise ValueError("0 的负幂次无定义")
        elif len(self.terms) == 1:
            (i, j), coef = next(iter(self.terms.items()))
            return f2Poly({(i * n, j * n): 1})
        else:
            # 多项式项数 > 1
            if n < 0:
                raise ValueError(f"多项式 {self} 不可逆，无法计算负幂次")
            elif n == 1:
                return self
            elif n % 2 == 0:
                return (self * self) ** (n // 2)
            else:
                return self * (self ** (n - 1))
    def degrees(self):
        if not self.terms:
            return (0, 0)
        i_vals = [i for (i, j) in self.terms.keys()]
        j_vals = [j for (i, j) in self.terms.keys()]
        return (max(i_vals), max(j_vals))
    def __repr__(self):
        if not self.terms:
            return "0"
        sorted_terms = sorted(self.terms.keys())
        terms = []
        for (i, j) in sorted_terms:
            xi = f"x^{i}" if i != 0 else ""
            yj = f"y^{j}" if j != 0 else ""
            term = xi + yj
            if term == "":
                term = "1"
            terms.append(term)
        return " + ".join(terms)
def antimap(poly):
    inverted_terms = {}
    for (i, j), coef in poly.terms.items():
        inverted_terms[(-i, -j)] = coef  # 系数不变（在 F2 中）
    return f2Poly(inverted_terms)
def commute(a, b):
    Lambda = [[0,0,1,0],[0,0,0,1],[1,0,0,0],[0,1,0,0]]
    result = f2Poly()
    for i in range(4):
        for j in range(4):
            if Lambda[i][j]:
                term = antimap(a[i][0]) * b[j][0]
                result += term
    return result
def excimap(*A, oper):
    result = []
    for a in A:
        # 计算 commute(a, oper)
        excitation = commute(a, oper)
        result.append(excitation)
    return result  # 返回 [F2Poly, F2Poly, ...]，就是一个行向量
def tdmap(f, m):
    result = []
    for j in range(-m, m + 1):      # y 的指数从 -m 到 m
        for i in range(-m, m + 1):  # x 的指数从 -m 到 m
            shift = f2Poly({(i, j): 1})  # x^i y^j
            shifted_f = [shift * i for i in f]         # 平移后的多项式
            result.append(shifted_f)
    return result
def shape(obj):
    if not obj:
        return (0,)  # 空对象
    if isinstance(obj[0], (list, tuple)):
        # 是矩阵：obj = [ [...], [...] ]
        rows = len(obj)
        cols = len(obj[0])
        return (rows, cols)
    else:
        # 是向量：obj = [a, b, c, d]
        return (len(obj),)

def hstack(A,B):
    if len(A) != len(B):
        raise ValueError("矩阵行数必须相同")
    result = []
    for i in range(len(A)):
        row = A[i] + B[i]  # 直接拼接
        result.append(row)
    return result
def I(n):
    zero = f2Poly()
    one  = f2Poly({(0,0): 1})
    return [[one if i == j else zero for j in range(n)] for i in range(n)]
def split(augmented):
    if not augmented:
        return [], []
    m = len(augmented)
    total_cols = len(augmented[0])
    n = total_cols - m
    if n < 0:
        raise ValueError("矩阵列数小于行数，无法分割")
    left = []   # 原矩阵部分 A'
    right = []  # 单位阵部分 B'
    for row in augmented:
        left.append(row[:n])   # 前 n 列
        right.append(row[n:])  # 后 m 列
    return left, right
def mulmatrix(A, B):
    if not A or not B:
        return []
    m = len(A)
    n = len(A[0])
    if len(B) != n:
        raise ValueError(f"矩阵维度不匹配：A 是 {m}x{n}，B 是 {len(B)}x{len(B[0])}")
    p = len(B[0])
    zero = f2Poly()
    C = [[zero for _ in range(p)] for _ in range(m)]
    for i in range(m):
        for j in range(p):
            total = f2Poly()  # 零多项式，累加用
            for k in range(n):
                prod = A[i][k] * B[k][j]  # f2Poly 乘法
                total = total + prod      # f2Poly 加法（F2 上）
            C[i][j] = total
    return C
def t(matrix):
    if not matrix or not matrix[0]:
        return []
    return [list(row) for row in zip(*matrix)]
def eliminate(mat, max_terms=2):
    def divides(a, b):
        """a divides b? a,b are exponent tuples"""
        return b[0] >= a[0] and b[1] >= a[1]
    # shallow-copy matrix into new list of f2Poly copies to avoid mutating caller
    rows = len(mat)
    cols = len(mat[0]) if rows > 0 else 0
    A = [[ f2Poly(dict(mat[r][c].terms))for c in range(cols) ] for r in range(rows)]
    # collect all monomials present (for pre-dep search)
    monos = {}
    for r in range(rows):
        for c in range(cols):
            for m in A[r][c].terms.keys():
                monos[m] = True
    monos[(0,0)] = True
    monos_list = sorted(monos.keys())
    # --- pre-dependency detection: 尝试把某行表示为前面行的单项式倍组合（较小的 max_terms）
    def try_predep_row(r):
        """若能表示，返回 [(prow_idx, mono), ...] 否则返回 None"""
        if r == 0:
            return None
        pivots = A[:r]
        n = len(pivots)
        # try 1-term combos
        if max_terms >= 1:
            for pi in range(n):
                prow = pivots[pi]
                for mono in monos_list:
                    ok = True
                    for c in range(cols):
                        s = prow[c] * f2Poly({(mono[0], mono[1]): 1})
                        if not len((s+ A[r][c]).terms)==0:
                            ok = False; break
                    if ok:
                        return [(pi, mono)]
        # try 2-term combos
        if max_terms >= 2:
            for p1 in range(n):
                for p2 in range(p1+1, n):
                    prow1 = pivots[p1]; prow2 = pivots[p2]
                    for mono1 in monos_list:
                        for mono2 in monos_list:
                            ok = True
                            for c in range(cols):
                                s = (prow1[c] * f2Poly({(mono1[0], mono1[1]): 1}))+ prow2[c] * f2Poly({(mono2[0], mono2[1]): 1})
                                if not len((s+ A[r][c]).terms)==0:
                                    ok = False; break
                            if ok:
                                return [(p1, mono1), (p2, mono2)]
        return None
    # apply pre-dependency elimination top-down
    r = 0
    while r < len(A):
        comb = try_predep_row(r)
        if comb is not None:
            for (pi, mono) in comb:
                for c in range(cols):
                    A[r][c] = (A[r][c]+A[pi][c] * f2Poly({(mono[0], mono[1]): 1}))
            # if became zero, move to bottom
            if all(len(A[r][c].terms)==0 for c in range(cols)):
                rowv = A.pop(r)
                A.append(rowv)
                # don't increment r (下一行已经移到当前位置)
                continue
        r += 1
    # --- staged elimination function (var 'y' 或 'x') ---
    def stage(var='y'):
        pivot_row = 0
        for c in range(cols):
            # find pivot among rows pivot_row..end with nonzero entry, choosing smallest degree (y or x)
            best_r = None
            best_deg = None
            for rr in range(pivot_row, len(A)):
                ent = A[rr][c]
                if len(ent.terms) == 0:
                    continue
                key = ent.degrees()[1] if var == 'y' else ent.degrees()[0]
                if best_r is None or key < best_deg:
                    best_r = rr
                    best_deg = key
            if best_r is None:
                continue
            # swap to pivot_row if needed
            if best_r != pivot_row:
                A[pivot_row], A[best_r] = A[best_r], A[pivot_row]
            pivot = A[pivot_row][c]
            # reduce other rows w.r.t pivot using leading monomial division
            LM = None
            for rr in range(len(A)):
                if rr == pivot_row:
                    continue
                # repeat reduction while possible (target's leading term divisible by pivot LM)
                while True:
                    target = A[rr][c]
                    if len(target.terms) == 0:
                        break
                    LM = max(pivot.terms.keys(), key=lambda m: (m[1], m[0]) if  ('y>x' if var=='y' else 'x>y') else (m[0], m[1])) if pivot.terms else None
                    if LM is None:
                        break
                    lt = max(target.terms.keys(), key=lambda m: (m[1], m[0]) if  ('y>x' if var=='y' else 'x>y') else (m[0], m[1])) if target.terms else None
                    if lt is None:
                        break
                    if divides(LM, lt):
                        mult = (lt[0] - LM[0], lt[1] - LM[1])
                        for cc in range(cols):
                            A[rr][cc] = (A[rr][cc]+ A[pivot_row][cc] * f2Poly({(mult[0], mult[1]): 1}))
                        # continue trying to reduce same rr
                        continue
                    else:
                        break
            pivot_row += 1
            if pivot_row >= len(A):
                break
    # y-stage then x-stage
    stage('y')
    stage('x')
    # --- final merging: 若有纯 x 行与纯 y 行且不在同一行，则乘幂使两元素相同并消掉一行 ---
    rx = ry = None; ax = by = None; cx = cy = None
    for r in range(len(A)):
        for c in range(cols):
            e = A[r][c]
            if len(e.terms) == 0:
                continue
            if len(e.terms) == 1:
                (i,j), = e.terms.items()  # Note: value is 1 but we want the key; safer to iterate keys
                # above returns ((i,j),1) so adjust:
                pass
    # safer scan for monomial entries:
    rx = ry = None
    for r in range(len(A)):
        for c in range(cols):
            e = A[r][c]
            if len(e.terms) == 0:
                continue
            if len(e.terms) == 1:
                (i,j) = next(iter(e.terms.keys()))
                if i > 0 and j == 0 and rx is None:
                    rx, cx, ax = r, c, i
                if j > 0 and i == 0 and ry is None:
                    ry, cy, by = r, c, j
    if rx is not None and ry is not None and rx != ry:
        for c in range(cols):
            A[rx][c] = A[rx][c]* f2Poly({(0, by): 1})
            A[ry][c] = A[ry][c] * f2Poly({(ax,0): 1})
        for c in range(cols):
            A[rx][c] = (A[rx][c]+A[ry][c])
        if all(len(A[rx][c].terms)==0 for c in range(cols)):
            rowv = A.pop(rx)
            A.append(rowv)
    return A
x1=[[f2Poly({(0,0):1})],
    [f2Poly()],
    [f2Poly()],
    [f2Poly()]]
x2=[[f2Poly()],
    [f2Poly({(0,0):1})],
    [f2Poly()],
    [f2Poly()]]
z1=[[f2Poly()],
    [f2Poly()],
    [f2Poly({(0,0):1})],
    [f2Poly()]]
z2=[[f2Poly()],
    [f2Poly()],
    [f2Poly()],
    [f2Poly({(0,0):1})]]
# 示例 1: 3x + 2y → 模 2 后是 x
s1=[
    [f2Poly({(0, 0): 1,(-1,0):1})],
    [f2Poly({(0, 0): 1,(0,-1):1})],
    [f2Poly()],
    [f2Poly()]
]
s2=[
    [f2Poly()],
    [f2Poly()],
    [f2Poly({(0, 0): 1,(0,1):1})],
    [f2Poly({(0, 0): 1,(1,0):1})],
]
m=1
m1=[]
m1.extend(tdmap(excimap(s1,s2,oper=x1),m))
m1.extend(tdmap(excimap(s1,s2,oper=x2),m))
m1.extend(tdmap(excimap(s1,s2,oper=z1),m))
m1.extend(tdmap(excimap(s1,s2,oper=z2),m))
print(eliminate(m1))

KeyboardInterrupt: 

In [534]:
# 定义基本变量
x = f2Poly({(1, 0): 1})
y = f2Poly({(0, 1): 1})
one = f2Poly({(0, 0): 1})
zero = f2Poly()

ma = [[x+y,x**2+y,x+y**2],
          [x,y,zero],
          [x**2,x**3+y**2+x*y,x**2+x*y**2]]
mat=hstack(ma,I(shape(ma)[0]))
print(eliminate(mat))

[[y^1, x^2, y^2 + x^1, 1, 1, 0], [x^1, y^1, 0, 0, 1, 0], [0, 0, 0, x^1, y^1, 1]]


In [None]:
def eliminate1(mat, max_terms=2):
    # 判断单项式 a 是否整除 b（指数均不超过）
    def divides(a, b):
        return b[0] >= a[0] and b[1] >= a[1]

    # 深复制矩阵并转换为 f2Poly 实例
    rows = len(mat)
    cols = len(mat[0]) if rows > 0 else 0
    A = [[f2Poly(dict(mat[r][c].terms)) for c in range(cols)] for r in range(rows)]

    # 收集矩阵中所有出现的单项式指数（包括常数项 (0,0)）
    monos = set()
    for r in range(rows):
        for c in range(cols):
            monos |= set(A[r][c].terms.keys())
    monos.add((0, 0))
    monos_list = sorted(monos)

    # 预依赖检测：尝试用之前的行（最多 max_terms 项组合）消去第 r 行
    def try_predep_row(r):
        if r == 0:
            return None
        pivots = A[:r]
        n = len(pivots)
        # 单行倍乘组合
        if max_terms >= 1:
            for pi in range(n):
                prow = pivots[pi]
                for mono in monos_list:
                    ok = True
                    for c in range(cols):
                        # prow[c] * mono 加到 A[r][c] 后是否为 0
                        if len((prow[c] * f2Poly({mono: 1}) + A[r][c]).terms) != 0:
                            ok = False
                            break
                    if ok:
                        return [(pi, mono)]
        # 两行倍乘组合
        if max_terms >= 2:
            for p1 in range(n):
                for p2 in range(p1 + 1, n):
                    prow1 = pivots[p1]; prow2 = pivots[p2]
                    for mono1 in monos_list:
                        for mono2 in monos_list:
                            ok = True
                            for c in range(cols):
                                # prow1[c]*mono1 + prow2[c]*mono2 + A[r][c] 是否为 0
                                combined = (prow1[c] * f2Poly({mono1: 1})) \
                                         + (prow2[c] * f2Poly({mono2: 1})) \
                                         + A[r][c]
                                if len(combined.terms) != 0:
                                    ok = False
                                    break
                            if ok:
                                return [(p1, mono1), (p2, mono2)]
        return None

    # 应用预依赖消去：若发现依赖组合则将行转换为零并移到底部
    r = 0
    while r < len(A):
        comb = try_predep_row(r)
        if comb is not None:
            for (pi, mono) in comb:
                for c in range(cols):
                    A[r][c] = A[r][c] + (A[pi][c] * f2Poly({mono: 1}))
            if all(len(A[r][c].terms) == 0 for c in range(cols)):
                # 移除该零行并追加到矩阵末尾
                rowv = A.pop(r)
                A.append(rowv)
                continue  # 新的第 r 行下移后不要立即 r+=1
        r += 1

    # 分阶段消元函数：按照 var 指定的变量优先级（先 'y' 再 'x'）
    def stage(var='y'):
        pivot_row = 0
        for c in range(cols):
            # 在[pivot_row..end]中找非零元素，选择对应 var 指数最小的一行
            best_r = None; best_deg = None
            for rr in range(pivot_row, len(A)):
                ent = A[rr][c]
                if not ent.terms:
                    continue
                deg = ent.degrees()[1] if var == 'y' else ent.degrees()[0]
                if best_r is None or deg < best_deg:
                    best_r = rr; best_deg = deg
            if best_r is None:
                continue
            # 将 best_r 换到 pivot_row
            if best_r != pivot_row:
                A[pivot_row], A[best_r] = A[best_r], A[pivot_row]
            pivot = A[pivot_row][c]
            # 找到主元多项式（最高阶单项式）
            if pivot.terms:
                if var == 'y':
                    LM = max(pivot.terms.keys(), key=lambda m: (m[1], m[0]))
                else:
                    LM = max(pivot.terms.keys(), key=lambda m: (m[0], m[1]))
            else:
                LM = None
            # 用主元消去其它行的该列项
            for rr in range(len(A)):
                if rr == pivot_row:
                    continue
                while True:
                    target = A[rr][c]
                    if not target.terms or LM is None:
                        break
                    if var == 'y':
                        lt = max(target.terms.keys(), key=lambda m: (m[1], m[0]))
                    else:
                        lt = max(target.terms.keys(), key=lambda m: (m[0], m[1]))
                    if divides(LM, lt):
                        mult = (lt[0] - LM[0], lt[1] - LM[1])
                        for cc in range(cols):
                            A[rr][cc] = A[rr][cc] + (A[pivot_row][cc] * f2Poly({mult: 1}))
                        # 继续检查该行中是否还有可消项
                        continue
                    break
            pivot_row += 1
            if pivot_row >= len(A):
                break

    # 先按 y 消元，再按 x 消元
    stage('y')
    stage('x')

    # 最后合并纯 x 和纯 y 行
    rx = ry = None
    ax = by = None
    for r in range(len(A)):
        for c in range(cols):
            e = A[r][c]
            if len(e.terms) != 1:
                continue
            (i, j), = e.terms.items()
            # 纯 x 项（i>0, j=0）
            if j == 0 and i > 0 and rx is None:
                rx, cx, ax = r, c, i
            # 纯 y 项（i=0, j>0）
            if i == 0 and j > 0 and ry is None:
                ry, cy, by = r, c, j
        if rx is not None and ry is not None:
            break
    # 若找到一行纯 x 和一行纯 y，则将它们相加消去一行
    if rx is not None and ry is not None and rx != ry:
        for c in range(cols):
            A[rx][c] = A[rx][c] * f2Poly({(0, by): 1})
            A[ry][c] = A[ry][c] * f2Poly({(ax, 0): 1})
        for c in range(cols):
            A[rx][c] = A[rx][c] + A[ry][c]
        if all(len(A[rx][c].terms) == 0 for c in range(cols)):
            rowv = A.pop(rx)
            A.append(rowv)

    return A
print(shape(eliminate1(m1)))