# Chess Pieces that move like drunks!


In [66]:
import jax
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import scipy.stats as stats
import numpy as np
import random as pyrand

In [102]:
def random_walk_only_last(key, bias, time=100):
    if time <= 1:
        return jnp.array([0.0, 0.0])

    #probabilities for each step that a horse can take
    p_1 = (1/8) + bias
    p_others = (1/8) - bias/7
    p_2 = p_3 = p_4 = p_5 = p_6 = p_7 = p_8 = p_others

    #sum = p_1 + p_2 + p_3 + p_4 + p_5 + p_6 + p_7 + p_8
    #print(sum)

    probs = jnp.array([p_1, p_2, p_3, p_4, p_5, p_6, p_7, p_8])
    movements = jnp.array([[1, 2], [-1, 2], [1, -2], [-1, -2], [2, 1], [2, -1], [-2, 1], [-2, -1]])

    key, subkey = random.split(key)
    step_indices = random.choice(subkey, 8, shape=(time - 1,), p=probs)
    steps = movements[step_indices]
    position = jnp.sum(steps, axis=0)

    return position

print(random_walk_only_last(random.PRNGKey(pyrand.randrange(1,100)), 0.1))

[36 49]


In [47]:
def jax_linregress(x, y):
    n = x.shape[0]
    x_mean = jnp.mean(x)
    y_mean = jnp.mean(y)

    numerator = jnp.sum((x - x_mean) * (y - y_mean))
    denominator = jnp.sum((x - x_mean) ** 2)
    slope = numerator / denominator

    return slope

In [116]:
def find_slope_vectorized(bias = 0.1, time_range=100, N=10):
    initial_key = random.PRNGKey(pyrand.randrange(1,100))
    r_meansquare = []
    t = []

    for time in range(1, time_range):
        t.append(time)
        key = initial_key
        r_temp = []

        for _ in range(N):
            key, subkey = random.split(key)
            final_position = random_walk_only_last(subkey, bias, time=time)
            r_squared = jnp.sum(final_position ** 2)
            r_temp.append(r_squared)

        mean_r_square = jnp.mean(jnp.array(r_temp))
        r_meansquare.append(mean_r_square)

    r_rms = jnp.sqrt(jnp.array(r_meansquare))
    log_r_rms = jnp.log(r_rms[1:])
    log_t = jnp.log(jnp.array(t[1:]))

    '''
    plt.scatter(log_t, log_r_rms)
    plt.xlabel('log(t)')
    plt.ylabel('log(r)')
    plt.title('log(r) vs log(t) with bias:' + str(bias))
    plt.show()
    '''

    slope = jax_linregress(log_t, log_r_rms)
    return slope

print(find_slope_vectorized())

0.65985495
