In [4]:

def is_power_of_two(n : int) -> bool:
    if n == 0:
        return False
    return (n & (n - 1)) == 0

def is_fermat_power(n : int) -> bool:
    if n < 4_000_000_000:
        return (n == 2) | (n == 4) | (n == 16) | (n == 256) | (n == 65536)
    elif not is_power_of_two(n):
        return False
    else:
        return is_power_of_two(n.bit_length() - 1)



In [16]:
def nim_sum(n : int,m : int) -> int:
    return n ^ m

def nim_prod(x : int,y : int) -> int:
    # first handle trivial cases
    if x == 0 | y == 0:
        return 0
    elif x == 1:
        return y
    elif y == 1:
        return x
    
    # next, use the rule for multiplying Fermat powers F=2 ** (2 ** n)
    # if x < F, then nim_prod(x,F)= x*F   (ordinary product)
    # and nim_prod(F,F) = 3*F/2
    m = min(x,y)
    M = max(x,y)
    
    if is_fermat_power(M):
        # nim product of Fermat power with smaller is ordinary product
        if m < M:
            return m*M
        else:
            # nim square of fermat power x is 3x/2
            return 3*M >> 1
    elif is_power_of_two(M):
        # if exponent is not power of 2, factor out 2's until it is
        # M = (factored) * (2 ** exponent)
        # we know at least one 2 needs to be pulled out        
        exponent = 1
        factored = M >> 1
        while not is_fermat_power(factored):
            factored >>= 1
            exponent += 1
        # now use formula for fermat power and associativity
        # m* M = (m * factored) * (2 ** exponent) = (m * (2 ** exponent)) * factored
        # we have to re-order the parentheses carefully to avoid infinite loop
        intermediate = nim_prod(m, factored)
        if intermediate == M:
            intermediate = nim_prod(m, 1 << exponent)
            return nim_prod(intermediate, factored)
        return nim_prod(intermediate, 1 << exponent)
    else:
        # otherwise, write it as the sum of powers of 2 and distribute
        sum = 0
        for index in range(M.bit_length()):
            if M >> index & 1 == 1:
                sum ^= nim_prod(m, 1 << index)
        return sum 
    
nim_prod(2,8)

12

In [17]:
import numpy as np
import sympy as sp

M = np.zeros((16,16),dtype=int)
for x in range(16):
    for y in range(16):
        M[x,y] = nim_prod(x,y)

sp.Matrix(M)

Matrix([
[0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
[0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
[0,  2,  3,  1,  8, 10, 11,  9, 12, 14, 15, 13,  4,  6,  7,  5],
[0,  3,  1,  2, 12, 15, 13, 14,  4,  7,  5,  6,  8, 11,  9, 10],
[0,  4,  8, 12,  6,  2, 14, 10, 11, 15,  3,  7, 13,  9,  5,  1],
[0,  5, 10, 15,  2,  7,  8, 13,  3,  6,  9, 12,  1,  4, 11, 14],
[0,  6, 11, 13, 14,  8,  5,  3,  7,  1, 12, 10,  9, 15,  2,  4],
[0,  7,  9, 14, 10, 13,  3,  4, 15,  8,  6,  1,  5,  2, 12, 11],
[0,  8, 12,  4, 11,  3,  7, 15, 13,  5,  1,  9,  6, 14, 10,  2],
[0,  9, 14,  7, 15,  6,  1,  8,  5, 12, 11,  2, 10,  3,  4, 13],
[0, 10, 15,  5,  3,  9, 12,  6,  1, 11, 14,  4,  2,  8, 13,  7],
[0, 11, 13,  6,  7, 12, 10,  1,  9,  2,  4, 15, 14,  5,  3,  8],
[0, 12,  4,  8, 13,  1,  9,  5,  6, 10,  2, 14, 11,  7, 15,  3],
[0, 13,  6, 11,  9,  4, 15,  2, 14,  3,  8,  5,  7, 10,  1, 12],
[0, 14,  7,  9,  5, 11,  2, 12, 10,  4, 13,  3, 15,  1,  8,  6],
[0, 15,  5, 10, 

In [22]:
def test_func(N : int) -> None:
    for n in range(N.bit_length()):
        print(f'The {n}th bit of {N} is {N >> n & 1}')

test_func(999)
print(f'{bin(999) =}')

The 0th bit of 999 is 1
The 1th bit of 999 is 1
The 2th bit of 999 is 1
The 3th bit of 999 is 0
The 4th bit of 999 is 0
The 5th bit of 999 is 1
The 6th bit of 999 is 1
The 7th bit of 999 is 1
The 8th bit of 999 is 1
The 9th bit of 999 is 1
bin(999) ='0b1111100111'
