In [1]:
import jax
import jax.numpy as jnp
import jax.nn as nn
from functools import partial
from itertools import product
from math import factorial

In [52]:
# define number of letters to permute
N = 5
ORD = factorial(N)

In [53]:
# find all partitions for given integer
def get_partitions(n):
    if n == 0:
        yield jnp.array([], dtype="i4")
    if n < 0:
        return
    for p in get_partitions(n-1):
        yield jnp.concatenate((p, jnp.array([1])))
        l = len(p)
        if l == 1 or (l > 1 and p[-1] < p[-2]):
            yield p + nn.one_hot(l-1, l, dtype="i4")
            
# translate partition into cycle type
def get_cycle_type(p, l=None):
    if not l:
        l = max(p)
    kvec = lambda lam: nn.one_hot(lam-1, l, dtype="i4")
    return jnp.sum(jnp.apply_along_axis(kvec, 0, p), axis=0)

def get_all_cycle_types(n):
    for p in get_partitions(n):
        yield get_cycle_type(p)
                
# find all repartitions of a given partition
def get_repartitions(p, k):
    l = len(k)
    def gp(lam):
        # calculate cycle types for all repartitions of a given row length
        return [get_cycle_type(mu, l) for mu in get_partitions(lam)]
    # calculate all these cycle types for all rows
    reparts = [gp(lam) for lam in p]
    # calculate cartesian product of the sets of repartitions of rows to 
    # get the repartitions of p
    reparts = jnp.array(list(product(*reparts)))
    # calculate the circle types of the new repartitions
    reparts_ks = jnp.sum(reparts, axis=1)
    # accept only the repartitions with repart_k = k
    mask = jnp.product(reparts_ks == k, axis=1, dtype="bool")
    reparts = reparts[mask]
    # order along second axis of repartitions is reversed, but does not matter 
    # because of the product over j in (3.68) in the script
    return reparts

def psi(p, k):
    # calculate entries of Psi matrix
    res = jnp.array(0., dtype="f4")
    for r in get_repartitions(p, k):
        k_factorial = jnp.array([factorial(k_i) for k_i in k])
        r_factorial = jnp.array([[factorial(r_ij) for r_ij in r_j] for r_j in r])
        res += jnp.product(k_factorial)/jnp.product(r_factorial)
    return res

In [58]:
Psi = jnp.array([[psi(p, k) for k in get_all_cycle_types(N)] for p in get_partitions(N)],
                dtype="i4")
print("Psi:\n", Psi, "\n")

def ord_C(k):
    # order of stabilizer
    ord_stab = jnp.product(jnp.array([(i+1)**k_i * factorial(k_i) 
                                      for (i, k_i) in enumerate(k)]))
    return ORD / ord_stab

Sigma = jnp.diag(jnp.array([ord_C(k) / ORD for k in get_all_cycle_types(N)]))
print("Sigma:\n", Sigma, "\n")

PSPT = (Psi @ Sigma @ Psi.T).astype("i1")
print("Psi * Sigma * Psi.T:\n", PSPT, "\n")

Psi:
 [[120   0   0   0   0   0   0]
 [ 60   6   0   0   0   0   0]
 [ 30   6   2   0   0   0   0]
 [ 20   6   0   2   0   0   0]
 [ 10   4   2   1   1   0   0]
 [  5   3   1   2   0   1   0]
 [  1   1   1   1   1   1   1]] 

Sigma:
 [[0.00833333 0.         0.         0.         0.         0.
  0.        ]
 [0.         0.08333334 0.         0.         0.         0.
  0.        ]
 [0.         0.         0.125      0.         0.         0.
  0.        ]
 [0.         0.         0.         0.16666667 0.         0.
  0.        ]
 [0.         0.         0.         0.         0.16666667 0.
  0.        ]
 [0.         0.         0.         0.         0.         0.25
  0.        ]
 [0.         0.         0.         0.         0.         0.
  0.2       ]] 

Psi * Sigma * Psi.T:
 [[120  60  30  20  10   5   1]
 [ 60  33  18  13   7   4   1]
 [ 30  18  11   8   5   3   1]
 [ 20  13   8   7   4   3   1]
 [ 10   7   5   4   3   2   1]
 [  5   4   3   3   2   2   1]
 [  1   1   1   1   1   1   1]] 



# How to determine K?

We can see that
$$ A_{ij} := [KK^T]_{ij} = \sum_{k=j}^n K_{ik}K_{jk} $$

The last column can be easily determined by calculating 
$$ K_{nn} = \sqrt{A_{nn}}, \quad K_{kn} = \frac{A_{kn}}{K_{nn}} $$

Then we can define a matrix
$$ [B]_{ij} := K_{in}K_{jn} $$

And construct the matrix
$$ \tilde{A} := A - B $$

Which is now effectively a lower dimensional version of of our initial problem, 
so we can just begin all over again and solve for K recursively!

In [60]:
K = jnp.zeros(shape=(len(PSPT), len(PSPT))).tolist()

def fill_K_column(A, n):
    if len(A) == 0:
        return
    K[n-1][n-1] = float(jnp.sqrt(A[n-1, n-1]))
    for i in range(n):
        K[i][n-1] = float(A[i, n-1]/K[n-1][n-1])
    
    K_arr = jnp.array(K)
    B = jnp.fromfunction(lambda i, j: K_arr[i, n-1]*K_arr[j, n-1], 
                         shape=(n, n), dtype="i1")
    
    A_tilde = (A - B)[:-1, :-1]
    return fill_K_column(A_tilde, n-1)

fill_K_column(PSPT, len(PSPT))
K = jnp.array(K, dtype="i1")
print("K:\n", K, "\n")

K:
 [[1 4 5 6 5 4 1]
 [0 1 2 3 3 3 1]
 [0 0 1 1 2 2 1]
 [0 0 0 1 1 2 1]
 [0 0 0 0 1 1 1]
 [0 0 0 0 0 1 1]
 [0 0 0 0 0 0 1]] 



In [66]:
X = (jnp.linalg.inv(K) @ Psi).astype("i1")
print("X:\n", X, "\n")

print("X * Sigma * X.T:\n", (X @ Sigma @ X.T).astype("i1"))

X:
 [[ 1 -1  1  1 -1 -1  1]
 [ 4 -2  0  1  1  0 -1]
 [ 5 -1  1 -1 -1  1  0]
 [ 6  0 -2  0  0  0  1]
 [ 5  1  1 -1  1 -1  0]
 [ 4  2  0  1 -1  0 -1]
 [ 1  1  1  1  1  1  1]] 

X * Sigma * X.T:
 [[1 0 0 0 0 0 0]
 [0 1 0 0 0 0 0]
 [0 0 1 0 0 0 0]
 [0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0]
 [0 0 0 0 0 1 0]
 [0 0 0 0 0 0 1]]
