In [None]:
import jax
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import scipy.stats as stats
import numpy as np

In [None]:
# Strictly positive truncated normal distribution

def truncated_normal(stepsize_key, mean = 1, std = 0.2):
    raw_sample = random.truncated_normal(stepsize_key, lower=0.0, upper=jnp.inf)
    stepsize = mean + std * raw_sample
    return stepsize

In [None]:
# I thought this would be a cool distribution

def sine_squared(x, wavelength, cutoff):
    A = 1/((cutoff/2)-(np.sin(4*np.pi*cutoff/wavelength)/(8*np.pi)))
    y = A * np.sin(2*np.pi*x/wavelength)**2
    return y

def sample_sine_squared_vectorized(wavelength, cutoff, n_samples):
    norm = (cutoff / 2) - (wavelength / (8 * np.pi)) * np.sin(4 * np.pi * cutoff / wavelength)
    pdf_max = 1 / norm
    samples = []
    batch_size = max(1000, n_samples)
    while len(samples) < n_samples:
        x = np.random.uniform(0, cutoff, size=batch_size)
        y = np.random.uniform(0, pdf_max, size=batch_size)
        accepted = x[y < sine_squared(x, wavelength, cutoff)]
        samples.extend(accepted.tolist())
    
    return np.array(samples[:n_samples])

x = np.linspace(0, 10, 1000)
y = sine_squared(x, 10, 10)*100
plt.plot(x, y)
samples = sample_sine_squared_vectorized(10, 10, 1000)
plt.hist(samples, bins=100)
plt.show()

# 2D Walk with variable step length

Playing around. But they only strictly move in rook-like directions.

In [None]:
def random_walk(key, xbias=0.0, ybias=0.0, time=100):
    if time <= 1:
        x = 0.0
        y = 0.0
        return x, y

    p_right = 0.25 * (1 + xbias)
    p_left = 0.25 * (1 - xbias)
    p_up = 0.25 * (1 + ybias)
    p_down = 0.25 * (1 - ybias)

    probs = jnp.array([p_right, p_left, p_up, p_down])
    
    key, stepsize_key = random.split(key)
    stepsize = truncated_normal(stepsize_key, mean = 1, std = 0.2)
    movements = jnp.array([[stepsize, 0], [-stepsize, 0], [0, stepsize], [0, -stepsize]])
    
    key, subkey = random.split(key)
    step_indices = random.choice(subkey, 4, shape=(time - 1,), p=probs)
    steps = movements[step_indices]
    
    positions = jnp.cumsum(steps, axis=0)
    positions = jnp.vstack([jnp.array([0.0, 0.0]), positions])
    x = positions[:, 0]
    y = positions[:, 1]

    return x, y

def random_walk_only_last(key, xbias=0.0, ybias=0.0, time=100):
    if time <= 1:
        return jnp.array([0.0, 0.0])

    p_right = 0.25 * (1 + xbias)
    p_left = 0.25 * (1 - xbias)
    p_up = 0.25 * (1 + ybias)
    p_down = 0.25 * (1 - ybias)

    probs = jnp.array([p_right, p_left, p_up, p_down])
    
    key, stepsize_key = random.split(key)
    stepsize = truncated_normal(stepsize_key, mean = 1, std = 0.2)
    movements = jnp.array([[stepsize, 0], [-stepsize, 0], [0, stepsize], [0, -stepsize]])

    key, subkey = random.split(key)
    step_indices = random.choice(subkey, 4, shape=(time - 1,), p=probs)
    steps = movements[step_indices]
    position = jnp.sum(steps, axis=0)

    return position

# Every step is super variable in both magnitude and direction

In [None]:
def random_direction_walk(key, time=10):
    key, stepsize_key = random.split(key)
    samples = sample_sine_squared_vectorized(3, 3, 1000)
    stepsize = samples[np.random.randint(0, len(samples))]

    #getting a random angle for each step
    key, angle_key = random.split(key)
    angles = random.uniform(angle_key, shape=(time - 1,), minval=0.0, maxval=2 * jnp.pi)

    #movement vector
    steps = stepsize * jnp.stack([jnp.cos(angles), jnp.sin(angles)], axis=1)

    positions = jnp.cumsum(steps, axis=0)
    positions = jnp.vstack([jnp.array([0.0, 0.0]), positions])
    x = positions[:, 0]
    y = positions[:, 1]

    return x, y

x, y = random_direction_walk(random.PRNGKey(np.random.randint(0, 100)), time=100)

plt.plot(x, y, marker='o')