In [1]:
# ================================================================
# SymPy implementation of central-idempotent projectors
# for the k-subset representation of S_n (two-row partitions only)
# Now with Fourier-basis transforms + blockwise nonlinear interface
# ================================================================
import itertools as it
import math
from collections import Counter
from sympy import Matrix, Rational, eye, sqrt

# -----------------------------------
# Exact Gram–Schmidt implementation
# -----------------------------------
def gram_schmidt(vectors, orthonormal=False):
    """
    Simple exact Gram–Schmidt process for SymPy column vectors.
    vectors: list of sympy.Matrix column vectors
    orthonormal: if True, normalizes each output vector to unit norm
    """
    ortho = []
    for v in vectors:
        w = v
        for u in ortho:
            w = w - (u.dot(v) / u.dot(u)) * u
        if w.norm() == 0:
            continue
        if orthonormal:
            w = w / sqrt(w.dot(w))
        ortho.append(w)
    return ortho

# -----------------------------------
# Combinatorial helpers
# -----------------------------------

def dim_two_row(n, i):
    """
    dim S^{(n-i,i)} = C(n,i) - C(n,i-1).
    This is the rank of the block for the irrep (n-i,i).
    """
    return math.comb(n, i) - (math.comb(n, i-1) if i > 0 else 0)

def all_perms(n):
    """All permutations of {0,...,n-1} as tuples."""
    return list(it.permutations(range(n)))

def cycle_type(p):
    """
    Return the cycle type of permutation p as a partition tuple.
    Example: (0 1)(2)(3 4 5) -> (3,2,1).
    """
    n = len(p)
    seen = [False]*n
    cyc = []
    for i in range(n):
        if not seen[i]:
            j=i; L=0
            while not seen[j]:
                seen[j] = True
                j = p[j]
                L += 1
            cyc.append(L)
    cyc.sort(reverse=True)
    return tuple(cyc)

def conjugacy_classes_with_members(n):
    """
    Return list of conjugacy classes of S_n, each as:
    {
        "type": cycle type (like (2,2,1,...)),
        "size": class size,
        "members": [permutations g in that class]
    }
    We'll brute-force because n is small.
    """
    classes = {}
    for g in all_perms(n):
        mu = cycle_type(g)
        classes.setdefault(mu, []).append(g)
    out = []
    for mu, members in classes.items():
        out.append({"type": mu, "size": len(members), "members": members})
    return out

# -----------------------------------
# k-subset permutation representation  ρ : S_n -> GL(R[X_k])
# -----------------------------------

def k_subsets(n, k):
    """
    Return list of all k-subsets of {0,...,n-1} as sorted tuples.
    This is our 'pixel basis' / original coordinate system.
    """
    return [tuple(c) for c in it.combinations(range(n), k)]

def rho_matrix(n, k, sigma):
    """
    Permutation matrix ρ(sigma) on the basis {e_A}_{A in X_k}.
    Each basis vector e_A gets sent to e_{sigma(A)}.
    """
    Xk = k_subsets(n, k)
    idx = {A:i for i,A in enumerate(Xk)}
    m = len(Xk)
    M = [[0]*m for _ in range(m)]
    for j,A in enumerate(Xk):
        B = tuple(sorted(sigma[a] for a in A))
        i = idx[B]
        M[i][j] = 1
    return Matrix(M)

# -----------------------------------
# Two-row character polynomials
# -----------------------------------

