In [10]:
# This program mimics the MatSPL representation in Spiral
# General Design:
#     Everything is a Mat
#     Mats are evaluated using getMat, which is done as late as possible to to maximize flexibility
#     There are 4 base Mats (NTT,TW,L,I) which correspond to the Spiral definitions
#     There are 2 binary Mat operations (MM (Matrix Multiply), Tensor) which are themselves Mats

class Mat:
    def __init__(self,n,p):
        self.n=n
        self.p=p
    #returns nxn matrix
    def getMat(self):
        pass
    def __str__(self):
        return "---------\n{}".format(str(self.getMat()))

class NTT(Mat):
    def __init__(self,n,p,NW=False):
        Mat.__init__(self,n,p);
        if NW:
            assert (p-1)%(2*n)==0
        else:
            assert (p-1)%n==0
        k=(p-1)//n
        r=primitive_root(p)
        self.psi=power_mod(r,k//2,p)
        self.w=power_mod(r,k,p)
        self.getMat=self.term_psi if NW else self.term
    def term(self):
        return Matrix([[power_mod(self.w,i*j,self.p) for i in range(self.n)] for j in range(self.n)])
    def term_psi(self):
        return Matrix([[power_mod(self.w,i*j,self.p)*power_mod(self.psi,i,self.p)%self.p for i in range(self.n)] for j in range(self.n)])

class Tw(Mat):
    def __init__(self,n,p,m,NW=False):
        Mat.__init__(self,n,p)
        assert n%m==0
        self.NW=NW
        self.m=m
        k=(p-1)//n
        r=primitive_root(p)
        self.psi=power_mod(r,k//2,p)
        self.w=power_mod(r,k,p)
    def getMat(self):
        vec=[]
        for i in range(self.n//self.m):
            for j in range(self.m):
                if self.NW:
                    vec.append((power_mod(self.w,i*j,self.p)*power_mod(self.psi,(-m+1)*i,self.p))%self.p)
                else:
                    vec.append((power_mod(self.w,i*j,self.p)))
        ret = Matrix([[vec[i] if i==j else 0 for i in range(self.n)] for j in range(self.n)])
        return ret

class Tensor(Mat):
    def __init__(self,m1,m2):
        Mat.__init__(self,m1.n*m2.n,m1.p)
        assert m1.p==m2.p
        self.m1=m1
        self.m2=m2
    def getMat(self):
        m1m=self.m1.getMat()
        m2m=self.m2.getMat()
        ret = [[0 for i in range(self.n)] for j in range(self.n)]
        for i in range(self.n):
            for j in range(self.n):
                ret[i][j]=m1m[i//m2m.nrows()][j//m2m.nrows()]*m2m[i%m2m.nrows()][j%m2m.nrows()]%self.p
        return Matrix(ret)

class MM(Mat):
    def __init__(self,m1,m2):
        Mat.__init__(self,m1.n,m1.p)
        assert m1.n==m2.n and m1.p==m2.p
        self.m1=m1
        self.m2=m2
    def getMat(self):
        ret = self.m1.getMat()*self.m2.getMat()
        return Matrix([[ret[j][i]%self.p for i in range(self.n)] for j in range(self.n)])
    
class I(Mat):
    def __init__(self,n,p):
        Mat.__init__(self,n,p)
    def getMat(self):
        return Matrix([[1 if i==j else 0 for i in range(self.n)] for j in range(self.n)])

class L(Mat):
    def __init__(self,n,p,str):
        Mat.__init__(self,n,p)
        self.str=str
        self.m=n//str
    def perm(self,i):
        return i//self.m+self.str*(i%self.m)
    def getMat(self):
        return Matrix([[1 if self.perm(j)==i else 0 for i in range(self.n)] for j in range(self.n)])

#some helper functions
def equal(m1,m2):
    if m1.n!=m2.n or m1.p!=m2.p:
        return False
    m1m=m1.getMat()
    m2m=m2.getMat()
    for i in range(m1.n):
        for j in range(m1.n):
            if m1m[i][j]!=m2m[i][j]:
                return False
    return True

def transform(m,v):
    mm=m.getMat()
    ret = mm*v
    return [i%m.p for i in ret]

def getMod(n,NW=False):
    c=n
    if NW:
        c*=2
    while (not is_prime(c+1)):
        c+=2*n
    return c+1

In [11]:
#radix-n/m FFT for NTT

#setup
n=6
m=3
p=getMod(n)
print(p)

#transform matrices
m1=Tensor(NTT(n//m,p),I(m,p))
m2=Tw(n,p,m)
m3=Tensor(I(n//m,p),NTT(m,p))
m4=L(n,p,n//m)
print(m1)
print(m2)
print(m3)
print(m4)

#multiply and check result
res=MM(MM(m1,m2),MM(m3,m4))
print(res)
print(NTT(n,p))
print("FFT result is equivalent:")
print(equal(res,NTT(n,p)))

7
---------
[1 0 0 1 0 0]
[0 1 0 0 1 0]
[0 0 1 0 0 1]
[1 0 0 6 0 0]
[0 1 0 0 6 0]
[0 0 1 0 0 6]
---------
[1 0 0 0 0 0]
[0 1 0 0 0 0]
[0 0 1 0 0 0]
[0 0 0 1 0 0]
[0 0 0 0 3 0]
[0 0 0 0 0 2]
---------
[1 1 1 0 0 0]
[1 2 4 0 0 0]
[1 4 2 0 0 0]
[0 0 0 1 1 1]
[0 0 0 1 2 4]
[0 0 0 1 4 2]
---------
[1 0 0 0 0 0]
[0 0 1 0 0 0]
[0 0 0 0 1 0]
[0 1 0 0 0 0]
[0 0 0 1 0 0]
[0 0 0 0 0 1]
---------
[1 1 1 1 1 1]
[1 3 2 6 4 5]
[1 2 4 1 2 4]
[1 6 1 6 1 6]
[1 4 2 1 4 2]
[1 5 4 6 2 3]
---------
[1 1 1 1 1 1]
[1 3 2 6 4 5]
[1 2 4 1 2 4]
[1 6 1 6 1 6]
[1 4 2 1 4 2]
[1 5 4 6 2 3]
FFT result is equivalent:
True


In [12]:
#radix-n/m FFT for Negative Wrapped NTT

#setup
n=6
m=3
p=getMod(n,NW=True)
print(p)

#transform matrices
m1=Tensor(NTT(n//m,p,NW=True),I(m,p))
m2=Tw(n,p,m,NW=True)
m3=Tensor(I(n//m,p),NTT(m,p,NW=True))
m4=L(n,p,n//m)
print(m1)
print(m2)
print(m3)
print(m4)

#multiply and check result
res=MM(MM(m1,m2),MM(m3,m4))
print(res)
print(NTT(n,p,NW=True))
print("FFT result is equivalent:")
print(equal(res,NTT(n,p,NW=True)))

13
---------
[1 0 0 8 0 0]
[0 1 0 0 8 0]
[0 0 1 0 0 8]
[1 0 0 5 0 0]
[0 1 0 0 5 0]
[0 0 1 0 0 5]
---------
[ 1  0  0  0  0  0]
[ 0  1  0  0  0  0]
[ 0  0  1  0  0  0]
[ 0  0  0 10  0  0]
[ 0  0  0  0  1  0]
[ 0  0  0  0  0  4]
---------
[ 1  4  3  0  0  0]
[ 1 12  1  0  0  0]
[ 1 10  9  0  0  0]
[ 0  0  0  1  4  3]
[ 0  0  0  1 12  1]
[ 0  0  0  1 10  9]
---------
[1 0 0 0 0 0]
[0 0 1 0 0 0]
[0 0 0 0 1 0]
[0 1 0 0 0 0]
[0 0 0 1 0 0]
[0 0 0 0 0 1]
---------
[ 1  2  4  8  3  6]
[ 1  8 12  5  1  8]
[ 1  6 10  8  9  2]
[ 1 11  4  5  3  7]
[ 1  5 12  8  1  5]
[ 1  7 10  5  9 11]
---------
[ 1  2  4  8  3  6]
[ 1  8 12  5  1  8]
[ 1  6 10  8  9  2]
[ 1 11  4  5  3  7]
[ 1  5 12  8  1  5]
[ 1  7 10  5  9 11]
FFT result is equivalent:
True
