In [1]:
import numpy as np

In [2]:
def period(npa):
    for a in range(1, len(npa) // 2):
        shift = np.roll(npa, a)
        if np.array_equal(npa[a:], shift[a:]):
            return a
    return None

In [3]:
def calc_cols(p, cmax, dmax, verbose):
    prev = np.array([1 for a in range(dmax)])
    cols = []
    for a in range(cmax):
        if verbose:
            print(period(prev, dmax), "\t", prev[:display])
        cols.append(prev)
        prev = np.cumsum(prev) % p
    return cols

In [4]:
def calc_snps(cols, nmax, verbose=False):
    unique = [1]
    stacked = [None for a in range(len(cols))]
    # periods = [period(cols[a]) for a in range(len(cols))]
    periods = [None for a in range(len(cols))]
    reduced = [cols[a][:(nmax + periods[a])] if periods[a] != None else cols[a] for a in range(len(cols))]
    if verbose:
        print(1, end="")
    for n in range(1, nmax):
        snp = None
        for a in range(len(cols)):
            if stacked[a] is None:
                stacked[a] = np.array([reduced[a][nmax:]])
            else:
                stacked[a] = np.vstack((stacked[a], np.roll(reduced[a], n - 1)[nmax:]))
            if snp is None:
                snp = np.unique(stacked[a], axis=1)
            else:
                snp = np.unique(np.hstack((snp, stacked[a])), axis=1)
        unique.append(snp.shape[1])
        if verbose:
            print(",", snp.shape[1], end="")
    if verbose:
        print()
    return unique

In [5]:
def gen_m(p, n):
    cols = calc_cols(p, n, n, False)
    shifts = []
    for a in range(len(cols)):
        shift = np.zeros(cols[a].shape, dtype=np.int)
        shift[a:] = cols[a][:-a] if a != 0 else cols[a]
        shifts.append(shift)
    return np.transpose(np.vstack(shifts))

In [6]:
def calc_classes(m, n, p):
    class_1 = set()
    class_2 = set()
    for c in range(m.shape[0]):
        rolled = m[:,c]
        for r in range(m.shape[1]):
            for a in range(p):
                for b in range(p):
                    seq = rolled.copy()
                    seq[r+1:] = (a * seq[r+1:]) % p
                    seq[:r+1] = (b * seq[:r+1]) % p
                    
                    if not (a == 0 or b == 0):
                        class_1.add(tuple(seq))
            rolled = np.roll(rolled, 1)
            
    for c in range(m.shape[0]):
        rolled = m[:,c]
        for r in range(m.shape[1]):
            for a in range(p):
                for b in range(p):
                    seq = rolled.copy()
                    seq[r+1:] = (a * seq[r+1:]) % p
                    seq[:r+1] = (b * seq[:r+1]) % p
                    
                    if a == 0 or b == 0:
                        if tuple(seq) not in class_1:
                            class_2.add(tuple(seq))
            rolled = np.roll(rolled, 1)
    
    return class_1, class_2

In [7]:
def print_classes(n, p):
    classes = calc_classes(gen_m(p, n), n, p)
    print(len(classes[0]), len(classes[1]))

In [8]:
def class_1_closed_form(n, p):
    return (p - 1) * (p * n * n + 1) // (p + 1) + (p - 1) * (p - 2) * n * (n - 1) // 2

In [16]:
def class_2_minus_1_closed_form(n, p):
    return  1 + (p - 1) * (n - 1) * ((p - 1) * n - 2) // (p + 1)

In [27]:
P = 5
K = 3
N = P**K

print_classes(N, P)
class_1_closed_form(N, P), class_2_minus_1_closed_form(N, P)

145084 41169


(145084, 41169)

In [61]:
m = gen_m(P, N)
nonzero = m != 0
print(nonzero, np.sum(nonzero))

[[ True False False False False False False False False]
 [ True  True False False False False False False False]
 [ True  True  True False False False False False False]
 [ True False False  True False False False False False]
 [ True  True False  True  True False False False False]
 [ True  True  True  True  True  True False False False]
 [ True False False  True False False  True False False]
 [ True  True False  True  True False  True  True False]
 [ True  True  True  True  True  True  True  True  True]] 36