def chi_two_row(n, i, cycle_type_part):
    """
    χ^{(n-i,i)}(σ) for σ with given cycle_type_part.
    Uses known character polynomials for two-row Specht modules.
    Let X_r = number of r-cycles in σ.
    """
    c = Counter(cycle_type_part)
    X1, X2, X3 = c[1], c[2], c[3]

    if i == 0:  # trivial rep (n)
        return Rational(1)

    elif i == 1:  # (n-1,1)
        # χ = X1 - 1
        return Rational(X1 - 1)

    elif i == 2:  # (n-2,2)
        # χ = C(X1,2) + X2 - X1
        return Rational((X1*(X1-1))/2 + X2 - X1)

    elif i == 3:  # (n-3,3)
        # χ = C(X1,3) + X2*(X1-1) + X3 - C(X1,2) - X2
        return Rational(
            (X1*(X1-1)*(X1-2))/6
            + X2*(X1-1)
            + X3
            - (X1*(X1-1))/2
            - X2
        )

    else:
        # You can extend this if you ever need i>=4.
        raise NotImplementedError("Extend χ^{(n-i,i)} polynomial for i >= 4 if needed.")

# -----------------------------------
# Central idempotent projectors p_{(n-i,i)}
# -----------------------------------

def projectors_two_row_vnk(n, k):
    """
    Construct:
    - proj[i] = projector p_{(n-i,i)} in the original (k-subset) basis.
      This is an S_n-equivariant idempotent of rank dim_two_row(n,i).
    - Qblocks[i] = an ONB (orthonormal basis) for im(proj[i]).
      Columns of Qblocks[i] span that irreducible subspace.

    We also return 'basis' = list of k-subsets, i.e. the coordinate labels.
    """
    basis = k_subsets(n, k)
    m = len(basis)
    classes = conjugacy_classes_with_members(n)

    # cache ρ(g) for all permutations we'll touch
    rho = {}
    for C in classes:
        for g in C["members"]:
            if g not in rho:
                rho[g] = rho_matrix(n, k, g)

    proj = {}
    Qblocks = {}
    for i in range(0, k+1):
        dimL = dim_two_row(n, i)

        # M = Σ_{classes C} χ^{(n-i,i)}(C) * (Σ_{g∈C} ρ(g))
        M = Matrix.zeros(m)
        for C in classes:
            mu = C["type"]
            chi_val = chi_two_row(n, i, mu)

            # true class sum S_C = Σ_{g in class C} ρ(g)
            S_C = sum((rho[g] for g in C["members"]), Matrix.zeros(m))

            M += chi_val * S_C

        # central idempotent formula:
        # p_λ = (dim λ / |S_n|) * M
        p = Rational(dimL, math.factorial(n)) * M

        # numerically enforce symmetry (p should be symmetric in this real rep)
        p = (p + p.T) / 2
        proj[i] = p

        # find eigenvectors with eigenvalue 1 (the image of the projector)
        evects = p.eigenvects()

        eig1 = [v for (val, mult, vecs) in evects if val == 1 for v in vecs]
        if len(eig1) < dimL:
            # fallback tolerance in case SymPy gives val=1 as Rational(1) vs Float
            eig1 += [
                v for (val, mult, vecs) in evects
                if abs(float(val) - 1) < 1e-8
                for v in vecs
            ]

        cols = [Matrix(v) for v in eig1[:dimL]]

        # exact Gram–Schmidt to get orthonormal columns
        ortho_cols = gram_schmidt(cols, orthonormal=True)
        Qblocks[i] = Matrix.hstack(*ortho_cols)

    return proj, Qblocks, basis

# -----------------------------------
# Fourier basis utilities
# -----------------------------------

def assemble_Q(Qblocks):
    """
    Stack the ONBs from each irreducible subspace into one big square matrix Q.
    Columns of Q form an orthonormal basis of the whole space, block-ordered
    by irreps (n-i,i), i=0..k.
    """
    return Matrix.hstack(*[Qblocks[i] for i in sorted(Qblocks.keys())])

def block_sizes_from_Qblocks(Qblocks):
    """
    Return list of block sizes [dim S^{(n-0,0)}, dim S^{(n-1,1)}, ...]
    in the same sorted order.
    """
    sizes = []
    for i in sorted(Qblocks.keys()):
        sizes.append(Qblocks[i].shape[1])
    return sizes

def to_fourier(Q, v):
    """
    Fourier transform: go from "physical" / subset basis coords v (length = C(n,k))
    to irrep-block coordinates v_hat.
    v_hat = Q^T v because Q is orthonormal.
    """
    return Q.T * v

