# GF(2)-linear RNG Jump Ahead

## Intro
This document serves to explain how the math behind jump ahead works, along with example code for how this can be implemented for Xoroshiro128+ (although it may be fairly easily applied to other PRNGs).

A lot of the explanation of the math is repeating what is described in “Efficient Jump Ahead for F2-Linear Random Number Generators” in a way that I feel is easier to understand.

## Matrix Representation of PRNG

The sequence of the states of GF(2)-linear PRNGs can be represented as 

**x**ₙ = **Ax**ₙ₋₁

where **x**ₙ is the *k*-bit state vector under GF(2) of the RNG at step *n* and **A** is the *k*×*k* transition matrix who's elements are in GF(2).

In [20]:
import numpy as np

def mat_rotl(n, size = 64):
    return np.roll(np.identity(size, dtype = np.uint8), -n, axis = 0)

def mat_shift(n, size = 64):
    return np.eye(size, k = n, dtype = np.uint8)

k = 128 # size of state
state = np.random.randint(0, 2, k, np.uint8) # random starting bits of rng state

"""
https://xoshiro.di.unimi.it/xoroshiro128plus.c
uint64_t next(void) {
	const uint64_t s0 = s[0];
	uint64_t s1 = s[1];
	const uint64_t result = s0 + s1;

	s1 ^= s0;
	s[0] = rotl(s0, 24) ^ s1 ^ (s1 << 16); // a, b
	s[1] = rotl(s1, 37); // c

	return result;
}
"""

s0_mat = np.zeros((128, 64), np.uint8)
s1_mat = np.zeros((128, 64), np.uint8)

# const uint64_t s0 = s[0];
s0_mat[0:64] = np.identity(64, np.uint8)
# uint64_t s1 = s[1];
s1_mat[64:128] = np.identity(64, np.uint8)

# s1 ^= s0;
s1_mat ^= s0_mat

# s[0] = rotl(s0, 24) ...
s0_mat = (s0_mat @ mat_rotl(24)) % 2
# ... ^ s1 ...
s0_mat ^= s1_mat
# ... ^ (s1 << 16);
s0_mat ^= (s1_mat @ mat_shift(16)) % 2

# s[1] = rotl(s1, 37);
s1_mat = (s1_mat @ mat_rotl(37)) % 2

# Xoroshiro128+ transformation matrix
A = np.hstack((s0_mat, s1_mat))

def advance_rng(x_n_minus_1):
    x_n = (x_n_minus_1 @ A) % 2
    return x_n

## Matrix Representation of PRNG jump

In order to jump the PRNG, one must be able to compute:

**x**ₙ₊ᵥ = **Jx**ₙ

where *v* is the amount of steps to jump and

**J** = **A**ᵛ

defines J as a matrix that describes jumping the PRNG *v* steps.

Computing J can be done via exponentiation by squares, but this can be very computationally intensive and may require a lot of memory depending on how large *k* is.

## Characteristic Polynomial

The characteristic polynomial of the matrix **A** is defined as:

p(z) = det(z**I** + **A**) = zᵏ + α₁zᵏ⁻¹ + ... + αₖ₋₁z + αₖ

where **I** is the identity matrix and each coefficient αⱼ is in GF(2) and thus is either 0 or 1.

The coefficients of this characteristic polynomial can be computed once for any given matrix A via sympy.

In [21]:
from sympy import Matrix
from functools import reduce

coeffs = tuple(x & 1 for x in Matrix(A).charpoly().all_coeffs())
representation = " + ".join(f"z**{i}" for i, coeff in reversed(tuple(enumerate(coeffs))) if coeff == 1)
print(representation)

char_poly = reduce(lambda p, q: (p << 1) | (q & 1), coeffs)
print(hex(char_poly))


