In [23]:
import math

known_primes = [2, 3]


def get_next_prime_gap():
    """Returns the next prime gap after the last known prime."""
    p1 = known_primes[-1]

    # Get the square root of p1, rounded up to the next integer
    p1_sqrt = math.ceil(math.sqrt(p1))

    # Get a set of known primes up to the square root of p1
    prime_divisors = [d for d in known_primes if d <= p1_sqrt]

    # Calculate moduli of p1 with respect to known primes
    blocked_gaps = [(-p1 % p) for p in prime_divisors]

    # Early exit - if 2 is not blocked then the next prime is p1 + 2
    if 2 not in blocked_gaps:
        return 2

    # Initialise modular cycle buckets for each prime divisor - initialise with blocked gaps value
    gap_buckets = blocked_gaps.copy()

    # Check each potential gap > 2 - if a gap is not blocked then return p1 + gap
    test_gap = 2  # We already checked gap of 2 above, so the first test gap will be 4

    while True:
        test_gap_blocked = False
        test_gap += 2
        if test_gap in gap_buckets:
            continue
        else:
            # Calculate new gaps by cycling the moduli of each prime divisor up to test_gap
            prime_divisors_index = 1  # Skip divisor 2 as we are only checking even gaps
            while prime_divisors[prime_divisors_index] < test_gap:
                while gap_buckets[prime_divisors_index] < test_gap:
                    gap_buckets[prime_divisors_index] += prime_divisors[prime_divisors_index]
                    if gap_buckets[prime_divisors_index] == test_gap:
                        test_gap_blocked = True
                        break
                if test_gap_blocked:
                    break
                prime_divisors_index += 1
                if prime_divisors_index >= len(prime_divisors):
                    return test_gap

        # found the gap
        if not test_gap_blocked:
            return test_gap


def get_next_prime():
    """Returns the next prime after the last known prime."""
    known_primes.append(known_primes[-1] + get_next_prime_gap())
    return known_primes[-1]


def get_prime_after(n):
    """Returns the first prime greater than n."""
    # If n is less than the last known prime, return the first known prime greater than n
    if n < known_primes[-1]:
        for p in known_primes:
            if p > n:
                return p
    # Make sure we have all known primes up to sqrt(n)
    while known_primes[-1] <= math.ceil(math.sqrt(n)):
        get_next_prime()

    # Get the square root of p1, rounded up to the next integer
    n_sqrt = math.ceil(math.sqrt(n))

    # Get a set of known primes up to the square root of p1
    prime_divisors = [d for d in known_primes if d <= n_sqrt]

    # Calculate moduli of p1 with respect to known primes
    blocked_gaps = [(-n % p) for p in prime_divisors]

    # Early exit - if 2 is not blocked then the next prime is p1 + 2
    if (1 if n % 2 == 0 else 2) not in blocked_gaps:
        return n + (1 if n % 2 == 0 else 2)

    # Initialise modular cycle buckets for each prime divisor - initialise with blocked gaps value
    gap_buckets = blocked_gaps.copy()

    # Check each potential gap > 2 - if a gap is not blocked then return p1 + gap
    test_gap = 1 if n % 2 == 0 else 2  # We already checked gap of 2 above, so the first test gap will be 4

    while True:
        test_gap_blocked = False
        test_gap += 2
        if test_gap in gap_buckets:
            continue
        else:
            # Calculate new gaps by cycling the moduli of each prime divisor up to test_gap
            prime_divisors_index = 1  # Skip divisor 2 as we are only checking even gaps
            while prime_divisors[prime_divisors_index] < test_gap:
                while gap_buckets[prime_divisors_index] < test_gap:
                    gap_buckets[prime_divisors_index] += prime_divisors[prime_divisors_index]
                    if gap_buckets[prime_divisors_index] == test_gap:
                        test_gap_blocked = True
                        break
                if test_gap_blocked:
                    break
                prime_divisors_index += 1
                if prime_divisors_index >= len(prime_divisors):
                    return n + test_gap

        # found the gap
        if not test_gap_blocked:
            return n + test_gap    


prime = 7213393222

print(get_prime_after(prime))

7213393223
