In [1]:
# This is my mersenne twister from the last challenge 

class MersenneTwister(object):
    # First start with constants
    w, n, m, r = 32, 624, 397, 31
    a = 0x9908B0DF
    u, d = 11, 0xFFFFFFFF
    s, b = 7, 0x9D2C5680
    t, c = 15, 0xEFC60000
    l = 18
    f = 1812433253
    
    mt = [0] * n

    index = n+1
    lower_mask: int = (1 << r) - 1
    upper_mask: int = (~lower_mask) & sum( 0b1 << shift for shift in range(w) )
    
    def __init__(self, seed: int):
        mt = self.mt # just so I don't have to write self a bunch of times

        self.index = self.n
        mt[0] = seed
        for i in range(1, self.n):
            mt[i] = sum( 0b1 << shift for shift in range(self.w) ) & ( self.f * ( mt[i-1] ^ ( mt[i-1] >> (self.w-2) )) + 1 )
            
    def __twist__(self):
        mt = self.mt
        for i in range(self.n):
            x = ( mt[i] & self.upper_mask ) | ( mt[ (i+1) % self.n ] & self.lower_mask )
            xA = x >> 1
            if x % 2 != 0:
                xA ^= self.a
            mt[i] = mt[ (i+self.m) % self.n ] ^ xA
        self.index = 0
        
    def __call__(self):
        mt = self.mt
        
        if self.index >= self.n:
            if self.index > self.n:
                assert False, "Generator was never seeded"
            else:
                self.__twist__()
        
        y  =  mt[self.index]
        y ^= (y >> self.u) & self.d
        y ^= (y << self.s) & self.b
        y ^= (y << self.t) & self.c
        y ^= (y >> self.l)
    
        self.index += 1
    
        return y & sum( 0b1 << shift for shift in range(self.w) )
        

In [2]:
# Now break out the "tempering" step for simplicity

def Temper(y: int):
    u, d = 11, 0xFFFFFFFF
    s, b = 7, 0x9D2C5680
    t, c = 15, 0xEFC60000
    l = 18
    f = 1812433253
    
    y ^= (y >> u) & d
    y ^= (y << s) & b
    y ^= (y << t) & c
    y ^= (y >> l)
    
    return y

In [3]:
# Now I need to make an inverse for this function... 
# I am first going to make a very simplified version of this function with only one step

BITSHIFT = 7
MAXBITS = 32
MASK = 0x9D2C5680

def SimpleTemper(y: int, maxbits = MAXBITS, bitshift = BITSHIFT, mask = MASK):
    assert y < ( 0b1<<maxbits)
    assert bitshift < maxbits
    if bitshift > 0:
        y ^= ( y >> bitshift ) & mask
    elif bitshift < 0:
        y ^= ( y << -bitshift ) & mask
    else:
        assert False
    return y

In [4]:
# Now I'll attempt to invert this function