z**128 + z**115 + z**113 + z**111 + z**109 + z**108 + z**105 + z**104 + z**103 + z**102 + z**100 + z**98 + z**95 + z**94 + z**92 + z**91 + z**90 + z**88 + z**87 + z**86 + z**85 + z**81 + z**80 + z**79 + z**77 + z**76 + z**74 + z**72 + z**69 + z**64 + z**62 + z**60 + z**58 + z**57 + z**56 + z**55 + z**50 + z**48 + z**47 + z**45 + z**44 + z**43 + z**40 + z**36 + z**34 + z**31 + z**30 + z**29 + z**25 + z**23 + z**17 + z**13 + z**0
0x10008828e513b43d5095b8f76579aa001


The coefficients of the characteristic polynomial can be condensed into a single integer where each bit represents the value of its respective coefficient. All computations on polynomials will be done in this form.

## Jump Polynomial

From the Cayley–Hamilton theorem, it is known that:

p(**A**) = **A**ᵏ + α₁**A**ᵏ⁻¹ + ... + αₖ₋₁**A** + αₖ**I** = 0

If we define a polynomial g(z) as:

g(z) = zᵛ % p(z) = a₁zᵏ⁻¹ + ... + aₖ₋₁z + aₖ

and therefore:

g(z) + q(z)p(z) = zᵛ

...

g(z) = zᵛ - q(z)p(z)

for some polynomial q(z).

We can now realize that for z = **A**:

g(**A**) = **A**ᵛ - q(**A**)p(**A**)

...

g(**A**) = **A**ᵛ - q(**A**) * 0

...

g(**A**) = **A**ᵛ

...

g(**A**) = **A**ᵛ % p(**A**) = **A**ᵛ = J

...

**J** = **A**ᵛ = g(**A**)

This polynomial g(x) will be referred to as the "jump polynomial" as it is the equivalent of multipling by **J** and will jump **v** steps.

## Computing Jump Polynomial

The goal of computing **x**ₙ₊ᵥ can be done by applying the jump polynomial g(z) = zᵛ % p(z) = a₁zᵏ⁻¹ + ... + aₖ₋₁z + aₖ.

In order to compute this jump polynomial we need to be able to compute zᵛ % p(z), this can be done via polynomial arithmetic under GF(2).

When representing GF(2) polynomials as integers who's bits represen their coefficients, addition is equivalent to a binary XOR, subtration is identical to addition, multiplication is a combination of binary AND operations and addition (XOR), division (and therefore modulo) can be represented by subtraction (XOR) and bitshifts, and exponentiation is repeated multiplication (optimized via exponentiation by squares).

In [22]:
def mssb_position(polynomial):
    result = -1
    while polynomial != 0:
        polynomial >>= 1
        result += 1
    return result

def mod_gf2(polynomial, modulus):
    # if the mssb of modulus is higher than polynomial: return polynomial
    last_bit_pos = mssb_position(modulus)
    if polynomial >> last_bit_pos == 0:
        return polynomial
    poly_mssb = mssb_position(polynomial >> last_bit_pos) + last_bit_pos
    shift_num = poly_mssb - last_bit_pos
    # line up mssb
    modulus <<= shift_num
    # only go until modulus is back at its original value
    for shift_pos in range(shift_num + 1):
        # divides perfectly before last xor
        if polynomial == 0:
            return 0
        # if modulus "fits" at this position: polynomial ^= shifted modulus
        if polynomial >> (poly_mssb - shift_pos) == 1:
        # if polynomial >> mssb_position(modulus) == 1:
            polynomial ^= modulus
        # check next position
        modulus >>= 1
    # remainder is left in polynomial
    return polynomial

def multmod_gf2(multiplicand, multiplier, modulus):
    result = 0
    # if either are 0, there is nothing left to do
    while multiplier != 0:
        # multiply 1 bit at a time
        result ^= multiplicand * (multiplier & 1)
        multiplicand <<= 1
        multiplier >>= 1
    return mod_gf2(result, modulus)

def base_z_modpow_gf2(power, modulus):
    base = 0b10 # z ** 1
    result = 1
    # exponentiation by squares
    while power > 0:
        if power & 1:
            result = multmod_gf2(result, base, modulus)

        power >>= 1
        base = multmod_gf2(base, base, modulus)
    return result