def from_fourier(Q, v_hat):
    """
    Inverse Fourier transform.
    v = Q v_hat.
    """
    return Q * v_hat

def split_fourier_blocks(v_hat, block_sizes):
    """
    Take a stacked vector in Fourier space and slice it into irreducible pieces.
    Returns a list [block0, block1, ...] where block_i lives in irrep (n-i,i).
    """
    blocks = []
    start = 0
    for sz in block_sizes:
        blocks.append(v_hat[start:start+sz, :])
        start += sz
    return blocks

def merge_fourier_blocks(block_list):
    """
    Inverse of split_fourier_blocks: vertical stack.
    """
    return Matrix.vstack(*block_list)

def nonlinear_fourier_update(v_hat, block_sizes, blockwise_funcs):
    """
    Apply user-specified nonlinearities in Fourier space, block by block.

    blockwise_funcs is a dict i -> function
      where i indexes the irreducible (n-i,i).
      Each function takes a block (Matrix of shape [dim_i, 1] or [dim_i, batch])
      and returns a same-shape Matrix.

    If some i is missing in blockwise_funcs, we leave that block unchanged.

    This is the hook to put in activations that respect equivariance structure.
    """
    blocks = split_fourier_blocks(v_hat, block_sizes)
    new_blocks = []
    for i, block in enumerate(blocks):
        if i in blockwise_funcs:
            new_blocks.append(blockwise_funcs[i](block))
        else:
            new_blocks.append(block)
    return merge_fourier_blocks(new_blocks)

# -----------------------------------
# Verification & block diagonalization
# -----------------------------------

def verify_projectors(proj, dims):
    """
    Check:
    - p_i^2 = p_i
    - trace(p_i) = dim block
    - p_i p_j = 0 for i != j
    - sum_i p_i = I
    """
    ok = True
    keys = sorted(proj.keys())
    m = proj[keys[0]].rows
    I = eye(m)

    for i in keys:
        if not (proj[i]*proj[i]).equals(proj[i]):
            print(f"[warn] Projector {i} not idempotent."); ok=False
        tr = proj[i].trace()
        if tr != dims[i]:
            print(f"[warn] trace(p_{i})={tr} differs from dim {dims[i]}."); ok=False

    S = sum((proj[i] for i in keys), Matrix.zeros(m))
    if not S.equals(I):
        print("[warn] Sum of projectors not identity."); ok=False

    for a in keys:
        for b in keys:
            if a < b and not (proj[a]*proj[b]).is_zero_matrix:
                print(f"[warn] Projectors {a},{b} not orthogonal."); ok=False
    return ok

def block_diagonalize(n, k, Q, sigma):
    """
    Compute Q^T ρ(sigma) Q.
    This should be block-diagonal (one block per irrep).
    """
    R = rho_matrix(n, k, sigma)
    return (Q.T * R * Q).applyfunc(lambda x: x.simplify())

In [14]:

n, k = 4, 2   # example: S_4 on 2-subsets
proj, Qblocks, basis = projectors_two_row_vnk(n, k)

# dimensions of each irreducible block (n-i,i)
dims = {i: dim_two_row(n, i) for i in range(0, k+1)}
print("Block dimensions:", dims)

ok = verify_projectors(proj, dims)
print("Verification passed:", ok)

# Build global change-of-basis (Fourier transform matrix)
Q = assemble_Q(Qblocks)
print("Q^T Q =")
print((Q.T*Q))  # should be the identity

# Show block sizes in Fourier domain ordering
block_sizes = block_sizes_from_Qblocks(Qblocks)
print("Fourier block sizes (per irrep):", block_sizes)

# Take some random vector v in the original basis (e.g. all-ones column)
import random
m = len(basis)
v_coords = Matrix([random.randint(-2,2) for _ in range(m)])  # shape (m,)

# Go to Fourier
v_hat = to_fourier(Q, v_coords)
print("v_hat in Fourier basis:")
print(v_hat)

