In [3]:
import random
from tqdm import tqdm
import concurrent.futures
from functools import partial
import numba
import numpy as np
from multiprocessing import Pool


In [4]:
# @numba.njit
def miller_rabin(n, k):
#     n = int(n.replace("\n",""))
    # Implementation uses the Miller-Rabin Primality Test
    # The optimal number of rounds for this test is 40
    # See http://stackoverflow.com/questions/6325576/how-many-iterations-of-rabin-miller-should-i-use-for-cryptographic-safe-primes
    # for justification
    # If number is even, it's a composite number
    if n == 2:
        return True
    
    if n % 2 == 0:
        return False
    
    r, s = 0, n - 1
    while s % 2 == 0:
        r += 1
        s //= 2
    for _ in range(k):
        a = random.randrange(2, n - 1)
        x = pow(a, s, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(r - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        else:
            return False
    return True


# @numba.njit
def modular_pow(base, exponent, modulus):
    if modulus == 1:
        # smol
        #return np.int64(0)
        return np.ones_like(base)
    # FIXME Wikipedia says: Assert :: (modulus - 1) * (modulus - 1) does not overflow base
    # Not sure what it means by "base" here: the "base" argument, or the bit representation?
    # Schneier's original doesn't mention this assertion.
    #result = np.int64(1)
    result = np.zeros_like(base)
    base = (base % modulus)
    while exponent > 0:
        if (exponent % 2 == 1):
            result = (result * base) % modulus
        exponent = exponent >> 1
        base = (base * base) % modulus
    return result
    



In [5]:
# for x in range(5):
%time miller_rabin(73, 10)
# %time modular_pow(66, 39, 79)

CPU times: user 20 µs, sys: 5 µs, total: 25 µs
Wall time: 26.9 µs


True

In [6]:
# miller_rabin(999988277, 10)

In [34]:
orign = """
11111111111111111111111111111111111111111111111111111111111111
11111111111111111111111118888888888888888811111111111111111111
11111111111111111111118888888888888888888888111111111111111111
11111111111111111111888888881111111111111111111111111111111111
11111111111111111188888881111888888888888888881111111111111111
11111111111111111188888118888888888888888888888881111111111111
11111111111111118888881188888888888888888888888888881111111111
11111111111111118888881188888888888888888888888888881111111111
11111111111111118888881188888888888888888888888888888811111111
11111111111111118888888118888888888888888888888888888111111111
11111111118888118888888111111188888888888888888881111111111111
11111111118888118888888888811111118888888888111111111111111111
11111188888888118888888888888888111111111111188811111111111111
11111188888888118888888888888888888888888888888811111111111111
11111188888888118888888888888888888888888888888811111111111111
11111188888888118888888888888888888888888888888811111111111111
11111188888888118888888888888888888888888888888811111111111111
11111188888888118888888888888888888888888888888811111111111111
11111188888888118888888888888888888888888888888811111111111111
11111111888888118888888888888888888888888888888811111111111111
11111111888888118888888888888888888888888888888811111111111111
11111111111111118888888888888888888888888888888811111111111111
11111111111111111188888888111111111111888888881111111111111111
11111111111111111188888888111111111111888888881111111111111111
11111111111111111188888888111111111111888888881111111111111111
11111111111111111188888888111111111111888888881111111111111111
11111111111111111188888888111111111111888888881111111111111111
11111111111111111111111111111111111111111111111111111111111111
11111111111111111111111111111111111111111111111111111111111111
11111111111111111111111111111111111111111111111111111111111111
00000000000000000000000000000000000000000000000000000000000001
"""
orign = orign[1:]
number = int(orign.replace("\n",""))

In [32]:
# number padding 1911
# number padding + newline 1989

# @numba.njit
def make_number(n):
    nl = 5 - len(str(n))
    npadded = "0" * nl + str(n)
    ncopy = list(orign)
    for i, x in enumerate(npadded):
        ncopy[orign.index("0") + i] = x
    return "".join(ncopy)


In [36]:
%time print(make_number(712))

print(miller_rabin(int(make_number(712).replace("\n","")), 10))

11111111111111111111111111111111111111111111111111111111111111
11111111111111111111111118888888888888888811111111111111111111
11111111111111111111118888888888888888888888111111111111111111
11111111111111111111888888881111111111111111111111111111111111
11111111111111111188888881111888888888888888881111111111111111
11111111111111111188888118888888888888888888888881111111111111
11111111111111118888881188888888888888888888888888881111111111
11111111111111118888881188888888888888888888888888881111111111
11111111111111118888881188888888888888888888888888888811111111
11111111111111118888888118888888888888888888888888888111111111
11111111118888118888888111111188888888888888888881111111111111
11111111118888118888888888811111118888888888111111111111111111
11111188888888118888888888888888111111111111188811111111111111
11111188888888118888888888888888888888888888888811111111111111
11111188888888118888888888888888888888888888888811111111111111
1111118888888811888888888888888888888888888888881111111

In [21]:


def check_numbers(n_list, k=1):
# ThreadPoolExecutor
#     with concurrent.futures.ThreadPoolExecutor(8) as executor:
#         with tqdm(total=len(n_list)) as progress:
#             func = partial(miller_rabin, k=k)
#             futures = []
#             for n_candidate in n_list:
#                 future = executor.submit(func, int(n_candidate.replace("\n","")))
#                 future.add_done_callback(lambda p: progress.update())
#                 futures.append(future)
            
#             results = []
#             for future in futures:
#                 result = future.result()
#                 results.append(result)
#             return [results[i] for i, x in enumerate(results) if x]

    with Pool(8) as p:
        func = partial(miller_rabin, k=k)
#         r = list(tqdm(p.imap(func, n_list), total=len(n_list)))
        results = []
        with tqdm(total=len(n_list), desc="Starting...") as progress:
            for res in p.imap(func, [int(x.replace("\n","")) for x in n_list]):
                results.append(res)
                progress.set_description(f"Found {sum(results)} primes")
                progress.update()

            
        p.close()
        p.join()

        return [n_list[i] for i, x in enumerate(results) if x]
        
                
# Normal tqdm
#     good = []
#     t = tqdm(n_list, desc='Starting', leave=True)
#     for n_candidate in t:
#         t.set_description(f"Found {len(good)} primes")
#         if miller_rabin(int(n_candidate.replace("\n","")), k):
#             good.append(n_candidate)
#     return good

In [37]:
g = check_numbers([make_number(x) for x in range(10000)])
len(g)

Found 3 primes: 100%|██████████| 10000/10000 [23:00<00:00,  7.24it/s]


3

In [None]:
%time a = [make_number(x) for x in range(10000)]

In [38]:
for x in g:
    print(x)

11111111111111111111111111111111111111111111111111111111111111
11111111111111111111111118888888888888888811111111111111111111
11111111111111111111118888888888888888888888111111111111111111
11111111111111111111888888881111111111111111111111111111111111
11111111111111111188888881111888888888888888881111111111111111
11111111111111111188888118888888888888888888888881111111111111
11111111111111118888881188888888888888888888888888881111111111
11111111111111118888881188888888888888888888888888881111111111
11111111111111118888881188888888888888888888888888888811111111
11111111111111118888888118888888888888888888888888888111111111
11111111118888118888888111111188888888888888888881111111111111
11111111118888118888888888811111118888888888111111111111111111
11111188888888118888888888888888111111111111188811111111111111
11111188888888118888888888888888888888888888888811111111111111
11111188888888118888888888888888888888888888888811111111111111
1111118888888811888888888888888888888888888888881111111

In [39]:
# round 2
g2 = check_numbers(g, 30)
len(g2)

Found 3 primes: 100%|██████████| 3/3 [00:16<00:00,  5.38s/it]


3