v = 2 ** 64 # arbitrary jump number

jump_poly = base_z_modpow_gf2(v, char_poly)

print(hex(jump_poly))

0x170865df4b3201fcdf900294d8f554a5


## Application of Jump Polynomial

To apply a computed jump polynomial, one must remember that:

**J** = **A**ᵛ = g(**A**) = a₁**A**ᵏ⁻¹ + ... + aₖ₋₁**A** + aₖ**I**

which means that the application of this polynomial via **Jx** is:

**Jx** = (a₁**A**ᵏ⁻¹ + ... + aₖ₋₁**A** + aₖ**I**)x

rewritten as:

**Jx** = **A**( ... **A**(**A**(**A**a₁**x** + a₂**x**) + a₃**x**) + ... + aₖ₋₁**x**) + aₖ**x**

via Horner's method for polynomial evaluation. **Jx** can therefore be computed via the addition of up to *k* state vectors.

In [23]:
result_state = np.zeros(k, np.uint8)
test_state = np.copy(state)

for j in range(k):
    if (jump_poly >> j) & 0b1 == 1: # only add if aⱼ₊₁ == 1, otherwise Aaⱼ₊₁x = 0 and addition is pointless
        result_state = (result_state + test_state) % 2 # + Aaⱼ₊₁x addition under GF(2), this is equivalent to the XOR operation
    test_state = advance_rng(test_state)
    # state is now equal to Aʲ⁺¹x
print(hex(reduce(lambda p,q: (int(p) << 1) | int(q), tuple(reversed(result_state)))))

0xbce52e95d6fa01c6f85c54b9b1e3e82c


or using the RNG directly

In [24]:
class Xoroshiro128Plus:
    """Xoroshiro128+ as per https://xoshiro.di.unimi.it/xoroshiro128plus.c"""
    ulongmask = 2 ** 64 - 1

    def __init__(self, seed0, seed1 = 0x82A2B175229D6A5B):
        self.seed = [seed0, seed1]

    @staticmethod
    def rotl(num, k):
        return ((num << k) | (num >> (64 - k))) & Xoroshiro128Plus.ulongmask

    def next(self):
        seed0, seed1 = self.seed
        seed1 ^= seed0
        self.seed = [
            Xoroshiro128Plus.rotl(seed0, 24) ^ seed1 ^ ((seed1 << 16) & Xoroshiro128Plus.ulongmask),
            Xoroshiro128Plus.rotl(seed1, 37)
        ]

int_state = reduce(lambda p,q: (int(p) << 1) | int(q), tuple(reversed(state)))
rng = Xoroshiro128Plus(int_state & 0xFFFFFFFFFFFFFFFF, int_state >> 64)
result_rng = Xoroshiro128Plus(0, 0)
jump_poly_copy = jump_poly

while jump_poly_copy > 0:
    if jump_poly_copy & 1:
        result_rng.seed[0] ^= rng.seed[0]
        result_rng.seed[1] ^= rng.seed[1]
    rng.next()
    jump_poly_copy >>= 1
rng.seed = result_rng.seed.copy()

print(hex(rng.seed[0] | rng.seed[1] << 64))

0xbce52e95d6fa01c6f85c54b9b1e3e82c


## Efficiently jumping an arbitrary amount of advances

Jumping *v* advances when *v* is static can easily be done by computing the jump polynomial ahead of time, but when *v* is not static this is obviously not an option.

### Computing jump polynomial during execution

One method of jumping an arbitrary amount of advances would be to compute the jump polynomial during execution of the program and then proceeding to apply it. This method works, though it's performance greatly suffers from the expensive GF(2) modulo function that must be repeatedly run.

### Precomputed step jump polynomials

A more efficient method of jumping an arbitrary amount of advances can be done by applying multiple precomputed jump polynomials one after another.

Consider a situation where you need to jump 123 advances, if you were to advance 3 steps, then 20 steps, then 100 steps, this would be the equivalent of advancing the rng 123 steps.

This scenario represents the equality:

**A**¹²³**x** = **A**⁽¹⁰⁰⁺²⁰⁺³⁾**x** = **A**¹⁰⁰**A**²⁰**A**³**x**

If the goal was to be able to arbitrarily jump *v* advances where *v* is in the range 0-999 inclusive, this could be represented by jumping i steps, then j steps, and finally k steps, where:

i = v % 10

j = floor((v % 100) / 10) * 10

k = floor(v / 100) * 100

By precomputing the jump polynomial for all 10 possible values of i, all 10 possible values of j, and all 10 possible values of k, storing them in a lookup table, jumping *v* steps is easily done by applying three jump polynomials from the three tables of only 30 total values.

This idea can be more generally applied to any base-n number system with any maximum advance nʸ - 1.

The maximum amount of state advances needed to compute a jump with this method is given by:

m = k * y

and the amount of bits of memory needed to store the jump polynomials is given by:

s = k * y * (n - 1)

so it is important to find a balance between memory usage and performance for your use case.

Doing this with any base-n system where n is a power of 2 is especially simple, as the floored division is equivalent to logical right shifts and modulo is equivalent to masking bits.

Two implementations of this method (base-2 and base-256 both with maximum advance 2ᵏ - 1) are shown below.

In [25]:
jump_polynomials_base_2 = tuple(base_z_modpow_gf2(2 ** i, char_poly) for i in range(k))
jump_polynomials_base_256 = tuple(tuple(base_z_modpow_gf2(p * 2 ** (i * 8), char_poly) for p in range(1, 256)) for i in range(k >> 3))

In [26]:
int_state = reduce(lambda p,q: (int(p) << 1) | int(q), tuple(reversed(state)))
rng = Xoroshiro128Plus(int_state & 0xFFFFFFFFFFFFFFFF, int_state >> 64)

jump_count = 1234567
index = 0

while jump_count > 0:
    if jump_count & 1: # equivalent of modulo 2, if there is 0 in this position we don't need to jump
        # jump 2**index steps
        step_jump_poly = jump_polynomials_base_2[index]
        result_rng = Xoroshiro128Plus(0, 0)
        while step_jump_poly > 0:
            if step_jump_poly & 1:
                result_rng.seed[0] ^= rng.seed[0]
                result_rng.seed[1] ^= rng.seed[1]
            rng.next()
            step_jump_poly >>= 1
        rng.seed = result_rng.seed.copy()
    jump_count >>= 1 # equivalent of floored division by 2
    index += 1

print(hex(rng.seed[0] | rng.seed[1] << 64))

0x89fedc0b2b9a776db8c130a8ff2c1977


In [27]:
int_state = reduce(lambda p,q: (int(p) << 1) | int(q), tuple(reversed(state)))
rng = Xoroshiro128Plus(int_state & 0xFFFFFFFFFFFFFFFF, int_state >> 64)

jump_count = 1234567
index = 0

while jump_count > 0:
    position_value = jump_count & 0xFF # equivalent of modulo 256
    if position_value: # equivalent of modulo 2, if there is 0 in this position we don't need to jump
        # jump position_value * 2**(index * 8) steps
        step_jump_poly = jump_polynomials_base_256[index][position_value - 1]
        result_rng = Xoroshiro128Plus(0, 0)
        while step_jump_poly > 0:
            if step_jump_poly & 1:
                result_rng.seed[0] ^= rng.seed[0]
                result_rng.seed[1] ^= rng.seed[1]
            rng.next()
            step_jump_poly >>= 1
        rng.seed = result_rng.seed.copy()
    jump_count >>= 8 # equivalent of floored division by 256
    index += 1

print(hex(rng.seed[0] | rng.seed[1] << 64))

0x89fedc0b2b9a776db8c130a8ff2c1977


# References
- Peter Occil - ["Notes on Jumping PRNGs Ahead"](http://peteroupc.github.io/jump.html)

- Haramoto, Matsumoto, Nishimura, Panneton, L’Ecuyer - "Efficient Jump Ahead for F2-Linear Random Number Generators"