# Apply a simple toy nonlinearity, blockwise
# For example:
#   block 0 (trivial irrep): identity
#   block 1 (standard irrep): ReLU-like: max(x,0)
#   block 2 (next irrep): scale by 2
def relu_like(mat):
    return mat.applyfunc(lambda z: z if z.is_real and z>=0 else 0)

blockwise_funcs = {
    0: (lambda x: x),          # leave trivial piece alone
    1: (lambda x: relu_like(x)),
    2: (lambda x: 2*x)         # just scale this block
}

v_hat_after = nonlinear_fourier_update(v_hat, block_sizes, blockwise_funcs)

# Bring it back to the original basis
v_after = from_fourier(Q, v_hat_after)

print("Output after blockwise nonlinearity (back in original basis):")
print(v_after)

# As a sanity check:, look at block-diagonal form of a generator (0 1)
sigma = list(range(n)); sigma[0], sigma[1] = sigma[1], sigma[0]; sigma = tuple(sigma)
N = block_diagonalize(n, k, Q, sigma)
print("Block-diagonal ρ((0 1)) in Fourier basis:")
print(N)


Block dimensions: {0: 1, 1: 3, 2: 2}
Verification passed: True
Q^T Q =
Matrix([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1]])
Fourier block sizes (per irrep): [1, 3, 2]
v_hat in Fourier basis:
Matrix([[2*sqrt(6)/3], [-3*sqrt(2)/2], [-3*sqrt(2)/2], [0], [-1], [4*sqrt(3)/3]])
Output after blockwise nonlinearity (back in original basis):
Matrix([[10/3], [-5/3], [1/3], [1/3], [-5/3], [10/3]])
Block-diagonal ρ((0 1)) in Fourier basis:
Matrix([[1, 0, 0, 0, 0, 0], [0, 0, -1, 0, 0, 0], [0, -1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, -1, 0], [0, 0, 0, 0, 0, 1]])


In [16]:
# As a sanity check:, look at block-diagonal form of a generator (0 1)
sigma = list(range(n)); sigma[0], sigma[1] = sigma[1], sigma[0]; sigma = tuple(sigma)
N = block_diagonalize(n, k, Q, sigma)
print("Block-diagonal ρ((0 1)) in Fourier basis:")
N


Block-diagonal ρ((0 1)) in Fourier basis:


Matrix([
[1,  0,  0, 0,  0, 0],
[0,  0, -1, 0,  0, 0],
[0, -1,  0, 0,  0, 0],
[0,  0,  0, 1,  0, 0],
[0,  0,  0, 0, -1, 0],
[0,  0,  0, 0,  0, 1]])

In [36]:
def chi_two_row(n, i, cycle_type_part):
    """
    χ^{(n-i,i)}(σ) for σ with given cycle_type_part.
    Uses known character polynomials for two-row Specht modules.
    Let X_r = number of r-cycles in σ.
    """
    c = Counter(cycle_type_part)
    X1, X2, X3 = c[1], c[2], c[3]

    if i == 0:  # trivial rep (n)
        return Rational(1)

    elif i == 1:  # (n-1,1)
        # χ = X1 - 1
        return Rational(X1 - 1)

    elif i == 2:  # (n-2,2)
        # χ = C(X1,2) + X2 - X1
        return Rational((X1*(X1-1))/2 + X2 - X1)

    elif i == 3:  # (n-3,3)
        # χ = C(X1,3) + X2*(X1-1) + X3 - C(X1,2) - X2
        return Rational(
            (X1*(X1-1)*(X1-2))/6
            + X2*(X1-1)
            + X3
            - (X1*(X1-1))/2
            - X2
        )

    else:
        # You can extend this if you ever need i>=4.
        raise NotImplementedError("Extend χ^{(n-i,i)} polynomial for i >= 4 if needed.")
    

chi_two_row(3,2,conjugacy_classes_with_members(4)[2]["type"])

-1