def InvertSimpleTemper(x: int, maxbits = MAXBITS, bitshift = BITSHIFT, mask = MASK):
    # x = y ^ ( y >> 1 )
    # x^x = x ^ y ^ ( y >> 1 )
    # 0 = x ^ y ^ ( y >> 1 )
    
    # solve the highest bits first, since anything higher must be a zero
    
    result = [None] * maxbits
    assert maxbits > bitshift
    
    if bitshift > 0:
        for n in reversed(range(maxbits)) :
            if n+bitshift >= maxbits:
                # I'm in the top section of bits so anything shifted in is assumed to be zero
                x_bit = x & ( 0b1 << n ) # corresponding bit in x
                mask_bit = mask & ( 0b1 << n ) # corresponding bit in the mask

                if x_bit:
                    # the xbit is set and the higher y bit must be zero, so to produce 0 this bit must be set
                    result[n] = 1
                else:
                    # the xbit is clear and the higher y bit must be zero, so to produce 0 this bit must be clear
                    result[n] = 0
            else:
                # There should be a higher y bit to check
                y_high = result[n+bitshift]
                assert y_high is not None
                x_bit = x & ( 0b1 << n ) # corresponding bit in x
                mask_bit = mask & ( 0b1 << n ) # corresponding bit in the mask


                # For each of these sets of bits, an even number should be set to produce 0 when XORed together
                if x_bit and ( y_high and mask_bit ):
                    result[n] = 0
                elif x_bit and not( y_high and mask_bit ):
                    result[n] = 1
                elif not(x_bit) and ( y_high and mask_bit ):
                    result[n] = 1
                elif not(x_bit) and not( y_high and mask_bit ):
                    result[n] = 0
                else:
                    # THIS SHOULD NEVER HAPPEN
                    assert False
                    
    elif bitshift < 0:
        bitshift *= -1
        for n in range(maxbits):
            if n < bitshift:
                # I'm in the bottom section of bits so anything shifted in must be 0
                x_bit = x & ( 0b1 << n ) # corresponding bit in x
                mask_bit = mask & ( 0b1 << n ) # corresponding bit in the mask

                if x_bit:
                    # the xbit is set and the lower y bit must be zero, so to produce 0 this bit must be set
                    result[n] = 1 
                else:
                    # the xbit is clear and the lower y bit must be zero, so to produce 0 this bit must be clear
                    result[n] = 0
            else:                
                # There should be a higher y bit to check
                y_low = result[n-bitshift]
                assert y_low is not None
                x_bit = x & ( 0b1 << n ) # corresponding bit in x
                mask_bit = mask & ( 0b1 << n ) # corresponding bit in the mask

                # For each of these sets of bits, an even number should be set to produce 0 when XORed together
                if x_bit and ( y_low and mask_bit ):
                    result[n] = 0
                elif x_bit and not( y_low and mask_bit ):
                    result[n] = 1
                elif not(x_bit) and ( y_low and mask_bit ):
                    result[n] = 1
                elif not(x_bit) and not( y_low and mask_bit ):
                    result[n] = 0
                else:
                    # THIS SHOULD NEVER HAPPEN
                    assert False
                           
    else:
        assert False
        
    assert {*result} <= {0, 1}
    return sum( value << n for n, value in enumerate(result) )

In [5]:
u, d = 11, 0xFFFFFFFF
s, b = 7, 0x9D2C5680
t, c = 15, 0xEFC60000
l = 18

x = 0xFFFFFFFF

y = SimpleTemper(x, bitshift = t, mask = c)
print(y)

4294967295


In [6]:
InvertSimpleTemper(y, bitshift = t, mask = c)

4294967295

In [7]:
# Make another version of the tempering function with parity to the simple tempers
# Positive shift indicate left shift, and negative right shift

def Temper2(y: int):
    u, d = 11, 0xFFFFFFFF
    s, b = 7, 0x9D2C5680
    t, c = 15, 0xEFC60000
    l = 18
    
    y = SimpleTemper(y, bitshift = u,  mask = d)
    y = SimpleTemper(y, bitshift = -s, mask = b)
    y = SimpleTemper(y, bitshift = -t, mask = c)
    y = SimpleTemper(y, bitshift = l,  mask = ~0)
    
    return y

In [8]:
# Check that Temper2 is equivalent to the orginal Temper

x = 543543543
assert Temper2(x) == Temper(x)

In [9]:
# Now build an invert templer function by chaining together invert simple tempers

def InvertTemper(y):
    u, d = 11, 0xFFFFFFFF
    s, b =  7, 0x9D2C5680
    t, c = 15, 0xEFC60000
    l = 18
    
    y = InvertSimpleTemper(y, bitshift = l, mask = ~0)
    y = InvertSimpleTemper(y, bitshift = -t, mask = c)
    y = InvertSimpleTemper(y, bitshift = -s, mask = b)
    y = InvertSimpleTemper(y, bitshift = u, mask = d)
    
    return y

In [10]:
x = 543543543
y = SimpleTemper(x, bitshift = -s, mask = b)
xprime = InvertSimpleTemper(y, bitshift = -s, mask = b)

assert x == xprime

