## Setup: Make sure to run this before proceeding with any other steps

In [1]:
import os
import ctypes
import numpy as np

here = os.getcwd()
os.add_dll_directory(here)

cuda_bin = r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin"
os.add_dll_directory(cuda_bin)

lib = ctypes.CDLL("multi_ntt.dll")

poly_mul = lib.poly_mul_multi_c
poly_mul.restype = ctypes.c_int
poly_mul.argtypes = [
    ctypes.POINTER(ctypes.c_uint32),  # A_host
    ctypes.POINTER(ctypes.c_uint32),  # B_host
    ctypes.POINTER(ctypes.c_uint32),  # C_host
    ctypes.c_uint,                    # N
    ctypes.POINTER(ctypes.c_uint32),  # primitive_roots
    ctypes.POINTER(ctypes.c_uint32),  # primitive_roots_inv
    ctypes.POINTER(ctypes.c_uint32),  # primes
    ctypes.c_uint                     # num_polys
]

## Preparing the Function

In [2]:
def fast_multi_polynomial_multiplication(A, B, N, num_polys, primes, primitive_roots, primitive_roots_inv):
    """
    A: coefficients of the first polynomials, represented as a list of lists
       where each inner list contains the coefficients of a polynomial of degree N-1 as a 2D array

    B: coefficients of the second polynomials, represented as a list of lists
        where each inner list contains the coefficients of a polynomial of degree N-1 as a 2D array
    
    N: degree of the polynomials
    
    num_polys: number of polynomials to be multiplied
    """
    # Allocate space for the result
    A_host = np.array(sum(A, []), dtype=np.uint32)
    B_host = np.array(sum(B, []), dtype=np.uint32)
    C_host = np.zeros(num_polys * N, dtype=np.uint32)
    primitive_roots_numpy = np.array(primitive_roots, dtype=np.uint32)
    primitive_roots_inv_numpy = np.array(primitive_roots_inv, dtype=np.uint32)
    primes_numpy = np.array(primes, dtype=np.uint32)

    # Call the C function
    poly_mul(
        A_host.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32)),
        B_host.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32)),
        C_host.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32)),
        N,
        primitive_roots_numpy.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32)),
        primitive_roots_inv_numpy.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32)),
        primes_numpy.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32)),
        num_polys
    )

    return C_host.reshape((num_polys, N))

## Setting Up the Test Case

In [3]:
# test parameters
primes            = [536871001, 536871017, 536871089, 536871233, 536871337]
primitive_roots   = [11, 3, 3, 3, 10]
primitive_roots_inv = [146419364,178957006,178957030,178957078,375809936]
N = 4
num_polys = 5

A_blocks = [
    [51623921, 107100116, 420317839, 529122549],
    [514852150, 175001745,  12583352, 34364657],
    [24888899, 137505621, 368513557, 19284858],
    [326213995, 66090711, 416104718, 468974766],
    [28829878, 269046690, 230329201, 48678983],
]
B_blocks = [[1,0,0,1]] * num_polys

## Calling the Function

In [4]:
results = fast_multi_polynomial_multiplication(A_blocks, B_blocks, N, num_polys, primes, primitive_roots, primitive_roots_inv)

## Verifying the Results

In [None]:
expected = np.array([
    [481394806, 223653278, 428066291,  43875469],
    [339850405, 162418393, 515089712,  12345790],
    [424254367, 305863153, 349228699,  44173757],
    [260123284, 186857226, 484001185, 258317528],
    [296654525,  38717489, 181650218,  77508861],
], dtype=np.uint32)

if not np.array_equal(results, expected):
    print("Got:", results)
    print("Expected:", expected)
    raise RuntimeError("NTT polynomial-multiplication test failed!")
print("All 5 modular multiplications correct! 🎉")


All 5 modular multiplications correct! 🎉