In [11]:
# Great success! We have now successfully inverted the temper function
# Now I'm going to re-write the Mersenne twister using my Temper2 function

class MersenneTwister(object):
    # First start with constants
    w, n, m, r = 32, 624, 397, 31
    a = 0x9908B0DF
    u, d = 11, 0xFFFFFFFF
    s, b = 7, 0x9D2C5680
    t, c = 15, 0xEFC60000
    l = 18
    f = 1812433253
    
    mt = [0] * n

    index = n+1
    lower_mask: int = (1 << r) - 1
    upper_mask: int = (~lower_mask) & sum( 0b1 << shift for shift in range(w) )
    
    def __init__(self, seed: int):
        mt = self.mt # just so I don't have to write self a bunch of times

        self.index = self.n
        mt[0] = seed
        for i in range(1, self.n):
            mt[i] = sum( 0b1 << shift for shift in range(self.w) ) & ( self.f * ( mt[i-1] ^ ( mt[i-1] >> (self.w-2) )) + 1 )
            
    def __twist__(self):
        mt = self.mt
        for i in range(self.n):
            x = ( mt[i] & self.upper_mask ) | ( mt[ (i+1) % self.n ] & self.lower_mask )
            xA = x >> 1
            if x % 2 != 0:
                xA ^= self.a
            mt[i] = mt[ (i+self.m) % self.n ] ^ xA
        self.index = 0
        
    def __call__(self):
        mt = self.mt
        
        if self.index >= self.n:
            if self.index > self.n:
                assert False, "Generator was never seeded"
            else:
                self.__twist__()
        
        y  =  mt[self.index]
        y  = Temper2(y)
        self.index += 1
    
        return y & sum( 0b1 << shift for shift in range(self.w) )

In [12]:
# And test it out

twister1 = MersenneTwister(seed = 123)

output1 = [ twister1() for n in range(624) ] # The state array mt should be this long
print(output1[:100])

[3225431814, 882238445, 2871375431, 2627082018, 1557640367, 330380317, 2228864956, 951923467, 3729394512, 3228856173, 3993000669, 1724365817, 1938047254, 1748532678, 3141670923, 2550503813, 4204086341, 978465857, 3630164559, 1309635741, 4163584072, 1899032274, 2527644726, 1483861524, 2892090636, 4157375774, 1016727817, 2039534179, 2817183000, 3152003499, 3240155330, 1944951299, 2472361193, 1527317263, 3503716578, 1543476778, 2071533355, 835057724, 986310944, 3504289539, 184520592, 4061051181, 735914624, 2491012502, 2983064219, 2700248934, 1402045086, 437341584, 3452060919, 62177866, 96636628, 2378339206, 3320370570, 3342921357, 1634978469, 2910179097, 1487528203, 86453469, 4220604508, 1051998411, 3124295924, 234555258, 1843034486, 585778021, 1794193701, 547624969, 1109911856, 1483215375, 277504901, 2475625659, 1181541282, 545617532, 3071839701, 3298776087, 3455401253, 2705283757, 1140108236, 2567099494, 3553586555, 115454251, 3269658412, 2181780318, 555926085, 142771657, 1864636022, 24

In [13]:
twister2 = MersenneTwister(seed = 0)
twister2() # need to call it once to initialize stuff
twister2.mt[:] = [ InvertTemper(value) for value in output1 ]

assert twister2.mt == twister1.mt

In [14]:
# twister1 has been called 624 times so we need to update twister2 accordingly

twister2.index = 624

assert twister1.index == twister2.index

In [15]:
twister1.index

624

In [16]:
twister2.index

624

In [17]:
twister1()

3510634545

In [18]:
twister2()

3009426959

In [19]:
twister1()

2424507078

In [20]:
twister2()

2424507078

In [21]:
assert twister1.index == twister2.index
assert twister1.mt == twister2.mt

In [22]:
for n in range(100):
    if twister1() != twister2():
        print(f'Error on {n}')
        break

In [23]:
# Looks like twister1 has succesfully copied the state of twister2